Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationSolution.cs @ 5894

Last change on this file since 5894 was 5894, checked in by gkronber, 13 years ago

#1453: Added an ErrorState property to online evaluators to indicate if the result value is valid or if there has been an error in the calculation. Adapted all classes that use one of the online evaluators to check this property.

File size: 8.5 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Data;
29using HeuristicLab.Optimization;
30
31namespace HeuristicLab.Problems.DataAnalysis {
32  /// <summary>
33  /// Represents a classification solution that uses a discriminant function and classification thresholds.
34  /// </summary>
35  [StorableClass]
36  [Item("DiscriminantFunctionClassificationSolution", "Represents a classification solution that uses a discriminant function and classification thresholds.")]
37  public class DiscriminantFunctionClassificationSolution : ClassificationSolution, IDiscriminantFunctionClassificationSolution {
38    private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)";
39    private const string TestMeanSquaredErrorResultName = "Mean squared error (test)";
40    private const string TrainingRSquaredResultName = "Pearson's R² (training)";
41    private const string TestRSquaredResultName = "Pearson's R² (test)";
42
43    public new IDiscriminantFunctionClassificationModel Model {
44      get { return (IDiscriminantFunctionClassificationModel)base.Model; }
45      protected set {
46        if (value != null && value != Model) {
47          if (Model != null) {
48            Model.ThresholdsChanged -= new EventHandler(Model_ThresholdsChanged);
49          }
50          value.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged);
51          base.Model = value;
52        }
53      }
54    }
55
56    public double TrainingMeanSquaredError {
57      get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; }
58      private set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; }
59    }
60
61    public double TestMeanSquaredError {
62      get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; }
63      private set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; }
64    }
65
66    public double TrainingRSquared {
67      get { return ((DoubleValue)this[TrainingRSquaredResultName].Value).Value; }
68      private set { ((DoubleValue)this[TrainingRSquaredResultName].Value).Value = value; }
69    }
70
71    public double TestRSquared {
72      get { return ((DoubleValue)this[TestRSquaredResultName].Value).Value; }
73      private set { ((DoubleValue)this[TestRSquaredResultName].Value).Value = value; }
74    }
75
76    [StorableConstructor]
77    protected DiscriminantFunctionClassificationSolution(bool deserializing) : base(deserializing) { }
78    protected DiscriminantFunctionClassificationSolution(DiscriminantFunctionClassificationSolution original, Cloner cloner)
79      : base(original, cloner) {
80      RegisterEventHandler();
81    }
82    public DiscriminantFunctionClassificationSolution(IRegressionModel model, IClassificationProblemData problemData)
83      : this(new DiscriminantFunctionClassificationModel(model), problemData) {
84    }
85    public DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData)
86      : base(model, problemData) {
87      Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue()));
88      Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue()));
89      Add(new Result(TrainingRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue()));
90      Add(new Result(TestRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue()));
91      RegisterEventHandler();
92      SetAccuracyMaximizingThresholds();
93      RecalculateResults();
94    }
95
96    [StorableHook(HookType.AfterDeserialization)]
97    private void AfterDeserialization() {
98      RegisterEventHandler();
99    }
100
101    protected new void RecalculateResults() {
102      double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values
103      IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
104      double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values
105      IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
106
107      OnlineEvaluatorError errorState;
108      double trainingMSE = OnlineMeanSquaredErrorEvaluator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
109      TrainingMeanSquaredError = errorState == OnlineEvaluatorError.None ? trainingMSE : double.NaN;
110      double testMSE = OnlineMeanSquaredErrorEvaluator.Calculate(estimatedTestValues, originalTestValues, out errorState);
111      TestMeanSquaredError = errorState == OnlineEvaluatorError.None ? testMSE : double.NaN;
112
113      double trainingR2 = OnlinePearsonsRSquaredEvaluator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
114      TrainingRSquared = errorState == OnlineEvaluatorError.None ? trainingR2 : double.NaN;
115      double testR2 = OnlinePearsonsRSquaredEvaluator.Calculate(estimatedTestValues, originalTestValues, out errorState);
116      TestRSquared = errorState == OnlineEvaluatorError.None ? testR2 : double.NaN;
117    }
118
119    private void RegisterEventHandler() {
120      Model.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged);
121    }
122    private void Model_ThresholdsChanged(object sender, EventArgs e) {
123      OnModelThresholdsChanged(e);
124    }
125
126    public void SetAccuracyMaximizingThresholds() {
127      double[] classValues;
128      double[] thresholds;
129      var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
130      AccuracyMaximizationThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds);
131
132      Model.SetThresholdsAndClassValues(thresholds, classValues);
133    }
134
135    public void SetClassDistibutionCutPointThresholds() {
136      double[] classValues;
137      double[] thresholds;
138      var targetClassValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
139      NormalDistributionCutPointsThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds);
140
141      Model.SetThresholdsAndClassValues(thresholds, classValues);
142    }
143
144    protected override void OnModelChanged(EventArgs e) {
145      base.OnModelChanged(e);
146      SetAccuracyMaximizingThresholds();
147      RecalculateResults();
148    }
149
150    protected override void OnProblemDataChanged(EventArgs e) {
151      base.OnProblemDataChanged(e);
152      SetAccuracyMaximizingThresholds();
153      RecalculateResults();
154    }
155    protected virtual void OnModelThresholdsChanged(EventArgs e) {
156      base.OnModelChanged(e);
157      RecalculateResults();
158    }
159
160    public IEnumerable<double> EstimatedValues {
161      get { return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows)); }
162    }
163
164    public IEnumerable<double> EstimatedTrainingValues {
165      get { return GetEstimatedValues(ProblemData.TrainingIndizes); }
166    }
167
168    public IEnumerable<double> EstimatedTestValues {
169      get { return GetEstimatedValues(ProblemData.TestIndizes); }
170    }
171
172    public IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
173      return Model.GetEstimatedValues(ProblemData.Dataset, rows);
174    }
175  }
176}
Note: See TracBrowser for help on using the repository browser.