Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers/SymbolicRegressionModelQualityAnalyzer.cs @ 3996

Last change on this file since 3996 was 3996, checked in by gkronber, 14 years ago

Improved efficiency of analyzers and evaluators for regression problems. #1074

File size: 17.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Linq;
23using HeuristicLab.Common;
24using HeuristicLab.Core;
25using HeuristicLab.Data;
26using HeuristicLab.Operators;
27using HeuristicLab.Optimization;
28using HeuristicLab.Parameters;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
31using HeuristicLab.Problems.DataAnalysis.Regression.Symbolic;
32using HeuristicLab.Problems.DataAnalysis.Symbolic;
33using System.Collections.Generic;
34using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
35using HeuristicLab.Problems.DataAnalysis;
36using HeuristicLab.Analysis;
37using HeuristicLab.Problems.DataAnalysis.Evaluators;
38using HeuristicLab.Optimization.Operators;
39using System;
40
41namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers {
42  /// <summary>
43  /// "An operator for analyzing the quality of symbolic regression solutions symbolic expression tree encoding."
44  /// </summary>
45  [Item("SymbolicRegressionModelQualityAnalyzer", "An operator for analyzing the quality of symbolic regression solutions symbolic expression tree encoding.")]
46  [StorableClass]
47  public sealed class SymbolicRegressionModelQualityAnalyzer : SingleSuccessorOperator, ISymbolicRegressionAnalyzer {
48    private const string SymbolicExpressionTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
49    private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree";
50    private const string ProblemDataParameterName = "ProblemData";
51    private const string ResultsParameterName = "Results";
52
53    private const string TrainingMeanSquaredErrorQualityParameterName = "Mean squared error (training)";
54    private const string MinTrainingMeanSquaredErrorQualityParameterName = "Min mean squared error (training)";
55    private const string MaxTrainingMeanSquaredErrorQualityParameterName = "Max mean squared error (training)";
56    private const string AverageTrainingMeanSquaredErrorQualityParameterName = "Average mean squared error (training)";
57    private const string BestTrainingMeanSquaredErrorQualityParameterName = "Best mean squared error (training)";
58
59    private const string TrainingAverageRelativeErrorQualityParameterName = "Average relative error (training)";
60    private const string MinTrainingAverageRelativeErrorQualityParameterName = "Min average relative error (training)";
61    private const string MaxTrainingAverageRelativeErrorQualityParameterName = "Max average relative error (training)";
62    private const string AverageTrainingAverageRelativeErrorQualityParameterName = "Average average relative error (training)";
63    private const string BestTrainingAverageRelativeErrorQualityParameterName = "Best average relative error (training)";
64
65    private const string TrainingRSquaredQualityParameterName = "R² (training)";
66    private const string MinTrainingRSquaredQualityParameterName = "Min R² (training)";
67    private const string MaxTrainingRSquaredQualityParameterName = "Max R² (training)";
68    private const string AverageTrainingRSquaredQualityParameterName = "Average R² (training)";
69    private const string BestTrainingRSquaredQualityParameterName = "Best R² (training)";
70
71    private const string TestMeanSquaredErrorQualityParameterName = "Mean squared error (test)";
72    private const string MinTestMeanSquaredErrorQualityParameterName = "Min mean squared error (test)";
73    private const string MaxTestMeanSquaredErrorQualityParameterName = "Max mean squared error (test)";
74    private const string AverageTestMeanSquaredErrorQualityParameterName = "Average mean squared error (test)";
75    private const string BestTestMeanSquaredErrorQualityParameterName = "Best mean squared error (test)";
76
77    private const string TestAverageRelativeErrorQualityParameterName = "Average relative error (test)";
78    private const string MinTestAverageRelativeErrorQualityParameterName = "Min average relative error (test)";
79    private const string MaxTestAverageRelativeErrorQualityParameterName = "Max average relative error (test)";
80    private const string AverageTestAverageRelativeErrorQualityParameterName = "Average average relative error (test)";
81    private const string BestTestAverageRelativeErrorQualityParameterName = "Best average relative error (test)";
82
83    private const string TestRSquaredQualityParameterName = "R² (test)";
84    private const string MinTestRSquaredQualityParameterName = "Min R² (test)";
85    private const string MaxTestRSquaredQualityParameterName = "Max R² (test)";
86    private const string AverageTestRSquaredQualityParameterName = "Average R² (test)";
87    private const string BestTestRSquaredQualityParameterName = "Best R² (test)";
88
89    private const string RSquaredValuesParameterName = "R²";
90    private const string MeanSquaredErrorValuesParameterName = "Mean squared error";
91    private const string RelativeErrorValuesParameterName = "Average relative error";
92
93    private const string TrainingSamplesStartParameterName = "TrainingSamplesStart";
94    private const string TrainingSamplesEndParameterName = "TrainingSamplesEnd";
95    private const string TestSamplesStartParameterName = "TestSamplesStart";
96    private const string TestSamplesEndParameterName = "TestSamplesEnd";
97    private const string UpperEstimationLimitParameterName = "UpperEstimationLimit";
98    private const string LowerEstimationLimitParameterName = "LowerEstimationLimit";
99
100    #region parameter properties
101    public ScopeTreeLookupParameter<SymbolicExpressionTree> SymbolicExpressionTreeParameter {
102      get { return (ScopeTreeLookupParameter<SymbolicExpressionTree>)Parameters[SymbolicExpressionTreeParameterName]; }
103    }
104    public IValueLookupParameter<ISymbolicExpressionTreeInterpreter> SymbolicExpressionTreeInterpreterParameter {
105      get { return (IValueLookupParameter<ISymbolicExpressionTreeInterpreter>)Parameters[SymbolicExpressionTreeInterpreterParameterName]; }
106    }
107    public IValueLookupParameter<DataAnalysisProblemData> ProblemDataParameter {
108      get { return (IValueLookupParameter<DataAnalysisProblemData>)Parameters[ProblemDataParameterName]; }
109    }
110    public IValueLookupParameter<IntValue> TrainingSamplesStartParameter {
111      get { return (IValueLookupParameter<IntValue>)Parameters[TrainingSamplesStartParameterName]; }
112    }
113    public IValueLookupParameter<IntValue> TrainingSamplesEndParameter {
114      get { return (IValueLookupParameter<IntValue>)Parameters[TrainingSamplesEndParameterName]; }
115    }
116    public IValueLookupParameter<IntValue> TestSamplesStartParameter {
117      get { return (IValueLookupParameter<IntValue>)Parameters[TestSamplesStartParameterName]; }
118    }
119    public IValueLookupParameter<IntValue> TestSamplesEndParameter {
120      get { return (IValueLookupParameter<IntValue>)Parameters[TestSamplesEndParameterName]; }
121    }
122    public IValueLookupParameter<DoubleValue> UpperEstimationLimitParameter {
123      get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperEstimationLimitParameterName]; }
124    }
125    public IValueLookupParameter<DoubleValue> LowerEstimationLimitParameter {
126      get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerEstimationLimitParameterName]; }
127    }
128    public ILookupParameter<ResultCollection> ResultsParameter {
129      get { return (ILookupParameter<ResultCollection>)Parameters[ResultsParameterName]; }
130    }
131    #endregion
132    #region properties
133    public DoubleValue UpperEstimationLimit {
134      get { return UpperEstimationLimitParameter.ActualValue; }
135    }
136    public DoubleValue LowerEstimationLimit {
137      get { return LowerEstimationLimitParameter.ActualValue; }
138    }
139    #endregion
140
141    public SymbolicRegressionModelQualityAnalyzer()
142      : base() {
143      Parameters.Add(new ScopeTreeLookupParameter<SymbolicExpressionTree>(SymbolicExpressionTreeParameterName, "The symbolic expression trees to analyze."));
144      Parameters.Add(new ValueLookupParameter<ISymbolicExpressionTreeInterpreter>(SymbolicExpressionTreeInterpreterParameterName, "The interpreter that should be used to calculate the output values of the symbolic expression tree."));
145      Parameters.Add(new ValueLookupParameter<DataAnalysisProblemData>(ProblemDataParameterName, "The problem data containing the input varaibles for the symbolic regression problem."));
146      Parameters.Add(new ValueLookupParameter<IntValue>(TrainingSamplesStartParameterName, "The first index of the training data set partition on which the model quality values should be calculated."));
147      Parameters.Add(new ValueLookupParameter<IntValue>(TrainingSamplesEndParameterName, "The last index of the training data set partition on which the model quality values should be calculated."));
148      Parameters.Add(new ValueLookupParameter<IntValue>(TestSamplesStartParameterName, "The first index of the test data set partition on which the model quality values should be calculated."));
149      Parameters.Add(new ValueLookupParameter<IntValue>(TestSamplesEndParameterName, "The last index of the test data set partition on which the model quality values should be calculated."));
150      Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperEstimationLimitParameterName, "The upper limit that should be used as cut off value for the output values of symbolic expression trees."));
151      Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerEstimationLimitParameterName, "The lower limit that should be used as cut off value for the output values of symbolic expression trees."));
152      Parameters.Add(new ValueLookupParameter<DataTable>(MeanSquaredErrorValuesParameterName, "The data table to collect mean squared error values."));
153      Parameters.Add(new ValueLookupParameter<DataTable>(RSquaredValuesParameterName, "The data table to collect R² correlation coefficient values."));
154      Parameters.Add(new ValueLookupParameter<DataTable>(RelativeErrorValuesParameterName, "The data table to collect relative error values."));
155      Parameters.Add(new LookupParameter<ResultCollection>(ResultsParameterName, "The result collection where the best symbolic regression solution should be stored."));
156    }
157
158    [StorableConstructor]
159    private SymbolicRegressionModelQualityAnalyzer(bool deserializing) : base() { }
160
161    public override IOperation Apply() {
162      Analyze(SymbolicExpressionTreeParameter.ActualValue, SymbolicExpressionTreeInterpreterParameter.ActualValue,
163        UpperEstimationLimit.Value, LowerEstimationLimit.Value, ProblemDataParameter.ActualValue,
164        TrainingSamplesStartParameter.ActualValue.Value, TrainingSamplesEndParameter.ActualValue.Value,
165        TestSamplesStartParameter.ActualValue.Value, TestSamplesEndParameter.ActualValue.Value,
166        ResultsParameter.ActualValue);
167      return base.Apply();
168    }
169
170    public static void Analyze(IEnumerable<SymbolicExpressionTree> trees, ISymbolicExpressionTreeInterpreter interpreter,
171      double upperEstimationLimit, double lowerEstimationLimit,
172      DataAnalysisProblemData problemData, int trainingStart, int trainingEnd, int testStart, int testEnd, ResultCollection results) {
173      int targetVariableIndex = problemData.Dataset.GetVariableIndex(problemData.TargetVariable.Value);
174      IEnumerable<double> originalTrainingValues = problemData.Dataset.GetEnumeratedVariableValues(targetVariableIndex, trainingStart, trainingEnd);
175      IEnumerable<double> originalTestValues = problemData.Dataset.GetEnumeratedVariableValues(targetVariableIndex, testStart, testEnd);
176      List<double> trainingMse = new List<double>();
177      List<double> trainingR2 = new List<double>();
178      List<double> trainingRelErr = new List<double>();
179      List<double> testMse = new List<double>();
180      List<double> testR2 = new List<double>();
181      List<double> testRelErr = new List<double>();
182
183      OnlineMeanSquaredErrorEvaluator mseEvaluator = new OnlineMeanSquaredErrorEvaluator();
184      OnlineMeanAbsolutePercentageErrorEvaluator relErrEvaluator = new OnlineMeanAbsolutePercentageErrorEvaluator();
185      OnlinePearsonsRSquaredEvaluator r2Evaluator = new OnlinePearsonsRSquaredEvaluator();
186
187      foreach (var tree in trees) {
188        #region training
189        var estimatedTrainingValues = interpreter.GetSymbolicExpressionTreeValues(tree, problemData.Dataset, Enumerable.Range(trainingStart, trainingEnd - trainingStart));
190        mseEvaluator.Reset();
191        r2Evaluator.Reset();
192        relErrEvaluator.Reset();
193        var estimatedEnumerator = estimatedTrainingValues.GetEnumerator();
194        var originalEnumerator = originalTrainingValues.GetEnumerator();
195        while (estimatedEnumerator.MoveNext() & originalEnumerator.MoveNext()) {
196          double estimated = estimatedEnumerator.Current;
197          if (double.IsNaN(estimated)) estimated = upperEstimationLimit;
198          else estimated = Math.Min(upperEstimationLimit, Math.Max(lowerEstimationLimit, estimated));
199          mseEvaluator.Add(originalEnumerator.Current, estimated);
200          r2Evaluator.Add(originalEnumerator.Current, estimated);
201          relErrEvaluator.Add(originalEnumerator.Current, estimated);
202        }
203        if (estimatedEnumerator.MoveNext() || originalEnumerator.MoveNext()) {
204          throw new InvalidOperationException("Number of elements in estimated and original enumeration doesn't match.");
205        }
206        trainingMse.Add(mseEvaluator.MeanSquaredError);
207        trainingR2.Add(r2Evaluator.RSquared);
208        trainingRelErr.Add(relErrEvaluator.MeanAbsolutePercentageError);
209        #endregion
210        #region test
211        var estimatedTestValues = interpreter.GetSymbolicExpressionTreeValues(tree, problemData.Dataset, Enumerable.Range(testStart, testEnd - testStart));
212
213        mseEvaluator.Reset();
214        r2Evaluator.Reset();
215        relErrEvaluator.Reset();
216        estimatedEnumerator = estimatedTestValues.GetEnumerator();
217        originalEnumerator = originalTestValues.GetEnumerator();
218        while (estimatedEnumerator.MoveNext() & originalEnumerator.MoveNext()) {
219          double estimated = estimatedEnumerator.Current;
220          if (double.IsNaN(estimated)) estimated = upperEstimationLimit;
221          else estimated = Math.Min(upperEstimationLimit, Math.Max(lowerEstimationLimit, estimated));
222          mseEvaluator.Add(originalEnumerator.Current, estimated);
223          r2Evaluator.Add(originalEnumerator.Current, estimated);
224          relErrEvaluator.Add(originalEnumerator.Current, estimated);
225        }
226        if (estimatedEnumerator.MoveNext() || originalEnumerator.MoveNext()) {
227          throw new InvalidOperationException("Number of elements in estimated and original enumeration doesn't match.");
228        }
229        testMse.Add(mseEvaluator.MeanSquaredError);
230        testR2.Add(r2Evaluator.RSquared);
231        testRelErr.Add(relErrEvaluator.MeanAbsolutePercentageError);
232        #endregion
233      }
234
235      AddResultTableValues(results, MeanSquaredErrorValuesParameterName, "mean squared error (training)", trainingMse.Min(), trainingMse.Average(), trainingMse.Max());
236      AddResultTableValues(results, MeanSquaredErrorValuesParameterName, "mean squared error (test)", testMse.Min(), testMse.Average(), testMse.Max());
237      AddResultTableValues(results, RelativeErrorValuesParameterName, "mean relative error (training)", trainingRelErr.Min(), trainingRelErr.Average(), trainingRelErr.Max());
238      AddResultTableValues(results, RelativeErrorValuesParameterName, "mean relative error (test)", testRelErr.Min(), testRelErr.Average(), testRelErr.Max());
239      AddResultTableValues(results, RSquaredValuesParameterName, "Pearson's R² (training)", trainingR2.Min(), trainingR2.Average(), trainingR2.Max());
240      AddResultTableValues(results, RSquaredValuesParameterName, "Pearson's R² (test)", testR2.Min(), testR2.Average(), testR2.Max());
241    }
242
243    private static void AddResultTableValues(ResultCollection results, string tableName, string valueName, double minValue, double avgValue, double maxValue) {
244      if (!results.ContainsKey(tableName)) {
245        results.Add(new Result(tableName, new DataTable(tableName)));
246      }
247      DataTable table = (DataTable)results[tableName].Value;
248      AddValue(table, minValue, "Min. " + valueName, string.Empty);
249      AddValue(table, avgValue, "Avg. " + valueName, string.Empty);
250      AddValue(table, maxValue, "Max. " + valueName, string.Empty);
251    }
252
253    private static void AddValue(DataTable table, double data, string name, string description) {
254      DataRow row;
255      table.Rows.TryGetValue(name, out row);
256      if (row == null) {
257        row = new DataRow(name, description);
258        row.Values.Add(data);
259        table.Rows.Add(row);
260      } else {
261        row.Values.Add(data);
262      }
263    }
264
265
266    private static void SetResultValue(ResultCollection results, string name, double value) {
267      if (results.ContainsKey(name))
268        results[name].Value = new DoubleValue(value);
269      else
270        results.Add(new Result(name, new DoubleValue(value)));
271    }
272  }
273}
Note: See TracBrowser for help on using the repository browser.