Free cookie consent management tool by TermsFeed Policy Generator

source: branches/DataAnalysis/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Analyzers/RegressionSolutionAnalyzer.cs @ 11299

Last change on this file since 11299 was 5275, checked in by gkronber, 14 years ago

Merged changes from trunk to data analysis exploration branch and added fractional distance metric evaluator. #1142

File size: 10.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using HeuristicLab.Common;
24using HeuristicLab.Core;
25using HeuristicLab.Data;
26using HeuristicLab.Operators;
27using HeuristicLab.Optimization;
28using HeuristicLab.Parameters;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30using HeuristicLab.Problems.DataAnalysis.Evaluators;
31
32namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers {
33  [StorableClass]
34  public abstract class RegressionSolutionAnalyzer : SingleSuccessorOperator {
35    private const string ProblemDataParameterName = "ProblemData";
36    private const string QualityParameterName = "Quality";
37    private const string UpperEstimationLimitParameterName = "UpperEstimationLimit";
38    private const string LowerEstimationLimitParameterName = "LowerEstimationLimit";
39    private const string BestSolutionQualityParameterName = "BestSolutionQuality";
40    private const string GenerationsParameterName = "Generations";
41    private const string ResultsParameterName = "Results";
42    private const string BestSolutionResultName = "Best solution (on validation set)";
43    private const string BestSolutionTrainingRSquared = "Best solution R² (training)";
44    private const string BestSolutionTestRSquared = "Best solution R² (test)";
45    private const string BestSolutionTrainingMse = "Best solution mean squared error (training)";
46    private const string BestSolutionTestMse = "Best solution mean squared error (test)";
47    private const string BestSolutionTrainingRelativeError = "Best solution average relative error (training)";
48    private const string BestSolutionTestRelativeError = "Best solution average relative error (test)";
49    private const string BestSolutionGeneration = "Best solution generation";
50
51    #region parameter properties
52    public IValueLookupParameter<DataAnalysisProblemData> ProblemDataParameter {
53      get { return (IValueLookupParameter<DataAnalysisProblemData>)Parameters[ProblemDataParameterName]; }
54    }
55    public ScopeTreeLookupParameter<DoubleValue> QualityParameter {
56      get { return (ScopeTreeLookupParameter<DoubleValue>)Parameters[QualityParameterName]; }
57    }
58    public IValueLookupParameter<DoubleValue> UpperEstimationLimitParameter {
59      get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperEstimationLimitParameterName]; }
60    }
61    public IValueLookupParameter<DoubleValue> LowerEstimationLimitParameter {
62      get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerEstimationLimitParameterName]; }
63    }
64    public ILookupParameter<DoubleValue> BestSolutionQualityParameter {
65      get { return (ILookupParameter<DoubleValue>)Parameters[BestSolutionQualityParameterName]; }
66    }
67    public ILookupParameter<ResultCollection> ResultsParameter {
68      get { return (ILookupParameter<ResultCollection>)Parameters[ResultsParameterName]; }
69    }
70    public ILookupParameter<IntValue> GenerationsParameter {
71      get { return (ILookupParameter<IntValue>)Parameters[GenerationsParameterName]; }
72    }
73    #endregion
74    #region properties
75    public DoubleValue UpperEstimationLimit {
76      get { return UpperEstimationLimitParameter.ActualValue; }
77    }
78    public DoubleValue LowerEstimationLimit {
79      get { return LowerEstimationLimitParameter.ActualValue; }
80    }
81    public ItemArray<DoubleValue> Quality {
82      get { return QualityParameter.ActualValue; }
83    }
84    public ResultCollection Results {
85      get { return ResultsParameter.ActualValue; }
86    }
87    public DataAnalysisProblemData ProblemData {
88      get { return ProblemDataParameter.ActualValue; }
89    }
90    #endregion
91
92
93    [StorableConstructor]
94    protected RegressionSolutionAnalyzer(bool deserializing) : base(deserializing) { }
95    protected RegressionSolutionAnalyzer(RegressionSolutionAnalyzer original, Cloner cloner) : base(original, cloner) { }
96    public RegressionSolutionAnalyzer()
97      : base() {
98      Parameters.Add(new ValueLookupParameter<DataAnalysisProblemData>(ProblemDataParameterName, "The problem data for which the symbolic expression tree is a solution."));
99      Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperEstimationLimitParameterName, "The upper estimation limit that was set for the evaluation of the symbolic expression trees."));
100      Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerEstimationLimitParameterName, "The lower estimation limit that was set for the evaluation of the symbolic expression trees."));
101      Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>(QualityParameterName, "The qualities of the symbolic regression trees which should be analyzed."));
102      Parameters.Add(new LookupParameter<DoubleValue>(BestSolutionQualityParameterName, "The quality of the best regression solution."));
103      Parameters.Add(new LookupParameter<IntValue>(GenerationsParameterName, "The number of generations calculated so far."));
104      Parameters.Add(new LookupParameter<ResultCollection>(ResultsParameterName, "The result collection where the best symbolic regression solution should be stored."));
105    }
106
107    [StorableHook(HookType.AfterDeserialization)]
108    private void AfterDeserialization() {
109      // backwards compatibility
110      if (!Parameters.ContainsKey(GenerationsParameterName)) {
111        Parameters.Add(new LookupParameter<IntValue>(GenerationsParameterName, "The number of generations calculated so far."));
112      }
113    }
114
115    public override IOperation Apply() {
116      DoubleValue prevBestSolutionQuality = BestSolutionQualityParameter.ActualValue;
117      var bestSolution = UpdateBestSolution();
118      if (prevBestSolutionQuality == null || prevBestSolutionQuality.Value > BestSolutionQualityParameter.ActualValue.Value) {
119        RegressionSolutionAnalyzer.UpdateBestSolutionResults(bestSolution, ProblemData, Results, GenerationsParameter.ActualValue);
120      }
121
122      return base.Apply();
123    }
124
125    public static void UpdateBestSolutionResults(DataAnalysisSolution solution, DataAnalysisProblemData problemData, ResultCollection results, IntValue generation) {
126      #region update R2,MSE, Rel Error
127      IEnumerable<double> trainingValues = problemData.Dataset.GetEnumeratedVariableValues(problemData.TargetVariable.Value, problemData.TrainingIndizes);
128      IEnumerable<double> testValues = problemData.Dataset.GetEnumeratedVariableValues(problemData.TargetVariable.Value, problemData.TestIndizes);
129      OnlineMeanSquaredErrorEvaluator mseEvaluator = new OnlineMeanSquaredErrorEvaluator();
130      OnlineMeanAbsolutePercentageErrorEvaluator relErrorEvaluator = new OnlineMeanAbsolutePercentageErrorEvaluator();
131      OnlinePearsonsRSquaredEvaluator r2Evaluator = new OnlinePearsonsRSquaredEvaluator();
132
133      #region training
134      var originalEnumerator = trainingValues.GetEnumerator();
135      var estimatedEnumerator = solution.EstimatedTrainingValues.GetEnumerator();
136      while (originalEnumerator.MoveNext() & estimatedEnumerator.MoveNext()) {
137        mseEvaluator.Add(originalEnumerator.Current, estimatedEnumerator.Current);
138        r2Evaluator.Add(originalEnumerator.Current, estimatedEnumerator.Current);
139        relErrorEvaluator.Add(originalEnumerator.Current, estimatedEnumerator.Current);
140      }
141      double trainingR2 = r2Evaluator.RSquared;
142      double trainingMse = mseEvaluator.MeanSquaredError;
143      double trainingRelError = relErrorEvaluator.MeanAbsolutePercentageError;
144      #endregion
145
146      mseEvaluator.Reset();
147      relErrorEvaluator.Reset();
148      r2Evaluator.Reset();
149
150      #region test
151      originalEnumerator = testValues.GetEnumerator();
152      estimatedEnumerator = solution.EstimatedTestValues.GetEnumerator();
153      while (originalEnumerator.MoveNext() & estimatedEnumerator.MoveNext()) {
154        mseEvaluator.Add(originalEnumerator.Current, estimatedEnumerator.Current);
155        r2Evaluator.Add(originalEnumerator.Current, estimatedEnumerator.Current);
156        relErrorEvaluator.Add(originalEnumerator.Current, estimatedEnumerator.Current);
157      }
158      double testR2 = r2Evaluator.RSquared;
159      double testMse = mseEvaluator.MeanSquaredError;
160      double testRelError = relErrorEvaluator.MeanAbsolutePercentageError;
161      #endregion
162
163      if (results.ContainsKey(BestSolutionResultName)) {
164        results[BestSolutionResultName].Value = solution;
165        results[BestSolutionTrainingRSquared].Value = new DoubleValue(trainingR2);
166        results[BestSolutionTestRSquared].Value = new DoubleValue(testR2);
167        results[BestSolutionTrainingMse].Value = new DoubleValue(trainingMse);
168        results[BestSolutionTestMse].Value = new DoubleValue(testMse);
169        results[BestSolutionTrainingRelativeError].Value = new DoubleValue(trainingRelError);
170        results[BestSolutionTestRelativeError].Value = new DoubleValue(testRelError);
171        if (generation != null) // this check is needed because linear regression solutions do not have a generations parameter
172          results[BestSolutionGeneration].Value = new IntValue(generation.Value);
173      } else {
174        results.Add(new Result(BestSolutionResultName, solution));
175        results.Add(new Result(BestSolutionTrainingRSquared, new DoubleValue(trainingR2)));
176        results.Add(new Result(BestSolutionTestRSquared, new DoubleValue(testR2)));
177        results.Add(new Result(BestSolutionTrainingMse, new DoubleValue(trainingMse)));
178        results.Add(new Result(BestSolutionTestMse, new DoubleValue(testMse)));
179        results.Add(new Result(BestSolutionTrainingRelativeError, new DoubleValue(trainingRelError)));
180        results.Add(new Result(BestSolutionTestRelativeError, new DoubleValue(testRelError)));
181        if (generation != null)
182          results.Add(new Result(BestSolutionGeneration, new IntValue(generation.Value)));
183      }
184      #endregion
185    }
186
187    protected abstract DataAnalysisSolution UpdateBestSolution();
188  }
189}
Note: See TracBrowser for help on using the repository browser.