Changeset 5885
- Timestamp:
- 03/30/11 13:57:33 (14 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationSolution.cs
r5849 r5885 26 26 using HeuristicLab.Core; 27 27 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 28 using HeuristicLab.Data; 29 using HeuristicLab.Optimization; 28 30 29 31 namespace HeuristicLab.Problems.DataAnalysis { … … 34 36 [Item("DiscriminantFunctionClassificationSolution", "Represents a classification solution that uses a discriminant function and classification thresholds.")] 35 37 public class DiscriminantFunctionClassificationSolution : ClassificationSolution, IDiscriminantFunctionClassificationSolution { 38 private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)"; 39 private const string TestMeanSquaredErrorResultName = "Mean squared error (test)"; 40 private const string TrainingRSquaredResultName = "Pearson's R² (training)"; 41 private const string TestRSquaredResultName = "Pearson's R² (test)"; 42 36 43 public new IDiscriminantFunctionClassificationModel Model { 37 44 get { return (IDiscriminantFunctionClassificationModel)base.Model; } … … 47 54 } 48 55 56 public double TrainingMeanSquaredError { 57 get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; } 58 private set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; } 59 } 60 61 public double TestMeanSquaredError { 62 get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; } 63 private set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; } 64 } 65 66 public double TrainingRSquared { 67 get { return ((DoubleValue)this[TrainingRSquaredResultName].Value).Value; } 68 private set { ((DoubleValue)this[TrainingRSquaredResultName].Value).Value = value; } 69 } 70 71 public double TestRSquared { 72 get { return ((DoubleValue)this[TestRSquaredResultName].Value).Value; } 73 private set { ((DoubleValue)this[TestRSquaredResultName].Value).Value = value; } 74 } 75 49 76 [StorableConstructor] 50 77 protected DiscriminantFunctionClassificationSolution(bool deserializing) : base(deserializing) { } … … 58 85 public DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData) 59 86 : base(model, problemData) { 87 Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue())); 88 Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue())); 89 Add(new Result(TrainingRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue())); 90 Add(new Result(TestRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue())); 60 91 RegisterEventHandler(); 61 92 SetAccuracyMaximizingThresholds(); 93 RecalculateResults(); 62 94 } 63 95 … … 65 97 private void AfterDeserialization() { 66 98 RegisterEventHandler(); 99 } 100 101 protected void RecalculateResults() { 102 double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values 103 IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes); 104 double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values 105 IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes); 106 107 double trainingMSE = OnlineMeanSquaredErrorEvaluator.Calculate(estimatedTrainingValues, originalTrainingValues); 108 double testMSE = OnlineMeanSquaredErrorEvaluator.Calculate(estimatedTestValues, originalTestValues); 109 double trainingR2 = OnlinePearsonsRSquaredEvaluator.Calculate(estimatedTrainingValues, originalTrainingValues); 110 double testR2 = OnlinePearsonsRSquaredEvaluator.Calculate(estimatedTestValues, originalTestValues); 111 112 TrainingMeanSquaredError = trainingMSE; 113 TestMeanSquaredError = testMSE; 114 TrainingRSquared = trainingR2; 115 TestRSquared = testR2; 67 116 } 68 117 … … 95 144 base.OnModelChanged(e); 96 145 SetAccuracyMaximizingThresholds(); 146 RecalculateResults(); 97 147 } 98 148 … … 100 150 base.OnProblemDataChanged(e); 101 151 SetAccuracyMaximizingThresholds(); 152 RecalculateResults(); 102 153 } 103 154 protected virtual void OnModelThresholdsChanged(EventArgs e) { 104 155 base.OnModelChanged(e); 156 RecalculateResults(); 105 157 } 106 158
Note: See TracChangeset
for help on using the changeset viewer.