#region License Information /* HeuristicLab * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System.Collections.Generic; using System.Linq; using HeuristicLab.Analysis; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; using HeuristicLab.Operators; using HeuristicLab.Optimization; using HeuristicLab.Parameters; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HeuristicLab.Problems.DataAnalysis.Evaluators; using HeuristicLab.Problems.DataAnalysis.Symbolic; using System; namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers { [Item("OverfittingAnalyzer", "")] [StorableClass] public sealed class OverfittingAnalyzer : SingleSuccessorOperator, ISymbolicRegressionAnalyzer { private const string RandomParameterName = "Random"; private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree"; private const string SymbolicExpressionTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter"; private const string ProblemDataParameterName = "ProblemData"; private const string ValidationSamplesStartParameterName = "SamplesStart"; private const string ValidationSamplesEndParameterName = "SamplesEnd"; private const string UpperEstimationLimitParameterName = "UpperEstimationLimit"; private const string LowerEstimationLimitParameterName = "LowerEstimationLimit"; private const string EvaluatorParameterName = "Evaluator"; private const string MaximizationParameterName = "Maximization"; private const string RelativeNumberOfEvaluatedSamplesParameterName = "RelativeNumberOfEvaluatedSamples"; #region parameter properties public ILookupParameter RandomParameter { get { return (ILookupParameter)Parameters[RandomParameterName]; } } public ScopeTreeLookupParameter SymbolicExpressionTreeParameter { get { return (ScopeTreeLookupParameter)Parameters[SymbolicExpressionTreeParameterName]; } } public ScopeTreeLookupParameter QualityParameter { get { return (ScopeTreeLookupParameter)Parameters["Quality"]; } } public ScopeTreeLookupParameter ValidationQualityParameter { get { return (ScopeTreeLookupParameter)Parameters["ValidationQuality"]; } } public IValueLookupParameter SymbolicExpressionTreeInterpreterParameter { get { return (IValueLookupParameter)Parameters[SymbolicExpressionTreeInterpreterParameterName]; } } public ILookupParameter EvaluatorParameter { get { return (ILookupParameter)Parameters[EvaluatorParameterName]; } } public ILookupParameter MaximizationParameter { get { return (ILookupParameter)Parameters[MaximizationParameterName]; } } public IValueLookupParameter ProblemDataParameter { get { return (IValueLookupParameter)Parameters[ProblemDataParameterName]; } } public IValueLookupParameter ValidationSamplesStartParameter { get { return (IValueLookupParameter)Parameters[ValidationSamplesStartParameterName]; } } public IValueLookupParameter ValidationSamplesEndParameter { get { return (IValueLookupParameter)Parameters[ValidationSamplesEndParameterName]; } } public IValueParameter RelativeNumberOfEvaluatedSamplesParameter { get { return (IValueParameter)Parameters[RelativeNumberOfEvaluatedSamplesParameterName]; } } public IValueLookupParameter UpperEstimationLimitParameter { get { return (IValueLookupParameter)Parameters[UpperEstimationLimitParameterName]; } } public IValueLookupParameter LowerEstimationLimitParameter { get { return (IValueLookupParameter)Parameters[LowerEstimationLimitParameterName]; } } public ILookupParameter RelativeValidationQualityParameter { get { return (ILookupParameter)Parameters["RelativeValidationQuality"]; } } //public IValueLookupParameter RelativeValidationQualityLowerLimitParameter { // get { return (IValueLookupParameter)Parameters["RelativeValidationQualityLowerLimit"]; } //} //public IValueLookupParameter RelativeValidationQualityUpperLimitParameter { // get { return (IValueLookupParameter)Parameters["RelativeValidationQualityUpperLimit"]; } //} public ILookupParameter TrainingValidationQualityCorrelationParameter { get { return (ILookupParameter)Parameters["TrainingValidationCorrelation"]; } } public IValueLookupParameter LowerCorrelationLimitParameter { get { return (IValueLookupParameter)Parameters["LowerCorrelationLimit"]; } } public IValueLookupParameter UpperCorrelationLimitParameter { get { return (IValueLookupParameter)Parameters["UpperCorrelationLimit"]; } } public ILookupParameter OverfittingParameter { get { return (ILookupParameter)Parameters["Overfitting"]; } } public ILookupParameter ResultsParameter { get { return (ILookupParameter)Parameters["Results"]; } } public ILookupParameter InitialTrainingQualityParameter { get { return (ILookupParameter)Parameters["InitialTrainingQuality"]; } } public ILookupParameter> TrainingAndValidationQualitiesParameter { get { return (ILookupParameter>)Parameters["TrainingAndValidationQualities"]; } } public IValueLookupParameter PercentileParameter { get { return (IValueLookupParameter)Parameters["Percentile"]; } } #endregion #region properties public IRandom Random { get { return RandomParameter.ActualValue; } } public ItemArray SymbolicExpressionTree { get { return SymbolicExpressionTreeParameter.ActualValue; } } public ISymbolicExpressionTreeInterpreter SymbolicExpressionTreeInterpreter { get { return SymbolicExpressionTreeInterpreterParameter.ActualValue; } } public ISymbolicRegressionEvaluator Evaluator { get { return EvaluatorParameter.ActualValue; } } public BoolValue Maximization { get { return MaximizationParameter.ActualValue; } } public DataAnalysisProblemData ProblemData { get { return ProblemDataParameter.ActualValue; } } public IntValue ValidiationSamplesStart { get { return ValidationSamplesStartParameter.ActualValue; } } public IntValue ValidationSamplesEnd { get { return ValidationSamplesEndParameter.ActualValue; } } public PercentValue RelativeNumberOfEvaluatedSamples { get { return RelativeNumberOfEvaluatedSamplesParameter.Value; } } public DoubleValue UpperEstimationLimit { get { return UpperEstimationLimitParameter.ActualValue; } } public DoubleValue LowerEstimationLimit { get { return LowerEstimationLimitParameter.ActualValue; } } #endregion public OverfittingAnalyzer() : base() { Parameters.Add(new LookupParameter(RandomParameterName, "The random generator to use.")); Parameters.Add(new LookupParameter(EvaluatorParameterName, "The evaluator which should be used to evaluate the solution on the validation set.")); Parameters.Add(new ScopeTreeLookupParameter(SymbolicExpressionTreeParameterName, "The symbolic expression trees to analyze.")); Parameters.Add(new ScopeTreeLookupParameter("Quality")); Parameters.Add(new ScopeTreeLookupParameter("ValidationQuality")); Parameters.Add(new LookupParameter(MaximizationParameterName, "The direction of optimization.")); Parameters.Add(new ValueLookupParameter(SymbolicExpressionTreeInterpreterParameterName, "The interpreter that should be used for the analysis of symbolic expression trees.")); Parameters.Add(new ValueLookupParameter(ProblemDataParameterName, "The problem data for which the symbolic expression tree is a solution.")); Parameters.Add(new ValueLookupParameter(ValidationSamplesStartParameterName, "The first index of the validation partition of the data set.")); Parameters.Add(new ValueLookupParameter(ValidationSamplesEndParameterName, "The last index of the validation partition of the data set.")); Parameters.Add(new ValueParameter(RelativeNumberOfEvaluatedSamplesParameterName, "The relative number of samples of the dataset partition, which should be randomly chosen for evaluation between the start and end index.", new PercentValue(1))); Parameters.Add(new ValueLookupParameter(UpperEstimationLimitParameterName, "The upper estimation limit that was set for the evaluation of the symbolic expression trees.")); Parameters.Add(new ValueLookupParameter(LowerEstimationLimitParameterName, "The lower estimation limit that was set for the evaluation of the symbolic expression trees.")); Parameters.Add(new LookupParameter("RelativeValidationQuality")); //Parameters.Add(new ValueLookupParameter("RelativeValidationQualityUpperLimit", new PercentValue(0.05))); //Parameters.Add(new ValueLookupParameter("RelativeValidationQualityLowerLimit", new PercentValue(-0.05))); Parameters.Add(new LookupParameter("TrainingValidationCorrelation")); Parameters.Add(new ValueLookupParameter("LowerCorrelationLimit", new DoubleValue(0.65))); Parameters.Add(new ValueLookupParameter("UpperCorrelationLimit", new DoubleValue(0.75))); Parameters.Add(new LookupParameter("Overfitting")); Parameters.Add(new LookupParameter("Results")); Parameters.Add(new LookupParameter("InitialTrainingQuality")); Parameters.Add(new LookupParameter>("TrainingAndValidationQualities")); Parameters.Add(new ValueLookupParameter("Percentile", new DoubleValue(1))); } [StorableConstructor] private OverfittingAnalyzer(bool deserializing) : base(deserializing) { } [StorableHook(HookType.AfterDeserialization)] private void AfterDeserialization() { if (!Parameters.ContainsKey("InitialTrainingQuality")) { Parameters.Add(new LookupParameter("InitialTrainingQuality")); } //if (!Parameters.ContainsKey("RelativeValidationQualityUpperLimit")) { // Parameters.Add(new ValueLookupParameter("RelativeValidationQualityUpperLimit", new PercentValue(0.05))); //} //if (!Parameters.ContainsKey("RelativeValidationQualityLowerLimit")) { // Parameters.Add(new ValueLookupParameter("RelativeValidationQualityLowerLimit", new PercentValue(-0.05))); //} if (!Parameters.ContainsKey("TrainingAndValidationQualities")) { Parameters.Add(new LookupParameter>("TrainingAndValidationQualities")); } if (!Parameters.ContainsKey("Percentile")) { Parameters.Add(new ValueLookupParameter("Percentile", new DoubleValue(1))); } if (!Parameters.ContainsKey("ValidationQuality")) { Parameters.Add(new ScopeTreeLookupParameter("ValidationQuality")); } if (!Parameters.ContainsKey("LowerCorrelationLimit")) { Parameters.Add(new ValueLookupParameter("LowerCorrelationLimit", new DoubleValue(0.65))); } if (!Parameters.ContainsKey("UpperCorrelationLimit")) { Parameters.Add(new ValueLookupParameter("UpperCorrelationLimit", new DoubleValue(0.75))); } } public override IOperation Apply() { var trees = SymbolicExpressionTree; ItemArray qualities = QualityParameter.ActualValue; ItemArray validationQualities = ValidationQualityParameter.ActualValue; double correlationLimit; if (OverfittingParameter.ActualValue != null && OverfittingParameter.ActualValue.Value) { // if is already overfitting have to reach the upper limit to switch back to non-overfitting state correlationLimit = UpperCorrelationLimitParameter.ActualValue.Value; } else { // if currently in non-overfitting state have to reach to lower limit to switch to overfitting state correlationLimit = LowerCorrelationLimitParameter.ActualValue.Value; } //string targetVariable = ProblemData.TargetVariable.Value; //// select a random subset of rows in the validation set //int validationStart = ValidiationSamplesStart.Value; //int validationEnd = ValidationSamplesEnd.Value; //int seed = Random.Next(); //int count = (int)((validationEnd - validationStart) * RelativeNumberOfEvaluatedSamples.Value); //if (count == 0) count = 1; //IEnumerable rows = RandomEnumerable.SampleRandomNumbers(seed, validationStart, validationEnd, count); //double upperEstimationLimit = UpperEstimationLimit != null ? UpperEstimationLimit.Value : double.PositiveInfinity; //double lowerEstimationLimit = LowerEstimationLimit != null ? LowerEstimationLimit.Value : double.NegativeInfinity; //double bestQuality = Maximization.Value ? double.NegativeInfinity : double.PositiveInfinity; //SymbolicExpressionTree bestTree = null; //List validationQualities = new List(); //foreach (var tree in trees) { // double quality = Evaluator.Evaluate(SymbolicExpressionTreeInterpreter, tree, // lowerEstimationLimit, upperEstimationLimit, // ProblemData.Dataset, targetVariable, // rows); // validationQualities.Add(quality); // //if ((Maximization.Value && quality > bestQuality) || // // (!Maximization.Value && quality < bestQuality)) { // // bestQuality = quality; // // bestTree = tree; // //} //} //if (RelativeValidationQualityParameter.ActualValue == null) { // first call initialize the relative quality using the difference between average training and validation quality double avgTrainingQuality = qualities.Select(x => x.Value).Average(); double avgValidationQuality = validationQualities.Select(x => x.Value).Average(); if (Maximization.Value) RelativeValidationQualityParameter.ActualValue = new PercentValue(avgValidationQuality / avgTrainingQuality - 1); else { RelativeValidationQualityParameter.ActualValue = new PercentValue(avgTrainingQuality / avgValidationQuality - 1); } //} // best first (only for maximization var orderedDistinctPairs = (from index in Enumerable.Range(0, qualities.Length) where qualities[index].Value > 0.0 select new { Training = qualities[index].Value, Validation = validationQualities[index].Value }) .OrderBy(x => -x.Training) .ToList(); int n = (int)Math.Round(PercentileParameter.ActualValue.Value * orderedDistinctPairs.Count); double[] validationArr = new double[n]; double[] trainingArr = new double[n]; double[,] qualitiesArr = new double[n, 2]; for (int i = 0; i < n; i++) { validationArr[i] = orderedDistinctPairs[i].Validation; trainingArr[i] = orderedDistinctPairs[i].Training; qualitiesArr[i, 0] = trainingArr[i]; qualitiesArr[i, 1] = validationArr[i]; } double r = alglib.correlation.spearmanrankcorrelation(trainingArr, validationArr, n); TrainingValidationQualityCorrelationParameter.ActualValue = new DoubleValue(r); if (InitialTrainingQualityParameter.ActualValue == null) InitialTrainingQualityParameter.ActualValue = new DoubleValue(avgValidationQuality); bool overfitting = avgTrainingQuality > InitialTrainingQualityParameter.ActualValue.Value && // better on training than in initial generation // RelativeValidationQualityParameter.ActualValue.Value < 0.0 && // validation quality is worse than training quality r < correlationLimit; OverfittingParameter.ActualValue = new BoolValue(overfitting); ItemList list = TrainingAndValidationQualitiesParameter.ActualValue; if (list == null) { TrainingAndValidationQualitiesParameter.ActualValue = new ItemList(); } TrainingAndValidationQualitiesParameter.ActualValue.Add(new DoubleMatrix(qualitiesArr)); return base.Apply(); } [StorableHook(HookType.AfterDeserialization)] private void Initialize() { } private static void AddValue(DataTable table, double data, string name, string description) { DataRow row; table.Rows.TryGetValue(name, out row); if (row == null) { row = new DataRow(name, description); row.Values.Add(data); table.Rows.Add(row); } else { row.Values.Add(data); } } } }