Free cookie consent management tool by TermsFeed Policy Generator

source: branches/DataAnalysis/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Analyzers/RegressionSolutionAnalyzer.cs @ 4308

Last change on this file since 4308 was 4308, checked in by gkronber, 14 years ago

Fixed typos. #1142

File size: 11.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using HeuristicLab.Core;
24using HeuristicLab.Data;
25using HeuristicLab.Operators;
26using HeuristicLab.Optimization;
27using HeuristicLab.Parameters;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29using HeuristicLab.Problems.DataAnalysis.Evaluators;
30using HeuristicLab.Analysis;
31
32namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers {
33  [StorableClass]
34  public abstract class RegressionSolutionAnalyzer : SingleSuccessorOperator {
35    private const string ProblemDataParameterName = "ProblemData";
36    private const string QualityParameterName = "Quality";
37    private const string UpperEstimationLimitParameterName = "UpperEstimationLimit";
38    private const string LowerEstimationLimitParameterName = "LowerEstimationLimit";
39    private const string BestSolutionQualityParameterName = "BestSolutionQuality";
40    private const string GenerationsParameterName = "Generations";
41    private const string ResultsParameterName = "Results";
42    private const string BestSolutionResultName = "Best solution (on validation set)";
43    private const string BestSolutionTrainingRSquared = "Best solution R² (training)";
44    private const string BestSolutionTestRSquared = "Best solution R² (test)";
45    private const string BestSolutionTrainingMse = "Best solution mean squared error (training)";
46    private const string BestSolutionTestMse = "Best solution mean squared error (test)";
47    private const string BestSolutionTrainingRelativeError = "Best solution average relative error (training)";
48    private const string BestSolutionTestRelativeError = "Best solution average relative error (test)";
49    private const string BestSolutionGeneration = "Best solution generation";
50
51    #region parameter properties
52    public IValueLookupParameter<DataAnalysisProblemData> ProblemDataParameter {
53      get { return (IValueLookupParameter<DataAnalysisProblemData>)Parameters[ProblemDataParameterName]; }
54    }
55    public ScopeTreeLookupParameter<DoubleValue> QualityParameter {
56      get { return (ScopeTreeLookupParameter<DoubleValue>)Parameters[QualityParameterName]; }
57    }
58    public IValueLookupParameter<DoubleValue> UpperEstimationLimitParameter {
59      get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperEstimationLimitParameterName]; }
60    }
61    public IValueLookupParameter<DoubleValue> LowerEstimationLimitParameter {
62      get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerEstimationLimitParameterName]; }
63    }
64    public ILookupParameter<DoubleValue> BestSolutionQualityParameter {
65      get { return (ILookupParameter<DoubleValue>)Parameters[BestSolutionQualityParameterName]; }
66    }
67    public ILookupParameter<ResultCollection> ResultsParameter {
68      get { return (ILookupParameter<ResultCollection>)Parameters[ResultsParameterName]; }
69    }
70    public ILookupParameter<IntValue> GenerationsParameter {
71      get { return (ILookupParameter<IntValue>)Parameters[GenerationsParameterName]; }
72    }
73    #endregion
74    #region properties
75    public DoubleValue UpperEstimationLimit {
76      get { return UpperEstimationLimitParameter.ActualValue; }
77    }
78    public DoubleValue LowerEstimationLimit {
79      get { return LowerEstimationLimitParameter.ActualValue; }
80    }
81    public ItemArray<DoubleValue> Quality {
82      get { return QualityParameter.ActualValue; }
83    }
84    public ResultCollection Results {
85      get { return ResultsParameter.ActualValue; }
86    }
87    public DataAnalysisProblemData ProblemData {
88      get { return ProblemDataParameter.ActualValue; }
89    }
90    #endregion
91
92    public RegressionSolutionAnalyzer()
93      : base() {
94      Parameters.Add(new ValueLookupParameter<DataAnalysisProblemData>(ProblemDataParameterName, "The problem data for which the symbolic expression tree is a solution."));
95      Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperEstimationLimitParameterName, "The upper estimation limit that was set for the evaluation of the symbolic expression trees."));
96      Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerEstimationLimitParameterName, "The lower estimation limit that was set for the evaluation of the symbolic expression trees."));
97      Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>(QualityParameterName, "The qualities of the symbolic regression trees which should be analyzed."));
98      Parameters.Add(new LookupParameter<DoubleValue>(BestSolutionQualityParameterName, "The quality of the best regression solution."));
99      Parameters.Add(new LookupParameter<IntValue>(GenerationsParameterName, "The number of generations calculated so far."));
100      Parameters.Add(new LookupParameter<ResultCollection>(ResultsParameterName, "The result collection where the best symbolic regression solution should be stored."));
101    }
102
103    [StorableHook(HookType.AfterDeserialization)]
104    private void Initialize() {
105      // backwards compatibility
106      if (!Parameters.ContainsKey(GenerationsParameterName)) {
107        Parameters.Add(new LookupParameter<IntValue>(GenerationsParameterName, "The number of generations calculated so far."));
108      }
109    }
110
111    public override IOperation Apply() {
112      DoubleValue prevBestSolutionQuality = BestSolutionQualityParameter.ActualValue;
113      var bestSolution = UpdateBestSolution();
114      if (prevBestSolutionQuality == null || prevBestSolutionQuality.Value > BestSolutionQualityParameter.ActualValue.Value) {
115        RegressionSolutionAnalyzer.UpdateBestSolutionResults(bestSolution, ProblemData, Results, GenerationsParameter.ActualValue);
116      }
117
118      return base.Apply();
119    }
120
121    public static void UpdateBestSolutionResults(DataAnalysisSolution bestSolution, DataAnalysisProblemData problemData, ResultCollection results, IntValue CurrentGeneration) {
122      var solution = bestSolution;
123      #region update R2,MSE, Rel Error
124      IEnumerable<double> trainingValues = problemData.Dataset.GetEnumeratedVariableValues(
125        problemData.TargetVariable.Value,
126        problemData.TrainingSamplesStart.Value,
127        problemData.TrainingSamplesEnd.Value);
128      IEnumerable<double> testValues = problemData.Dataset.GetEnumeratedVariableValues(
129        problemData.TargetVariable.Value,
130        problemData.TestSamplesStart.Value,
131        problemData.TestSamplesEnd.Value);
132      OnlineMeanSquaredErrorEvaluator mseEvaluator = new OnlineMeanSquaredErrorEvaluator();
133      OnlineMeanAbsolutePercentageErrorEvaluator relErrorEvaluator = new OnlineMeanAbsolutePercentageErrorEvaluator();
134      OnlinePearsonsRSquaredEvaluator r2Evaluator = new OnlinePearsonsRSquaredEvaluator();
135      #region training
136      var originalEnumerator = trainingValues.GetEnumerator();
137      var estimatedEnumerator = solution.EstimatedTrainingValues.GetEnumerator();
138      while (originalEnumerator.MoveNext() & estimatedEnumerator.MoveNext()) {
139        mseEvaluator.Add(originalEnumerator.Current, estimatedEnumerator.Current);
140        r2Evaluator.Add(originalEnumerator.Current, estimatedEnumerator.Current);
141        relErrorEvaluator.Add(originalEnumerator.Current, estimatedEnumerator.Current);
142      }
143      double trainingR2 = r2Evaluator.RSquared;
144      double trainingMse = mseEvaluator.MeanSquaredError;
145      double trainingRelError = relErrorEvaluator.MeanAbsolutePercentageError;
146      #endregion
147      mseEvaluator.Reset();
148      relErrorEvaluator.Reset();
149      r2Evaluator.Reset();
150      #region test
151      originalEnumerator = testValues.GetEnumerator();
152      estimatedEnumerator = solution.EstimatedTestValues.GetEnumerator();
153      while (originalEnumerator.MoveNext() & estimatedEnumerator.MoveNext()) {
154        mseEvaluator.Add(originalEnumerator.Current, estimatedEnumerator.Current);
155        r2Evaluator.Add(originalEnumerator.Current, estimatedEnumerator.Current);
156        relErrorEvaluator.Add(originalEnumerator.Current, estimatedEnumerator.Current);
157      }
158      double testR2 = r2Evaluator.RSquared;
159      double testMse = mseEvaluator.MeanSquaredError;
160      double testRelError = relErrorEvaluator.MeanAbsolutePercentageError;
161      #endregion
162      if (results.ContainsKey(BestSolutionResultName)) {
163        results[BestSolutionResultName].Value = solution;
164        results[BestSolutionTrainingRSquared].Value = new DoubleValue(trainingR2);
165        results[BestSolutionTestRSquared].Value = new DoubleValue(testR2);
166        results[BestSolutionTrainingMse].Value = new DoubleValue(trainingMse);
167        results[BestSolutionTestMse].Value = new DoubleValue(testMse);
168        results[BestSolutionTrainingRelativeError].Value = new DoubleValue(trainingRelError);
169        results[BestSolutionTestRelativeError].Value = new DoubleValue(testRelError);
170        if (CurrentGeneration != null) // this check is needed because linear regression solutions do not have a generations parameter
171          results[BestSolutionGeneration].Value = new IntValue(CurrentGeneration.Value);
172        var solutionQualityTable = (DataTable)results["Best solution quality table"].Value;
173        solutionQualityTable.Rows["Training R²"].Values.Add(trainingR2);
174        solutionQualityTable.Rows["Training MSE"].Values.Add(trainingMse);
175        solutionQualityTable.Rows["Test R²"].Values.Add(testR2);
176        solutionQualityTable.Rows["Test MSE"].Values.Add(testMse);
177      } else {
178        results.Add(new Result(BestSolutionResultName, solution));
179        results.Add(new Result(BestSolutionTrainingRSquared, new DoubleValue(trainingR2)));
180        results.Add(new Result(BestSolutionTestRSquared, new DoubleValue(testR2)));
181        results.Add(new Result(BestSolutionTrainingMse, new DoubleValue(trainingMse)));
182        results.Add(new Result(BestSolutionTestMse, new DoubleValue(testMse)));
183        results.Add(new Result(BestSolutionTrainingRelativeError, new DoubleValue(trainingRelError)));
184        results.Add(new Result(BestSolutionTestRelativeError, new DoubleValue(testRelError)));
185        if (CurrentGeneration != null)
186          results.Add(new Result(BestSolutionGeneration, new IntValue(CurrentGeneration.Value)));
187        var solutionQualityTable = new DataTable("Best solution quality table");
188        solutionQualityTable.Rows.Add(new DataRow("Training R²"));
189        solutionQualityTable.Rows.Add(new DataRow("Training MSE"));
190        solutionQualityTable.Rows.Add(new DataRow("Test R²"));
191        solutionQualityTable.Rows.Add(new DataRow("Test MSE"));
192        solutionQualityTable.Rows["Training R²"].Values.Add(trainingR2);
193        solutionQualityTable.Rows["Training MSE"].Values.Add(trainingMse);
194        solutionQualityTable.Rows["Test R²"].Values.Add(testR2);
195        solutionQualityTable.Rows["Test MSE"].Values.Add(testMse);
196        results.Add(new Result("Best solution quality table", solutionQualityTable));
197      }
198      #endregion
199    }
200
201    protected abstract DataAnalysisSolution UpdateBestSolution();
202  }
203}
Note: See TracBrowser for help on using the repository browser.