Free cookie consent management tool by TermsFeed Policy Generator

source: branches/DataAnalysis/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers/SymbolicRegressionModelQualityAnalyzer.cs @ 7624

Last change on this file since 7624 was 5275, checked in by gkronber, 14 years ago

Merged changes from trunk to data analysis exploration branch and added fractional distance metric evaluator. #1142

File size: 15.6 KB
RevLine 
[3652]1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
[4068]22using System;
23using System.Collections.Generic;
[3652]24using System.Linq;
[4068]25using HeuristicLab.Analysis;
[5275]26using HeuristicLab.Common;
[3652]27using HeuristicLab.Core;
28using HeuristicLab.Data;
[4068]29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
[3652]30using HeuristicLab.Operators;
31using HeuristicLab.Optimization;
32using HeuristicLab.Parameters;
33using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
[4068]34using HeuristicLab.Problems.DataAnalysis.Evaluators;
[3652]35using HeuristicLab.Problems.DataAnalysis.Symbolic;
36
37namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers {
38  /// <summary>
39  /// "An operator for analyzing the quality of symbolic regression solutions symbolic expression tree encoding."
40  /// </summary>
[3681]41  [Item("SymbolicRegressionModelQualityAnalyzer", "An operator for analyzing the quality of symbolic regression solutions symbolic expression tree encoding.")]
[3652]42  [StorableClass]
[3996]43  public sealed class SymbolicRegressionModelQualityAnalyzer : SingleSuccessorOperator, ISymbolicRegressionAnalyzer {
[3652]44    private const string SymbolicExpressionTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
45    private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree";
46    private const string ProblemDataParameterName = "ProblemData";
47    private const string ResultsParameterName = "Results";
[3666]48
[3710]49    private const string TrainingMeanSquaredErrorQualityParameterName = "Mean squared error (training)";
50    private const string MinTrainingMeanSquaredErrorQualityParameterName = "Min mean squared error (training)";
51    private const string MaxTrainingMeanSquaredErrorQualityParameterName = "Max mean squared error (training)";
52    private const string AverageTrainingMeanSquaredErrorQualityParameterName = "Average mean squared error (training)";
53    private const string BestTrainingMeanSquaredErrorQualityParameterName = "Best mean squared error (training)";
[3666]54
[3710]55    private const string TrainingAverageRelativeErrorQualityParameterName = "Average relative error (training)";
56    private const string MinTrainingAverageRelativeErrorQualityParameterName = "Min average relative error (training)";
57    private const string MaxTrainingAverageRelativeErrorQualityParameterName = "Max average relative error (training)";
58    private const string AverageTrainingAverageRelativeErrorQualityParameterName = "Average average relative error (training)";
59    private const string BestTrainingAverageRelativeErrorQualityParameterName = "Best average relative error (training)";
[3666]60
[3710]61    private const string TrainingRSquaredQualityParameterName = "R² (training)";
62    private const string MinTrainingRSquaredQualityParameterName = "Min R² (training)";
63    private const string MaxTrainingRSquaredQualityParameterName = "Max R² (training)";
64    private const string AverageTrainingRSquaredQualityParameterName = "Average R² (training)";
65    private const string BestTrainingRSquaredQualityParameterName = "Best R² (training)";
[3666]66
[3710]67    private const string TestMeanSquaredErrorQualityParameterName = "Mean squared error (test)";
68    private const string MinTestMeanSquaredErrorQualityParameterName = "Min mean squared error (test)";
69    private const string MaxTestMeanSquaredErrorQualityParameterName = "Max mean squared error (test)";
70    private const string AverageTestMeanSquaredErrorQualityParameterName = "Average mean squared error (test)";
71    private const string BestTestMeanSquaredErrorQualityParameterName = "Best mean squared error (test)";
[3666]72
[3710]73    private const string TestAverageRelativeErrorQualityParameterName = "Average relative error (test)";
74    private const string MinTestAverageRelativeErrorQualityParameterName = "Min average relative error (test)";
75    private const string MaxTestAverageRelativeErrorQualityParameterName = "Max average relative error (test)";
76    private const string AverageTestAverageRelativeErrorQualityParameterName = "Average average relative error (test)";
77    private const string BestTestAverageRelativeErrorQualityParameterName = "Best average relative error (test)";
[3666]78
[3710]79    private const string TestRSquaredQualityParameterName = "R² (test)";
80    private const string MinTestRSquaredQualityParameterName = "Min R² (test)";
81    private const string MaxTestRSquaredQualityParameterName = "Max R² (test)";
82    private const string AverageTestRSquaredQualityParameterName = "Average R² (test)";
83    private const string BestTestRSquaredQualityParameterName = "Best R² (test)";
[3666]84
[3710]85    private const string RSquaredValuesParameterName = "R²";
86    private const string MeanSquaredErrorValuesParameterName = "Mean squared error";
87    private const string RelativeErrorValuesParameterName = "Average relative error";
[3666]88
[3652]89    private const string UpperEstimationLimitParameterName = "UpperEstimationLimit";
90    private const string LowerEstimationLimitParameterName = "LowerEstimationLimit";
91
92    #region parameter properties
[3681]93    public ScopeTreeLookupParameter<SymbolicExpressionTree> SymbolicExpressionTreeParameter {
94      get { return (ScopeTreeLookupParameter<SymbolicExpressionTree>)Parameters[SymbolicExpressionTreeParameterName]; }
[3652]95    }
[3681]96    public IValueLookupParameter<ISymbolicExpressionTreeInterpreter> SymbolicExpressionTreeInterpreterParameter {
97      get { return (IValueLookupParameter<ISymbolicExpressionTreeInterpreter>)Parameters[SymbolicExpressionTreeInterpreterParameterName]; }
[3652]98    }
[3681]99    public IValueLookupParameter<DataAnalysisProblemData> ProblemDataParameter {
100      get { return (IValueLookupParameter<DataAnalysisProblemData>)Parameters[ProblemDataParameterName]; }
[3652]101    }
102    public IValueLookupParameter<DoubleValue> UpperEstimationLimitParameter {
103      get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperEstimationLimitParameterName]; }
104    }
105    public IValueLookupParameter<DoubleValue> LowerEstimationLimitParameter {
106      get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerEstimationLimitParameterName]; }
107    }
[3681]108    public ILookupParameter<ResultCollection> ResultsParameter {
109      get { return (ILookupParameter<ResultCollection>)Parameters[ResultsParameterName]; }
110    }
[3652]111    #endregion
[3996]112    #region properties
113    public DoubleValue UpperEstimationLimit {
114      get { return UpperEstimationLimitParameter.ActualValue; }
115    }
116    public DoubleValue LowerEstimationLimit {
117      get { return LowerEstimationLimitParameter.ActualValue; }
118    }
119    #endregion
[3652]120
[5275]121    [StorableConstructor]
122    private SymbolicRegressionModelQualityAnalyzer(bool deserializing) : base(deserializing) { }
123    private SymbolicRegressionModelQualityAnalyzer(SymbolicRegressionModelQualityAnalyzer original, Cloner cloner) : base(original, cloner) { }
[3681]124    public SymbolicRegressionModelQualityAnalyzer()
[3652]125      : base() {
[3659]126      Parameters.Add(new ScopeTreeLookupParameter<SymbolicExpressionTree>(SymbolicExpressionTreeParameterName, "The symbolic expression trees to analyze."));
[3681]127      Parameters.Add(new ValueLookupParameter<ISymbolicExpressionTreeInterpreter>(SymbolicExpressionTreeInterpreterParameterName, "The interpreter that should be used to calculate the output values of the symbolic expression tree."));
128      Parameters.Add(new ValueLookupParameter<DataAnalysisProblemData>(ProblemDataParameterName, "The problem data containing the input varaibles for the symbolic regression problem."));
[3652]129      Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperEstimationLimitParameterName, "The upper limit that should be used as cut off value for the output values of symbolic expression trees."));
130      Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerEstimationLimitParameterName, "The lower limit that should be used as cut off value for the output values of symbolic expression trees."));
[3681]131      Parameters.Add(new ValueLookupParameter<DataTable>(MeanSquaredErrorValuesParameterName, "The data table to collect mean squared error values."));
132      Parameters.Add(new ValueLookupParameter<DataTable>(RSquaredValuesParameterName, "The data table to collect R² correlation coefficient values."));
133      Parameters.Add(new ValueLookupParameter<DataTable>(RelativeErrorValuesParameterName, "The data table to collect relative error values."));
134      Parameters.Add(new LookupParameter<ResultCollection>(ResultsParameterName, "The result collection where the best symbolic regression solution should be stored."));
[3996]135    }
[3652]136
[5275]137    public override IDeepCloneable Clone(Cloner cloner) {
138      return new SymbolicRegressionModelQualityAnalyzer(this, cloner);
139    }
[3681]140
[3996]141    public override IOperation Apply() {
142      Analyze(SymbolicExpressionTreeParameter.ActualValue, SymbolicExpressionTreeInterpreterParameter.ActualValue,
143        UpperEstimationLimit.Value, LowerEstimationLimit.Value, ProblemDataParameter.ActualValue,
144        ResultsParameter.ActualValue);
145      return base.Apply();
146    }
[3681]147
[3996]148    public static void Analyze(IEnumerable<SymbolicExpressionTree> trees, ISymbolicExpressionTreeInterpreter interpreter,
149      double upperEstimationLimit, double lowerEstimationLimit,
[5275]150      DataAnalysisProblemData problemData, ResultCollection results) {
[3996]151      int targetVariableIndex = problemData.Dataset.GetVariableIndex(problemData.TargetVariable.Value);
[5275]152      IEnumerable<double> originalTrainingValues = problemData.Dataset.GetEnumeratedVariableValues(targetVariableIndex, problemData.TrainingIndizes);
153      IEnumerable<double> originalTestValues = problemData.Dataset.GetEnumeratedVariableValues(targetVariableIndex, problemData.TestIndizes);
[3996]154      List<double> trainingMse = new List<double>();
155      List<double> trainingR2 = new List<double>();
156      List<double> trainingRelErr = new List<double>();
157      List<double> testMse = new List<double>();
158      List<double> testR2 = new List<double>();
159      List<double> testRelErr = new List<double>();
[3652]160
[3996]161      OnlineMeanSquaredErrorEvaluator mseEvaluator = new OnlineMeanSquaredErrorEvaluator();
162      OnlineMeanAbsolutePercentageErrorEvaluator relErrEvaluator = new OnlineMeanAbsolutePercentageErrorEvaluator();
163      OnlinePearsonsRSquaredEvaluator r2Evaluator = new OnlinePearsonsRSquaredEvaluator();
[3666]164
[3996]165      foreach (var tree in trees) {
166        #region training
[5275]167        var estimatedTrainingValues = interpreter.GetSymbolicExpressionTreeValues(tree, problemData.Dataset, problemData.TrainingIndizes);
[3996]168        mseEvaluator.Reset();
169        r2Evaluator.Reset();
170        relErrEvaluator.Reset();
171        var estimatedEnumerator = estimatedTrainingValues.GetEnumerator();
172        var originalEnumerator = originalTrainingValues.GetEnumerator();
173        while (estimatedEnumerator.MoveNext() & originalEnumerator.MoveNext()) {
174          double estimated = estimatedEnumerator.Current;
175          if (double.IsNaN(estimated)) estimated = upperEstimationLimit;
176          else estimated = Math.Min(upperEstimationLimit, Math.Max(lowerEstimationLimit, estimated));
177          mseEvaluator.Add(originalEnumerator.Current, estimated);
178          r2Evaluator.Add(originalEnumerator.Current, estimated);
179          relErrEvaluator.Add(originalEnumerator.Current, estimated);
180        }
181        if (estimatedEnumerator.MoveNext() || originalEnumerator.MoveNext()) {
182          throw new InvalidOperationException("Number of elements in estimated and original enumeration doesn't match.");
183        }
184        trainingMse.Add(mseEvaluator.MeanSquaredError);
185        trainingR2.Add(r2Evaluator.RSquared);
186        trainingRelErr.Add(relErrEvaluator.MeanAbsolutePercentageError);
187        #endregion
188        #region test
[5275]189        var estimatedTestValues = interpreter.GetSymbolicExpressionTreeValues(tree, problemData.Dataset, problemData.TestIndizes);
[3666]190
[3996]191        mseEvaluator.Reset();
192        r2Evaluator.Reset();
193        relErrEvaluator.Reset();
194        estimatedEnumerator = estimatedTestValues.GetEnumerator();
195        originalEnumerator = originalTestValues.GetEnumerator();
196        while (estimatedEnumerator.MoveNext() & originalEnumerator.MoveNext()) {
197          double estimated = estimatedEnumerator.Current;
198          if (double.IsNaN(estimated)) estimated = upperEstimationLimit;
199          else estimated = Math.Min(upperEstimationLimit, Math.Max(lowerEstimationLimit, estimated));
200          mseEvaluator.Add(originalEnumerator.Current, estimated);
201          r2Evaluator.Add(originalEnumerator.Current, estimated);
202          relErrEvaluator.Add(originalEnumerator.Current, estimated);
203        }
204        if (estimatedEnumerator.MoveNext() || originalEnumerator.MoveNext()) {
205          throw new InvalidOperationException("Number of elements in estimated and original enumeration doesn't match.");
206        }
207        testMse.Add(mseEvaluator.MeanSquaredError);
208        testR2.Add(r2Evaluator.RSquared);
209        testRelErr.Add(relErrEvaluator.MeanAbsolutePercentageError);
210        #endregion
211      }
[3710]212
[3996]213      AddResultTableValues(results, MeanSquaredErrorValuesParameterName, "mean squared error (training)", trainingMse.Min(), trainingMse.Average(), trainingMse.Max());
214      AddResultTableValues(results, MeanSquaredErrorValuesParameterName, "mean squared error (test)", testMse.Min(), testMse.Average(), testMse.Max());
215      AddResultTableValues(results, RelativeErrorValuesParameterName, "mean relative error (training)", trainingRelErr.Min(), trainingRelErr.Average(), trainingRelErr.Max());
216      AddResultTableValues(results, RelativeErrorValuesParameterName, "mean relative error (test)", testRelErr.Min(), testRelErr.Average(), testRelErr.Max());
217      AddResultTableValues(results, RSquaredValuesParameterName, "Pearson's R² (training)", trainingR2.Min(), trainingR2.Average(), trainingR2.Max());
218      AddResultTableValues(results, RSquaredValuesParameterName, "Pearson's R² (test)", testR2.Min(), testR2.Average(), testR2.Max());
[3652]219    }
[3681]220
[3996]221    private static void AddResultTableValues(ResultCollection results, string tableName, string valueName, double minValue, double avgValue, double maxValue) {
222      if (!results.ContainsKey(tableName)) {
223        results.Add(new Result(tableName, new DataTable(tableName)));
224      }
225      DataTable table = (DataTable)results[tableName].Value;
226      AddValue(table, minValue, "Min. " + valueName, string.Empty);
227      AddValue(table, avgValue, "Avg. " + valueName, string.Empty);
228      AddValue(table, maxValue, "Max. " + valueName, string.Empty);
[3681]229    }
230
[3996]231    private static void AddValue(DataTable table, double data, string name, string description) {
232      DataRow row;
233      table.Rows.TryGetValue(name, out row);
234      if (row == null) {
235        row = new DataRow(name, description);
236        row.Values.Add(data);
237        table.Rows.Add(row);
238      } else {
239        row.Values.Add(data);
240      }
[3681]241    }
242
[3996]243
244    private static void SetResultValue(ResultCollection results, string name, double value) {
245      if (results.ContainsKey(name))
246        results[name].Value = new DoubleValue(value);
247      else
248        results.Add(new Result(name, new DoubleValue(value)));
[3681]249    }
[3652]250  }
251}
Note: See TracBrowser for help on using the repository browser.