Free cookie consent management tool by TermsFeed Policy Generator

source: branches/DataAnalysis/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers/OverfittingAnalyzer.cs @ 4326

Last change on this file since 4326 was 4326, checked in by gkronber, 14 years ago

Changed OverfittingAnalyzer to make overfitting boundaries more fuzzy through upper and lower limits for correlations instead of a hard limit. #1142

File size: 18.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Analysis;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Operators;
30using HeuristicLab.Optimization;
31using HeuristicLab.Parameters;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33using HeuristicLab.Problems.DataAnalysis.Evaluators;
34using HeuristicLab.Problems.DataAnalysis.Symbolic;
35using System;
36
37namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers {
38  [Item("OverfittingAnalyzer", "")]
39  [StorableClass]
40  public sealed class OverfittingAnalyzer : SingleSuccessorOperator, ISymbolicRegressionAnalyzer {
41    private const string RandomParameterName = "Random";
42    private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree";
43    private const string SymbolicExpressionTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
44    private const string ProblemDataParameterName = "ProblemData";
45    private const string ValidationSamplesStartParameterName = "SamplesStart";
46    private const string ValidationSamplesEndParameterName = "SamplesEnd";
47    private const string UpperEstimationLimitParameterName = "UpperEstimationLimit";
48    private const string LowerEstimationLimitParameterName = "LowerEstimationLimit";
49    private const string EvaluatorParameterName = "Evaluator";
50    private const string MaximizationParameterName = "Maximization";
51    private const string RelativeNumberOfEvaluatedSamplesParameterName = "RelativeNumberOfEvaluatedSamples";
52
53    #region parameter properties
54    public ILookupParameter<IRandom> RandomParameter {
55      get { return (ILookupParameter<IRandom>)Parameters[RandomParameterName]; }
56    }
57    public ScopeTreeLookupParameter<SymbolicExpressionTree> SymbolicExpressionTreeParameter {
58      get { return (ScopeTreeLookupParameter<SymbolicExpressionTree>)Parameters[SymbolicExpressionTreeParameterName]; }
59    }
60    public ScopeTreeLookupParameter<DoubleValue> QualityParameter {
61      get { return (ScopeTreeLookupParameter<DoubleValue>)Parameters["Quality"]; }
62    }
63    public ScopeTreeLookupParameter<DoubleValue> ValidationQualityParameter {
64      get { return (ScopeTreeLookupParameter<DoubleValue>)Parameters["ValidationQuality"]; }
65    }
66    public IValueLookupParameter<ISymbolicExpressionTreeInterpreter> SymbolicExpressionTreeInterpreterParameter {
67      get { return (IValueLookupParameter<ISymbolicExpressionTreeInterpreter>)Parameters[SymbolicExpressionTreeInterpreterParameterName]; }
68    }
69    public ILookupParameter<ISymbolicRegressionEvaluator> EvaluatorParameter {
70      get { return (ILookupParameter<ISymbolicRegressionEvaluator>)Parameters[EvaluatorParameterName]; }
71    }
72    public ILookupParameter<BoolValue> MaximizationParameter {
73      get { return (ILookupParameter<BoolValue>)Parameters[MaximizationParameterName]; }
74    }
75    public IValueLookupParameter<DataAnalysisProblemData> ProblemDataParameter {
76      get { return (IValueLookupParameter<DataAnalysisProblemData>)Parameters[ProblemDataParameterName]; }
77    }
78    public IValueLookupParameter<IntValue> ValidationSamplesStartParameter {
79      get { return (IValueLookupParameter<IntValue>)Parameters[ValidationSamplesStartParameterName]; }
80    }
81    public IValueLookupParameter<IntValue> ValidationSamplesEndParameter {
82      get { return (IValueLookupParameter<IntValue>)Parameters[ValidationSamplesEndParameterName]; }
83    }
84    public IValueParameter<PercentValue> RelativeNumberOfEvaluatedSamplesParameter {
85      get { return (IValueParameter<PercentValue>)Parameters[RelativeNumberOfEvaluatedSamplesParameterName]; }
86    }
87
88    public IValueLookupParameter<DoubleValue> UpperEstimationLimitParameter {
89      get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperEstimationLimitParameterName]; }
90    }
91    public IValueLookupParameter<DoubleValue> LowerEstimationLimitParameter {
92      get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerEstimationLimitParameterName]; }
93    }
94    public ILookupParameter<PercentValue> RelativeValidationQualityParameter {
95      get { return (ILookupParameter<PercentValue>)Parameters["RelativeValidationQuality"]; }
96    }
97    //public IValueLookupParameter<PercentValue> RelativeValidationQualityLowerLimitParameter {
98    //  get { return (IValueLookupParameter<PercentValue>)Parameters["RelativeValidationQualityLowerLimit"]; }
99    //}
100    //public IValueLookupParameter<PercentValue> RelativeValidationQualityUpperLimitParameter {
101    //  get { return (IValueLookupParameter<PercentValue>)Parameters["RelativeValidationQualityUpperLimit"]; }
102    //}
103    public ILookupParameter<DoubleValue> TrainingValidationQualityCorrelationParameter {
104      get { return (ILookupParameter<DoubleValue>)Parameters["TrainingValidationCorrelation"]; }
105    }
106    public IValueLookupParameter<DoubleValue> LowerCorrelationLimitParameter {
107      get { return (IValueLookupParameter<DoubleValue>)Parameters["LowerCorrelationLimit"]; }
108    }
109    public IValueLookupParameter<DoubleValue> UpperCorrelationLimitParameter {
110      get { return (IValueLookupParameter<DoubleValue>)Parameters["UpperCorrelationLimit"]; }
111    }
112    public ILookupParameter<BoolValue> OverfittingParameter {
113      get { return (ILookupParameter<BoolValue>)Parameters["Overfitting"]; }
114    }
115    public ILookupParameter<ResultCollection> ResultsParameter {
116      get { return (ILookupParameter<ResultCollection>)Parameters["Results"]; }
117    }
118    public ILookupParameter<DoubleValue> InitialTrainingQualityParameter {
119      get { return (ILookupParameter<DoubleValue>)Parameters["InitialTrainingQuality"]; }
120    }
121    public ILookupParameter<DoubleMatrix> TrainingAndValidationQualitiesParameter {
122      get { return (ILookupParameter<DoubleMatrix>)Parameters["TrainingAndValidationQualities"]; }
123    }
124    public IValueLookupParameter<DoubleValue> PercentileParameter {
125      get { return (IValueLookupParameter<DoubleValue>)Parameters["Percentile"]; }
126    }
127    #endregion
128    #region properties
129    public IRandom Random {
130      get { return RandomParameter.ActualValue; }
131    }
132    public ItemArray<SymbolicExpressionTree> SymbolicExpressionTree {
133      get { return SymbolicExpressionTreeParameter.ActualValue; }
134    }
135    public ISymbolicExpressionTreeInterpreter SymbolicExpressionTreeInterpreter {
136      get { return SymbolicExpressionTreeInterpreterParameter.ActualValue; }
137    }
138    public ISymbolicRegressionEvaluator Evaluator {
139      get { return EvaluatorParameter.ActualValue; }
140    }
141    public BoolValue Maximization {
142      get { return MaximizationParameter.ActualValue; }
143    }
144    public DataAnalysisProblemData ProblemData {
145      get { return ProblemDataParameter.ActualValue; }
146    }
147    public IntValue ValidiationSamplesStart {
148      get { return ValidationSamplesStartParameter.ActualValue; }
149    }
150    public IntValue ValidationSamplesEnd {
151      get { return ValidationSamplesEndParameter.ActualValue; }
152    }
153    public PercentValue RelativeNumberOfEvaluatedSamples {
154      get { return RelativeNumberOfEvaluatedSamplesParameter.Value; }
155    }
156
157    public DoubleValue UpperEstimationLimit {
158      get { return UpperEstimationLimitParameter.ActualValue; }
159    }
160    public DoubleValue LowerEstimationLimit {
161      get { return LowerEstimationLimitParameter.ActualValue; }
162    }
163    #endregion
164
165    public OverfittingAnalyzer()
166      : base() {
167      Parameters.Add(new LookupParameter<IRandom>(RandomParameterName, "The random generator to use."));
168      Parameters.Add(new LookupParameter<ISymbolicRegressionEvaluator>(EvaluatorParameterName, "The evaluator which should be used to evaluate the solution on the validation set."));
169      Parameters.Add(new ScopeTreeLookupParameter<SymbolicExpressionTree>(SymbolicExpressionTreeParameterName, "The symbolic expression trees to analyze."));
170      Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>("Quality"));
171      Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>("ValidationQuality"));
172      Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization."));
173      Parameters.Add(new ValueLookupParameter<ISymbolicExpressionTreeInterpreter>(SymbolicExpressionTreeInterpreterParameterName, "The interpreter that should be used for the analysis of symbolic expression trees."));
174      Parameters.Add(new ValueLookupParameter<DataAnalysisProblemData>(ProblemDataParameterName, "The problem data for which the symbolic expression tree is a solution."));
175      Parameters.Add(new ValueLookupParameter<IntValue>(ValidationSamplesStartParameterName, "The first index of the validation partition of the data set."));
176      Parameters.Add(new ValueLookupParameter<IntValue>(ValidationSamplesEndParameterName, "The last index of the validation partition of the data set."));
177      Parameters.Add(new ValueParameter<PercentValue>(RelativeNumberOfEvaluatedSamplesParameterName, "The relative number of samples of the dataset partition, which should be randomly chosen for evaluation between the start and end index.", new PercentValue(1)));
178      Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperEstimationLimitParameterName, "The upper estimation limit that was set for the evaluation of the symbolic expression trees."));
179      Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerEstimationLimitParameterName, "The lower estimation limit that was set for the evaluation of the symbolic expression trees."));
180      Parameters.Add(new LookupParameter<PercentValue>("RelativeValidationQuality"));
181      //Parameters.Add(new ValueLookupParameter<PercentValue>("RelativeValidationQualityUpperLimit", new PercentValue(0.05)));
182      //Parameters.Add(new ValueLookupParameter<PercentValue>("RelativeValidationQualityLowerLimit", new PercentValue(-0.05)));
183      Parameters.Add(new LookupParameter<DoubleValue>("TrainingValidationCorrelation"));
184      Parameters.Add(new ValueLookupParameter<DoubleValue>("LowerCorrelationLimit", new DoubleValue(0.65)));
185      Parameters.Add(new ValueLookupParameter<DoubleValue>("UpperCorrelationLimit", new DoubleValue(0.75)));
186      Parameters.Add(new LookupParameter<BoolValue>("Overfitting"));
187      Parameters.Add(new LookupParameter<ResultCollection>("Results"));
188      Parameters.Add(new LookupParameter<DoubleValue>("InitialTrainingQuality"));
189      Parameters.Add(new LookupParameter<DoubleMatrix>("TrainingAndValidationQualities"));
190      Parameters.Add(new ValueLookupParameter<DoubleValue>("Percentile", new DoubleValue(1)));
191
192    }
193
194    [StorableConstructor]
195    private OverfittingAnalyzer(bool deserializing) : base(deserializing) { }
196
197    [StorableHook(HookType.AfterDeserialization)]
198    private void AfterDeserialization() {
199      if (!Parameters.ContainsKey("InitialTrainingQuality")) {
200        Parameters.Add(new LookupParameter<DoubleValue>("InitialTrainingQuality"));
201      }
202      //if (!Parameters.ContainsKey("RelativeValidationQualityUpperLimit")) {
203      //  Parameters.Add(new ValueLookupParameter<PercentValue>("RelativeValidationQualityUpperLimit", new PercentValue(0.05)));
204      //}
205      //if (!Parameters.ContainsKey("RelativeValidationQualityLowerLimit")) {
206      //  Parameters.Add(new ValueLookupParameter<PercentValue>("RelativeValidationQualityLowerLimit", new PercentValue(-0.05)));
207      //}
208      if (!Parameters.ContainsKey("TrainingAndValidationQualities")) {
209        Parameters.Add(new LookupParameter<DoubleMatrix>("TrainingAndValidationQualities"));
210      }
211      if (!Parameters.ContainsKey("Percentile")) {
212        Parameters.Add(new ValueLookupParameter<DoubleValue>("Percentile", new DoubleValue(1)));
213      }
214      if (!Parameters.ContainsKey("ValidationQuality")) {
215        Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>("ValidationQuality"));
216      }
217      if (!Parameters.ContainsKey("LowerCorrelationLimit")) {
218        Parameters.Add(new ValueLookupParameter<DoubleValue>("LowerCorrelationLimit", new DoubleValue(0.65)));
219      }
220      if (!Parameters.ContainsKey("UpperCorrelationLimit")) {
221        Parameters.Add(new ValueLookupParameter<DoubleValue>("UpperCorrelationLimit", new DoubleValue(0.75)));
222      }
223
224    }
225
226    public override IOperation Apply() {
227      var trees = SymbolicExpressionTree;
228      ItemArray<DoubleValue> qualities = QualityParameter.ActualValue;
229      ItemArray<DoubleValue> validationQualities = ValidationQualityParameter.ActualValue;
230
231      double correlationLimit;
232      if (OverfittingParameter.ActualValue != null && OverfittingParameter.ActualValue.Value) {
233        // if is already overfitting have to reach the upper limit to switch back to non-overfitting state
234        correlationLimit = UpperCorrelationLimitParameter.ActualValue.Value;
235      } else {
236        // if currently in non-overfitting state have to reach to lower limit to switch to overfitting state
237        correlationLimit = LowerCorrelationLimitParameter.ActualValue.Value;
238      }
239      //string targetVariable = ProblemData.TargetVariable.Value;
240
241      //// select a random subset of rows in the validation set
242      //int validationStart = ValidiationSamplesStart.Value;
243      //int validationEnd = ValidationSamplesEnd.Value;
244      //int seed = Random.Next();
245      //int count = (int)((validationEnd - validationStart) * RelativeNumberOfEvaluatedSamples.Value);
246      //if (count == 0) count = 1;
247      //IEnumerable<int> rows = RandomEnumerable.SampleRandomNumbers(seed, validationStart, validationEnd, count);
248
249      //double upperEstimationLimit = UpperEstimationLimit != null ? UpperEstimationLimit.Value : double.PositiveInfinity;
250      //double lowerEstimationLimit = LowerEstimationLimit != null ? LowerEstimationLimit.Value : double.NegativeInfinity;
251
252      //double bestQuality = Maximization.Value ? double.NegativeInfinity : double.PositiveInfinity;
253      //SymbolicExpressionTree bestTree = null;
254
255      //List<double> validationQualities = new List<double>();
256      //foreach (var tree in trees) {
257      //  double quality = Evaluator.Evaluate(SymbolicExpressionTreeInterpreter, tree,
258      //    lowerEstimationLimit, upperEstimationLimit,
259      //    ProblemData.Dataset, targetVariable,
260      //   rows);
261      //  validationQualities.Add(quality);
262      //  //if ((Maximization.Value && quality > bestQuality) ||
263      //  //    (!Maximization.Value && quality < bestQuality)) {
264      //  //  bestQuality = quality;
265      //  //  bestTree = tree;
266      //  //}
267      //}
268
269      //if (RelativeValidationQualityParameter.ActualValue == null) {
270      // first call initialize the relative quality using the difference between average training and validation quality
271      double avgTrainingQuality = qualities.Select(x => x.Value).Average();
272      double avgValidationQuality = validationQualities.Select(x => x.Value).Average();
273
274      if (Maximization.Value)
275        RelativeValidationQualityParameter.ActualValue = new PercentValue(avgValidationQuality / avgTrainingQuality - 1);
276      else {
277        RelativeValidationQualityParameter.ActualValue = new PercentValue(avgTrainingQuality / avgValidationQuality - 1);
278      }
279      //}
280
281      // best first (only for maximization
282      var orderedDistinctPairs = (from index in Enumerable.Range(0, qualities.Length)
283                                  where qualities[index].Value > 0.0
284                                  select new { Training = qualities[index].Value, Validation = validationQualities[index].Value })
285                                 .OrderBy(x => -x.Training)
286                                 .ToList();
287
288      int n = (int)Math.Round(PercentileParameter.ActualValue.Value * orderedDistinctPairs.Count);
289
290      double[] validationArr = new double[n];
291      double[] trainingArr = new double[n];
292      //double[,] qualitiesArr = new double[n, 2];
293      for (int i = 0; i < n; i++) {
294        validationArr[i] = orderedDistinctPairs[i].Validation;
295        trainingArr[i] = orderedDistinctPairs[i].Training;
296
297        //qualitiesArr[i, 0] = trainingArr[i];
298        //qualitiesArr[i, 1] = validationArr[i];
299      }
300      double r = alglib.correlation.spearmanrankcorrelation(trainingArr, validationArr, n);
301      TrainingValidationQualityCorrelationParameter.ActualValue = new DoubleValue(r);
302      if (InitialTrainingQualityParameter.ActualValue == null)
303        InitialTrainingQualityParameter.ActualValue = new DoubleValue(avgValidationQuality);
304      bool overfitting =
305        avgTrainingQuality > InitialTrainingQualityParameter.ActualValue.Value &&  // better on training than in initial generation
306        // RelativeValidationQualityParameter.ActualValue.Value < 0.0 && // validation quality is worse than training quality
307        r < correlationLimit;
308
309
310      OverfittingParameter.ActualValue = new BoolValue(overfitting);
311      //TrainingAndValidationQualitiesParameter.ActualValue = new DoubleMatrix(qualitiesArr);
312      return base.Apply();
313    }
314
315    [StorableHook(HookType.AfterDeserialization)]
316    private void Initialize() { }
317
318    private static void AddValue(DataTable table, double data, string name, string description) {
319      DataRow row;
320      table.Rows.TryGetValue(name, out row);
321      if (row == null) {
322        row = new DataRow(name, description);
323        row.Values.Add(data);
324        table.Rows.Add(row);
325      } else {
326        row.Values.Add(data);
327      }
328    }
329  }
330}
Note: See TracBrowser for help on using the repository browser.