Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
09/10/12 17:13:46 (12 years ago)
Author:
gkronber
Message:

#1902 implemented LS Gaussian Process classification

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationModel.cs

    r7259 r8623  
    5151    }
    5252
     53    private IDiscriminantFunctionThresholdCalculator thresholdCalculator;
     54    [Storable]
     55    public IDiscriminantFunctionThresholdCalculator ThresholdCalculator {
     56      get { return thresholdCalculator; }
     57      private set { thresholdCalculator = value; }
     58    }
     59
    5360
    5461    [StorableConstructor]
     
    6168    }
    6269
    63     public DiscriminantFunctionClassificationModel(IRegressionModel model)
     70    public DiscriminantFunctionClassificationModel(IRegressionModel model, IDiscriminantFunctionThresholdCalculator thresholdCalculator)
    6471      : base() {
    6572      this.name = ItemName;
    6673      this.description = ItemDescription;
    6774      this.model = model;
    68       this.classValues = new double[] { 0.0 };
    69       this.thresholds = new double[] { double.NegativeInfinity };
     75      this.classValues = new double[0];
     76      this.thresholds = new double[0];
     77      this.thresholdCalculator = thresholdCalculator;
     78    }
     79
     80    [StorableHook(HookType.AfterDeserialization)]
     81    private void AfterDeserialization() {
     82      if (ThresholdCalculator == null) ThresholdCalculator = new AccuracyMaximizationThresholdCalculator();
    7083    }
    7184
     
    8093    }
    8194
     95    public virtual void RecalculateModelParameters(IClassificationProblemData problemData, IEnumerable<int> rows) {
     96      double[] classValues;
     97      double[] thresholds;
     98      var targetClassValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
     99      var estimatedTrainingValues = GetEstimatedValues(problemData.Dataset, rows);
     100      thresholdCalculator.Calculate(problemData, estimatedTrainingValues, targetClassValues, out classValues, out thresholds);
     101      SetThresholdsAndClassValues(thresholds, classValues);
     102    }
     103
     104
    82105    public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) {
    83106      return model.GetEstimatedValues(dataset, rows);
     
    85108
    86109    public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
     110      if (!Thresholds.Any() && !ClassValues.Any()) throw new ArgumentException("No thresholds and class values were set for the current classification model.");
    87111      foreach (var x in GetEstimatedValues(dataset, rows)) {
    88112        int classIndex = 0;
Note: See TracChangeset for help on using the changeset viewer.