Free cookie consent management tool by TermsFeed Policy Generator

source: branches/DataAnalysis/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers/OverfittingAnalyzer.cs @ 5477

Last change on this file since 5477 was 5010, checked in by gkronber, 14 years ago

commit of local changes in data-analysis feature exploration branch. #1142

File size: 18.4 KB
RevLine 
[4271]1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Analysis;
[4272]25using HeuristicLab.Common;
[4271]26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Operators;
30using HeuristicLab.Optimization;
31using HeuristicLab.Parameters;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33using HeuristicLab.Problems.DataAnalysis.Evaluators;
34using HeuristicLab.Problems.DataAnalysis.Symbolic;
35using System;
36
37namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers {
38  [Item("OverfittingAnalyzer", "")]
39  [StorableClass]
40  public sealed class OverfittingAnalyzer : SingleSuccessorOperator, ISymbolicRegressionAnalyzer {
41    private const string RandomParameterName = "Random";
42    private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree";
43    private const string SymbolicExpressionTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
44    private const string ProblemDataParameterName = "ProblemData";
45    private const string ValidationSamplesStartParameterName = "SamplesStart";
46    private const string ValidationSamplesEndParameterName = "SamplesEnd";
47    private const string UpperEstimationLimitParameterName = "UpperEstimationLimit";
48    private const string LowerEstimationLimitParameterName = "LowerEstimationLimit";
49    private const string EvaluatorParameterName = "Evaluator";
50    private const string MaximizationParameterName = "Maximization";
51    private const string RelativeNumberOfEvaluatedSamplesParameterName = "RelativeNumberOfEvaluatedSamples";
52
53    #region parameter properties
54    public ILookupParameter<IRandom> RandomParameter {
55      get { return (ILookupParameter<IRandom>)Parameters[RandomParameterName]; }
56    }
57    public ScopeTreeLookupParameter<SymbolicExpressionTree> SymbolicExpressionTreeParameter {
58      get { return (ScopeTreeLookupParameter<SymbolicExpressionTree>)Parameters[SymbolicExpressionTreeParameterName]; }
59    }
60    public ScopeTreeLookupParameter<DoubleValue> QualityParameter {
61      get { return (ScopeTreeLookupParameter<DoubleValue>)Parameters["Quality"]; }
62    }
[4297]63    public ScopeTreeLookupParameter<DoubleValue> ValidationQualityParameter {
64      get { return (ScopeTreeLookupParameter<DoubleValue>)Parameters["ValidationQuality"]; }
65    }
[4271]66    public IValueLookupParameter<ISymbolicExpressionTreeInterpreter> SymbolicExpressionTreeInterpreterParameter {
67      get { return (IValueLookupParameter<ISymbolicExpressionTreeInterpreter>)Parameters[SymbolicExpressionTreeInterpreterParameterName]; }
68    }
69    public ILookupParameter<ISymbolicRegressionEvaluator> EvaluatorParameter {
70      get { return (ILookupParameter<ISymbolicRegressionEvaluator>)Parameters[EvaluatorParameterName]; }
71    }
72    public ILookupParameter<BoolValue> MaximizationParameter {
73      get { return (ILookupParameter<BoolValue>)Parameters[MaximizationParameterName]; }
74    }
75    public IValueLookupParameter<DataAnalysisProblemData> ProblemDataParameter {
76      get { return (IValueLookupParameter<DataAnalysisProblemData>)Parameters[ProblemDataParameterName]; }
77    }
78    public IValueLookupParameter<IntValue> ValidationSamplesStartParameter {
79      get { return (IValueLookupParameter<IntValue>)Parameters[ValidationSamplesStartParameterName]; }
80    }
81    public IValueLookupParameter<IntValue> ValidationSamplesEndParameter {
82      get { return (IValueLookupParameter<IntValue>)Parameters[ValidationSamplesEndParameterName]; }
83    }
84    public IValueParameter<PercentValue> RelativeNumberOfEvaluatedSamplesParameter {
85      get { return (IValueParameter<PercentValue>)Parameters[RelativeNumberOfEvaluatedSamplesParameterName]; }
86    }
87
88    public IValueLookupParameter<DoubleValue> UpperEstimationLimitParameter {
89      get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperEstimationLimitParameterName]; }
90    }
91    public IValueLookupParameter<DoubleValue> LowerEstimationLimitParameter {
92      get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerEstimationLimitParameterName]; }
93    }
94    public ILookupParameter<PercentValue> RelativeValidationQualityParameter {
95      get { return (ILookupParameter<PercentValue>)Parameters["RelativeValidationQuality"]; }
96    }
[4272]97    //public IValueLookupParameter<PercentValue> RelativeValidationQualityLowerLimitParameter {
98    //  get { return (IValueLookupParameter<PercentValue>)Parameters["RelativeValidationQualityLowerLimit"]; }
99    //}
100    //public IValueLookupParameter<PercentValue> RelativeValidationQualityUpperLimitParameter {
101    //  get { return (IValueLookupParameter<PercentValue>)Parameters["RelativeValidationQualityUpperLimit"]; }
102    //}
[4271]103    public ILookupParameter<DoubleValue> TrainingValidationQualityCorrelationParameter {
104      get { return (ILookupParameter<DoubleValue>)Parameters["TrainingValidationCorrelation"]; }
105    }
[4326]106    public IValueLookupParameter<DoubleValue> LowerCorrelationLimitParameter {
107      get { return (IValueLookupParameter<DoubleValue>)Parameters["LowerCorrelationLimit"]; }
[4271]108    }
[4326]109    public IValueLookupParameter<DoubleValue> UpperCorrelationLimitParameter {
110      get { return (IValueLookupParameter<DoubleValue>)Parameters["UpperCorrelationLimit"]; }
111    }
[4271]112    public ILookupParameter<BoolValue> OverfittingParameter {
113      get { return (ILookupParameter<BoolValue>)Parameters["Overfitting"]; }
114    }
115    public ILookupParameter<ResultCollection> ResultsParameter {
116      get { return (ILookupParameter<ResultCollection>)Parameters["Results"]; }
117    }
[4272]118    public ILookupParameter<DoubleValue> InitialTrainingQualityParameter {
119      get { return (ILookupParameter<DoubleValue>)Parameters["InitialTrainingQuality"]; }
120    }
[5010]121    public ILookupParameter<ItemList<DoubleMatrix>> TrainingAndValidationQualitiesParameter {
122      get { return (ILookupParameter<ItemList<DoubleMatrix>>)Parameters["TrainingAndValidationQualities"]; }
[4275]123    }
124    public IValueLookupParameter<DoubleValue> PercentileParameter {
125      get { return (IValueLookupParameter<DoubleValue>)Parameters["Percentile"]; }
126    }
[4271]127    #endregion
128    #region properties
129    public IRandom Random {
130      get { return RandomParameter.ActualValue; }
131    }
132    public ItemArray<SymbolicExpressionTree> SymbolicExpressionTree {
133      get { return SymbolicExpressionTreeParameter.ActualValue; }
134    }
135    public ISymbolicExpressionTreeInterpreter SymbolicExpressionTreeInterpreter {
136      get { return SymbolicExpressionTreeInterpreterParameter.ActualValue; }
137    }
138    public ISymbolicRegressionEvaluator Evaluator {
139      get { return EvaluatorParameter.ActualValue; }
140    }
141    public BoolValue Maximization {
142      get { return MaximizationParameter.ActualValue; }
143    }
144    public DataAnalysisProblemData ProblemData {
145      get { return ProblemDataParameter.ActualValue; }
146    }
147    public IntValue ValidiationSamplesStart {
148      get { return ValidationSamplesStartParameter.ActualValue; }
149    }
150    public IntValue ValidationSamplesEnd {
151      get { return ValidationSamplesEndParameter.ActualValue; }
152    }
153    public PercentValue RelativeNumberOfEvaluatedSamples {
154      get { return RelativeNumberOfEvaluatedSamplesParameter.Value; }
155    }
156
157    public DoubleValue UpperEstimationLimit {
158      get { return UpperEstimationLimitParameter.ActualValue; }
159    }
160    public DoubleValue LowerEstimationLimit {
161      get { return LowerEstimationLimitParameter.ActualValue; }
162    }
163    #endregion
164
165    public OverfittingAnalyzer()
166      : base() {
167      Parameters.Add(new LookupParameter<IRandom>(RandomParameterName, "The random generator to use."));
168      Parameters.Add(new LookupParameter<ISymbolicRegressionEvaluator>(EvaluatorParameterName, "The evaluator which should be used to evaluate the solution on the validation set."));
169      Parameters.Add(new ScopeTreeLookupParameter<SymbolicExpressionTree>(SymbolicExpressionTreeParameterName, "The symbolic expression trees to analyze."));
170      Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>("Quality"));
[4297]171      Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>("ValidationQuality"));
[4271]172      Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization."));
173      Parameters.Add(new ValueLookupParameter<ISymbolicExpressionTreeInterpreter>(SymbolicExpressionTreeInterpreterParameterName, "The interpreter that should be used for the analysis of symbolic expression trees."));
174      Parameters.Add(new ValueLookupParameter<DataAnalysisProblemData>(ProblemDataParameterName, "The problem data for which the symbolic expression tree is a solution."));
175      Parameters.Add(new ValueLookupParameter<IntValue>(ValidationSamplesStartParameterName, "The first index of the validation partition of the data set."));
176      Parameters.Add(new ValueLookupParameter<IntValue>(ValidationSamplesEndParameterName, "The last index of the validation partition of the data set."));
177      Parameters.Add(new ValueParameter<PercentValue>(RelativeNumberOfEvaluatedSamplesParameterName, "The relative number of samples of the dataset partition, which should be randomly chosen for evaluation between the start and end index.", new PercentValue(1)));
178      Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperEstimationLimitParameterName, "The upper estimation limit that was set for the evaluation of the symbolic expression trees."));
179      Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerEstimationLimitParameterName, "The lower estimation limit that was set for the evaluation of the symbolic expression trees."));
180      Parameters.Add(new LookupParameter<PercentValue>("RelativeValidationQuality"));
[4272]181      //Parameters.Add(new ValueLookupParameter<PercentValue>("RelativeValidationQualityUpperLimit", new PercentValue(0.05)));
182      //Parameters.Add(new ValueLookupParameter<PercentValue>("RelativeValidationQualityLowerLimit", new PercentValue(-0.05)));
[4271]183      Parameters.Add(new LookupParameter<DoubleValue>("TrainingValidationCorrelation"));
[4326]184      Parameters.Add(new ValueLookupParameter<DoubleValue>("LowerCorrelationLimit", new DoubleValue(0.65)));
185      Parameters.Add(new ValueLookupParameter<DoubleValue>("UpperCorrelationLimit", new DoubleValue(0.75)));
[4271]186      Parameters.Add(new LookupParameter<BoolValue>("Overfitting"));
187      Parameters.Add(new LookupParameter<ResultCollection>("Results"));
[4272]188      Parameters.Add(new LookupParameter<DoubleValue>("InitialTrainingQuality"));
[5010]189      Parameters.Add(new LookupParameter<ItemList<DoubleMatrix>>("TrainingAndValidationQualities"));
[4297]190      Parameters.Add(new ValueLookupParameter<DoubleValue>("Percentile", new DoubleValue(1)));
[4275]191
[4271]192    }
193
194    [StorableConstructor]
195    private OverfittingAnalyzer(bool deserializing) : base(deserializing) { }
196
197    [StorableHook(HookType.AfterDeserialization)]
198    private void AfterDeserialization() {
[4272]199      if (!Parameters.ContainsKey("InitialTrainingQuality")) {
200        Parameters.Add(new LookupParameter<DoubleValue>("InitialTrainingQuality"));
201      }
202      //if (!Parameters.ContainsKey("RelativeValidationQualityUpperLimit")) {
203      //  Parameters.Add(new ValueLookupParameter<PercentValue>("RelativeValidationQualityUpperLimit", new PercentValue(0.05)));
204      //}
205      //if (!Parameters.ContainsKey("RelativeValidationQualityLowerLimit")) {
206      //  Parameters.Add(new ValueLookupParameter<PercentValue>("RelativeValidationQualityLowerLimit", new PercentValue(-0.05)));
207      //}
[4275]208      if (!Parameters.ContainsKey("TrainingAndValidationQualities")) {
[5010]209        Parameters.Add(new LookupParameter<ItemList<DoubleMatrix>>("TrainingAndValidationQualities"));
[4275]210      }
211      if (!Parameters.ContainsKey("Percentile")) {
[4297]212        Parameters.Add(new ValueLookupParameter<DoubleValue>("Percentile", new DoubleValue(1)));
[4275]213      }
[4297]214      if (!Parameters.ContainsKey("ValidationQuality")) {
215        Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>("ValidationQuality"));
216      }
[4326]217      if (!Parameters.ContainsKey("LowerCorrelationLimit")) {
218        Parameters.Add(new ValueLookupParameter<DoubleValue>("LowerCorrelationLimit", new DoubleValue(0.65)));
219      }
220      if (!Parameters.ContainsKey("UpperCorrelationLimit")) {
221        Parameters.Add(new ValueLookupParameter<DoubleValue>("UpperCorrelationLimit", new DoubleValue(0.75)));
222      }
223
[4271]224    }
225
226    public override IOperation Apply() {
227      var trees = SymbolicExpressionTree;
228      ItemArray<DoubleValue> qualities = QualityParameter.ActualValue;
[4297]229      ItemArray<DoubleValue> validationQualities = ValidationQualityParameter.ActualValue;
[4271]230
[4326]231      double correlationLimit;
232      if (OverfittingParameter.ActualValue != null && OverfittingParameter.ActualValue.Value) {
233        // if is already overfitting have to reach the upper limit to switch back to non-overfitting state
234        correlationLimit = UpperCorrelationLimitParameter.ActualValue.Value;
235      } else {
236        // if currently in non-overfitting state have to reach to lower limit to switch to overfitting state
237        correlationLimit = LowerCorrelationLimitParameter.ActualValue.Value;
238      }
[4309]239      //string targetVariable = ProblemData.TargetVariable.Value;
[4271]240
[4309]241      //// select a random subset of rows in the validation set
242      //int validationStart = ValidiationSamplesStart.Value;
243      //int validationEnd = ValidationSamplesEnd.Value;
244      //int seed = Random.Next();
245      //int count = (int)((validationEnd - validationStart) * RelativeNumberOfEvaluatedSamples.Value);
246      //if (count == 0) count = 1;
247      //IEnumerable<int> rows = RandomEnumerable.SampleRandomNumbers(seed, validationStart, validationEnd, count);
[4271]248
[4309]249      //double upperEstimationLimit = UpperEstimationLimit != null ? UpperEstimationLimit.Value : double.PositiveInfinity;
250      //double lowerEstimationLimit = LowerEstimationLimit != null ? LowerEstimationLimit.Value : double.NegativeInfinity;
[4271]251
252      //double bestQuality = Maximization.Value ? double.NegativeInfinity : double.PositiveInfinity;
253      //SymbolicExpressionTree bestTree = null;
254
[4297]255      //List<double> validationQualities = new List<double>();
256      //foreach (var tree in trees) {
257      //  double quality = Evaluator.Evaluate(SymbolicExpressionTreeInterpreter, tree,
258      //    lowerEstimationLimit, upperEstimationLimit,
259      //    ProblemData.Dataset, targetVariable,
260      //   rows);
261      //  validationQualities.Add(quality);
262      //  //if ((Maximization.Value && quality > bestQuality) ||
263      //  //    (!Maximization.Value && quality < bestQuality)) {
264      //  //  bestQuality = quality;
265      //  //  bestTree = tree;
266      //  //}
267      //}
[4271]268
269      //if (RelativeValidationQualityParameter.ActualValue == null) {
270      // first call initialize the relative quality using the difference between average training and validation quality
[4309]271      double avgTrainingQuality = qualities.Select(x => x.Value).Average();
272      double avgValidationQuality = validationQualities.Select(x => x.Value).Average();
[4271]273
274      if (Maximization.Value)
275        RelativeValidationQualityParameter.ActualValue = new PercentValue(avgValidationQuality / avgTrainingQuality - 1);
276      else {
277        RelativeValidationQualityParameter.ActualValue = new PercentValue(avgTrainingQuality / avgValidationQuality - 1);
278      }
279      //}
280
[4275]281      // best first (only for maximization
282      var orderedDistinctPairs = (from index in Enumerable.Range(0, qualities.Length)
[4326]283                                  where qualities[index].Value > 0.0
[4297]284                                  select new { Training = qualities[index].Value, Validation = validationQualities[index].Value })
[4275]285                                 .OrderBy(x => -x.Training)
286                                 .ToList();
[4272]287
[4275]288      int n = (int)Math.Round(PercentileParameter.ActualValue.Value * orderedDistinctPairs.Count);
289
290      double[] validationArr = new double[n];
291      double[] trainingArr = new double[n];
[5010]292      double[,] qualitiesArr = new double[n, 2];
[4275]293      for (int i = 0; i < n; i++) {
294        validationArr[i] = orderedDistinctPairs[i].Validation;
295        trainingArr[i] = orderedDistinctPairs[i].Training;
296
[5010]297        qualitiesArr[i, 0] = trainingArr[i];
298        qualitiesArr[i, 1] = validationArr[i];
[4272]299      }
[4275]300      double r = alglib.correlation.spearmanrankcorrelation(trainingArr, validationArr, n);
[4271]301      TrainingValidationQualityCorrelationParameter.ActualValue = new DoubleValue(r);
[4272]302      if (InitialTrainingQualityParameter.ActualValue == null)
303        InitialTrainingQualityParameter.ActualValue = new DoubleValue(avgValidationQuality);
304      bool overfitting =
305        avgTrainingQuality > InitialTrainingQualityParameter.ActualValue.Value &&  // better on training than in initial generation
[4309]306        // RelativeValidationQualityParameter.ActualValue.Value < 0.0 && // validation quality is worse than training quality
[4326]307        r < correlationLimit;
[4272]308
309
310      OverfittingParameter.ActualValue = new BoolValue(overfitting);
[5010]311      ItemList<DoubleMatrix> list = TrainingAndValidationQualitiesParameter.ActualValue;
312      if (list == null) {
313        TrainingAndValidationQualitiesParameter.ActualValue = new ItemList<DoubleMatrix>();
314      }
315      TrainingAndValidationQualitiesParameter.ActualValue.Add(new DoubleMatrix(qualitiesArr));
[4271]316      return base.Apply();
317    }
318
319    [StorableHook(HookType.AfterDeserialization)]
320    private void Initialize() { }
321
322    private static void AddValue(DataTable table, double data, string name, string description) {
323      DataRow row;
324      table.Rows.TryGetValue(name, out row);
325      if (row == null) {
326        row = new DataRow(name, description);
327        row.Values.Add(data);
328        table.Rows.Add(row);
329      } else {
330        row.Values.Add(data);
331      }
332    }
333  }
334}
Note: See TracBrowser for help on using the repository browser.