Changeset 6589 for trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification
- Timestamp:
- 07/25/11 16:54:15 (13 years ago)
- Location:
- trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification
- Files:
-
- 2 added
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationEnsembleSolution.cs
r6574 r6589 20 20 #endregion 21 21 22 using System; 22 23 using System.Collections.Generic; 23 24 using System.Linq; 24 25 using HeuristicLab.Common; 25 26 using HeuristicLab.Core; 27 using HeuristicLab.Data; 26 28 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 27 using HeuristicLab.Data;28 using System;29 29 30 30 namespace HeuristicLab.Problems.DataAnalysis { … … 87 87 public override IDeepCloneable Clone(Cloner cloner) { 88 88 return new ClassificationEnsembleSolution(this, cloner); 89 } 90 91 protected override void RecalculateResults() { 92 CalculateResults(); 89 93 } 90 94 -
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ClassificationSolution.cs
r6411 r6589 23 23 using System.Linq; 24 24 using HeuristicLab.Common; 25 using HeuristicLab.Data;26 using HeuristicLab.Optimization;27 25 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 28 26 … … 32 30 /// </summary> 33 31 [StorableClass] 34 public class ClassificationSolution : DataAnalysisSolution, IClassificationSolution { 35 private const string TrainingAccuracyResultName = "Accuracy (training)"; 36 private const string TestAccuracyResultName = "Accuracy (test)"; 37 38 public new IClassificationModel Model { 39 get { return (IClassificationModel)base.Model; } 40 protected set { base.Model = value; } 41 } 42 43 public new IClassificationProblemData ProblemData { 44 get { return (IClassificationProblemData)base.ProblemData; } 45 protected set { base.ProblemData = value; } 46 } 47 48 public double TrainingAccuracy { 49 get { return ((DoubleValue)this[TrainingAccuracyResultName].Value).Value; } 50 private set { ((DoubleValue)this[TrainingAccuracyResultName].Value).Value = value; } 51 } 52 53 public double TestAccuracy { 54 get { return ((DoubleValue)this[TestAccuracyResultName].Value).Value; } 55 private set { ((DoubleValue)this[TestAccuracyResultName].Value).Value = value; } 56 } 57 32 public abstract class ClassificationSolution : ClassificationSolutionBase { 58 33 [StorableConstructor] 59 34 protected ClassificationSolution(bool deserializing) : base(deserializing) { } … … 63 38 public ClassificationSolution(IClassificationModel model, IClassificationProblemData problemData) 64 39 : base(model, problemData) { 65 Add(new Result(TrainingAccuracyResultName, "Accuracy of the model on the training partition (percentage of correctly classified instances).", new PercentValue()));66 Add(new Result(TestAccuracyResultName, "Accuracy of the model on the test partition (percentage of correctly classified instances).", new PercentValue()));67 CalculateResults();68 40 } 69 41 70 public override IDeepCloneable Clone(Cloner cloner) { 71 return new ClassificationSolution(this, cloner); 42 public override IEnumerable<double> EstimatedClassValues { 43 get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); } 44 } 45 public override IEnumerable<double> EstimatedTrainingClassValues { 46 get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); } 47 } 48 public override IEnumerable<double> EstimatedTestClassValues { 49 get { return GetEstimatedClassValues(ProblemData.TestIndizes); } 72 50 } 73 51 74 protected override void RecalculateResults() { 75 CalculateResults(); 76 } 77 78 private void CalculateResults() { 79 double[] estimatedTrainingClassValues = EstimatedTrainingClassValues.ToArray(); // cache values 80 IEnumerable<double> originalTrainingClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 81 double[] estimatedTestClassValues = EstimatedTestClassValues.ToArray(); // cache values 82 IEnumerable<double> originalTestClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes); 83 84 OnlineCalculatorError errorState; 85 double trainingAccuracy = OnlineAccuracyCalculator.Calculate(estimatedTrainingClassValues, originalTrainingClassValues, out errorState); 86 if (errorState != OnlineCalculatorError.None) trainingAccuracy = double.NaN; 87 double testAccuracy = OnlineAccuracyCalculator.Calculate(estimatedTestClassValues, originalTestClassValues, out errorState); 88 if (errorState != OnlineCalculatorError.None) testAccuracy = double.NaN; 89 90 TrainingAccuracy = trainingAccuracy; 91 TestAccuracy = testAccuracy; 92 } 93 94 public virtual IEnumerable<double> EstimatedClassValues { 95 get { 96 return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); 97 } 98 } 99 100 public virtual IEnumerable<double> EstimatedTrainingClassValues { 101 get { 102 return GetEstimatedClassValues(ProblemData.TrainingIndizes); 103 } 104 } 105 106 public virtual IEnumerable<double> EstimatedTestClassValues { 107 get { 108 return GetEstimatedClassValues(ProblemData.TestIndizes); 109 } 110 } 111 112 public virtual IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 52 public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 113 53 return Model.GetEstimatedClassValues(ProblemData.Dataset, rows); 114 54 } -
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationSolution.cs
r6411 r6589 20 20 #endregion 21 21 22 using System;23 22 using System.Collections.Generic; 24 23 using System.Linq; 25 24 using HeuristicLab.Common; 26 25 using HeuristicLab.Core; 27 using HeuristicLab.Data;28 using HeuristicLab.Optimization;29 26 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 30 27 … … 35 32 [StorableClass] 36 33 [Item("DiscriminantFunctionClassificationSolution", "Represents a classification solution that uses a discriminant function and classification thresholds.")] 37 public class DiscriminantFunctionClassificationSolution : ClassificationSolution, IDiscriminantFunctionClassificationSolution { 38 private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)"; 39 private const string TestMeanSquaredErrorResultName = "Mean squared error (test)"; 40 private const string TrainingRSquaredResultName = "Pearson's R² (training)"; 41 private const string TestRSquaredResultName = "Pearson's R² (test)"; 42 43 public new IDiscriminantFunctionClassificationModel Model { 44 get { return (IDiscriminantFunctionClassificationModel)base.Model; } 45 protected set { 46 if (value != null && value != Model) { 47 if (Model != null) { 48 Model.ThresholdsChanged -= new EventHandler(Model_ThresholdsChanged); 49 } 50 value.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged); 51 base.Model = value; 52 } 53 } 54 } 55 56 public double TrainingMeanSquaredError { 57 get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; } 58 private set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; } 59 } 60 61 public double TestMeanSquaredError { 62 get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; } 63 private set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; } 64 } 65 66 public double TrainingRSquared { 67 get { return ((DoubleValue)this[TrainingRSquaredResultName].Value).Value; } 68 private set { ((DoubleValue)this[TrainingRSquaredResultName].Value).Value = value; } 69 } 70 71 public double TestRSquared { 72 get { return ((DoubleValue)this[TestRSquaredResultName].Value).Value; } 73 private set { ((DoubleValue)this[TestRSquaredResultName].Value).Value = value; } 74 } 34 public abstract class DiscriminantFunctionClassificationSolution : DiscriminantFunctionClassificationSolutionBase { 75 35 76 36 [StorableConstructor] … … 78 38 protected DiscriminantFunctionClassificationSolution(DiscriminantFunctionClassificationSolution original, Cloner cloner) 79 39 : base(original, cloner) { 80 RegisterEventHandler();81 40 } 82 p ublicDiscriminantFunctionClassificationSolution(IRegressionModel model, IClassificationProblemData problemData)41 protected DiscriminantFunctionClassificationSolution(IRegressionModel model, IClassificationProblemData problemData) 83 42 : this(new DiscriminantFunctionClassificationModel(model), problemData) { 84 43 } 85 p ublicDiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData)44 protected DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData) 86 45 : base(model, problemData) { 87 Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue()));88 Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue()));89 Add(new Result(TrainingRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue()));90 Add(new Result(TestRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue()));91 SetAccuracyMaximizingThresholds();92 93 //mkommend: important to recalculate accuracy because during the calculation before no thresholds were present94 base.RecalculateResults();95 CalculateResults();96 RegisterEventHandler();97 46 } 98 47 99 [StorableHook(HookType.AfterDeserialization)] 100 private void AfterDeserialization() { 101 RegisterEventHandler(); 48 public override IEnumerable<double> EstimatedClassValues { 49 get { return GetEstimatedClassValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); } 50 } 51 public override IEnumerable<double> EstimatedTrainingClassValues { 52 get { return GetEstimatedClassValues(ProblemData.TrainingIndizes); } 53 } 54 public override IEnumerable<double> EstimatedTestClassValues { 55 get { return GetEstimatedClassValues(ProblemData.TestIndizes); } 102 56 } 103 57 104 protected override void OnModelChanged(EventArgs e) { 105 DeregisterEventHandler(); 106 SetAccuracyMaximizingThresholds(); 107 RegisterEventHandler(); 108 base.OnModelChanged(e); 58 public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) { 59 return Model.GetEstimatedClassValues(ProblemData.Dataset, rows); 109 60 } 110 61 111 protected override void RecalculateResults() {112 base.RecalculateResults();113 CalculateResults();114 }115 62 116 private void CalculateResults() { 117 double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values 118 IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 119 double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values 120 IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes); 121 122 OnlineCalculatorError errorState; 123 double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 124 TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN; 125 double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 126 TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN; 127 128 double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState); 129 TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN; 130 double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(estimatedTestValues, originalTestValues, out errorState); 131 TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN; 132 } 133 134 private void RegisterEventHandler() { 135 Model.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged); 136 } 137 private void DeregisterEventHandler() { 138 Model.ThresholdsChanged -= new EventHandler(Model_ThresholdsChanged); 139 } 140 private void Model_ThresholdsChanged(object sender, EventArgs e) { 141 OnModelThresholdsChanged(e); 142 } 143 144 public void SetAccuracyMaximizingThresholds() { 145 double[] classValues; 146 double[] thresholds; 147 var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 148 AccuracyMaximizationThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds); 149 150 Model.SetThresholdsAndClassValues(thresholds, classValues); 151 } 152 153 public void SetClassDistibutionCutPointThresholds() { 154 double[] classValues; 155 double[] thresholds; 156 var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 157 NormalDistributionCutPointsThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds); 158 159 Model.SetThresholdsAndClassValues(thresholds, classValues); 160 } 161 162 protected virtual void OnModelThresholdsChanged(EventArgs e) { 163 RecalculateResults(); 164 } 165 166 public IEnumerable<double> EstimatedValues { 63 public override IEnumerable<double> EstimatedValues { 167 64 get { return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); } 168 65 } 169 170 public IEnumerable<double> EstimatedTrainingValues { 66 public override IEnumerable<double> EstimatedTrainingValues { 171 67 get { return GetEstimatedValues(ProblemData.TrainingIndizes); } 172 68 } 173 174 public IEnumerable<double> EstimatedTestValues { 69 public override IEnumerable<double> EstimatedTestValues { 175 70 get { return GetEstimatedValues(ProblemData.TestIndizes); } 176 71 } 177 72 178 public IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {73 public override IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) { 179 74 return Model.GetEstimatedValues(ProblemData.Dataset, rows); 180 75 }
Note: See TracChangeset
for help on using the changeset viewer.