Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers/SymbolicRegressionOverfittingAnalyzer.cs @ 5192

Last change on this file since 5192 was 5192, checked in by gkronber, 13 years ago

Copied overfitting analyzer for symbolic regression from feature exploration branch. #1356

File size: 14.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Analysis;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Operators;
30using HeuristicLab.Optimization;
31using HeuristicLab.Parameters;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33using HeuristicLab.Problems.DataAnalysis.Evaluators;
34using HeuristicLab.Problems.DataAnalysis.Symbolic;
35using System;
36
37namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers {
38  [Item("SymbolicRegressionOverfittingAnalyzer", "Calculates and tracks correlation of training and validation fitness of symbolic regression models.")]
39  [StorableClass]
40  public sealed class SymbolicRegressionOverfittingAnalyzer : SingleSuccessorOperator, ISymbolicRegressionAnalyzer {
41    private const string RandomParameterName = "Random";
42    private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree";
43    private const string MaximizationParameterName = "Maximization";
44    private const string QualityParameterName = "Quality";
45    private const string ValidationQualityParameterName = "ValidationQuality";
46    private const string TrainingValidationCorrelationParameterName = "TrainingValidationCorrelation";
47    private const string TrainingValidationCorrelationTableParameterName = "TrainingValidationCorrelationTable";
48    private const string LowerCorrelationThresholdParameterName = "LowerCorrelationThreshold";
49    private const string UpperCorrelationThresholdParameterName = "UpperCorrelationThreshold";
50    private const string OverfittingParameterName = "IsOverfitting";
51    private const string ResultsParameterName = "Results";
52    private const string EvaluatorParameterName = "Evaluator";
53    private const string SymbolicExpressionTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
54    private const string ProblemDataParameterName = "ProblemData";
55    private const string ValidationSamplesStartParameterName = "ValidationSamplesStart";
56    private const string ValidationSamplesEndParameterName = "ValidationSamplesEnd";
57    private const string RelativeNumberOfEvaluatedSamplesParameterName = "RelativeNumberOfEvaluatedSamples";
58    private const string UpperEstimationLimitParameterName = "UpperEstimationLimit";
59    private const string LowerEstimationLimitParameterName = "LowerEstimationLimit";
60
61    #region parameter properties
62    public ILookupParameter<IRandom> RandomParameter {
63      get { return (ILookupParameter<IRandom>)Parameters[RandomParameterName]; }
64    }
65    public ScopeTreeLookupParameter<SymbolicExpressionTree> SymbolicExpressionTreeParameter {
66      get { return (ScopeTreeLookupParameter<SymbolicExpressionTree>)Parameters[SymbolicExpressionTreeParameterName]; }
67    }
68    public ScopeTreeLookupParameter<DoubleValue> QualityParameter {
69      get { return (ScopeTreeLookupParameter<DoubleValue>)Parameters[QualityParameterName]; }
70    }
71    public ScopeTreeLookupParameter<DoubleValue> ValidationQualityParameter {
72      get { return (ScopeTreeLookupParameter<DoubleValue>)Parameters[ValidationQualityParameterName]; }
73    }
74    public ILookupParameter<BoolValue> MaximizationParameter {
75      get { return (ILookupParameter<BoolValue>)Parameters[MaximizationParameterName]; }
76    }
77    public IValueLookupParameter<ISymbolicExpressionTreeInterpreter> SymbolicExpressionTreeInterpreterParameter {
78      get { return (IValueLookupParameter<ISymbolicExpressionTreeInterpreter>)Parameters[SymbolicExpressionTreeInterpreterParameterName]; }
79    }
80    public ILookupParameter<ISymbolicRegressionEvaluator> EvaluatorParameter {
81      get { return (ILookupParameter<ISymbolicRegressionEvaluator>)Parameters[EvaluatorParameterName]; }
82    }
83    public IValueLookupParameter<DataAnalysisProblemData> ProblemDataParameter {
84      get { return (IValueLookupParameter<DataAnalysisProblemData>)Parameters[ProblemDataParameterName]; }
85    }
86    public IValueLookupParameter<IntValue> ValidationSamplesStartParameter {
87      get { return (IValueLookupParameter<IntValue>)Parameters[ValidationSamplesStartParameterName]; }
88    }
89    public IValueLookupParameter<IntValue> ValidationSamplesEndParameter {
90      get { return (IValueLookupParameter<IntValue>)Parameters[ValidationSamplesEndParameterName]; }
91    }
92    public IValueParameter<PercentValue> RelativeNumberOfEvaluatedSamplesParameter {
93      get { return (IValueParameter<PercentValue>)Parameters[RelativeNumberOfEvaluatedSamplesParameterName]; }
94    }
95    public IValueLookupParameter<DoubleValue> UpperEstimationLimitParameter {
96      get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperEstimationLimitParameterName]; }
97    }
98    public IValueLookupParameter<DoubleValue> LowerEstimationLimitParameter {
99      get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerEstimationLimitParameterName]; }
100    }
101    public ILookupParameter<DoubleValue> TrainingValidationQualityCorrelationParameter {
102      get { return (ILookupParameter<DoubleValue>)Parameters[TrainingValidationCorrelationParameterName]; }
103    }
104    public ILookupParameter<DataTable> TrainingValidationQualityCorrelationTableParameter {
105      get { return (ILookupParameter<DataTable>)Parameters[TrainingValidationCorrelationTableParameterName]; }
106    }
107    public IValueLookupParameter<DoubleValue> LowerCorrelationThresholdParameter {
108      get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerCorrelationThresholdParameterName]; }
109    }
110    public IValueLookupParameter<DoubleValue> UpperCorrelationThresholdParameter {
111      get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperCorrelationThresholdParameterName]; }
112    }
113    public ILookupParameter<BoolValue> OverfittingParameter {
114      get { return (ILookupParameter<BoolValue>)Parameters[OverfittingParameterName]; }
115    }
116    public ILookupParameter<ResultCollection> ResultsParameter {
117      get { return (ILookupParameter<ResultCollection>)Parameters[ResultsParameterName]; }
118    }
119    #endregion
120    #region properties
121    public IRandom Random {
122      get { return RandomParameter.ActualValue; }
123    }
124    public BoolValue Maximization {
125      get { return MaximizationParameter.ActualValue; }
126    }
127    public ISymbolicExpressionTreeInterpreter SymbolicExpressionTreeInterpreter {
128      get { return SymbolicExpressionTreeInterpreterParameter.ActualValue; }
129    }
130    public ISymbolicRegressionEvaluator Evaluator {
131      get { return EvaluatorParameter.ActualValue; }
132    }
133    public DataAnalysisProblemData ProblemData {
134      get { return ProblemDataParameter.ActualValue; }
135    }
136    public IntValue ValidiationSamplesStart {
137      get { return ValidationSamplesStartParameter.ActualValue; }
138    }
139    public IntValue ValidationSamplesEnd {
140      get { return ValidationSamplesEndParameter.ActualValue; }
141    }
142    public PercentValue RelativeNumberOfEvaluatedSamples {
143      get { return RelativeNumberOfEvaluatedSamplesParameter.Value; }
144    }
145
146    public DoubleValue UpperEstimationLimit {
147      get { return UpperEstimationLimitParameter.ActualValue; }
148    }
149    public DoubleValue LowerEstimationLimit {
150      get { return LowerEstimationLimitParameter.ActualValue; }
151    }
152    #endregion
153
154    [StorableConstructor]
155    private SymbolicRegressionOverfittingAnalyzer(bool deserializing) : base(deserializing) { }
156    private SymbolicRegressionOverfittingAnalyzer(SymbolicRegressionOverfittingAnalyzer original, Cloner cloner) : base(original, cloner) { }
157    public SymbolicRegressionOverfittingAnalyzer()
158      : base() {
159      Parameters.Add(new LookupParameter<IRandom>(RandomParameterName, "The random generator to use."));
160      Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>(QualityParameterName, "Training fitness"));
161      Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization."));
162
163      Parameters.Add(new ScopeTreeLookupParameter<SymbolicExpressionTree>(SymbolicExpressionTreeParameterName, "The symbolic expression trees to analyze."));
164      Parameters.Add(new LookupParameter<ISymbolicRegressionEvaluator>(EvaluatorParameterName, "The evaluator which should be used to evaluate the solution on the validation set."));
165      Parameters.Add(new ValueLookupParameter<ISymbolicExpressionTreeInterpreter>(SymbolicExpressionTreeInterpreterParameterName, "The interpreter that should be used for the analysis of symbolic expression trees."));
166      Parameters.Add(new ValueLookupParameter<DataAnalysisProblemData>(ProblemDataParameterName, "The problem data for which the symbolic expression tree is a solution."));
167      Parameters.Add(new ValueLookupParameter<IntValue>(ValidationSamplesStartParameterName, "The first index of the validation partition of the data set."));
168      Parameters.Add(new ValueLookupParameter<IntValue>(ValidationSamplesEndParameterName, "The last index of the validation partition of the data set."));
169      Parameters.Add(new ValueParameter<PercentValue>(RelativeNumberOfEvaluatedSamplesParameterName, "The relative number of samples of the dataset partition, which should be randomly chosen for evaluation between the start and end index.", new PercentValue(1)));
170      Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperEstimationLimitParameterName, "The upper estimation limit that was set for the evaluation of the symbolic expression trees."));
171      Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerEstimationLimitParameterName, "The lower estimation limit that was set for the evaluation of the symbolic expression trees."));
172
173      Parameters.Add(new LookupParameter<DoubleValue>(TrainingValidationCorrelationParameterName, "Correlation of training and validation fitnesses"));
174      Parameters.Add(new LookupParameter<DataTable>(TrainingValidationCorrelationTableParameterName, "Data table of training and validation fitness correlation values over the whole run."));
175      Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerCorrelationThresholdParameterName, "Lower threshold for correlation value that marks the boundary from non-overfitting to overfitting.", new DoubleValue(0.65)));
176      Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperCorrelationThresholdParameterName, "Upper threshold for correlation value that marks the boundary from overfitting to non-overfitting.", new DoubleValue(0.75)));
177      Parameters.Add(new LookupParameter<BoolValue>(OverfittingParameterName, "Boolean indicator for overfitting."));
178      Parameters.Add(new LookupParameter<ResultCollection>(ResultsParameterName, "The results collection."));
179    }
180
181    [StorableHook(HookType.AfterDeserialization)]
182    private void AfterDeserialization() {
183    }
184
185    public override IDeepCloneable Clone(Cloner cloner) {
186      return new SymbolicRegressionOverfittingAnalyzer(this, cloner);
187    }
188
189    public override IOperation Apply() {
190      ItemArray<DoubleValue> qualities = QualityParameter.ActualValue;
191      double[] trainingArr = qualities.Select(x => x.Value).ToArray();
192      double[] validationArr = new double[trainingArr.Length];
193
194      #region calculate validation fitness
195      string targetVariable = ProblemData.TargetVariable.Value;
196
197      // select a random subset of rows in the validation set
198      int validationStart = ValidiationSamplesStart.Value;
199      int validationEnd = ValidationSamplesEnd.Value;
200      int seed = Random.Next();
201      int count = (int)((validationEnd - validationStart) * RelativeNumberOfEvaluatedSamples.Value);
202      if (count == 0) count = 1;
203      IEnumerable<int> rows = RandomEnumerable.SampleRandomNumbers(seed, validationStart, validationEnd, count)
204        .Where(row => row < ProblemData.TestSamplesStart.Value || ProblemData.TestSamplesEnd.Value <= row);
205
206      double upperEstimationLimit = UpperEstimationLimit != null ? UpperEstimationLimit.Value : double.PositiveInfinity;
207      double lowerEstimationLimit = LowerEstimationLimit != null ? LowerEstimationLimit.Value : double.NegativeInfinity;
208
209      var trees = SymbolicExpressionTreeParameter.ActualValue;
210
211      for (int i = 0; i < validationArr.Length; i++) {
212        var tree = trees[i];
213        double quality = Evaluator.Evaluate(SymbolicExpressionTreeInterpreter, tree,
214            lowerEstimationLimit, upperEstimationLimit,
215            ProblemData.Dataset, targetVariable,
216           rows);
217        validationArr[i] = quality;
218      }
219     
220      #endregion
221
222
223      double r = alglib.spearmancorr2(trainingArr, validationArr);
224
225      TrainingValidationQualityCorrelationParameter.ActualValue = new DoubleValue(r);
226
227      if (TrainingValidationQualityCorrelationTableParameter.ActualValue == null) {
228        var dataTable = new DataTable("Training and validation fitness correlation table", "Data table of training and validation fitness correlation values over the whole run.");
229        dataTable.Rows.Add(new DataRow("Training and validation fitness correlation", "Training and validation fitness correlation values"));
230        TrainingValidationQualityCorrelationTableParameter.ActualValue = dataTable;
231        ResultsParameter.ActualValue.Add(new Result(TrainingValidationCorrelationTableParameterName, dataTable));
232      }
233
234      TrainingValidationQualityCorrelationTableParameter.ActualValue.Rows["Training and validation fitness correlation"].Values.Add(r);
235
236      double correlationThreshold;
237      if (OverfittingParameter.ActualValue != null && OverfittingParameter.ActualValue.Value) {
238        // if is already overfitting => have to reach the upper threshold to switch back to non-overfitting state
239        correlationThreshold = UpperCorrelationThresholdParameter.ActualValue.Value;
240      } else {
241        // if currently in non-overfitting state => have to reach to lower threshold to switch to overfitting state
242        correlationThreshold = LowerCorrelationThresholdParameter.ActualValue.Value;
243      }
244      bool overfitting = r < correlationThreshold;
245
246      OverfittingParameter.ActualValue = new BoolValue(overfitting);
247
248      return base.Apply();
249    }
250  }
251}
Note: See TracBrowser for help on using the repository browser.