source: trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationSolutionBase.cs @ 7234

Last change on this file since 7234 was 7234, checked in by gkronber, 11 years ago

#1685: changed simplification view for symbolic classification solutions to use the Gini index to determine the impact of a node it describes the degree of separation of the classes and we do not have to search for the optimal threshold value each time we calculate the impact of one node. Also fixed a problem with the Gini index result of classification solutions using a discriminating function as for these solutions the Gini index is calculated twice (once for the class values and once for the output values of the discriminating function.)

File size: 8.5 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Optimization;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30
31namespace HeuristicLab.Problems.DataAnalysis {
32  /// <summary>
33  /// Represents a classification solution that uses a discriminant function and classification thresholds.
34  /// </summary>
35  [StorableClass]
36  [Item("DiscriminantFunctionClassificationSolution", "Represents a classification solution that uses a discriminant function and classification thresholds.")]
37  public abstract class DiscriminantFunctionClassificationSolutionBase : ClassificationSolutionBase, IDiscriminantFunctionClassificationSolution {
38    private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)";
39    private const string TestMeanSquaredErrorResultName = "Mean squared error (test)";
40    private const string TrainingRSquaredResultName = "Pearson's R² (training)";
41    private const string TestRSquaredResultName = "Pearson's R² (test)";
42
43    public new IDiscriminantFunctionClassificationModel Model {
44      get { return (IDiscriminantFunctionClassificationModel)base.Model; }
45      protected set {
46        if (value != null && value != Model) {
47          if (Model != null) {
48            Model.ThresholdsChanged -= new EventHandler(Model_ThresholdsChanged);
49          }
50          value.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged);
51          base.Model = value;
52        }
53      }
54    }
55
56    #region Results
57    public double TrainingMeanSquaredError {
58      get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; }
59      private set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; }
60    }
61    public double TestMeanSquaredError {
62      get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; }
63      private set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; }
64    }
65    public double TrainingRSquared {
66      get { return ((DoubleValue)this[TrainingRSquaredResultName].Value).Value; }
67      private set { ((DoubleValue)this[TrainingRSquaredResultName].Value).Value = value; }
68    }
69    public double TestRSquared {
70      get { return ((DoubleValue)this[TestRSquaredResultName].Value).Value; }
71      private set { ((DoubleValue)this[TestRSquaredResultName].Value).Value = value; }
72    }
73    #endregion
74
75    [StorableConstructor]
76    protected DiscriminantFunctionClassificationSolutionBase(bool deserializing) : base(deserializing) { }
77    protected DiscriminantFunctionClassificationSolutionBase(DiscriminantFunctionClassificationSolutionBase original, Cloner cloner)
78      : base(original, cloner) {
79      RegisterEventHandler();
80    }
81    protected DiscriminantFunctionClassificationSolutionBase(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData)
82      : base(model, problemData) {
83      Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue()));
84      Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue()));
85      Add(new Result(TrainingRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue()));
86      Add(new Result(TestRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue()));
87
88      RegisterEventHandler();
89    }
90
91    [StorableHook(HookType.AfterDeserialization)]
92    private void AfterDeserialization() {
93      RegisterEventHandler();
94    }
95
96    protected override void OnModelChanged() {
97      DeregisterEventHandler();
98      SetAccuracyMaximizingThresholds();
99      RegisterEventHandler();
100      base.OnModelChanged();
101    }
102
103    protected void CalculateRegressionResults() {
104      double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values
105      double[] originalTrainingValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes).ToArray();
106      double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values
107      double[] originalTestValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndizes).ToArray();
108
109      OnlineCalculatorError errorState;
110      double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(originalTrainingValues, estimatedTrainingValues, out errorState);
111      TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN;
112      double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(originalTestValues, estimatedTestValues, out errorState);
113      TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN;
114
115      double trainingR2 = OnlinePearsonsRSquaredCalculator.Calculate(originalTrainingValues, estimatedTrainingValues, out errorState);
116      TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR2 : double.NaN;
117      double testR2 = OnlinePearsonsRSquaredCalculator.Calculate(originalTestValues, estimatedTestValues, out errorState);
118      TestRSquared = errorState == OnlineCalculatorError.None ? testR2 : double.NaN;
119
120      double trainingNormalizedGini = NormalizedGiniCalculator.Calculate(originalTrainingValues, estimatedTrainingValues, out errorState);
121      if (errorState != OnlineCalculatorError.None) trainingNormalizedGini = double.NaN;
122      double testNormalizedGini = NormalizedGiniCalculator.Calculate(originalTestValues, estimatedTestValues, out errorState);
123      if (errorState != OnlineCalculatorError.None) testNormalizedGini = double.NaN;
124
125      TrainingNormalizedGiniCoefficient = trainingNormalizedGini;
126      TestNormalizedGiniCoefficient = testNormalizedGini;
127    }
128
129    private void RegisterEventHandler() {
130      Model.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged);
131    }
132    private void DeregisterEventHandler() {
133      Model.ThresholdsChanged -= new EventHandler(Model_ThresholdsChanged);
134    }
135    private void Model_ThresholdsChanged(object sender, EventArgs e) {
136      OnModelThresholdsChanged(e);
137    }
138
139    public void SetAccuracyMaximizingThresholds() {
140      double[] classValues;
141      double[] thresholds;
142      var targetClassValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
143      AccuracyMaximizationThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds);
144
145      Model.SetThresholdsAndClassValues(thresholds, classValues);
146    }
147
148    public void SetClassDistibutionCutPointThresholds() {
149      double[] classValues;
150      double[] thresholds;
151      var targetClassValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
152      NormalDistributionCutPointsThresholdCalculator.CalculateThresholds(ProblemData, EstimatedTrainingValues, targetClassValues, out classValues, out thresholds);
153
154      Model.SetThresholdsAndClassValues(thresholds, classValues);
155    }
156
157    protected virtual void OnModelThresholdsChanged(EventArgs e) {
158      CalculateResults();
159      CalculateRegressionResults();
160    }
161
162    public abstract IEnumerable<double> EstimatedValues { get; }
163    public abstract IEnumerable<double> EstimatedTrainingValues { get; }
164    public abstract IEnumerable<double> EstimatedTestValues { get; }
165
166    public abstract IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows);
167  }
168}
Note: See TracBrowser for help on using the repository browser.