Free cookie consent management tool by TermsFeed Policy Generator

source: branches/DataAnalysis/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers/OverfittingAnalyzer.cs @ 4272

Last change on this file since 4272 was 4272, checked in by gkronber, 14 years ago

Worked on overfitting analyzer and CPP. #1142

File size: 16.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Analysis;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Operators;
30using HeuristicLab.Optimization;
31using HeuristicLab.Parameters;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33using HeuristicLab.Problems.DataAnalysis.Evaluators;
34using HeuristicLab.Problems.DataAnalysis.Symbolic;
35using System;
36
37namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers {
38  [Item("OverfittingAnalyzer", "")]
39  [StorableClass]
40  public sealed class OverfittingAnalyzer : SingleSuccessorOperator, ISymbolicRegressionAnalyzer {
41    private const string RandomParameterName = "Random";
42    private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree";
43    private const string SymbolicExpressionTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
44    private const string ProblemDataParameterName = "ProblemData";
45    private const string ValidationSamplesStartParameterName = "SamplesStart";
46    private const string ValidationSamplesEndParameterName = "SamplesEnd";
47    private const string UpperEstimationLimitParameterName = "UpperEstimationLimit";
48    private const string LowerEstimationLimitParameterName = "LowerEstimationLimit";
49    private const string EvaluatorParameterName = "Evaluator";
50    private const string MaximizationParameterName = "Maximization";
51    private const string RelativeNumberOfEvaluatedSamplesParameterName = "RelativeNumberOfEvaluatedSamples";
52
53    #region parameter properties
54    public ILookupParameter<IRandom> RandomParameter {
55      get { return (ILookupParameter<IRandom>)Parameters[RandomParameterName]; }
56    }
57    public ScopeTreeLookupParameter<SymbolicExpressionTree> SymbolicExpressionTreeParameter {
58      get { return (ScopeTreeLookupParameter<SymbolicExpressionTree>)Parameters[SymbolicExpressionTreeParameterName]; }
59    }
60    public ScopeTreeLookupParameter<DoubleValue> QualityParameter {
61      get { return (ScopeTreeLookupParameter<DoubleValue>)Parameters["Quality"]; }
62    }
63    public IValueLookupParameter<ISymbolicExpressionTreeInterpreter> SymbolicExpressionTreeInterpreterParameter {
64      get { return (IValueLookupParameter<ISymbolicExpressionTreeInterpreter>)Parameters[SymbolicExpressionTreeInterpreterParameterName]; }
65    }
66    public ILookupParameter<ISymbolicRegressionEvaluator> EvaluatorParameter {
67      get { return (ILookupParameter<ISymbolicRegressionEvaluator>)Parameters[EvaluatorParameterName]; }
68    }
69    public ILookupParameter<BoolValue> MaximizationParameter {
70      get { return (ILookupParameter<BoolValue>)Parameters[MaximizationParameterName]; }
71    }
72    public IValueLookupParameter<DataAnalysisProblemData> ProblemDataParameter {
73      get { return (IValueLookupParameter<DataAnalysisProblemData>)Parameters[ProblemDataParameterName]; }
74    }
75    public IValueLookupParameter<IntValue> ValidationSamplesStartParameter {
76      get { return (IValueLookupParameter<IntValue>)Parameters[ValidationSamplesStartParameterName]; }
77    }
78    public IValueLookupParameter<IntValue> ValidationSamplesEndParameter {
79      get { return (IValueLookupParameter<IntValue>)Parameters[ValidationSamplesEndParameterName]; }
80    }
81    public IValueParameter<PercentValue> RelativeNumberOfEvaluatedSamplesParameter {
82      get { return (IValueParameter<PercentValue>)Parameters[RelativeNumberOfEvaluatedSamplesParameterName]; }
83    }
84
85    public IValueLookupParameter<DoubleValue> UpperEstimationLimitParameter {
86      get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperEstimationLimitParameterName]; }
87    }
88    public IValueLookupParameter<DoubleValue> LowerEstimationLimitParameter {
89      get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerEstimationLimitParameterName]; }
90    }
91    public ILookupParameter<PercentValue> RelativeValidationQualityParameter {
92      get { return (ILookupParameter<PercentValue>)Parameters["RelativeValidationQuality"]; }
93    }
94    //public IValueLookupParameter<PercentValue> RelativeValidationQualityLowerLimitParameter {
95    //  get { return (IValueLookupParameter<PercentValue>)Parameters["RelativeValidationQualityLowerLimit"]; }
96    //}
97    //public IValueLookupParameter<PercentValue> RelativeValidationQualityUpperLimitParameter {
98    //  get { return (IValueLookupParameter<PercentValue>)Parameters["RelativeValidationQualityUpperLimit"]; }
99    //}
100    public ILookupParameter<DoubleValue> TrainingValidationQualityCorrelationParameter {
101      get { return (ILookupParameter<DoubleValue>)Parameters["TrainingValidationCorrelation"]; }
102    }
103    public IValueLookupParameter<DoubleValue> CorrelationLimitParameter {
104      get { return (IValueLookupParameter<DoubleValue>)Parameters["CorrelationLimit"]; }
105    }
106    public ILookupParameter<BoolValue> OverfittingParameter {
107      get { return (ILookupParameter<BoolValue>)Parameters["Overfitting"]; }
108    }
109    public ILookupParameter<ResultCollection> ResultsParameter {
110      get { return (ILookupParameter<ResultCollection>)Parameters["Results"]; }
111    }
112    public ILookupParameter<DoubleValue> InitialTrainingQualityParameter {
113      get { return (ILookupParameter<DoubleValue>)Parameters["InitialTrainingQuality"]; }
114    }
115    #endregion
116    #region properties
117    public IRandom Random {
118      get { return RandomParameter.ActualValue; }
119    }
120    public ItemArray<SymbolicExpressionTree> SymbolicExpressionTree {
121      get { return SymbolicExpressionTreeParameter.ActualValue; }
122    }
123    public ISymbolicExpressionTreeInterpreter SymbolicExpressionTreeInterpreter {
124      get { return SymbolicExpressionTreeInterpreterParameter.ActualValue; }
125    }
126    public ISymbolicRegressionEvaluator Evaluator {
127      get { return EvaluatorParameter.ActualValue; }
128    }
129    public BoolValue Maximization {
130      get { return MaximizationParameter.ActualValue; }
131    }
132    public DataAnalysisProblemData ProblemData {
133      get { return ProblemDataParameter.ActualValue; }
134    }
135    public IntValue ValidiationSamplesStart {
136      get { return ValidationSamplesStartParameter.ActualValue; }
137    }
138    public IntValue ValidationSamplesEnd {
139      get { return ValidationSamplesEndParameter.ActualValue; }
140    }
141    public PercentValue RelativeNumberOfEvaluatedSamples {
142      get { return RelativeNumberOfEvaluatedSamplesParameter.Value; }
143    }
144
145    public DoubleValue UpperEstimationLimit {
146      get { return UpperEstimationLimitParameter.ActualValue; }
147    }
148    public DoubleValue LowerEstimationLimit {
149      get { return LowerEstimationLimitParameter.ActualValue; }
150    }
151    #endregion
152
153    public OverfittingAnalyzer()
154      : base() {
155      Parameters.Add(new LookupParameter<IRandom>(RandomParameterName, "The random generator to use."));
156      Parameters.Add(new LookupParameter<ISymbolicRegressionEvaluator>(EvaluatorParameterName, "The evaluator which should be used to evaluate the solution on the validation set."));
157      Parameters.Add(new ScopeTreeLookupParameter<SymbolicExpressionTree>(SymbolicExpressionTreeParameterName, "The symbolic expression trees to analyze."));
158      Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>("Quality"));
159      Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization."));
160      Parameters.Add(new ValueLookupParameter<ISymbolicExpressionTreeInterpreter>(SymbolicExpressionTreeInterpreterParameterName, "The interpreter that should be used for the analysis of symbolic expression trees."));
161      Parameters.Add(new ValueLookupParameter<DataAnalysisProblemData>(ProblemDataParameterName, "The problem data for which the symbolic expression tree is a solution."));
162      Parameters.Add(new ValueLookupParameter<IntValue>(ValidationSamplesStartParameterName, "The first index of the validation partition of the data set."));
163      Parameters.Add(new ValueLookupParameter<IntValue>(ValidationSamplesEndParameterName, "The last index of the validation partition of the data set."));
164      Parameters.Add(new ValueParameter<PercentValue>(RelativeNumberOfEvaluatedSamplesParameterName, "The relative number of samples of the dataset partition, which should be randomly chosen for evaluation between the start and end index.", new PercentValue(1)));
165      Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperEstimationLimitParameterName, "The upper estimation limit that was set for the evaluation of the symbolic expression trees."));
166      Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerEstimationLimitParameterName, "The lower estimation limit that was set for the evaluation of the symbolic expression trees."));
167      Parameters.Add(new LookupParameter<PercentValue>("RelativeValidationQuality"));
168      //Parameters.Add(new ValueLookupParameter<PercentValue>("RelativeValidationQualityUpperLimit", new PercentValue(0.05)));
169      //Parameters.Add(new ValueLookupParameter<PercentValue>("RelativeValidationQualityLowerLimit", new PercentValue(-0.05)));
170      Parameters.Add(new LookupParameter<DoubleValue>("TrainingValidationCorrelation"));
171      Parameters.Add(new ValueLookupParameter<DoubleValue>("CorrelationLimit", new DoubleValue(0.65)));
172      Parameters.Add(new LookupParameter<BoolValue>("Overfitting"));
173      Parameters.Add(new LookupParameter<ResultCollection>("Results"));
174      Parameters.Add(new LookupParameter<DoubleValue>("InitialTrainingQuality"));
175    }
176
177    [StorableConstructor]
178    private OverfittingAnalyzer(bool deserializing) : base(deserializing) { }
179
180    [StorableHook(HookType.AfterDeserialization)]
181    private void AfterDeserialization() {
182      if (!Parameters.ContainsKey("InitialTrainingQuality")) {
183        Parameters.Add(new LookupParameter<DoubleValue>("InitialTrainingQuality"));
184      }
185      //if (!Parameters.ContainsKey("RelativeValidationQualityUpperLimit")) {
186      //  Parameters.Add(new ValueLookupParameter<PercentValue>("RelativeValidationQualityUpperLimit", new PercentValue(0.05)));
187      //}
188      //if (!Parameters.ContainsKey("RelativeValidationQualityLowerLimit")) {
189      //  Parameters.Add(new ValueLookupParameter<PercentValue>("RelativeValidationQualityLowerLimit", new PercentValue(-0.05)));
190      //}
191    }
192
193    public override IOperation Apply() {
194      var trees = SymbolicExpressionTree;
195      ItemArray<DoubleValue> qualities = QualityParameter.ActualValue;
196
197      string targetVariable = ProblemData.TargetVariable.Value;
198
199      // select a random subset of rows in the validation set
200      int validationStart = ValidiationSamplesStart.Value;
201      int validationEnd = ValidationSamplesEnd.Value;
202      int seed = Random.Next();
203      int count = (int)((validationEnd - validationStart) * RelativeNumberOfEvaluatedSamples.Value);
204      if (count == 0) count = 1;
205      IEnumerable<int> rows = RandomEnumerable.SampleRandomNumbers(seed, validationStart, validationEnd, count);
206
207      double upperEstimationLimit = UpperEstimationLimit != null ? UpperEstimationLimit.Value : double.PositiveInfinity;
208      double lowerEstimationLimit = LowerEstimationLimit != null ? LowerEstimationLimit.Value : double.NegativeInfinity;
209
210      //double bestQuality = Maximization.Value ? double.NegativeInfinity : double.PositiveInfinity;
211      //SymbolicExpressionTree bestTree = null;
212
213      List<double> validationQualities = new List<double>();
214      foreach (var tree in trees) {
215        double quality = Evaluator.Evaluate(SymbolicExpressionTreeInterpreter, tree,
216          lowerEstimationLimit, upperEstimationLimit,
217          ProblemData.Dataset, targetVariable,
218         rows);
219        validationQualities.Add(quality);
220        //if ((Maximization.Value && quality > bestQuality) ||
221        //    (!Maximization.Value && quality < bestQuality)) {
222        //  bestQuality = quality;
223        //  bestTree = tree;
224        //}
225      }
226
227      //if (RelativeValidationQualityParameter.ActualValue == null) {
228      // first call initialize the relative quality using the difference between average training and validation quality
229      double avgTrainingQuality = qualities.Select(x => x.Value).Median();
230      double avgValidationQuality = validationQualities.Median();
231
232      if (Maximization.Value)
233        RelativeValidationQualityParameter.ActualValue = new PercentValue(avgValidationQuality / avgTrainingQuality - 1);
234      else {
235        RelativeValidationQualityParameter.ActualValue = new PercentValue(avgTrainingQuality / avgValidationQuality - 1);
236      }
237      //}
238
239      // cut away 0.0 values to make the correlation stronger
240      // necessary because R² values of 0.0 are strong outliers
241      //int percentile = (int)Math.Round(0.1 * validationQualities.Count);
242      //double validationCutOffValue = validationQualities.OrderBy(x => x).ElementAt(percentile);
243      //double trainingCutOffValue = qualities.Select(x => x.Value).OrderBy(x => x).ElementAt(percentile);
244      double validationCutOffValue = 0.05;
245      double trainingCutOffValue = validationCutOffValue;
246
247      double[] validationArr = new double[validationQualities.Count];
248      double[] trainingArr = new double[validationQualities.Count];
249      int arrIndex = 0;
250      for (int i = 0; i < validationQualities.Count; i++) {
251        if (validationQualities[i] > validationCutOffValue &&
252            qualities[i].Value > trainingCutOffValue) {
253          validationArr[arrIndex] = validationQualities[i];
254          trainingArr[arrIndex] = qualities[i].Value;
255          arrIndex++;
256        }
257      }
258      double r = alglib.correlation.spearmanrankcorrelation(trainingArr, validationArr, arrIndex);
259      TrainingValidationQualityCorrelationParameter.ActualValue = new DoubleValue(r);
260      if (InitialTrainingQualityParameter.ActualValue == null)
261        InitialTrainingQualityParameter.ActualValue = new DoubleValue(avgValidationQuality);
262      bool overfitting =
263        avgTrainingQuality > InitialTrainingQualityParameter.ActualValue.Value &&  // better on training than in initial generation
264        r < CorrelationLimitParameter.ActualValue.Value;  // low correlation between training and validation quality
265
266      //// if validation quality is within a certain margin of percentage deviation (default -5% .. 5%) then there is no overfitting
267      //// correlation is also bad when underfitting but validation quality cannot be a lot larger than training quality if overfitting
268      //(RelativeValidationQualityParameter.ActualValue.Value > RelativeValidationQualityUpperLimitParameter.ActualValue.Value || // better on training than on validation
269      // RelativeValidationQualityParameter.ActualValue.Value < RelativeValidationQualityLowerLimitParameter.ActualValue.Value); // better on training than on validation
270
271      OverfittingParameter.ActualValue = new BoolValue(overfitting);
272      return base.Apply();
273    }
274
275    [StorableHook(HookType.AfterDeserialization)]
276    private void Initialize() { }
277
278    private static void AddValue(DataTable table, double data, string name, string description) {
279      DataRow row;
280      table.Rows.TryGetValue(name, out row);
281      if (row == null) {
282        row = new DataRow(name, description);
283        row.Values.Add(data);
284        table.Rows.Add(row);
285      } else {
286        row.Values.Add(data);
287      }
288    }
289  }
290}
Note: See TracBrowser for help on using the repository browser.