Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3026_IntegrationIntoSymSpace/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/DiscriminantFunctionClassificationSolutionBase.cs @ 18027

Last change on this file since 18027 was 18027, checked in by dpiringe, 3 years ago

#3026

  • merged trunk into branch
File size: 10.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Optimization;
29using HEAL.Attic;
30
31namespace HeuristicLab.Problems.DataAnalysis {
32  /// <summary>
33  /// Represents a classification solution that uses a discriminant function and classification thresholds.
34  /// </summary>
35  [StorableType("3668EBE0-128C-4BC4-902C-161670F98FAD")]
36  [Item("DiscriminantFunctionClassificationSolution", "Represents a classification solution that uses a discriminant function and classification thresholds.")]
37  public abstract class DiscriminantFunctionClassificationSolutionBase : ClassificationSolutionBase, IDiscriminantFunctionClassificationSolution {
38    private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)";
39    private const string TestMeanSquaredErrorResultName = "Mean squared error (test)";
40    private const string TrainingRSquaredResultName = "Pearson's R² (training)";
41    private const string TestRSquaredResultName = "Pearson's R² (test)";
42    private const string TrainingNormalizedGiniCoefficientResultName = "Norm. Gini coeff. (training, discriminant values)";
43    private const string TestNormalizedGiniCoefficientResultName = "Norm. Gini coeff. (test, discriminant values)";
44
45
46    public new IDiscriminantFunctionClassificationModel Model {
47      get { return (IDiscriminantFunctionClassificationModel)base.Model; }
48      protected set {
49        if (value != null && value != Model) {
50          if (Model != null) {
51            Model.ThresholdsChanged -= new EventHandler(Model_ThresholdsChanged);
52          }
53          value.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged);
54          base.Model = value;
55        }
56      }
57    }
58
59    #region Results
60    public double TrainingMeanSquaredError {
61      get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; }
62      private set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; }
63    }
64    public double TestMeanSquaredError {
65      get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; }
66      private set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; }
67    }
68    public double TrainingRSquared {
69      get { return ((DoubleValue)this[TrainingRSquaredResultName].Value).Value; }
70      private set { ((DoubleValue)this[TrainingRSquaredResultName].Value).Value = value; }
71    }
72    public double TestRSquared {
73      get { return ((DoubleValue)this[TestRSquaredResultName].Value).Value; }
74      private set { ((DoubleValue)this[TestRSquaredResultName].Value).Value = value; }
75    }
76    public double TrainingNormalizedGiniCoefficientForDiscriminantValues {
77      get { return ((DoubleValue)this[TrainingNormalizedGiniCoefficientResultName].Value).Value; }
78      protected set { ((DoubleValue)this[TrainingNormalizedGiniCoefficientResultName].Value).Value = value; }
79    }
80    public double TestNormalizedGiniCoefficientForDiscriminantValues {
81      get { return ((DoubleValue)this[TestNormalizedGiniCoefficientResultName].Value).Value; }
82      protected set { ((DoubleValue)this[TestNormalizedGiniCoefficientResultName].Value).Value = value; }
83    }
84    #endregion
85
86    [StorableConstructor]
87    protected DiscriminantFunctionClassificationSolutionBase(StorableConstructorFlag _) : base(_) { }
88    protected DiscriminantFunctionClassificationSolutionBase(DiscriminantFunctionClassificationSolutionBase original, Cloner cloner)
89      : base(original, cloner) {
90      RegisterEventHandler();
91    }
92    protected DiscriminantFunctionClassificationSolutionBase(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData)
93      : base(model, problemData) {
94      Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue()));
95      Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue()));
96      Add(new Result(TrainingRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue()));
97      Add(new Result(TestRSquaredResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue()));
98      Add(new Result(TrainingNormalizedGiniCoefficientResultName, "Normalized Gini coefficient of the discriminant values produced by the model on the training partition.", new DoubleValue()));
99      Add(new Result(TestNormalizedGiniCoefficientResultName, "Normalized Gini coefficient of the discriminant values produced by the model on the test partition.", new DoubleValue()));
100      RegisterEventHandler();
101    }
102
103    [StorableHook(HookType.AfterDeserialization)]
104    private void AfterDeserialization() {
105      #region backwards compatibility
106      if (!ContainsKey(TrainingNormalizedGiniCoefficientResultName)) {
107        Add(new Result(TrainingNormalizedGiniCoefficientResultName, "Normalized Gini coefficient of the discriminant values produced by the model on the training partition.", new DoubleValue()));
108        Add(new Result(TestNormalizedGiniCoefficientResultName, "Normalized Gini coefficient of the discriminant values produced by the model on the test partition.", new DoubleValue()));
109        double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values
110        double[] originalTrainingValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices).ToArray();
111        double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values
112        double[] originalTestValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndices).ToArray();
113        double trainingNormalizedGini = NormalizedGiniCalculator.Calculate(originalTrainingValues, estimatedTrainingValues, out var errorState);
114        if (errorState != OnlineCalculatorError.None) trainingNormalizedGini = double.NaN;
115        double testNormalizedGini = NormalizedGiniCalculator.Calculate(originalTestValues, estimatedTestValues, out errorState);
116        if (errorState != OnlineCalculatorError.None) testNormalizedGini = double.NaN;
117
118        TrainingNormalizedGiniCoefficientForDiscriminantValues = trainingNormalizedGini;
119        TestNormalizedGiniCoefficientForDiscriminantValues = testNormalizedGini;
120      }
121      #endregion
122      RegisterEventHandler();
123    }
124
125    protected void CalculateRegressionResults() {
126      double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values
127      double[] originalTrainingValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices).ToArray();
128      double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values
129      double[] originalTestValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndices).ToArray();
130
131      OnlineCalculatorError errorState;
132      double trainingMSE = OnlineMeanSquaredErrorCalculator.Calculate(originalTrainingValues, estimatedTrainingValues, out errorState);
133      TrainingMeanSquaredError = errorState == OnlineCalculatorError.None ? trainingMSE : double.NaN;
134      double testMSE = OnlineMeanSquaredErrorCalculator.Calculate(originalTestValues, estimatedTestValues, out errorState);
135      TestMeanSquaredError = errorState == OnlineCalculatorError.None ? testMSE : double.NaN;
136
137      double trainingR = OnlinePearsonsRCalculator.Calculate(originalTrainingValues, estimatedTrainingValues, out errorState);
138      TrainingRSquared = errorState == OnlineCalculatorError.None ? trainingR * trainingR : double.NaN;
139      double testR = OnlinePearsonsRCalculator.Calculate(originalTestValues, estimatedTestValues, out errorState);
140      TestRSquared = errorState == OnlineCalculatorError.None ? testR * testR : double.NaN;
141
142      double trainingNormalizedGini = NormalizedGiniCalculator.Calculate(originalTrainingValues, estimatedTrainingValues, out errorState);
143      if (errorState != OnlineCalculatorError.None) trainingNormalizedGini = double.NaN;
144      double testNormalizedGini = NormalizedGiniCalculator.Calculate(originalTestValues, estimatedTestValues, out errorState);
145      if (errorState != OnlineCalculatorError.None) testNormalizedGini = double.NaN;
146
147      TrainingNormalizedGiniCoefficientForDiscriminantValues = trainingNormalizedGini;
148      TestNormalizedGiniCoefficientForDiscriminantValues = testNormalizedGini;
149    }
150
151    private void RegisterEventHandler() {
152      Model.ThresholdsChanged += new EventHandler(Model_ThresholdsChanged);
153    }
154    private void DeregisterEventHandler() {
155      Model.ThresholdsChanged -= new EventHandler(Model_ThresholdsChanged);
156    }
157    private void Model_ThresholdsChanged(object sender, EventArgs e) {
158      OnModelThresholdsChanged(e);
159    }
160
161    protected virtual void OnModelThresholdsChanged(EventArgs e) {
162      OnModelChanged();
163    }
164
165    public abstract IEnumerable<double> EstimatedValues { get; }
166    public abstract IEnumerable<double> EstimatedTrainingValues { get; }
167    public abstract IEnumerable<double> EstimatedTestValues { get; }
168
169    public abstract IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows);
170
171    protected override void RecalculateResults() {
172      base.RecalculateResults();
173      CalculateRegressionResults();
174    }
175  }
176}
Note: See TracBrowser for help on using the repository browser.