Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers/SymbolicRegressionModelQualityAnalyzer.cs @ 7214

Last change on this file since 7214 was 7214, checked in by ascheibe, 13 years ago

#1706 adapted outdated plugins to changes in IAnalyzer

File size: 15.9 KB
RevLine 
[3652]1#region License Information
2/* HeuristicLab
[5445]3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[3652]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
[4068]22using System;
23using System.Collections.Generic;
[3652]24using System.Linq;
[4068]25using HeuristicLab.Analysis;
[4722]26using HeuristicLab.Common;
[3652]27using HeuristicLab.Core;
28using HeuristicLab.Data;
[4068]29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
[3652]30using HeuristicLab.Operators;
31using HeuristicLab.Optimization;
32using HeuristicLab.Parameters;
33using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
[5863]34using HeuristicLab.PluginInfrastructure;
[4068]35using HeuristicLab.Problems.DataAnalysis.Evaluators;
[3652]36using HeuristicLab.Problems.DataAnalysis.Symbolic;
37
38namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers {
39  /// <summary>
40  /// "An operator for analyzing the quality of symbolic regression solutions symbolic expression tree encoding."
41  /// </summary>
[3681]42  [Item("SymbolicRegressionModelQualityAnalyzer", "An operator for analyzing the quality of symbolic regression solutions symbolic expression tree encoding.")]
[3652]43  [StorableClass]
[5863]44  [NonDiscoverableType]
[3996]45  public sealed class SymbolicRegressionModelQualityAnalyzer : SingleSuccessorOperator, ISymbolicRegressionAnalyzer {
[3652]46    private const string SymbolicExpressionTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
47    private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree";
48    private const string ProblemDataParameterName = "ProblemData";
49    private const string ResultsParameterName = "Results";
[3666]50
[3710]51    private const string TrainingMeanSquaredErrorQualityParameterName = "Mean squared error (training)";
52    private const string MinTrainingMeanSquaredErrorQualityParameterName = "Min mean squared error (training)";
53    private const string MaxTrainingMeanSquaredErrorQualityParameterName = "Max mean squared error (training)";
54    private const string AverageTrainingMeanSquaredErrorQualityParameterName = "Average mean squared error (training)";
55    private const string BestTrainingMeanSquaredErrorQualityParameterName = "Best mean squared error (training)";
[3666]56
[3710]57    private const string TrainingAverageRelativeErrorQualityParameterName = "Average relative error (training)";
58    private const string MinTrainingAverageRelativeErrorQualityParameterName = "Min average relative error (training)";
59    private const string MaxTrainingAverageRelativeErrorQualityParameterName = "Max average relative error (training)";
60    private const string AverageTrainingAverageRelativeErrorQualityParameterName = "Average average relative error (training)";
61    private const string BestTrainingAverageRelativeErrorQualityParameterName = "Best average relative error (training)";
[3666]62
[3710]63    private const string TrainingRSquaredQualityParameterName = "R² (training)";
64    private const string MinTrainingRSquaredQualityParameterName = "Min R² (training)";
65    private const string MaxTrainingRSquaredQualityParameterName = "Max R² (training)";
66    private const string AverageTrainingRSquaredQualityParameterName = "Average R² (training)";
67    private const string BestTrainingRSquaredQualityParameterName = "Best R² (training)";
[3666]68
[3710]69    private const string TestMeanSquaredErrorQualityParameterName = "Mean squared error (test)";
70    private const string MinTestMeanSquaredErrorQualityParameterName = "Min mean squared error (test)";
71    private const string MaxTestMeanSquaredErrorQualityParameterName = "Max mean squared error (test)";
72    private const string AverageTestMeanSquaredErrorQualityParameterName = "Average mean squared error (test)";
73    private const string BestTestMeanSquaredErrorQualityParameterName = "Best mean squared error (test)";
[3666]74
[3710]75    private const string TestAverageRelativeErrorQualityParameterName = "Average relative error (test)";
76    private const string MinTestAverageRelativeErrorQualityParameterName = "Min average relative error (test)";
77    private const string MaxTestAverageRelativeErrorQualityParameterName = "Max average relative error (test)";
78    private const string AverageTestAverageRelativeErrorQualityParameterName = "Average average relative error (test)";
79    private const string BestTestAverageRelativeErrorQualityParameterName = "Best average relative error (test)";
[3666]80
[3710]81    private const string TestRSquaredQualityParameterName = "R² (test)";
82    private const string MinTestRSquaredQualityParameterName = "Min R² (test)";
83    private const string MaxTestRSquaredQualityParameterName = "Max R² (test)";
84    private const string AverageTestRSquaredQualityParameterName = "Average R² (test)";
85    private const string BestTestRSquaredQualityParameterName = "Best R² (test)";
[3666]86
[3710]87    private const string RSquaredValuesParameterName = "R²";
88    private const string MeanSquaredErrorValuesParameterName = "Mean squared error";
89    private const string RelativeErrorValuesParameterName = "Average relative error";
[3666]90
[3652]91    private const string UpperEstimationLimitParameterName = "UpperEstimationLimit";
92    private const string LowerEstimationLimitParameterName = "LowerEstimationLimit";
93
[7214]94    public bool EnabledByDefault {
95      get { return true; }
96    }
97
[3652]98    #region parameter properties
[3681]99    public ScopeTreeLookupParameter<SymbolicExpressionTree> SymbolicExpressionTreeParameter {
100      get { return (ScopeTreeLookupParameter<SymbolicExpressionTree>)Parameters[SymbolicExpressionTreeParameterName]; }
[3652]101    }
[3681]102    public IValueLookupParameter<ISymbolicExpressionTreeInterpreter> SymbolicExpressionTreeInterpreterParameter {
103      get { return (IValueLookupParameter<ISymbolicExpressionTreeInterpreter>)Parameters[SymbolicExpressionTreeInterpreterParameterName]; }
[3652]104    }
[3681]105    public IValueLookupParameter<DataAnalysisProblemData> ProblemDataParameter {
106      get { return (IValueLookupParameter<DataAnalysisProblemData>)Parameters[ProblemDataParameterName]; }
[3652]107    }
108    public IValueLookupParameter<DoubleValue> UpperEstimationLimitParameter {
109      get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperEstimationLimitParameterName]; }
110    }
111    public IValueLookupParameter<DoubleValue> LowerEstimationLimitParameter {
112      get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerEstimationLimitParameterName]; }
113    }
[3681]114    public ILookupParameter<ResultCollection> ResultsParameter {
115      get { return (ILookupParameter<ResultCollection>)Parameters[ResultsParameterName]; }
116    }
[3652]117    #endregion
[3996]118    #region properties
119    public DoubleValue UpperEstimationLimit {
120      get { return UpperEstimationLimitParameter.ActualValue; }
121    }
122    public DoubleValue LowerEstimationLimit {
123      get { return LowerEstimationLimitParameter.ActualValue; }
124    }
125    #endregion
[3652]126
[4722]127    [StorableConstructor]
128    private SymbolicRegressionModelQualityAnalyzer(bool deserializing) : base(deserializing) { }
129    private SymbolicRegressionModelQualityAnalyzer(SymbolicRegressionModelQualityAnalyzer original, Cloner cloner) : base(original, cloner) { }
[3681]130    public SymbolicRegressionModelQualityAnalyzer()
[3652]131      : base() {
[3659]132      Parameters.Add(new ScopeTreeLookupParameter<SymbolicExpressionTree>(SymbolicExpressionTreeParameterName, "The symbolic expression trees to analyze."));
[3681]133      Parameters.Add(new ValueLookupParameter<ISymbolicExpressionTreeInterpreter>(SymbolicExpressionTreeInterpreterParameterName, "The interpreter that should be used to calculate the output values of the symbolic expression tree."));
134      Parameters.Add(new ValueLookupParameter<DataAnalysisProblemData>(ProblemDataParameterName, "The problem data containing the input varaibles for the symbolic regression problem."));
[3652]135      Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperEstimationLimitParameterName, "The upper limit that should be used as cut off value for the output values of symbolic expression trees."));
136      Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerEstimationLimitParameterName, "The lower limit that should be used as cut off value for the output values of symbolic expression trees."));
[3681]137      Parameters.Add(new ValueLookupParameter<DataTable>(MeanSquaredErrorValuesParameterName, "The data table to collect mean squared error values."));
138      Parameters.Add(new ValueLookupParameter<DataTable>(RSquaredValuesParameterName, "The data table to collect R² correlation coefficient values."));
139      Parameters.Add(new ValueLookupParameter<DataTable>(RelativeErrorValuesParameterName, "The data table to collect relative error values."));
140      Parameters.Add(new LookupParameter<ResultCollection>(ResultsParameterName, "The result collection where the best symbolic regression solution should be stored."));
[3996]141    }
[3652]142
[4722]143    public override IDeepCloneable Clone(Cloner cloner) {
144      return new SymbolicRegressionModelQualityAnalyzer(this, cloner);
145    }
[3681]146
[3996]147    public override IOperation Apply() {
[5437]148      double upperEstimationLimit = UpperEstimationLimit != null ? UpperEstimationLimit.Value : double.PositiveInfinity;
149      double lowerEstimationLimit = LowerEstimationLimit != null ? LowerEstimationLimit.Value : double.NegativeInfinity;
[3996]150      Analyze(SymbolicExpressionTreeParameter.ActualValue, SymbolicExpressionTreeInterpreterParameter.ActualValue,
[5437]151        upperEstimationLimit, lowerEstimationLimit, ProblemDataParameter.ActualValue,
[3996]152        ResultsParameter.ActualValue);
153      return base.Apply();
154    }
[3681]155
[3996]156    public static void Analyze(IEnumerable<SymbolicExpressionTree> trees, ISymbolicExpressionTreeInterpreter interpreter,
157      double upperEstimationLimit, double lowerEstimationLimit,
[4468]158      DataAnalysisProblemData problemData, ResultCollection results) {
[3996]159      int targetVariableIndex = problemData.Dataset.GetVariableIndex(problemData.TargetVariable.Value);
[4468]160      IEnumerable<double> originalTrainingValues = problemData.Dataset.GetEnumeratedVariableValues(targetVariableIndex, problemData.TrainingIndizes);
161      IEnumerable<double> originalTestValues = problemData.Dataset.GetEnumeratedVariableValues(targetVariableIndex, problemData.TestIndizes);
[3996]162      List<double> trainingMse = new List<double>();
163      List<double> trainingR2 = new List<double>();
164      List<double> trainingRelErr = new List<double>();
165      List<double> testMse = new List<double>();
166      List<double> testR2 = new List<double>();
167      List<double> testRelErr = new List<double>();
[3652]168
[3996]169      OnlineMeanSquaredErrorEvaluator mseEvaluator = new OnlineMeanSquaredErrorEvaluator();
170      OnlineMeanAbsolutePercentageErrorEvaluator relErrEvaluator = new OnlineMeanAbsolutePercentageErrorEvaluator();
171      OnlinePearsonsRSquaredEvaluator r2Evaluator = new OnlinePearsonsRSquaredEvaluator();
[3666]172
[3996]173      foreach (var tree in trees) {
174        #region training
[4468]175        var estimatedTrainingValues = interpreter.GetSymbolicExpressionTreeValues(tree, problemData.Dataset, problemData.TrainingIndizes);
[3996]176        mseEvaluator.Reset();
177        r2Evaluator.Reset();
178        relErrEvaluator.Reset();
179        var estimatedEnumerator = estimatedTrainingValues.GetEnumerator();
180        var originalEnumerator = originalTrainingValues.GetEnumerator();
181        while (estimatedEnumerator.MoveNext() & originalEnumerator.MoveNext()) {
182          double estimated = estimatedEnumerator.Current;
183          if (double.IsNaN(estimated)) estimated = upperEstimationLimit;
184          else estimated = Math.Min(upperEstimationLimit, Math.Max(lowerEstimationLimit, estimated));
185          mseEvaluator.Add(originalEnumerator.Current, estimated);
186          r2Evaluator.Add(originalEnumerator.Current, estimated);
187          relErrEvaluator.Add(originalEnumerator.Current, estimated);
188        }
189        if (estimatedEnumerator.MoveNext() || originalEnumerator.MoveNext()) {
190          throw new InvalidOperationException("Number of elements in estimated and original enumeration doesn't match.");
191        }
192        trainingMse.Add(mseEvaluator.MeanSquaredError);
193        trainingR2.Add(r2Evaluator.RSquared);
194        trainingRelErr.Add(relErrEvaluator.MeanAbsolutePercentageError);
195        #endregion
196        #region test
[4468]197        var estimatedTestValues = interpreter.GetSymbolicExpressionTreeValues(tree, problemData.Dataset, problemData.TestIndizes);
[3666]198
[3996]199        mseEvaluator.Reset();
200        r2Evaluator.Reset();
201        relErrEvaluator.Reset();
202        estimatedEnumerator = estimatedTestValues.GetEnumerator();
203        originalEnumerator = originalTestValues.GetEnumerator();
204        while (estimatedEnumerator.MoveNext() & originalEnumerator.MoveNext()) {
205          double estimated = estimatedEnumerator.Current;
206          if (double.IsNaN(estimated)) estimated = upperEstimationLimit;
207          else estimated = Math.Min(upperEstimationLimit, Math.Max(lowerEstimationLimit, estimated));
208          mseEvaluator.Add(originalEnumerator.Current, estimated);
209          r2Evaluator.Add(originalEnumerator.Current, estimated);
210          relErrEvaluator.Add(originalEnumerator.Current, estimated);
211        }
212        if (estimatedEnumerator.MoveNext() || originalEnumerator.MoveNext()) {
213          throw new InvalidOperationException("Number of elements in estimated and original enumeration doesn't match.");
214        }
215        testMse.Add(mseEvaluator.MeanSquaredError);
216        testR2.Add(r2Evaluator.RSquared);
217        testRelErr.Add(relErrEvaluator.MeanAbsolutePercentageError);
218        #endregion
219      }
[3710]220
[3996]221      AddResultTableValues(results, MeanSquaredErrorValuesParameterName, "mean squared error (training)", trainingMse.Min(), trainingMse.Average(), trainingMse.Max());
222      AddResultTableValues(results, MeanSquaredErrorValuesParameterName, "mean squared error (test)", testMse.Min(), testMse.Average(), testMse.Max());
223      AddResultTableValues(results, RelativeErrorValuesParameterName, "mean relative error (training)", trainingRelErr.Min(), trainingRelErr.Average(), trainingRelErr.Max());
224      AddResultTableValues(results, RelativeErrorValuesParameterName, "mean relative error (test)", testRelErr.Min(), testRelErr.Average(), testRelErr.Max());
225      AddResultTableValues(results, RSquaredValuesParameterName, "Pearson's R² (training)", trainingR2.Min(), trainingR2.Average(), trainingR2.Max());
226      AddResultTableValues(results, RSquaredValuesParameterName, "Pearson's R² (test)", testR2.Min(), testR2.Average(), testR2.Max());
[3652]227    }
[3681]228
[3996]229    private static void AddResultTableValues(ResultCollection results, string tableName, string valueName, double minValue, double avgValue, double maxValue) {
230      if (!results.ContainsKey(tableName)) {
231        results.Add(new Result(tableName, new DataTable(tableName)));
232      }
233      DataTable table = (DataTable)results[tableName].Value;
234      AddValue(table, minValue, "Min. " + valueName, string.Empty);
235      AddValue(table, avgValue, "Avg. " + valueName, string.Empty);
236      AddValue(table, maxValue, "Max. " + valueName, string.Empty);
[3681]237    }
238
[3996]239    private static void AddValue(DataTable table, double data, string name, string description) {
240      DataRow row;
241      table.Rows.TryGetValue(name, out row);
242      if (row == null) {
243        row = new DataRow(name, description);
244        row.Values.Add(data);
245        table.Rows.Add(row);
246      } else {
247        row.Values.Add(data);
248      }
[3681]249    }
250
[3996]251
252    private static void SetResultValue(ResultCollection results, string name, double value) {
253      if (results.ContainsKey(name))
254        results[name].Value = new DoubleValue(value);
255      else
256        results.Add(new Result(name, new DoubleValue(value)));
[3681]257    }
[3652]258  }
259}
Note: See TracBrowser for help on using the repository browser.