Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/25/11 16:54:15 (13 years ago)
Author:
mkommend
Message:

#1600: Adapted classification solutions to the same design as used by regression solutions.

Location:
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification
Files:
2 added
3 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs

    r6574 r6589  
    2020#endregion
    2121
     22using System;
    2223using System.Collections.Generic;
    2324using System.Linq;
    2425using HeuristicLab.Common;
    2526using HeuristicLab.Core;
     27using HeuristicLab.Data;
    2628using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    27 using HeuristicLab.Data;
    28 using System;
    2929
    3030namespace HeuristicLab.Problems.DataAnalysis {
     
    8787    public override IDeepCloneable Clone(Cloner cloner) {
    8888      return new ClassificationEnsembleSolution(this, cloner);
     89    }
     90
     91    protected override void RecalculateResults() {
     92      CalculateResults();
    8993    }
    9094
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationSolution.cs

    r6411 r6589  
    2323using System.Linq;
    2424using HeuristicLab.Common;
    25 using HeuristicLab.Data;
    26 using HeuristicLab.Optimization;
    2725using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2826
     
    3230  /// </summary>
    3331  [StorableClass]
    34   public class ClassificationSolution : DataAnalysisSolution, IClassificationSolution {
    35     private const string TrainingAccuracyResultName = "Accuracy (training)";
    36     private const string TestAccuracyResultName = "Accuracy (test)";
    37 
    38     public new IClassificationModel Model {
    39       get { return (IClassificationModel)base.Model; }
    40       protected set { base.Model = value; }
    41     }
    42 
    43     public new IClassificationProblemData ProblemData {
    44       get { return (IClassificationProblemData)base.ProblemData; }
    45       protected set { base.ProblemData = value; }
    46     }
    47 
    48     public double TrainingAccuracy {
    49       get { return ((DoubleValue)this[TrainingAccuracyResultName].Value).Value; }
    50       private set { ((DoubleValue)this[TrainingAccuracyResultName].Value).Value = value; }
    51     }
    52 
    53     public double TestAccuracy {
    54       get { return ((DoubleValue)this[TestAccuracyResultName].Value).Value; }
    55       private set { ((DoubleValue)this[TestAccuracyResultName].Value).Value = value; }
    56     }
    57 
     32  public abstract class ClassificationSolution : ClassificationSolutionBase {
    5833    [StorableConstructor]
    5934    protected ClassificationSolution(bool deserializing) : base(deserializing) { }
     
    6338    public ClassificationSolution(IClassificationModel model, IClassificationProblemData problemData)
    6439      : base(model, problemData) {
    65       Add(new Result(TrainingAccuracyResultName, "Accuracy of the model on the training partition (percentage of correctly classified instances).", new PercentValue()));
    66       Add(new Result(TestAccuracyResultName, "Accuracy of the model on the test partition (percentage of correctly classified instances).", new PercentValue()));
    67       CalculateResults();
    6840    }
    6941
    70     public override IDeepCloneable Clone(Cloner cloner) {
    71       return new ClassificationSolution(this, cloner);
     42    public override IEnumerable<double> EstimatedClassValues {
     43      get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
     44    }
     45    public override IEnumerable<double> EstimatedTrainingClassValues {
     46      get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); }
     47    }
     48    public override IEnumerable<double> EstimatedTestClassValues {
     49      get { return GetEstimatedClassValues(ProblemData.TestIndizes); }
    7250    }
    7351
    74     protected override void RecalculateResults() {
    75       CalculateResults();
    76     }
    77 
    78     private void CalculateResults() {
    79       double[] estimatedTrainingClassValues = EstimatedTrainingClassValues.ToArray(); // cache values
    80       IEnumerable<double> originalTrainingClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    81       double[] estimatedTestClassValues = EstimatedTestClassValues.ToArray(); // cache values
    82       IEnumerable<double> originalTestClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
    83 
    84       OnlineCalculatorError errorState;
    85       double trainingAccuracy = OnlineAccuracyCalculator.Calculate(estimatedTrainingClassValues, originalTrainingClassValues, out errorState);
    86       if (errorState != OnlineCalculatorError.None) trainingAccuracy = double.NaN;
    87       double testAccuracy = OnlineAccuracyCalculator.Calculate(estimatedTestClassValues, originalTestClassValues, out errorState);
    88       if (errorState != OnlineCalculatorError.None) testAccuracy = double.NaN;
    89 
    90       TrainingAccuracy = trainingAccuracy;
    91       TestAccuracy = testAccuracy;
    92     }
    93 
    94     public virtual IEnumerable<double> EstimatedClassValues {
    95       get {
    96         return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows));
    97       }
    98     }
    99 
    100     public virtual IEnumerable<double> EstimatedTrainingClassValues {
    101       get {
    102         return GetEstimatedClassValues(ProblemData.TrainingIndizes);
    103       }
    104     }
    105 
    106     public virtual IEnumerable<double> EstimatedTestClassValues {
    107       get {
    108         return GetEstimatedClassValues(ProblemData.TestIndizes);
    109       }
    110     }
    111 
    112     public virtual IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
     52    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
    11353      return Model.GetEstimatedClassValues(ProblemData.Dataset, rows);
    11454    }
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationSolution.cs

    r6411 r6589  
    2020#endregion
    2121
    22 using System;
    2322using System.Collections.Generic;
    2423using System.Linq;
    2524using HeuristicLab.Common;
    2625using HeuristicLab.Core;
    27 using HeuristicLab.Data;
    28 using HeuristicLab.Optimization;
    2926using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3027
     
    3532  [StorableClass]
    3633  [Item("DiscriminantFunctionClassificationSolution", "Represents a classification solution that uses a discriminant function and classification thresholds.")]
    37   public class DiscriminantFunctionClassificationSolution : ClassificationSolution, IDiscriminantFunctionClassificationSolution {
    38     private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)";
    39     private const string TestMeanSquaredErrorResultName = "Mean squared error (test)";
    40     private const string TrainingRSquaredResultName = "Pearson's R² (training)";
    41     private const string TestRSquaredResultName = "Pearson's R² (test)";
    42 
    43     public new IDiscriminantFunctionClassificationModel Model {
    44       get { return (IDiscriminantFunctionClassificationModel)base.Model; }
    45       protected set {
    46         if (value != null && value != Model) {
    47           if (Model != null) {
    48             Model.ThresholdsChanged -= new EventHandler(Model_ThresholdsChanged);
    49           }
    50           value.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged);
    51           base.Model = value;
    52         }
    53       }
    54     }
    55 
    56     public double TrainingMeanSquaredError {
    57       get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; }
    58       private set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; }
    59     }
    60 
    61     public double TestMeanSquaredError {
    62       get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; }
    63       private set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; }
    64     }
    65 
    66     public double TrainingRSquared {
    67       get { return ((DoubleValue)this[TrainingRSquaredResultName].Value).Value; }
    68       private set { ((DoubleValue)this[TrainingRSquaredResultName].Value).Value = value; }
    69     }
    70 
    71     public double TestRSquared {
    72       get { return ((DoubleValue)this[TestRSquaredResultName].Value).Value; }
    73       private set { ((DoubleValue)this[TestRSquaredResultName].Value).Value = value; }
    74     }
     34  public abstract class DiscriminantFunctionClassificationSolution : DiscriminantFunctionClassificationSolutionBase {
    7535
    7636    [StorableConstructor]
     
    7838    protected DiscriminantFunctionClassificationSolution(DiscriminantFunctionClassificationSolution original, Cloner cloner)
    7939      : base(original, cloner) {
    80       RegisterEventHandler();
    8140    }
    82     public DiscriminantFunctionClassificationSolution(IRegressionModel model, IClassificationProblemData problemData)
     41    protected DiscriminantFunctionClassificationSolution(IRegressionModel model, IClassificationProblemData problemData)
    8342      : this(new DiscriminantFunctionClassificationModel(model), problemData) {
    8443    }
    85     public DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData)
     44    protected DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData)
    8645      : base(model, problemData) {
    87       Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue()));
    88       Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue()));
    89       Add(new Result(TrainingRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue()));
    90       Add(new Result(TestRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue()));
    91       SetAccuracyMaximizingThresholds();
    92 
    93       //mkommend: important to recalculate accuracy because during the calculation before no thresholds were present     
    94       base.RecalculateResults();
    95       CalculateResults();
    96       RegisterEventHandler();
    9746    }
    9847
    99     [StorableHook(HookType.AfterDeserialization)]
    100     private void AfterDeserialization() {
    101       RegisterEventHandler();
     48    public override IEnumerable<double> EstimatedClassValues {
     49      get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
     50    }
     51    public override IEnumerable<double> EstimatedTrainingClassValues {
     52      get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); }
     53    }
     54    public override IEnumerable<double> EstimatedTestClassValues {
     55      get { return GetEstimatedClassValues(ProblemData.TestIndizes); }
    10256    }
    10357
    104     protected override void OnModelChanged(EventArgs e) {
    105       DeregisterEventHandler();
    106       SetAccuracyMaximizingThresholds();
    107       RegisterEventHandler();
    108       base.OnModelChanged(e);
     58    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
     59      return Model.GetEstimatedClassValues(ProblemData.Dataset, rows);
    10960    }
    11061
    111     protected override void RecalculateResults() {
    112       base.RecalculateResults();
    113       CalculateResults();
    114     }
    11562
    116     private void CalculateResults() {
    117       double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values
    118       IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    119       double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values
    120       IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
    121 
    122       OnlineCalculatorError errorState;
    123       double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    124       TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN;
    125       double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    126       TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN;
    127 
    128       double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
    129       TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN;
    130       double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState);
    131       TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN;
    132     }
    133 
    134     private void RegisterEventHandler() {
    135       Model.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged);
    136     }
    137     private void DeregisterEventHandler() {
    138       Model.ThresholdsChanged -= new EventHandler(Model_ThresholdsChanged);
    139     }
    140     private void Model_ThresholdsChanged(object sender, EventArgs e) {
    141       OnModelThresholdsChanged(e);
    142     }
    143 
    144     public void SetAccuracyMaximizingThresholds() {
    145       double[] classValues;
    146       double[] thresholds;
    147       var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    148       AccuracyMaximizationThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds);
    149 
    150       Model.SetThresholdsAndClassValues(thresholds, classValues);
    151     }
    152 
    153     public void SetClassDistibutionCutPointThresholds() {
    154       double[] classValues;
    155       double[] thresholds;
    156       var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
    157       NormalDistributionCutPointsThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds);
    158 
    159       Model.SetThresholdsAndClassValues(thresholds, classValues);
    160     }
    161 
    162     protected virtual void OnModelThresholdsChanged(EventArgs e) {
    163       RecalculateResults();
    164     }
    165 
    166     public IEnumerable<double> EstimatedValues {
     63    public override IEnumerable<double> EstimatedValues {
    16764      get { return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
    16865    }
    169 
    170     public IEnumerable<double> EstimatedTrainingValues {
     66    public override IEnumerable<double> EstimatedTrainingValues {
    17167      get { return GetEstimatedValues(ProblemData.TrainingIndizes); }
    17268    }
    173 
    174     public IEnumerable<double> EstimatedTestValues {
     69    public override IEnumerable<double> EstimatedTestValues {
    17570      get { return GetEstimatedValues(ProblemData.TestIndizes); }
    17671    }
    17772
    178     public IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
     73    public override IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
    17974      return Model.GetEstimatedValues(ProblemData.Dataset, rows);
    18075    }
Note: See TracChangeset for help on using the changeset viewer.