Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/25/11 16:54:15 (13 years ago)
Author:
mkommend
Message:

#1600: Adapted classification solutions to the same design as used by regression solutions.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationSolution.cs

    r6411 r6589  
    2020#endregion
    2121
    22 using System;
    2322using System.Collections.Generic;
    2423using System.Linq;
    2524using HeuristicLab.Common;
    2625using HeuristicLab.Core;
    27 using HeuristicLab.Data;
    28 using HeuristicLab.Optimization;
    2926using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3027
     
    3532  [StorableClass]
    3633  [Item("DiscriminantFunctionClassificationSolution", "Represents a classification solution that uses a discriminant function and classification thresholds.")]
    37   public class DiscriminantFunctionClassificationSolution : ClassificationSolution, IDiscriminantFunctionClassificationSolution {
    38     private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)";
    39     private const string TestMeanSquaredErrorResultName = "Mean squared error (test)";
    40     private const string TrainingRSquaredResultName = "Pearson's R² (training)";
    41     private const string TestRSquaredResultName = "Pearson's R² (test)";
    42 
    43     public new IDiscriminantFunctionClassificationModel Model {
    44       get { return (IDiscriminantFunctionClassificationModel)base.Model; }
    45       protected set {
    46         if (value != null && value != Model) {
    47           if (Model != null) {
    48             Model.ThresholdsChanged -= new EventHandler(Model_ThresholdsChanged);
    49           }
    50           value.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged);
    51           base.Model = value;
    52         }
    53       }
    54     }
    55 
    56     public double TrainingMeanSquaredError {
    57       get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; }
    58       private set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; }
    59     }
    60 
    61     public double TestMeanSquaredError {
    62       get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; }
    63       private set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; }
    64     }
    65 
    66     public double TrainingRSquared {
    67       get { return ((DoubleValue)this[TrainingRSquaredResultName].Value).Value; }
    68       private set { ((DoubleValue)this[TrainingRSquaredResultName].Value).Value = value; }
    69     }
    70 
    71     public double TestRSquared {
    72       get { return ((DoubleValue)this[TestRSquaredResultName].Value).Value; }
    73       private set { ((DoubleValue)this[TestRSquaredResultName].Value).Value = value; }
    74     }
     34  public abstract class DiscriminantFunctionClassificationSolution : DiscriminantFunctionClassificationSolutionBase {
    7535
    7636    [StorableConstructor]
     
    7838    protected DiscriminantFunctionClassificationSolution(DiscriminantFunctionClassificationSolution original, Cloner cloner)
    7939      : base(original, cloner) {
    80       RegisterEventHandler();
    8140    }
    82     public DiscriminantFunctionClassificationSolution(IRegressionModel model, IClassificationProblemData problemData)
     41    protected DiscriminantFunctionClassificationSolution(IRegressionModel model, IClassificationProblemData problemData)
    8342      : this(new DiscriminantFunctionClassificationModel(model), problemData) {
    8443    }
    85     public DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData)
     44    protected DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData)
    8645      : base(model, problemData) {
    87       Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue()));
    88       Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue()));
    89       Add(new Result(TrainingRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue()));
    90       Add(new Result(TestRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue()));
    91       SetAccuracyMaximizingThresholds();
    92 
    93       //mkommend: important to recalculate accuracy because during the calculation before no thresholds were present     
    94       base.RecalculateResults();
    95       CalculateResults();
    96       RegisterEventHandler();
    9746    }
    9847
    99     [StorableHook(HookType.AfterDeserialization)]
    100     private void AfterDeserialization() {
    101       RegisterEventHandler();
     48    public override IEnumerable<double> EstimatedClassValues {
     49      get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
     50    }
     51    public override IEnumerable<double> EstimatedTrainingClassValues {
     52      get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); }
     53    }
     54    public override IEnumerable<double> EstimatedTestClassValues {
     55      get { return GetEstimatedClassValues(ProblemData.TestIndizes); }
    10256    }
    10357
    104     protected override void OnModelChanged(EventArgs e) {
    105       DeregisterEventHandler();
    106       SetAccuracyMaximizingThresholds();
    107       RegisterEventHandler();
    108       base.OnModelChanged(e);
     58    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
     59      return Model.GetEstimatedClassValues(ProblemData.Dataset, rows);
    10960    }
    11061
    111     protected override void RecalculateResults() {
    112       base.RecalculateResults();
    113       CalculateResults();
    114     }
    11562
    116     private void CalculateResults() {
    117       double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values
    118       IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    119       double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values
    120       IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
    121 
    122       OnlineCalculatorError errorState;
    123       double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    124       TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN;
    125       double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    126       TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN;
    127 
    128       double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    129       TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN;
    130       double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    131       TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN;
    132     }
    133 
    134     private void RegisterEventHandler() {
    135       Model.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged);
    136     }
    137     private void DeregisterEventHandler() {
    138       Model.ThresholdsChanged -= new EventHandler(Model_ThresholdsChanged);
    139     }
    140     private void Model_ThresholdsChanged(object sender, EventArgs e) {
    141       OnModelThresholdsChanged(e);
    142     }
    143 
    144     public void SetAccuracyMaximizingThresholds() {
    145       double[] classValues;
    146       double[] thresholds;
    147       var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    148       AccuracyMaximizationThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds);
    149 
    150       Model.SetThresholdsAndClassValues(thresholds, classValues);
    151     }
    152 
    153     public void SetClassDistibutionCutPointThresholds() {
    154       double[] classValues;
    155       double[] thresholds;
    156       var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    157       NormalDistributionCutPointsThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds);
    158 
    159       Model.SetThresholdsAndClassValues(thresholds, classValues);
    160     }
    161 
    162     protected virtual void OnModelThresholdsChanged(EventArgs e) {
    163       RecalculateResults();
    164     }
    165 
    166     public IEnumerable<double> EstimatedValues {
     63    public override IEnumerable<double> EstimatedValues {
    16764      get { return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
    16865    }
    169 
    170     public IEnumerable<double> EstimatedTrainingValues {
     66    public override IEnumerable<double> EstimatedTrainingValues {
    17167      get { return GetEstimatedValues(ProblemData.TrainingIndizes); }
    17268    }
    173 
    174     public IEnumerable<double> EstimatedTestValues {
     69    public override IEnumerable<double> EstimatedTestValues {
    17570      get { return GetEstimatedValues(ProblemData.TestIndizes); }
    17671    }
    17772
    178     public IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
     73    public override IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
    17974      return Model.GetEstimatedValues(ProblemData.Dataset, rows);
    18075    }
Note: See TracChangeset for help on using the changeset viewer.