Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
10/05/12 11:58:17 (12 years ago)
Author:
mkommend
Message:

#1081: Merged trunk changes and fixed compilation errors due to the merge.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationModel.cs

    r7268 r8742  
    3333  [StorableClass]
    3434  [Item("DiscriminantFunctionClassificationModel", "Represents a classification model that uses a discriminant function and classification thresholds.")]
    35   public abstract class DiscriminantFunctionClassificationModel : NamedItem, IDiscriminantFunctionClassificationModel {
     35  public class DiscriminantFunctionClassificationModel : NamedItem, IDiscriminantFunctionClassificationModel {
    3636    [Storable]
    3737    private IRegressionModel model;
     38    public IRegressionModel Model {
     39      get { return model; }
     40      private set { model = value; }
     41    }
    3842
    3943    [Storable]
     
    5155    }
    5256
     57    private IDiscriminantFunctionThresholdCalculator thresholdCalculator;
     58    [Storable]
     59    public IDiscriminantFunctionThresholdCalculator ThresholdCalculator {
     60      get { return thresholdCalculator; }
     61      private set { thresholdCalculator = value; }
     62    }
     63
    5364
    5465    [StorableConstructor]
     
    6172    }
    6273
    63     public DiscriminantFunctionClassificationModel(IRegressionModel model)
     74    public DiscriminantFunctionClassificationModel(IRegressionModel model, IDiscriminantFunctionThresholdCalculator thresholdCalculator)
    6475      : base() {
    6576      this.name = ItemName;
    6677      this.description = ItemDescription;
    6778      this.model = model;
    68       this.classValues = new double[] { 0.0 };
    69       this.thresholds = new double[] { double.NegativeInfinity };
     79      this.classValues = new double[0];
     80      this.thresholds = new double[0];
     81      this.thresholdCalculator = thresholdCalculator;
     82    }
     83
     84    [StorableHook(HookType.AfterDeserialization)]
     85    private void AfterDeserialization() {
     86      if (ThresholdCalculator == null) ThresholdCalculator = new AccuracyMaximizationThresholdCalculator();
     87    }
     88
     89    public override IDeepCloneable Clone(Cloner cloner) {
     90      return new DiscriminantFunctionClassificationModel(this, cloner);
    7091    }
    7192
     
    80101    }
    81102
     103    public virtual void RecalculateModelParameters(IClassificationProblemData problemData, IEnumerable<int> rows) {
     104      double[] classValues;
     105      double[] thresholds;
     106      var targetClassValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
     107      var estimatedTrainingValues = GetEstimatedValues(problemData.Dataset, rows);
     108      thresholdCalculator.Calculate(problemData, estimatedTrainingValues, targetClassValues, out classValues, out thresholds);
     109      SetThresholdsAndClassValues(thresholds, classValues);
     110    }
     111
     112
    82113    public IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows) {
    83114      return model.GetEstimatedValues(dataset, rows);
     
    85116
    86117    public IEnumerable<double> GetEstimatedClassValues(Dataset dataset, IEnumerable<int> rows) {
     118      if (!Thresholds.Any() && !ClassValues.Any()) throw new ArgumentException("No thresholds and class values were set for the current classification model.");
    87119      foreach (var x in GetEstimatedValues(dataset, rows)) {
    88120        int classIndex = 0;
     
    103135    #endregion
    104136
    105     public abstract IDiscriminantFunctionClassificationSolution CreateDiscriminantFunctionClassificationSolution(IClassificationProblemData problemData);
    106     public abstract IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData);
     137    public virtual IDiscriminantFunctionClassificationSolution CreateDiscriminantFunctionClassificationSolution(IClassificationProblemData problemData) {
     138      return new DiscriminantFunctionClassificationSolution(this, problemData);
     139    }
     140
     141    public virtual IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
     142      return CreateDiscriminantFunctionClassificationSolution(problemData);
     143    }
    107144  }
    108145}
Note: See TracChangeset for help on using the changeset viewer.