Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3057_DynamicALPS/TestProblems/oesr-alps-master/HeuristicLab.Algorithms.OESRALPS/Analyzers/SymbolicDataAnalysisSingleObjectiveOverfittingSlidingWindowAnalyzer.cs @ 17479

Last change on this file since 17479 was 17479, checked in by kyang, 4 years ago

#3057

  1. upload the latest version of ALPS with SMS-EMOA
  2. upload the related dynamic test problems (dynamic, single-objective symbolic regression), written by David Daninel.
File size: 7.4 KB
Line 
1using HEAL.Attic;
2using HeuristicLab.Common;
3using HeuristicLab.Core;
4using HeuristicLab.Data;
5using HeuristicLab.Parameters;
6using HeuristicLab.Problems.DataAnalysis;
7using HeuristicLab.Problems.DataAnalysis.Symbolic;
8using System;
9using System.Collections.Generic;
10using HeuristicLab.Analysis;
11using System.Linq;
12using System.Text;
13using System.Threading.Tasks;
14using HeuristicLab.Optimization;
15
16namespace HeuristicLab.Algorithms.OESRALPS.Analyzers
17{
18    [Item("SymbolicDataAnalysisSingleObjectiveOverfittingAnalyzer", "Calculates and tracks correlation of training and validation fitness of symbolic regression models.")]
19    [StorableType("AE1F0B73-BEB1-47AF-8ECF-DBCFA32AA5B9")]
20    public abstract class SymbolicDataAnalysisSingleObjectiveOverfittingAnalyzer<T, U>
21        : SymbolicDataAnalysisSingleObjectiveLayerValidationAnalyzer<T, U>
22        where T : class, ISymbolicDataAnalysisSingleObjectiveEvaluator<U>
23        where U : class, IDataAnalysisProblemData
24    {
25        private const string TrainingValidationCorrelationParameterName = "Training and validation fitness correlation";
26        private const string TrainingValidationCorrelationTableParameterName = "Training and validation fitness correlation table";
27        private const string LowerCorrelationThresholdParameterName = "LowerCorrelationThreshold";
28        private const string UpperCorrelationThresholdParameterName = "UpperCorrelationThreshold";
29        private const string OverfittingParameterName = "IsOverfitting";
30
31        #region parameter properties
32        public ILookupParameter<DoubleValue> TrainingValidationQualityCorrelationParameter {
33            get { return (ILookupParameter<DoubleValue>)Parameters[TrainingValidationCorrelationParameterName]; }
34        }
35        public ILookupParameter<DataTable> TrainingValidationQualityCorrelationTableParameter {
36            get { return (ILookupParameter<DataTable>)Parameters[TrainingValidationCorrelationTableParameterName]; }
37        }
38        public IValueLookupParameter<DoubleValue> LowerCorrelationThresholdParameter {
39            get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerCorrelationThresholdParameterName]; }
40        }
41        public IValueLookupParameter<DoubleValue> UpperCorrelationThresholdParameter {
42            get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperCorrelationThresholdParameterName]; }
43        }
44        public ILookupParameter<BoolValue> OverfittingParameter {
45            get { return (ILookupParameter<BoolValue>)Parameters[OverfittingParameterName]; }
46        }
47        #endregion
48
49        [StorableConstructor]
50        protected SymbolicDataAnalysisSingleObjectiveOverfittingAnalyzer(StorableConstructorFlag _) : base(_) { }
51        protected SymbolicDataAnalysisSingleObjectiveOverfittingAnalyzer(SymbolicDataAnalysisSingleObjectiveOverfittingAnalyzer<T, U> original, Cloner cloner) : base(original, cloner) { }
52        public SymbolicDataAnalysisSingleObjectiveOverfittingAnalyzer()
53          : base()
54        {
55            Parameters.Add(new LookupParameter<DoubleValue>(TrainingValidationCorrelationParameterName, "Correlation of training and validation fitnesses"));
56            Parameters.Add(new LookupParameter<DataTable>(TrainingValidationCorrelationTableParameterName, "Data table of training and validation fitness correlation values over the whole run."));
57            Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerCorrelationThresholdParameterName, "Lower threshold for correlation value that marks the boundary from non-overfitting to overfitting.", new DoubleValue(0.65)));
58            Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperCorrelationThresholdParameterName, "Upper threshold for correlation value that marks the boundary from overfitting to non-overfitting.", new DoubleValue(0.75)));
59            Parameters.Add(new LookupParameter<BoolValue>(OverfittingParameterName, "Boolean indicator for overfitting."));
60        }
61
62        public override IOperation Apply()
63        {
64            IEnumerable<int> rows = GenerateRowsToEvaluate();
65            if (!rows.Any()) return base.Apply();
66
67            double[] trainingQuality = QualityParameter.ActualValue.Select(x => x.Value).ToArray();
68            var problemData = ProblemDataParameter.ActualValue;
69            var evaluator = EvaluatorParameter.ActualValue;
70
71            // evaluate on validation partition
72            IExecutionContext childContext = (IExecutionContext)ExecutionContext.CreateChildOperation(evaluator);
73            double[] validationQuality = SymbolicExpressionTree
74              .Select(t => evaluator.Evaluate(childContext, t, problemData, rows))
75              .ToArray();
76
77            double r = 0.0;
78            try
79            {
80                r = alglib.spearmancorr2(trainingQuality, validationQuality);
81            }
82            catch (alglib.alglibexception)
83            {
84                r = 0.0;
85            }
86
87            var results = ResultCollection;
88            #region Add Parameters
89            if (!results.ContainsKey(TrainingValidationQualityCorrelationTableParameter.Name))
90                ResultCollectionParameter.ActualValue.Add(new Result(TrainingValidationQualityCorrelationTableParameter.Name, TrainingValidationQualityCorrelationTableParameter.Description, typeof(DataTable)));
91            if (!results.ContainsKey(OverfittingParameter.Name))
92                results.Add(new Result(OverfittingParameter.Name, OverfittingParameter.Description, typeof(BoolValue)));
93            #endregion
94
95            TrainingValidationQualityCorrelationParameter.ActualValue = new DoubleValue(r);
96
97            if (TrainingValidationQualityCorrelationTableParameter.ActualValue == null)
98            {
99                var dataTable = new DataTable(TrainingValidationQualityCorrelationTableParameter.Name, TrainingValidationQualityCorrelationTableParameter.Description);
100                dataTable.Rows.Add(new DataRow(TrainingValidationQualityCorrelationParameter.Name, TrainingValidationQualityCorrelationParameter.Description));
101                dataTable.Rows[TrainingValidationQualityCorrelationParameter.Name].VisualProperties.StartIndexZero = true;
102                TrainingValidationQualityCorrelationTableParameter.ActualValue = dataTable;
103            }
104
105            TrainingValidationQualityCorrelationTableParameter.ActualValue.Rows[TrainingValidationQualityCorrelationParameter.Name].Values.Add(r);
106
107            if (OverfittingParameter.ActualValue != null && OverfittingParameter.ActualValue.Value)
108            {
109                // overfitting == true
110                // => r must reach the upper threshold to switch back to non-overfitting state
111                OverfittingParameter.ActualValue = new BoolValue(r < UpperCorrelationThresholdParameter.ActualValue.Value);
112            }
113            else
114            {
115                // overfitting == false
116                // => r must drop below lower threshold to switch to overfitting state
117                OverfittingParameter.ActualValue = new BoolValue(r < LowerCorrelationThresholdParameter.ActualValue.Value);
118            }
119
120            results[TrainingValidationQualityCorrelationTableParameter.Name].Value = TrainingValidationQualityCorrelationTableParameter.ActualValue;
121            results[OverfittingParameter.Name].Value = OverfittingParameter.ActualValue;
122
123            return base.Apply();
124        }
125    }
126}
Note: See TracBrowser for help on using the repository browser.