Changeset 6618 for branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationSolution.cs
- Timestamp:
- 08/01/11 17:48:53 (13 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/GP.Grammar.Editor/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationSolution.cs
r6415 r6618 20 20 #endregion 21 21 22 using System;23 22 using System.Collections.Generic; 24 23 using System.Linq; 25 24 using HeuristicLab.Common; 26 25 using HeuristicLab.Core; 27 using HeuristicLab.Data;28 using HeuristicLab.Optimization;29 26 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 30 27 … … 35 32 [StorableClass] 36 33 [Item("DiscriminantFunctionClassificationSolution", "Represents a classification solution that uses a discriminant function and classification thresholds.")] 37 public class DiscriminantFunctionClassificationSolution : ClassificationSolution, IDiscriminantFunctionClassificationSolution { 38 private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)"; 39 private const string TestMeanSquaredErrorResultName = "Mean squared error (test)"; 40 private const string TrainingRSquaredResultName = "Pearson's R² (training)"; 41 private const string TestRSquaredResultName = "Pearson's R² (test)"; 34 public abstract class DiscriminantFunctionClassificationSolution : DiscriminantFunctionClassificationSolutionBase { 35 protected readonly Dictionary<int, double> valueEvaluationCache; 36 protected readonly Dictionary<int, double> classValueEvaluationCache; 42 37 43 public new IDiscriminantFunctionClassificationModel Model { 44 get { return (IDiscriminantFunctionClassificationModel)base.Model; } 45 protected set { 46 if (value != null && value != Model) { 47 if (Model != null) { 48 Model.ThresholdsChanged -= new EventHandler(Model_ThresholdsChanged); 49 } 50 value.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged); 51 base.Model = value; 52 } 53 } 38 [StorableConstructor] 39 protected DiscriminantFunctionClassificationSolution(bool deserializing) 40 : base(deserializing) { 41 valueEvaluationCache = new Dictionary<int, double>(); 42 classValueEvaluationCache = new Dictionary<int, double>(); 43 } 44 protected DiscriminantFunctionClassificationSolution(DiscriminantFunctionClassificationSolution original, Cloner cloner) 45 : base(original, cloner) { 46 valueEvaluationCache = new Dictionary<int, double>(original.valueEvaluationCache); 47 classValueEvaluationCache = new Dictionary<int, double>(original.classValueEvaluationCache); 48 } 49 protected DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData) 50 : base(model, problemData) { 51 valueEvaluationCache = new Dictionary<int, double>(); 52 classValueEvaluationCache = new Dictionary<int, double>(); 53 54 SetAccuracyMaximizingThresholds(); 54 55 } 55 56 56 public double TrainingMeanSquaredError { 57 get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; } 58 private set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; } 57 public override IEnumerable<double> EstimatedClassValues { 58 get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); } 59 } 60 public override IEnumerable<double> EstimatedTrainingClassValues { 61 get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); } 62 } 63 public override IEnumerable<double> EstimatedTestClassValues { 64 get { return GetEstimatedClassValues(ProblemData.TestIndizes); } 59 65 } 60 66 61 public double TestMeanSquaredError { 62 get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; } 63 private set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; } 67 public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 68 var rowsToEvaluate = rows.Except(classValueEvaluationCache.Keys); 69 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 70 var valuesEnumerator = Model.GetEstimatedClassValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator(); 71 72 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { 73 classValueEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current); 74 } 75 76 return rows.Select(row => classValueEvaluationCache[row]); 64 77 } 65 78 66 public double TrainingRSquared {67 get { return ((DoubleValue)this[TrainingRSquaredResultName].Value).Value; }68 private set { ((DoubleValue)this[TrainingRSquaredResultName].Value).Value = value; }69 }70 79 71 public double TestRSquared { 72 get { return ((DoubleValue)this[TestRSquaredResultName].Value).Value; } 73 private set { ((DoubleValue)this[TestRSquaredResultName].Value).Value = value; } 74 } 75 76 [StorableConstructor] 77 protected DiscriminantFunctionClassificationSolution(bool deserializing) : base(deserializing) { } 78 protected DiscriminantFunctionClassificationSolution(DiscriminantFunctionClassificationSolution original, Cloner cloner) 79 : base(original, cloner) { 80 RegisterEventHandler(); 81 } 82 public DiscriminantFunctionClassificationSolution(IRegressionModel model, IClassificationProblemData problemData) 83 : this(new DiscriminantFunctionClassificationModel(model), problemData) { 84 } 85 public DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData) 86 : base(model, problemData) { 87 Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue())); 88 Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue())); 89 Add(new Result(TrainingRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue())); 90 Add(new Result(TestRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue())); 91 SetAccuracyMaximizingThresholds(); 92 93 //mkommend: important to recalculate accuracy because during the calculation before no thresholds were present 94 base.RecalculateResults(); 95 CalculateResults(); 96 RegisterEventHandler(); 97 } 98 99 [StorableHook(HookType.AfterDeserialization)] 100 private void AfterDeserialization() { 101 RegisterEventHandler(); 102 } 103 104 protected override void OnModelChanged(EventArgs e) { 105 DeregisterEventHandler(); 106 SetAccuracyMaximizingThresholds(); 107 RegisterEventHandler(); 108 base.OnModelChanged(e); 109 } 110 111 protected override void RecalculateResults() { 112 base.RecalculateResults(); 113 CalculateResults(); 114 } 115 116 private void CalculateResults() { 117 double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values 118 IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 119 double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values 120 IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes); 121 122 OnlineCalculatorError errorState; 123 double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 124 TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN; 125 double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 126 TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN; 127 128 double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 129 TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN; 130 double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 131 TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN; 132 } 133 134 private void RegisterEventHandler() { 135 Model.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged); 136 } 137 private void DeregisterEventHandler() { 138 Model.ThresholdsChanged -= new EventHandler(Model_ThresholdsChanged); 139 } 140 private void Model_ThresholdsChanged(object sender, EventArgs e) { 141 OnModelThresholdsChanged(e); 142 } 143 144 public void SetAccuracyMaximizingThresholds() { 145 double[] classValues; 146 double[] thresholds; 147 var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 148 AccuracyMaximizationThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds); 149 150 Model.SetThresholdsAndClassValues(thresholds, classValues); 151 } 152 153 public void SetClassDistibutionCutPointThresholds() { 154 double[] classValues; 155 double[] thresholds; 156 var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 157 NormalDistributionCutPointsThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds); 158 159 Model.SetThresholdsAndClassValues(thresholds, classValues); 160 } 161 162 protected virtual void OnModelThresholdsChanged(EventArgs e) { 163 RecalculateResults(); 164 } 165 166 public IEnumerable<double> EstimatedValues { 80 public override IEnumerable<double> EstimatedValues { 167 81 get { return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); } 168 82 } 169 170 public IEnumerable<double> EstimatedTrainingValues { 83 public override IEnumerable<double> EstimatedTrainingValues { 171 84 get { return GetEstimatedValues(ProblemData.TrainingIndizes); } 172 85 } 173 174 public IEnumerable<double> EstimatedTestValues { 86 public override IEnumerable<double> EstimatedTestValues { 175 87 get { return GetEstimatedValues(ProblemData.TestIndizes); } 176 88 } 177 89 178 public IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) { 179 return Model.GetEstimatedValues(ProblemData.Dataset, rows); 90 public override IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) { 91 var rowsToEvaluate = rows.Except(valueEvaluationCache.Keys); 92 var rowsEnumerator = rowsToEvaluate.GetEnumerator(); 93 var valuesEnumerator = Model.GetEstimatedValues(ProblemData.Dataset, rowsToEvaluate).GetEnumerator(); 94 95 while (rowsEnumerator.MoveNext() & valuesEnumerator.MoveNext()) { 96 valueEvaluationCache.Add(rowsEnumerator.Current, valuesEnumerator.Current); 97 } 98 99 return rows.Select(row => valueEvaluationCache[row]); 100 } 101 102 protected override void OnModelChanged() { 103 valueEvaluationCache.Clear(); 104 classValueEvaluationCache.Clear(); 105 base.OnModelChanged(); 106 } 107 protected override void OnModelThresholdsChanged(System.EventArgs e) { 108 classValueEvaluationCache.Clear(); 109 base.OnModelThresholdsChanged(e); 110 } 111 protected override void OnProblemDataChanged() { 112 valueEvaluationCache.Clear(); 113 classValueEvaluationCache.Clear(); 114 base.OnProblemDataChanged(); 180 115 } 181 116 }
Note: See TracChangeset
for help on using the changeset viewer.