Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Analyzers/RegressionSolutionAnalyzer.cs @ 4468

Last change on this file since 4468 was 4468, checked in by mkommend, 14 years ago

Preparation for cross validation - removed the test samples from the trainining samples and added ValidationPercentage parameter (ticket #1199).

File size: 10.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using HeuristicLab.Core;
24using HeuristicLab.Data;
25using HeuristicLab.Operators;
26using HeuristicLab.Optimization;
27using HeuristicLab.Parameters;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29using HeuristicLab.Problems.DataAnalysis.Evaluators;
30
31namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers {
32  [StorableClass]
33  public abstract class RegressionSolutionAnalyzer : SingleSuccessorOperator {
34    private const string ProblemDataParameterName = "ProblemData";
35    private const string QualityParameterName = "Quality";
36    private const string UpperEstimationLimitParameterName = "UpperEstimationLimit";
37    private const string LowerEstimationLimitParameterName = "LowerEstimationLimit";
38    private const string BestSolutionQualityParameterName = "BestSolutionQuality";
39    private const string GenerationsParameterName = "Generations";
40    private const string ResultsParameterName = "Results";
41    private const string BestSolutionResultName = "Best solution (on validiation set)";
42    private const string BestSolutionTrainingRSquared = "Best solution R² (training)";
43    private const string BestSolutionTestRSquared = "Best solution R² (test)";
44    private const string BestSolutionTrainingMse = "Best solution mean squared error (training)";
45    private const string BestSolutionTestMse = "Best solution mean squared error (test)";
46    private const string BestSolutionTrainingRelativeError = "Best solution average relative error (training)";
47    private const string BestSolutionTestRelativeError = "Best solution average relative error (test)";
48    private const string BestSolutionGeneration = "Best solution generation";
49
50    #region parameter properties
51    public IValueLookupParameter<DataAnalysisProblemData> ProblemDataParameter {
52      get { return (IValueLookupParameter<DataAnalysisProblemData>)Parameters[ProblemDataParameterName]; }
53    }
54    public ScopeTreeLookupParameter<DoubleValue> QualityParameter {
55      get { return (ScopeTreeLookupParameter<DoubleValue>)Parameters[QualityParameterName]; }
56    }
57    public IValueLookupParameter<DoubleValue> UpperEstimationLimitParameter {
58      get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperEstimationLimitParameterName]; }
59    }
60    public IValueLookupParameter<DoubleValue> LowerEstimationLimitParameter {
61      get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerEstimationLimitParameterName]; }
62    }
63    public ILookupParameter<DoubleValue> BestSolutionQualityParameter {
64      get { return (ILookupParameter<DoubleValue>)Parameters[BestSolutionQualityParameterName]; }
65    }
66    public ILookupParameter<ResultCollection> ResultsParameter {
67      get { return (ILookupParameter<ResultCollection>)Parameters[ResultsParameterName]; }
68    }
69    public ILookupParameter<IntValue> GenerationsParameter {
70      get { return (ILookupParameter<IntValue>)Parameters[GenerationsParameterName]; }
71    }
72    #endregion
73    #region properties
74    public DoubleValue UpperEstimationLimit {
75      get { return UpperEstimationLimitParameter.ActualValue; }
76    }
77    public DoubleValue LowerEstimationLimit {
78      get { return LowerEstimationLimitParameter.ActualValue; }
79    }
80    public ItemArray<DoubleValue> Quality {
81      get { return QualityParameter.ActualValue; }
82    }
83    public ResultCollection Results {
84      get { return ResultsParameter.ActualValue; }
85    }
86    public DataAnalysisProblemData ProblemData {
87      get { return ProblemDataParameter.ActualValue; }
88    }
89    #endregion
90
91    public RegressionSolutionAnalyzer()
92      : base() {
93      Parameters.Add(new ValueLookupParameter<DataAnalysisProblemData>(ProblemDataParameterName, "The problem data for which the symbolic expression tree is a solution."));
94      Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperEstimationLimitParameterName, "The upper estimation limit that was set for the evaluation of the symbolic expression trees."));
95      Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerEstimationLimitParameterName, "The lower estimation limit that was set for the evaluation of the symbolic expression trees."));
96      Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>(QualityParameterName, "The qualities of the symbolic regression trees which should be analyzed."));
97      Parameters.Add(new LookupParameter<DoubleValue>(BestSolutionQualityParameterName, "The quality of the best regression solution."));
98      Parameters.Add(new LookupParameter<IntValue>(GenerationsParameterName, "The number of generations calculated so far."));
99      Parameters.Add(new LookupParameter<ResultCollection>(ResultsParameterName, "The result collection where the best symbolic regression solution should be stored."));
100    }
101
102    [StorableHook(HookType.AfterDeserialization)]
103    private void Initialize() {
104      // backwards compatibility
105      if (!Parameters.ContainsKey(GenerationsParameterName)) {
106        Parameters.Add(new LookupParameter<IntValue>(GenerationsParameterName, "The number of generations calculated so far."));
107      }
108    }
109
110    public override IOperation Apply() {
111      DoubleValue prevBestSolutionQuality = BestSolutionQualityParameter.ActualValue;
112      var bestSolution = UpdateBestSolution();
113      if (prevBestSolutionQuality == null || prevBestSolutionQuality.Value > BestSolutionQualityParameter.ActualValue.Value) {
114        RegressionSolutionAnalyzer.UpdateBestSolutionResults(bestSolution, ProblemData, Results, GenerationsParameter.ActualValue);
115      }
116
117      return base.Apply();
118    }
119
120    public static void UpdateBestSolutionResults(DataAnalysisSolution bestSolution, DataAnalysisProblemData problemData, ResultCollection results, IntValue CurrentGeneration) {
121      var solution = bestSolution;
122      #region update R2,MSE, Rel Error
123      IEnumerable<double> trainingValues = problemData.Dataset.GetEnumeratedVariableValues(problemData.TargetVariable.Value, problemData.TrainingIndizes);
124      IEnumerable<double> testValues = problemData.Dataset.GetEnumeratedVariableValues(problemData.TargetVariable.Value, problemData.TestIndizes);
125      OnlineMeanSquaredErrorEvaluator mseEvaluator = new OnlineMeanSquaredErrorEvaluator();
126      OnlineMeanAbsolutePercentageErrorEvaluator relErrorEvaluator = new OnlineMeanAbsolutePercentageErrorEvaluator();
127      OnlinePearsonsRSquaredEvaluator r2Evaluator = new OnlinePearsonsRSquaredEvaluator();
128
129      #region training
130      var originalEnumerator = trainingValues.GetEnumerator();
131      var estimatedEnumerator = solution.EstimatedTrainingValues.GetEnumerator();
132      while (originalEnumerator.MoveNext() & estimatedEnumerator.MoveNext()) {
133        mseEvaluator.Add(originalEnumerator.Current, estimatedEnumerator.Current);
134        r2Evaluator.Add(originalEnumerator.Current, estimatedEnumerator.Current);
135        relErrorEvaluator.Add(originalEnumerator.Current, estimatedEnumerator.Current);
136      }
137      double trainingR2 = r2Evaluator.RSquared;
138      double trainingMse = mseEvaluator.MeanSquaredError;
139      double trainingRelError = relErrorEvaluator.MeanAbsolutePercentageError;
140      #endregion
141
142      mseEvaluator.Reset();
143      relErrorEvaluator.Reset();
144      r2Evaluator.Reset();
145
146      #region test
147      originalEnumerator = testValues.GetEnumerator();
148      estimatedEnumerator = solution.EstimatedTestValues.GetEnumerator();
149      while (originalEnumerator.MoveNext() & estimatedEnumerator.MoveNext()) {
150        mseEvaluator.Add(originalEnumerator.Current, estimatedEnumerator.Current);
151        r2Evaluator.Add(originalEnumerator.Current, estimatedEnumerator.Current);
152        relErrorEvaluator.Add(originalEnumerator.Current, estimatedEnumerator.Current);
153      }
154      double testR2 = r2Evaluator.RSquared;
155      double testMse = mseEvaluator.MeanSquaredError;
156      double testRelError = relErrorEvaluator.MeanAbsolutePercentageError;
157      #endregion
158
159      if (results.ContainsKey(BestSolutionResultName)) {
160        results[BestSolutionResultName].Value = solution;
161        results[BestSolutionTrainingRSquared].Value = new DoubleValue(trainingR2);
162        results[BestSolutionTestRSquared].Value = new DoubleValue(testR2);
163        results[BestSolutionTrainingMse].Value = new DoubleValue(trainingMse);
164        results[BestSolutionTestMse].Value = new DoubleValue(testMse);
165        results[BestSolutionTrainingRelativeError].Value = new DoubleValue(trainingRelError);
166        results[BestSolutionTestRelativeError].Value = new DoubleValue(testRelError);
167        if (CurrentGeneration != null) // this check is needed because linear regression solutions do not have a generations parameter
168          results[BestSolutionGeneration].Value = new IntValue(CurrentGeneration.Value);
169      } else {
170        results.Add(new Result(BestSolutionResultName, solution));
171        results.Add(new Result(BestSolutionTrainingRSquared, new DoubleValue(trainingR2)));
172        results.Add(new Result(BestSolutionTestRSquared, new DoubleValue(testR2)));
173        results.Add(new Result(BestSolutionTrainingMse, new DoubleValue(trainingMse)));
174        results.Add(new Result(BestSolutionTestMse, new DoubleValue(testMse)));
175        results.Add(new Result(BestSolutionTrainingRelativeError, new DoubleValue(trainingRelError)));
176        results.Add(new Result(BestSolutionTestRelativeError, new DoubleValue(testRelError)));
177        if (CurrentGeneration != null)
178          results.Add(new Result(BestSolutionGeneration, new IntValue(CurrentGeneration.Value)));
179      }
180      #endregion
181    }
182
183    protected abstract DataAnalysisSolution UpdateBestSolution();
184  }
185}
Note: See TracBrowser for help on using the repository browser.