Free cookie consent management tool by TermsFeed Policy Generator

source: branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/RegressionSolution.cs @ 5681

Last change on this file since 5681 was 5649, checked in by gkronber, 14 years ago

#1418 Implemented classes for classification based on a discriminant function and thresholds and implemented interfaces and base classes for clustering.

File size: 5.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Data;
27using HeuristicLab.Operators;
28using HeuristicLab.Parameters;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30using HeuristicLab.Optimization;
31using System;
32
33namespace HeuristicLab.Problems.DataAnalysis {
34  /// <summary>
35  /// Abstract base class for regression data analysis solutions
36  /// </summary>
37  [StorableClass]
38  public abstract class RegressionSolution : DataAnalysisSolution, IRegressionSolution {
39    private const string TrainingMeanSquaredErrorResultName = "Mean squared error (training)";
40    private const string TestMeanSquaredErrorResultName = "Mean squared error (test)";
41    private const string TrainingSquaredCorrelationResultName = "Pearson's R² (training)";
42    private const string TestSquaredCorrelationResultName = "Pearson's R² (test)";
43    private const string TrainingRelativeErrorResultName = "Average relative error (training)";
44    private const string TestRelativeErrorResultName = "Average relative error (test)";
45
46    [StorableConstructor]
47    protected RegressionSolution(bool deserializing) : base(deserializing) { }
48    protected RegressionSolution(RegressionSolution original, Cloner cloner)
49      : base(original, cloner) {
50    }
51    public RegressionSolution(IRegressionModel model, IRegressionProblemData problemData)
52      : base(model, problemData) {
53      double[] estimatedTrainingValues = EstimatedTrainingValues.ToArray(); // cache values
54      IEnumerable<double> originalTrainingValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
55      double[] estimatedTestValues = EstimatedTestValues.ToArray(); // cache values
56      IEnumerable<double> originalTestValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
57
58      double trainingMSE = OnlineMeanSquaredErrorEvaluator.Calculate(estimatedTrainingValues, originalTrainingValues);
59      double testMSE = OnlineMeanSquaredErrorEvaluator.Calculate(estimatedTestValues, originalTestValues);
60      double trainingR2 = OnlinePearsonsRSquaredEvaluator.Calculate(estimatedTrainingValues, originalTrainingValues);
61      double testR2 = OnlinePearsonsRSquaredEvaluator.Calculate(estimatedTestValues, originalTestValues);
62      double trainingRelError = OnlineMeanAbsolutePercentageErrorEvaluator.Calculate(estimatedTrainingValues, originalTrainingValues);
63      double testRelError = OnlineMeanAbsolutePercentageErrorEvaluator.Calculate(estimatedTestValues, originalTestValues);
64
65      Add(new Result(TrainingMeanSquaredErrorResultName, "Mean of squared errors of the model on the training partition", new DoubleValue(trainingMSE)));
66      Add(new Result(TestMeanSquaredErrorResultName, "Mean of squared errors of the model on the test partition", new DoubleValue(testMSE)));
67      Add(new Result(TrainingSquaredCorrelationResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the training partition", new DoubleValue(trainingR2)));
68      Add(new Result(TestSquaredCorrelationResultName, "Squared Pearson's correlation coefficient of the model output and the actual values on the test partition", new DoubleValue(testR2)));
69      Add(new Result(TrainingRelativeErrorResultName, "Average of the relative errors of the model output and the actual values on the training partition", new PercentValue(trainingRelError)));
70      Add(new Result(TestRelativeErrorResultName, "Average of the relative errors of the model output and the actual values on the test partition", new PercentValue(testRelError)));
71    }
72
73    protected override void OnProblemDataChanged(EventArgs e) {
74      base.OnProblemDataChanged(e);
75      throw new NotImplementedException(); // need to recalculate results
76    }
77    protected override void OnModelChanged(EventArgs e) {
78      base.OnModelChanged(e);
79      throw new NotImplementedException(); // need to recalculate results
80    }
81    #region IRegressionSolution Members
82
83    public new IRegressionModel Model {
84      get { return (IRegressionModel)base.Model; }
85    }
86
87    public new IRegressionProblemData ProblemData {
88      get { return (IRegressionProblemData)base.ProblemData; }
89    }
90
91    public virtual IEnumerable<double> EstimatedValues {
92      get {
93        return GetEstimatedValues(Enumerable.Range(0, ProblemData.Dataset.Rows));
94      }
95    }
96
97    public virtual IEnumerable<double> EstimatedTrainingValues {
98      get {
99        return GetEstimatedValues(ProblemData.TrainingIndizes);
100      }
101    }
102
103    public virtual IEnumerable<double> EstimatedTestValues {
104      get {
105        return GetEstimatedValues(ProblemData.TestIndizes);
106      }
107    }
108
109    public virtual IEnumerable<double> GetEstimatedValues(IEnumerable<int> rows) {
110      return Model.GetEstimatedValues(ProblemData.Dataset, rows);
111    }
112    #endregion
113  }
114}
Note: See TracBrowser for help on using the repository browser.