Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Regression/RegressionSolution.cs @ 5894

Last change on this file since 5894 was 5894, checked in by gkronber, 13 years ago

#1453: Added an ErrorState property to online evaluators to indicate if the result value is valid or if there has been an error in the calculation. Adapted all classes that use one of the online evaluators to check this property.

File size: 7.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Data;
27using HeuristicLab.Optimization;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29
30namespace HeuristicLab.Problems.DataAnalysis {
31  /// <summary>
32  /// Abstract base class for regression data analysis solutions
33  /// </summary>
34  [StorableClass]
35  public abstract class RegressionSolution : DataAnalysisSolution, IRegressionSolution {
36    private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)";
37    private const string TestMeanSquaredErrorResultName = "Mean squared error (test)";
38    private const string TrainingSquaredCorrelationResultName = "Pearson's R² (training)";
39    private const string TestSquaredCorrelationResultName = "Pearson's R² (test)";
40    private const string TrainingRelativeErrorResultName = "Average relative error (training)";
41    private const string TestRelativeErrorResultName = "Average relative error (test)";
42
43    public new IRegressionModel Model {
44      get { return (IRegressionModel)base.Model; }
45      protected set { base.Model = value; }
46    }
47
48    public new IRegressionProblemData ProblemData {
49      get { return (IRegressionProblemData)base.ProblemData; }
50      protected set { base.ProblemData = value; }
51    }
52
53    public double TrainingMeanSquaredError {
54      get { return ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value; }
55      private set { ((DoubleValue)this[TrainingMeanSquaredErrorResultName].Value).Value = value; }
56    }
57
58    public double TestMeanSquaredError {
59      get { return ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value; }
60      private set { ((DoubleValue)this[TestMeanSquaredErrorResultName].Value).Value = value; }
61    }
62
63    public double TrainingRSquared {
64      get { return ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value; }
65      private set { ((DoubleValue)this[TrainingSquaredCorrelationResultName].Value).Value = value; }
66    }
67
68    public double TestRSquared {
69      get { return ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value; }
70      private set { ((DoubleValue)this[TestSquaredCorrelationResultName].Value).Value = value; }
71    }
72
73    public double TrainingRelativeError {
74      get { return ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value; }
75      private set { ((DoubleValue)this[TrainingRelativeErrorResultName].Value).Value = value; }
76    }
77
78    public double TestRelativeError {
79      get { return ((DoubleValue)this[TestRelativeErrorResultName].Value).Value; }
80      private set { ((DoubleValue)this[TestRelativeErrorResultName].Value).Value = value; }
81    }
82
83
84    [StorableConstructor]
85    protected RegressionSolution(bool deserializing) : base(deserializing) { }
86    protected RegressionSolution(RegressionSolution original, Cloner cloner)
87      : base(original, cloner) {
88    }
89    public RegressionSolution(IRegressionModel model, IRegressionProblemData problemData)
90      : base(model, problemData) {
91      Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue()));
92      Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue()));
93      Add(new Result(TrainingSquaredCorrelationResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue()));
94      Add(new Result(TestSquaredCorrelationResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue()));
95      Add(new Result(TrainingRelativeErrorResultName, "Average of the relative errors of the model output and the actual values on the training partition", new PercentValue()));
96      Add(new Result(TestRelativeErrorResultName, "Average of the relative errors of the model output and the actual values on the test partition", new PercentValue()));
97
98      RecalculateResults();
99    }
100
101    protected override void OnProblemDataChanged(EventArgs e) {
102      base.OnProblemDataChanged(e);
103      RecalculateResults();
104    }
105    protected override void OnModelChanged(EventArgs e) {
106      base.OnModelChanged(e);
107      RecalculateResults();
108    }
109
110    protected void RecalculateResults() {
111      double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values
112      IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
113      double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values
114      IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
115
116      OnlineEvaluatorError errorState;
117      double trainingMSE = OnlineMeanSquaredErrorEvaluator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
118      TrainingMeanSquaredError = errorState == OnlineEvaluatorError.None ? trainingMSE : double.NaN;
119      double testMSE = OnlineMeanSquaredErrorEvaluator.Calculate(estimatedTestValues, originalTestValues, out errorState);
120      TestMeanSquaredError = errorState == OnlineEvaluatorError.None ? testMSE : double.NaN;
121
122      double trainingR2 = OnlinePearsonsRSquaredEvaluator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
123      TrainingRSquared = errorState == OnlineEvaluatorError.None ? trainingR2 : double.NaN;
124      double testR2 = OnlinePearsonsRSquaredEvaluator.Calculate(estimatedTestValues, originalTestValues, out errorState);
125      TestRSquared = errorState == OnlineEvaluatorError.None ? testR2 : double.NaN;
126
127      double trainingRelError = OnlineMeanAbsolutePercentageErrorEvaluator.Calculate(estimatedTrainingValues, originalTrainingValues, out errorState);
128      TrainingRelativeError = errorState == OnlineEvaluatorError.None ? trainingRelError : double.NaN;
129      double testRelError = OnlineMeanAbsolutePercentageErrorEvaluator.Calculate(estimatedTestValues, originalTestValues, out errorState);
130      TestRelativeError = errorState == OnlineEvaluatorError.None ? testRelError : double.NaN;
131    }
132
133    public virtual IEnumerable<double> EstimatedValues {
134      get {
135        return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows));
136      }
137    }
138
139    public virtual IEnumerable<double> EstimatedTrainingValues {
140      get {
141        return GetEstimatedValues(ProblemData.TrainingIndizes);
142      }
143    }
144
145    public virtual IEnumerable<double> EstimatedTestValues {
146      get {
147        return GetEstimatedValues(ProblemData.TestIndizes);
148      }
149    }
150
151    public virtual IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
152      return Model.GetEstimatedValues(ProblemData.Dataset, rows);
153    }
154  }
155}
Note: See TracBrowser for help on using the repository browser.