#region License Information
/* HeuristicLab
* Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System.Collections.Generic;
using System.Linq;
using HeuristicLab.Analysis;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Data;
using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
using HeuristicLab.Operators;
using HeuristicLab.Optimization;
using HeuristicLab.Parameters;
using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
using HeuristicLab.Problems.DataAnalysis.Evaluators;
using HeuristicLab.Problems.DataAnalysis.Symbolic;
using System;
namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers {
[Item("OverfittingAnalyzer", "")]
[StorableClass]
public sealed class OverfittingAnalyzer : SingleSuccessorOperator, ISymbolicRegressionAnalyzer {
private const string RandomParameterName = "Random";
private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree";
private const string SymbolicExpressionTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
private const string ProblemDataParameterName = "ProblemData";
private const string ValidationSamplesStartParameterName = "SamplesStart";
private const string ValidationSamplesEndParameterName = "SamplesEnd";
private const string UpperEstimationLimitParameterName = "UpperEstimationLimit";
private const string LowerEstimationLimitParameterName = "LowerEstimationLimit";
private const string EvaluatorParameterName = "Evaluator";
private const string MaximizationParameterName = "Maximization";
private const string RelativeNumberOfEvaluatedSamplesParameterName = "RelativeNumberOfEvaluatedSamples";
#region parameter properties
public ILookupParameter RandomParameter {
get { return (ILookupParameter)Parameters[RandomParameterName]; }
}
public ScopeTreeLookupParameter SymbolicExpressionTreeParameter {
get { return (ScopeTreeLookupParameter)Parameters[SymbolicExpressionTreeParameterName]; }
}
public ScopeTreeLookupParameter QualityParameter {
get { return (ScopeTreeLookupParameter)Parameters["Quality"]; }
}
public ScopeTreeLookupParameter ValidationQualityParameter {
get { return (ScopeTreeLookupParameter)Parameters["ValidationQuality"]; }
}
public IValueLookupParameter SymbolicExpressionTreeInterpreterParameter {
get { return (IValueLookupParameter)Parameters[SymbolicExpressionTreeInterpreterParameterName]; }
}
public ILookupParameter EvaluatorParameter {
get { return (ILookupParameter)Parameters[EvaluatorParameterName]; }
}
public ILookupParameter MaximizationParameter {
get { return (ILookupParameter)Parameters[MaximizationParameterName]; }
}
public IValueLookupParameter ProblemDataParameter {
get { return (IValueLookupParameter)Parameters[ProblemDataParameterName]; }
}
public IValueLookupParameter ValidationSamplesStartParameter {
get { return (IValueLookupParameter)Parameters[ValidationSamplesStartParameterName]; }
}
public IValueLookupParameter ValidationSamplesEndParameter {
get { return (IValueLookupParameter)Parameters[ValidationSamplesEndParameterName]; }
}
public IValueParameter RelativeNumberOfEvaluatedSamplesParameter {
get { return (IValueParameter)Parameters[RelativeNumberOfEvaluatedSamplesParameterName]; }
}
public IValueLookupParameter UpperEstimationLimitParameter {
get { return (IValueLookupParameter)Parameters[UpperEstimationLimitParameterName]; }
}
public IValueLookupParameter LowerEstimationLimitParameter {
get { return (IValueLookupParameter)Parameters[LowerEstimationLimitParameterName]; }
}
public ILookupParameter RelativeValidationQualityParameter {
get { return (ILookupParameter)Parameters["RelativeValidationQuality"]; }
}
//public IValueLookupParameter RelativeValidationQualityLowerLimitParameter {
// get { return (IValueLookupParameter)Parameters["RelativeValidationQualityLowerLimit"]; }
//}
//public IValueLookupParameter RelativeValidationQualityUpperLimitParameter {
// get { return (IValueLookupParameter)Parameters["RelativeValidationQualityUpperLimit"]; }
//}
public ILookupParameter TrainingValidationQualityCorrelationParameter {
get { return (ILookupParameter)Parameters["TrainingValidationCorrelation"]; }
}
public IValueLookupParameter LowerCorrelationLimitParameter {
get { return (IValueLookupParameter)Parameters["LowerCorrelationLimit"]; }
}
public IValueLookupParameter UpperCorrelationLimitParameter {
get { return (IValueLookupParameter)Parameters["UpperCorrelationLimit"]; }
}
public ILookupParameter OverfittingParameter {
get { return (ILookupParameter)Parameters["Overfitting"]; }
}
public ILookupParameter ResultsParameter {
get { return (ILookupParameter)Parameters["Results"]; }
}
public ILookupParameter InitialTrainingQualityParameter {
get { return (ILookupParameter)Parameters["InitialTrainingQuality"]; }
}
public ILookupParameter> TrainingAndValidationQualitiesParameter {
get { return (ILookupParameter>)Parameters["TrainingAndValidationQualities"]; }
}
public IValueLookupParameter PercentileParameter {
get { return (IValueLookupParameter)Parameters["Percentile"]; }
}
#endregion
#region properties
public IRandom Random {
get { return RandomParameter.ActualValue; }
}
public ItemArray SymbolicExpressionTree {
get { return SymbolicExpressionTreeParameter.ActualValue; }
}
public ISymbolicExpressionTreeInterpreter SymbolicExpressionTreeInterpreter {
get { return SymbolicExpressionTreeInterpreterParameter.ActualValue; }
}
public ISymbolicRegressionEvaluator Evaluator {
get { return EvaluatorParameter.ActualValue; }
}
public BoolValue Maximization {
get { return MaximizationParameter.ActualValue; }
}
public DataAnalysisProblemData ProblemData {
get { return ProblemDataParameter.ActualValue; }
}
public IntValue ValidiationSamplesStart {
get { return ValidationSamplesStartParameter.ActualValue; }
}
public IntValue ValidationSamplesEnd {
get { return ValidationSamplesEndParameter.ActualValue; }
}
public PercentValue RelativeNumberOfEvaluatedSamples {
get { return RelativeNumberOfEvaluatedSamplesParameter.Value; }
}
public DoubleValue UpperEstimationLimit {
get { return UpperEstimationLimitParameter.ActualValue; }
}
public DoubleValue LowerEstimationLimit {
get { return LowerEstimationLimitParameter.ActualValue; }
}
#endregion
public OverfittingAnalyzer()
: base() {
Parameters.Add(new LookupParameter(RandomParameterName, "The random generator to use."));
Parameters.Add(new LookupParameter(EvaluatorParameterName, "The evaluator which should be used to evaluate the solution on the validation set."));
Parameters.Add(new ScopeTreeLookupParameter(SymbolicExpressionTreeParameterName, "The symbolic expression trees to analyze."));
Parameters.Add(new ScopeTreeLookupParameter("Quality"));
Parameters.Add(new ScopeTreeLookupParameter("ValidationQuality"));
Parameters.Add(new LookupParameter(MaximizationParameterName, "The direction of optimization."));
Parameters.Add(new ValueLookupParameter(SymbolicExpressionTreeInterpreterParameterName, "The interpreter that should be used for the analysis of symbolic expression trees."));
Parameters.Add(new ValueLookupParameter(ProblemDataParameterName, "The problem data for which the symbolic expression tree is a solution."));
Parameters.Add(new ValueLookupParameter(ValidationSamplesStartParameterName, "The first index of the validation partition of the data set."));
Parameters.Add(new ValueLookupParameter(ValidationSamplesEndParameterName, "The last index of the validation partition of the data set."));
Parameters.Add(new ValueParameter(RelativeNumberOfEvaluatedSamplesParameterName, "The relative number of samples of the dataset partition, which should be randomly chosen for evaluation between the start and end index.", new PercentValue(1)));
Parameters.Add(new ValueLookupParameter(UpperEstimationLimitParameterName, "The upper estimation limit that was set for the evaluation of the symbolic expression trees."));
Parameters.Add(new ValueLookupParameter(LowerEstimationLimitParameterName, "The lower estimation limit that was set for the evaluation of the symbolic expression trees."));
Parameters.Add(new LookupParameter("RelativeValidationQuality"));
//Parameters.Add(new ValueLookupParameter("RelativeValidationQualityUpperLimit", new PercentValue(0.05)));
//Parameters.Add(new ValueLookupParameter("RelativeValidationQualityLowerLimit", new PercentValue(-0.05)));
Parameters.Add(new LookupParameter("TrainingValidationCorrelation"));
Parameters.Add(new ValueLookupParameter("LowerCorrelationLimit", new DoubleValue(0.65)));
Parameters.Add(new ValueLookupParameter("UpperCorrelationLimit", new DoubleValue(0.75)));
Parameters.Add(new LookupParameter("Overfitting"));
Parameters.Add(new LookupParameter("Results"));
Parameters.Add(new LookupParameter("InitialTrainingQuality"));
Parameters.Add(new LookupParameter>("TrainingAndValidationQualities"));
Parameters.Add(new ValueLookupParameter("Percentile", new DoubleValue(1)));
}
[StorableConstructor]
private OverfittingAnalyzer(bool deserializing) : base(deserializing) { }
[StorableHook(HookType.AfterDeserialization)]
private void AfterDeserialization() {
if (!Parameters.ContainsKey("InitialTrainingQuality")) {
Parameters.Add(new LookupParameter("InitialTrainingQuality"));
}
//if (!Parameters.ContainsKey("RelativeValidationQualityUpperLimit")) {
// Parameters.Add(new ValueLookupParameter("RelativeValidationQualityUpperLimit", new PercentValue(0.05)));
//}
//if (!Parameters.ContainsKey("RelativeValidationQualityLowerLimit")) {
// Parameters.Add(new ValueLookupParameter("RelativeValidationQualityLowerLimit", new PercentValue(-0.05)));
//}
if (!Parameters.ContainsKey("TrainingAndValidationQualities")) {
Parameters.Add(new LookupParameter>("TrainingAndValidationQualities"));
}
if (!Parameters.ContainsKey("Percentile")) {
Parameters.Add(new ValueLookupParameter("Percentile", new DoubleValue(1)));
}
if (!Parameters.ContainsKey("ValidationQuality")) {
Parameters.Add(new ScopeTreeLookupParameter("ValidationQuality"));
}
if (!Parameters.ContainsKey("LowerCorrelationLimit")) {
Parameters.Add(new ValueLookupParameter("LowerCorrelationLimit", new DoubleValue(0.65)));
}
if (!Parameters.ContainsKey("UpperCorrelationLimit")) {
Parameters.Add(new ValueLookupParameter("UpperCorrelationLimit", new DoubleValue(0.75)));
}
}
public override IOperation Apply() {
var trees = SymbolicExpressionTree;
ItemArray qualities = QualityParameter.ActualValue;
ItemArray validationQualities = ValidationQualityParameter.ActualValue;
double correlationLimit;
if (OverfittingParameter.ActualValue != null && OverfittingParameter.ActualValue.Value) {
// if is already overfitting have to reach the upper limit to switch back to non-overfitting state
correlationLimit = UpperCorrelationLimitParameter.ActualValue.Value;
} else {
// if currently in non-overfitting state have to reach to lower limit to switch to overfitting state
correlationLimit = LowerCorrelationLimitParameter.ActualValue.Value;
}
//string targetVariable = ProblemData.TargetVariable.Value;
//// select a random subset of rows in the validation set
//int validationStart = ValidiationSamplesStart.Value;
//int validationEnd = ValidationSamplesEnd.Value;
//int seed = Random.Next();
//int count = (int)((validationEnd - validationStart) * RelativeNumberOfEvaluatedSamples.Value);
//if (count == 0) count = 1;
//IEnumerable rows = RandomEnumerable.SampleRandomNumbers(seed, validationStart, validationEnd, count);
//double upperEstimationLimit = UpperEstimationLimit != null ? UpperEstimationLimit.Value : double.PositiveInfinity;
//double lowerEstimationLimit = LowerEstimationLimit != null ? LowerEstimationLimit.Value : double.NegativeInfinity;
//double bestQuality = Maximization.Value ? double.NegativeInfinity : double.PositiveInfinity;
//SymbolicExpressionTree bestTree = null;
//List validationQualities = new List();
//foreach (var tree in trees) {
// double quality = Evaluator.Evaluate(SymbolicExpressionTreeInterpreter, tree,
// lowerEstimationLimit, upperEstimationLimit,
// ProblemData.Dataset, targetVariable,
// rows);
// validationQualities.Add(quality);
// //if ((Maximization.Value && quality > bestQuality) ||
// // (!Maximization.Value && quality < bestQuality)) {
// // bestQuality = quality;
// // bestTree = tree;
// //}
//}
//if (RelativeValidationQualityParameter.ActualValue == null) {
// first call initialize the relative quality using the difference between average training and validation quality
double avgTrainingQuality = qualities.Select(x => x.Value).Average();
double avgValidationQuality = validationQualities.Select(x => x.Value).Average();
if (Maximization.Value)
RelativeValidationQualityParameter.ActualValue = new PercentValue(avgValidationQuality / avgTrainingQuality - 1);
else {
RelativeValidationQualityParameter.ActualValue = new PercentValue(avgTrainingQuality / avgValidationQuality - 1);
}
//}
// best first (only for maximization
var orderedDistinctPairs = (from index in Enumerable.Range(0, qualities.Length)
where qualities[index].Value > 0.0
select new { Training = qualities[index].Value, Validation = validationQualities[index].Value })
.OrderBy(x => -x.Training)
.ToList();
int n = (int)Math.Round(PercentileParameter.ActualValue.Value * orderedDistinctPairs.Count);
double[] validationArr = new double[n];
double[] trainingArr = new double[n];
double[,] qualitiesArr = new double[n, 2];
for (int i = 0; i < n; i++) {
validationArr[i] = orderedDistinctPairs[i].Validation;
trainingArr[i] = orderedDistinctPairs[i].Training;
qualitiesArr[i, 0] = trainingArr[i];
qualitiesArr[i, 1] = validationArr[i];
}
double r = alglib.correlation.spearmanrankcorrelation(trainingArr, validationArr, n);
TrainingValidationQualityCorrelationParameter.ActualValue = new DoubleValue(r);
if (InitialTrainingQualityParameter.ActualValue == null)
InitialTrainingQualityParameter.ActualValue = new DoubleValue(avgValidationQuality);
bool overfitting =
avgTrainingQuality > InitialTrainingQualityParameter.ActualValue.Value && // better on training than in initial generation
// RelativeValidationQualityParameter.ActualValue.Value < 0.0 && // validation quality is worse than training quality
r < correlationLimit;
OverfittingParameter.ActualValue = new BoolValue(overfitting);
ItemList list = TrainingAndValidationQualitiesParameter.ActualValue;
if (list == null) {
TrainingAndValidationQualitiesParameter.ActualValue = new ItemList();
}
TrainingAndValidationQualitiesParameter.ActualValue.Add(new DoubleMatrix(qualitiesArr));
return base.Apply();
}
[StorableHook(HookType.AfterDeserialization)]
private void Initialize() { }
private static void AddValue(DataTable table, double data, string name, string description) {
DataRow row;
table.Rows.TryGetValue(name, out row);
if (row == null) {
row = new DataRow(name, description);
row.Values.Add(data);
table.Rows.Add(row);
} else {
row.Values.Add(data);
}
}
}
}