Free cookie consent management tool by TermsFeed Policy Generator

source: tags/3.3.0/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/SymbolicRegressionScaledMeanSquaredErrorEvaluator.cs @ 13398

Last change on this file since 13398 was 3807, checked in by gkronber, 14 years ago

Made linear scaling operator more numerically stable. #938

File size: 7.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Drawing;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Optimization;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32using HeuristicLab.PluginInfrastructure;
33using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
34using HeuristicLab.Problems.DataAnalysis;
35using HeuristicLab.Operators;
36using HeuristicLab.Problems.DataAnalysis.Evaluators;
37using HeuristicLab.Problems.DataAnalysis.Symbolic;
38
39namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic {
40  [Item("SymbolicRegressionScaledMeanSquaredErrorEvaluator", "Calculates the mean squared error of a linearly scaled symbolic regression solution.")]
41  [StorableClass]
42  public class SymbolicRegressionScaledMeanSquaredErrorEvaluator : SymbolicRegressionMeanSquaredErrorEvaluator {
43
44    #region parameter properties
45    public ILookupParameter<DoubleValue> AlphaParameter {
46      get { return (ILookupParameter<DoubleValue>)Parameters["Alpha"]; }
47    }
48    public ILookupParameter<DoubleValue> BetaParameter {
49      get { return (ILookupParameter<DoubleValue>)Parameters["Beta"]; }
50    }
51    #endregion
52    #region properties
53    public DoubleValue Alpha {
54      get { return AlphaParameter.ActualValue; }
55      set { AlphaParameter.ActualValue = value; }
56    }
57    public DoubleValue Beta {
58      get { return BetaParameter.ActualValue; }
59      set { BetaParameter.ActualValue = value; }
60    }
61    #endregion
62    public SymbolicRegressionScaledMeanSquaredErrorEvaluator()
63      : base() {
64      Parameters.Add(new LookupParameter<DoubleValue>("Alpha", "Alpha parameter for linear scaling of the estimated values."));
65      Parameters.Add(new LookupParameter<DoubleValue>("Beta", "Beta parameter for linear scaling of the estimated values."));
66    }
67
68    protected override double Evaluate(ISymbolicExpressionTreeInterpreter interpreter, SymbolicExpressionTree solution, Dataset dataset, StringValue targetVariable, IntValue samplesStart, IntValue samplesEnd) {
69      double alpha, beta;
70      double mse = Calculate(interpreter, solution, LowerEstimationLimit.Value, UpperEstimationLimit.Value, dataset, targetVariable.Value, samplesStart.Value, samplesEnd.Value, out beta, out alpha);
71      AlphaParameter.ActualValue = new DoubleValue(alpha);
72      BetaParameter.ActualValue = new DoubleValue(beta);
73      return mse;
74    }
75
76    public static double Calculate(ISymbolicExpressionTreeInterpreter interpreter, SymbolicExpressionTree solution, double lowerEstimationLimit, double upperEstimationLimit, Dataset dataset, string targetVariable, int start, int end, out double beta, out double alpha) {
77      var estimatedValues = CalculateScaledEstimatedValues(interpreter, solution, dataset, targetVariable, start, end, out beta, out alpha);
78      estimatedValues = from x in estimatedValues
79                        let boundedX = Math.Min(upperEstimationLimit, Math.Max(lowerEstimationLimit, x))
80                        select double.IsNaN(boundedX) ? upperEstimationLimit : boundedX;
81      var originalValues = dataset.GetVariableValues(targetVariable, start, end);
82      return SimpleMSEEvaluator.Calculate(originalValues, estimatedValues);
83    }
84
85    public static double CalculateWithScaling(ISymbolicExpressionTreeInterpreter interpreter, SymbolicExpressionTree solution, double lowerEstimationLimit, double upperEstimationLimit, Dataset dataset, string targetVariable, int start, int end, double beta, double alpha) {
86      var estimatedValues = from x in interpreter.GetSymbolicExpressionTreeValues(solution, dataset, Enumerable.Range(start, end - start))
87                            let boundedX = Math.Min(upperEstimationLimit, Math.Max(lowerEstimationLimit, x * beta + alpha))
88                            select double.IsNaN(boundedX) ? upperEstimationLimit : boundedX;
89      var originalValues = dataset.GetVariableValues(targetVariable, start, end);
90      return SimpleMSEEvaluator.Calculate(originalValues, estimatedValues);
91    }
92
93    private static IEnumerable<double> CalculateScaledEstimatedValues(ISymbolicExpressionTreeInterpreter interpreter, SymbolicExpressionTree solution, Dataset dataset, string targetVariable, int start, int end, out double beta, out double alpha) {
94      int targetVariableIndex = dataset.GetVariableIndex(targetVariable);
95      var estimatedValues = interpreter.GetSymbolicExpressionTreeValues(solution, dataset, Enumerable.Range(start, end - start)).ToArray();
96      var originalValues = dataset.GetVariableValues(targetVariable, start, end);
97      CalculateScalingParameters(originalValues, estimatedValues, out beta, out alpha);
98      for (int i = 0; i < estimatedValues.Length; i++)
99        estimatedValues[i] = estimatedValues[i] * beta + alpha;
100      return estimatedValues;
101    }
102
103
104    public static void CalculateScalingParameters(IEnumerable<double> original, IEnumerable<double> estimated, out double beta, out double alpha) {
105      double[] originalValues = original.ToArray();
106      double[] estimatedValues = estimated.ToArray();
107      if (originalValues.Length != estimatedValues.Length) throw new ArgumentException();
108      var filteredResult = (from row in Enumerable.Range(0, originalValues.Length)
109                            let t = originalValues[row]
110                            let e = estimatedValues[row]
111                            where IsValidValue(t)
112                            where IsValidValue(e)
113                            select new { Estimation = e, Target = t })
114                   .OrderBy(x => Math.Abs(x.Target))            // make sure small values are considered before large values
115                   .ToArray();     
116
117      // calculate alpha and beta on the subset of rows with valid values
118      originalValues = filteredResult.Select(x => x.Target).ToArray();
119      estimatedValues = filteredResult.Select(x => x.Estimation).ToArray();
120      int n = originalValues.Length;
121      if (n > 2) {
122        double tMean = originalValues.Average();
123        double xMean = estimatedValues.Average();
124        double sumXT = 0;
125        double sumXX = 0;
126        for (int i = 0; i < n; i++) {
127          // calculate alpha and beta on the subset of rows with valid values
128          double x = estimatedValues[i];
129          double t = originalValues[i];
130          sumXT += (x - xMean) * (t - tMean);
131          sumXX += (x - xMean) * (x - xMean);
132        }
133        if (!sumXX.IsAlmost(0.0)) {
134          beta = sumXT / sumXX;
135        } else {
136          beta = 1;
137        }
138        alpha = tMean - beta * xMean;
139      } else {
140        alpha = 0.0;
141        beta = 1.0;
142      }
143    }
144
145    private static bool IsValidValue(double d) {
146      return !double.IsInfinity(d) && !double.IsNaN(d) && d > -1.0E07 && d < 1.0E07;  // don't consider very large or very small values for scaling
147    }
148  }
149}
Note: See TracBrowser for help on using the repository browser.