Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers/SymbolicRegressionOverfittingAnalyzer.cs @ 5304

Last change on this file since 5304 was 5197, checked in by gkronber, 14 years ago

Introduced base class for operators that evaluate symbolic regression models on a validation set. #1356

File size: 7.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Analysis;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Operators;
30using HeuristicLab.Optimization;
31using HeuristicLab.Parameters;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33using HeuristicLab.Problems.DataAnalysis.Evaluators;
34using HeuristicLab.Problems.DataAnalysis.Symbolic;
35using System;
36
37namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers {
38  [Item("SymbolicRegressionOverfittingAnalyzer", "Calculates and tracks correlation of training and validation fitness of symbolic regression models.")]
39  [StorableClass]
40  public sealed class SymbolicRegressionOverfittingAnalyzer : SymbolicRegressionValidationAnalyzer, ISymbolicRegressionAnalyzer {
41    private const string MaximizationParameterName = "Maximization";
42    private const string QualityParameterName = "Quality";
43    private const string TrainingValidationCorrelationParameterName = "TrainingValidationCorrelation";
44    private const string TrainingValidationCorrelationTableParameterName = "TrainingValidationCorrelationTable";
45    private const string LowerCorrelationThresholdParameterName = "LowerCorrelationThreshold";
46    private const string UpperCorrelationThresholdParameterName = "UpperCorrelationThreshold";
47    private const string OverfittingParameterName = "IsOverfitting";
48    private const string ResultsParameterName = "Results";
49
50    #region parameter properties
51    public ScopeTreeLookupParameter<DoubleValue> QualityParameter {
52      get { return (ScopeTreeLookupParameter<DoubleValue>)Parameters[QualityParameterName]; }
53    }
54    public ILookupParameter<BoolValue> MaximizationParameter {
55      get { return (ILookupParameter<BoolValue>)Parameters[MaximizationParameterName]; }
56    }
57    public ILookupParameter<DoubleValue> TrainingValidationQualityCorrelationParameter {
58      get { return (ILookupParameter<DoubleValue>)Parameters[TrainingValidationCorrelationParameterName]; }
59    }
60    public ILookupParameter<DataTable> TrainingValidationQualityCorrelationTableParameter {
61      get { return (ILookupParameter<DataTable>)Parameters[TrainingValidationCorrelationTableParameterName]; }
62    }
63    public IValueLookupParameter<DoubleValue> LowerCorrelationThresholdParameter {
64      get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerCorrelationThresholdParameterName]; }
65    }
66    public IValueLookupParameter<DoubleValue> UpperCorrelationThresholdParameter {
67      get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperCorrelationThresholdParameterName]; }
68    }
69    public ILookupParameter<BoolValue> OverfittingParameter {
70      get { return (ILookupParameter<BoolValue>)Parameters[OverfittingParameterName]; }
71    }
72    public ILookupParameter<ResultCollection> ResultsParameter {
73      get { return (ILookupParameter<ResultCollection>)Parameters[ResultsParameterName]; }
74    }
75    #endregion
76    #region properties
77    public BoolValue Maximization {
78      get { return MaximizationParameter.ActualValue; }
79    }
80    #endregion
81
82    [StorableConstructor]
83    private SymbolicRegressionOverfittingAnalyzer(bool deserializing) : base(deserializing) { }
84    private SymbolicRegressionOverfittingAnalyzer(SymbolicRegressionOverfittingAnalyzer original, Cloner cloner) : base(original, cloner) { }
85    public SymbolicRegressionOverfittingAnalyzer()
86      : base() {
87      Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>(QualityParameterName, "Training fitness"));
88      Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization."));
89      Parameters.Add(new LookupParameter<DoubleValue>(TrainingValidationCorrelationParameterName, "Correlation of training and validation fitnesses"));
90      Parameters.Add(new LookupParameter<DataTable>(TrainingValidationCorrelationTableParameterName, "Data table of training and validation fitness correlation values over the whole run."));
91      Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerCorrelationThresholdParameterName, "Lower threshold for correlation value that marks the boundary from non-overfitting to overfitting.", new DoubleValue(0.65)));
92      Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperCorrelationThresholdParameterName, "Upper threshold for correlation value that marks the boundary from overfitting to non-overfitting.", new DoubleValue(0.75)));
93      Parameters.Add(new LookupParameter<BoolValue>(OverfittingParameterName, "Boolean indicator for overfitting."));
94      Parameters.Add(new LookupParameter<ResultCollection>(ResultsParameterName, "The results collection."));
95    }
96
97    [StorableHook(HookType.AfterDeserialization)]
98    private void AfterDeserialization() {
99    }
100
101    public override IDeepCloneable Clone(Cloner cloner) {
102      return new SymbolicRegressionOverfittingAnalyzer(this, cloner);
103    }
104
105    protected override void Analyze(SymbolicExpressionTree[] trees, double[] validationQuality) {
106      double[] trainingQuality = QualityParameter.ActualValue.Select(x => x.Value).ToArray();
107
108      double r = alglib.spearmancorr2(trainingQuality, validationQuality);
109
110      TrainingValidationQualityCorrelationParameter.ActualValue = new DoubleValue(r);
111
112      if (TrainingValidationQualityCorrelationTableParameter.ActualValue == null) {
113        var dataTable = new DataTable("Training and validation fitness correlation table", "Data table of training and validation fitness correlation values over the whole run.");
114        dataTable.Rows.Add(new DataRow("Training and validation fitness correlation", "Training and validation fitness correlation values"));
115        TrainingValidationQualityCorrelationTableParameter.ActualValue = dataTable;
116        ResultsParameter.ActualValue.Add(new Result(TrainingValidationCorrelationTableParameterName, dataTable));
117      }
118
119      TrainingValidationQualityCorrelationTableParameter.ActualValue.Rows["Training and validation fitness correlation"].Values.Add(r);
120
121      double correlationThreshold;
122      if (OverfittingParameter.ActualValue != null && OverfittingParameter.ActualValue.Value) {
123        // if is already overfitting => have to reach the upper threshold to switch back to non-overfitting state
124        correlationThreshold = UpperCorrelationThresholdParameter.ActualValue.Value;
125      } else {
126        // if currently in non-overfitting state => have to reach to lower threshold to switch to overfitting state
127        correlationThreshold = LowerCorrelationThresholdParameter.ActualValue.Value;
128      }
129      bool overfitting = r < correlationThreshold;
130
131      OverfittingParameter.ActualValue = new BoolValue(overfitting);
132    }
133  }
134}
Note: See TracBrowser for help on using the repository browser.