- Timestamp:
- 01/02/11 23:48:45 (14 years ago)
- File:
-
- 1 copied
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers/SymbolicRegressionOverfittingAnalyzer.cs
r5188 r5192 36 36 37 37 namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers { 38 [Item(" OverfittingAnalyzer", "")]38 [Item("SymbolicRegressionOverfittingAnalyzer", "Calculates and tracks correlation of training and validation fitness of symbolic regression models.")] 39 39 [StorableClass] 40 public sealed class OverfittingAnalyzer : SingleSuccessorOperator, ISymbolicRegressionAnalyzer {40 public sealed class SymbolicRegressionOverfittingAnalyzer : SingleSuccessorOperator, ISymbolicRegressionAnalyzer { 41 41 private const string RandomParameterName = "Random"; 42 42 private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree"; 43 private const string MaximizationParameterName = "Maximization"; 44 private const string QualityParameterName = "Quality"; 45 private const string ValidationQualityParameterName = "ValidationQuality"; 46 private const string TrainingValidationCorrelationParameterName = "TrainingValidationCorrelation"; 47 private const string TrainingValidationCorrelationTableParameterName = "TrainingValidationCorrelationTable"; 48 private const string LowerCorrelationThresholdParameterName = "LowerCorrelationThreshold"; 49 private const string UpperCorrelationThresholdParameterName = "UpperCorrelationThreshold"; 50 private const string OverfittingParameterName = "IsOverfitting"; 51 private const string ResultsParameterName = "Results"; 52 private const string EvaluatorParameterName = "Evaluator"; 43 53 private const string SymbolicExpressionTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter"; 44 54 private const string ProblemDataParameterName = "ProblemData"; 45 private const string ValidationSamplesStartParameterName = "SamplesStart"; 46 private const string ValidationSamplesEndParameterName = "SamplesEnd"; 55 private const string ValidationSamplesStartParameterName = "ValidationSamplesStart"; 56 private const string ValidationSamplesEndParameterName = "ValidationSamplesEnd"; 57 private const string RelativeNumberOfEvaluatedSamplesParameterName = "RelativeNumberOfEvaluatedSamples"; 47 58 private const string UpperEstimationLimitParameterName = "UpperEstimationLimit"; 48 59 private const string LowerEstimationLimitParameterName = "LowerEstimationLimit"; 49 private const string EvaluatorParameterName = "Evaluator";50 private const string MaximizationParameterName = "Maximization";51 private const string RelativeNumberOfEvaluatedSamplesParameterName = "RelativeNumberOfEvaluatedSamples";52 60 53 61 #region parameter properties … … 59 67 } 60 68 public ScopeTreeLookupParameter<DoubleValue> QualityParameter { 61 get { return (ScopeTreeLookupParameter<DoubleValue>)Parameters[ "Quality"]; }69 get { return (ScopeTreeLookupParameter<DoubleValue>)Parameters[QualityParameterName]; } 62 70 } 63 71 public ScopeTreeLookupParameter<DoubleValue> ValidationQualityParameter { 64 get { return (ScopeTreeLookupParameter<DoubleValue>)Parameters["ValidationQuality"]; } 72 get { return (ScopeTreeLookupParameter<DoubleValue>)Parameters[ValidationQualityParameterName]; } 73 } 74 public ILookupParameter<BoolValue> MaximizationParameter { 75 get { return (ILookupParameter<BoolValue>)Parameters[MaximizationParameterName]; } 65 76 } 66 77 public IValueLookupParameter<ISymbolicExpressionTreeInterpreter> SymbolicExpressionTreeInterpreterParameter { … … 70 81 get { return (ILookupParameter<ISymbolicRegressionEvaluator>)Parameters[EvaluatorParameterName]; } 71 82 } 72 public ILookupParameter<BoolValue> MaximizationParameter {73 get { return (ILookupParameter<BoolValue>)Parameters[MaximizationParameterName]; }74 }75 83 public IValueLookupParameter<DataAnalysisProblemData> ProblemDataParameter { 76 84 get { return (IValueLookupParameter<DataAnalysisProblemData>)Parameters[ProblemDataParameterName]; } … … 85 93 get { return (IValueParameter<PercentValue>)Parameters[RelativeNumberOfEvaluatedSamplesParameterName]; } 86 94 } 87 88 95 public IValueLookupParameter<DoubleValue> UpperEstimationLimitParameter { 89 96 get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperEstimationLimitParameterName]; } … … 92 99 get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerEstimationLimitParameterName]; } 93 100 } 94 public ILookupParameter<PercentValue> RelativeValidationQualityParameter {95 get { return (ILookupParameter<PercentValue>)Parameters["RelativeValidationQuality"]; }96 }97 //public IValueLookupParameter<PercentValue> RelativeValidationQualityLowerLimitParameter {98 // get { return (IValueLookupParameter<PercentValue>)Parameters["RelativeValidationQualityLowerLimit"]; }99 //}100 //public IValueLookupParameter<PercentValue> RelativeValidationQualityUpperLimitParameter {101 // get { return (IValueLookupParameter<PercentValue>)Parameters["RelativeValidationQualityUpperLimit"]; }102 //}103 101 public ILookupParameter<DoubleValue> TrainingValidationQualityCorrelationParameter { 104 get { return (ILookupParameter<DoubleValue>)Parameters["TrainingValidationCorrelation"]; } 105 } 106 public IValueLookupParameter<DoubleValue> LowerCorrelationLimitParameter { 107 get { return (IValueLookupParameter<DoubleValue>)Parameters["LowerCorrelationLimit"]; } 108 } 109 public IValueLookupParameter<DoubleValue> UpperCorrelationLimitParameter { 110 get { return (IValueLookupParameter<DoubleValue>)Parameters["UpperCorrelationLimit"]; } 102 get { return (ILookupParameter<DoubleValue>)Parameters[TrainingValidationCorrelationParameterName]; } 103 } 104 public ILookupParameter<DataTable> TrainingValidationQualityCorrelationTableParameter { 105 get { return (ILookupParameter<DataTable>)Parameters[TrainingValidationCorrelationTableParameterName]; } 106 } 107 public IValueLookupParameter<DoubleValue> LowerCorrelationThresholdParameter { 108 get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerCorrelationThresholdParameterName]; } 109 } 110 public IValueLookupParameter<DoubleValue> UpperCorrelationThresholdParameter { 111 get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperCorrelationThresholdParameterName]; } 111 112 } 112 113 public ILookupParameter<BoolValue> OverfittingParameter { 113 get { return (ILookupParameter<BoolValue>)Parameters[ "Overfitting"]; }114 get { return (ILookupParameter<BoolValue>)Parameters[OverfittingParameterName]; } 114 115 } 115 116 public ILookupParameter<ResultCollection> ResultsParameter { 116 get { return (ILookupParameter<ResultCollection>)Parameters["Results"]; } 117 } 118 public ILookupParameter<DoubleValue> InitialTrainingQualityParameter { 119 get { return (ILookupParameter<DoubleValue>)Parameters["InitialTrainingQuality"]; } 120 } 121 public ILookupParameter<ItemList<DoubleMatrix>> TrainingAndValidationQualitiesParameter { 122 get { return (ILookupParameter<ItemList<DoubleMatrix>>)Parameters["TrainingAndValidationQualities"]; } 123 } 124 public IValueLookupParameter<DoubleValue> PercentileParameter { 125 get { return (IValueLookupParameter<DoubleValue>)Parameters["Percentile"]; } 117 get { return (ILookupParameter<ResultCollection>)Parameters[ResultsParameterName]; } 126 118 } 127 119 #endregion … … 130 122 get { return RandomParameter.ActualValue; } 131 123 } 132 public ItemArray<SymbolicExpressionTree> SymbolicExpressionTree{133 get { return SymbolicExpressionTreeParameter.ActualValue; }124 public BoolValue Maximization { 125 get { return MaximizationParameter.ActualValue; } 134 126 } 135 127 public ISymbolicExpressionTreeInterpreter SymbolicExpressionTreeInterpreter { … … 139 131 get { return EvaluatorParameter.ActualValue; } 140 132 } 141 public BoolValue Maximization {142 get { return MaximizationParameter.ActualValue; }143 }144 133 public DataAnalysisProblemData ProblemData { 145 134 get { return ProblemDataParameter.ActualValue; } … … 163 152 #endregion 164 153 165 public OverfittingAnalyzer() 154 [StorableConstructor] 155 private SymbolicRegressionOverfittingAnalyzer(bool deserializing) : base(deserializing) { } 156 private SymbolicRegressionOverfittingAnalyzer(SymbolicRegressionOverfittingAnalyzer original, Cloner cloner) : base(original, cloner) { } 157 public SymbolicRegressionOverfittingAnalyzer() 166 158 : base() { 167 Parameters.Add(new LookupParameter<IRandom>(RandomParameterName, "The random generator to use.")); 159 Parameters.Add(new LookupParameter<IRandom>(RandomParameterName, "The random generator to use.")); 160 Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>(QualityParameterName, "Training fitness")); 161 Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization.")); 162 163 Parameters.Add(new ScopeTreeLookupParameter<SymbolicExpressionTree>(SymbolicExpressionTreeParameterName, "The symbolic expression trees to analyze.")); 168 164 Parameters.Add(new LookupParameter<ISymbolicRegressionEvaluator>(EvaluatorParameterName, "The evaluator which should be used to evaluate the solution on the validation set.")); 169 Parameters.Add(new ScopeTreeLookupParameter<SymbolicExpressionTree>(SymbolicExpressionTreeParameterName, "The symbolic expression trees to analyze."));170 Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>("Quality"));171 Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>("ValidationQuality"));172 Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization."));173 165 Parameters.Add(new ValueLookupParameter<ISymbolicExpressionTreeInterpreter>(SymbolicExpressionTreeInterpreterParameterName, "The interpreter that should be used for the analysis of symbolic expression trees.")); 174 166 Parameters.Add(new ValueLookupParameter<DataAnalysisProblemData>(ProblemDataParameterName, "The problem data for which the symbolic expression tree is a solution.")); … … 178 170 Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperEstimationLimitParameterName, "The upper estimation limit that was set for the evaluation of the symbolic expression trees.")); 179 171 Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerEstimationLimitParameterName, "The lower estimation limit that was set for the evaluation of the symbolic expression trees.")); 180 Parameters.Add(new LookupParameter<PercentValue>("RelativeValidationQuality")); 181 //Parameters.Add(new ValueLookupParameter<PercentValue>("RelativeValidationQualityUpperLimit", new PercentValue(0.05))); 182 //Parameters.Add(new ValueLookupParameter<PercentValue>("RelativeValidationQualityLowerLimit", new PercentValue(-0.05))); 183 Parameters.Add(new LookupParameter<DoubleValue>("TrainingValidationCorrelation")); 184 Parameters.Add(new ValueLookupParameter<DoubleValue>("LowerCorrelationLimit", new DoubleValue(0.65))); 185 Parameters.Add(new ValueLookupParameter<DoubleValue>("UpperCorrelationLimit", new DoubleValue(0.75))); 186 Parameters.Add(new LookupParameter<BoolValue>("Overfitting")); 187 Parameters.Add(new LookupParameter<ResultCollection>("Results")); 188 Parameters.Add(new LookupParameter<DoubleValue>("InitialTrainingQuality")); 189 Parameters.Add(new LookupParameter<ItemList<DoubleMatrix>>("TrainingAndValidationQualities")); 190 Parameters.Add(new ValueLookupParameter<DoubleValue>("Percentile", new DoubleValue(1))); 191 192 } 193 194 [StorableConstructor] 195 private OverfittingAnalyzer(bool deserializing) : base(deserializing) { } 172 173 Parameters.Add(new LookupParameter<DoubleValue>(TrainingValidationCorrelationParameterName, "Correlation of training and validation fitnesses")); 174 Parameters.Add(new LookupParameter<DataTable>(TrainingValidationCorrelationTableParameterName, "Data table of training and validation fitness correlation values over the whole run.")); 175 Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerCorrelationThresholdParameterName, "Lower threshold for correlation value that marks the boundary from non-overfitting to overfitting.", new DoubleValue(0.65))); 176 Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperCorrelationThresholdParameterName, "Upper threshold for correlation value that marks the boundary from overfitting to non-overfitting.", new DoubleValue(0.75))); 177 Parameters.Add(new LookupParameter<BoolValue>(OverfittingParameterName, "Boolean indicator for overfitting.")); 178 Parameters.Add(new LookupParameter<ResultCollection>(ResultsParameterName, "The results collection.")); 179 } 196 180 197 181 [StorableHook(HookType.AfterDeserialization)] 198 182 private void AfterDeserialization() { 199 if (!Parameters.ContainsKey("InitialTrainingQuality")) { 200 Parameters.Add(new LookupParameter<DoubleValue>("InitialTrainingQuality")); 183 } 184 185 public override IDeepCloneable Clone(Cloner cloner) { 186 return new SymbolicRegressionOverfittingAnalyzer(this, cloner); 187 } 188 189 public override IOperation Apply() { 190 ItemArray<DoubleValue> qualities = QualityParameter.ActualValue; 191 double[] trainingArr = qualities.Select(x => x.Value).ToArray(); 192 double[] validationArr = new double[trainingArr.Length]; 193 194 #region calculate validation fitness 195 string targetVariable = ProblemData.TargetVariable.Value; 196 197 // select a random subset of rows in the validation set 198 int validationStart = ValidiationSamplesStart.Value; 199 int validationEnd = ValidationSamplesEnd.Value; 200 int seed = Random.Next(); 201 int count = (int)((validationEnd - validationStart) * RelativeNumberOfEvaluatedSamples.Value); 202 if (count == 0) count = 1; 203 IEnumerable<int> rows = RandomEnumerable.SampleRandomNumbers(seed, validationStart, validationEnd, count) 204 .Where(row => row < ProblemData.TestSamplesStart.Value || ProblemData.TestSamplesEnd.Value <= row); 205 206 double upperEstimationLimit = UpperEstimationLimit != null ? UpperEstimationLimit.Value : double.PositiveInfinity; 207 double lowerEstimationLimit = LowerEstimationLimit != null ? LowerEstimationLimit.Value : double.NegativeInfinity; 208 209 var trees = SymbolicExpressionTreeParameter.ActualValue; 210 211 for (int i = 0; i < validationArr.Length; i++) { 212 var tree = trees[i]; 213 double quality = Evaluator.Evaluate(SymbolicExpressionTreeInterpreter, tree, 214 lowerEstimationLimit, upperEstimationLimit, 215 ProblemData.Dataset, targetVariable, 216 rows); 217 validationArr[i] = quality; 201 218 } 202 //if (!Parameters.ContainsKey("RelativeValidationQualityUpperLimit")) { 203 // Parameters.Add(new ValueLookupParameter<PercentValue>("RelativeValidationQualityUpperLimit", new PercentValue(0.05))); 204 //} 205 //if (!Parameters.ContainsKey("RelativeValidationQualityLowerLimit")) { 206 // Parameters.Add(new ValueLookupParameter<PercentValue>("RelativeValidationQualityLowerLimit", new PercentValue(-0.05))); 207 //} 208 if (!Parameters.ContainsKey("TrainingAndValidationQualities")) { 209 Parameters.Add(new LookupParameter<ItemList<DoubleMatrix>>("TrainingAndValidationQualities")); 219 220 #endregion 221 222 223 double r = alglib.spearmancorr2(trainingArr, validationArr); 224 225 TrainingValidationQualityCorrelationParameter.ActualValue = new DoubleValue(r); 226 227 if (TrainingValidationQualityCorrelationTableParameter.ActualValue == null) { 228 var dataTable = new DataTable("Training and validation fitness correlation table", "Data table of training and validation fitness correlation values over the whole run."); 229 dataTable.Rows.Add(new DataRow("Training and validation fitness correlation", "Training and validation fitness correlation values")); 230 TrainingValidationQualityCorrelationTableParameter.ActualValue = dataTable; 231 ResultsParameter.ActualValue.Add(new Result(TrainingValidationCorrelationTableParameterName, dataTable)); 210 232 } 211 if (!Parameters.ContainsKey("Percentile")) { 212 Parameters.Add(new ValueLookupParameter<DoubleValue>("Percentile", new DoubleValue(1))); 233 234 TrainingValidationQualityCorrelationTableParameter.ActualValue.Rows["Training and validation fitness correlation"].Values.Add(r); 235 236 double correlationThreshold; 237 if (OverfittingParameter.ActualValue != null && OverfittingParameter.ActualValue.Value) { 238 // if is already overfitting => have to reach the upper threshold to switch back to non-overfitting state 239 correlationThreshold = UpperCorrelationThresholdParameter.ActualValue.Value; 240 } else { 241 // if currently in non-overfitting state => have to reach to lower threshold to switch to overfitting state 242 correlationThreshold = LowerCorrelationThresholdParameter.ActualValue.Value; 213 243 } 214 if (!Parameters.ContainsKey("ValidationQuality")) { 215 Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>("ValidationQuality")); 216 } 217 if (!Parameters.ContainsKey("LowerCorrelationLimit")) { 218 Parameters.Add(new ValueLookupParameter<DoubleValue>("LowerCorrelationLimit", new DoubleValue(0.65))); 219 } 220 if (!Parameters.ContainsKey("UpperCorrelationLimit")) { 221 Parameters.Add(new ValueLookupParameter<DoubleValue>("UpperCorrelationLimit", new DoubleValue(0.75))); 222 } 223 224 } 225 226 public override IOperation Apply() { 227 var trees = SymbolicExpressionTree; 228 ItemArray<DoubleValue> qualities = QualityParameter.ActualValue; 229 ItemArray<DoubleValue> validationQualities = ValidationQualityParameter.ActualValue; 230 231 double correlationLimit; 232 if (OverfittingParameter.ActualValue != null && OverfittingParameter.ActualValue.Value) { 233 // if is already overfitting have to reach the upper limit to switch back to non-overfitting state 234 correlationLimit = UpperCorrelationLimitParameter.ActualValue.Value; 235 } else { 236 // if currently in non-overfitting state have to reach to lower limit to switch to overfitting state 237 correlationLimit = LowerCorrelationLimitParameter.ActualValue.Value; 238 } 239 //string targetVariable = ProblemData.TargetVariable.Value; 240 241 //// select a random subset of rows in the validation set 242 //int validationStart = ValidiationSamplesStart.Value; 243 //int validationEnd = ValidationSamplesEnd.Value; 244 //int seed = Random.Next(); 245 //int count = (int)((validationEnd - validationStart) * RelativeNumberOfEvaluatedSamples.Value); 246 //if (count == 0) count = 1; 247 //IEnumerable<int> rows = RandomEnumerable.SampleRandomNumbers(seed, validationStart, validationEnd, count); 248 249 //double upperEstimationLimit = UpperEstimationLimit != null ? UpperEstimationLimit.Value : double.PositiveInfinity; 250 //double lowerEstimationLimit = LowerEstimationLimit != null ? LowerEstimationLimit.Value : double.NegativeInfinity; 251 252 //double bestQuality = Maximization.Value ? double.NegativeInfinity : double.PositiveInfinity; 253 //SymbolicExpressionTree bestTree = null; 254 255 //List<double> validationQualities = new List<double>(); 256 //foreach (var tree in trees) { 257 // double quality = Evaluator.Evaluate(SymbolicExpressionTreeInterpreter, tree, 258 // lowerEstimationLimit, upperEstimationLimit, 259 // ProblemData.Dataset, targetVariable, 260 // rows); 261 // validationQualities.Add(quality); 262 // //if ((Maximization.Value && quality > bestQuality) || 263 // // (!Maximization.Value && quality < bestQuality)) { 264 // // bestQuality = quality; 265 // // bestTree = tree; 266 // //} 267 //} 268 269 //if (RelativeValidationQualityParameter.ActualValue == null) { 270 // first call initialize the relative quality using the difference between average training and validation quality 271 double avgTrainingQuality = qualities.Select(x => x.Value).Average(); 272 double avgValidationQuality = validationQualities.Select(x => x.Value).Average(); 273 274 if (Maximization.Value) 275 RelativeValidationQualityParameter.ActualValue = new PercentValue(avgValidationQuality / avgTrainingQuality - 1); 276 else { 277 RelativeValidationQualityParameter.ActualValue = new PercentValue(avgTrainingQuality / avgValidationQuality - 1); 278 } 279 //} 280 281 // best first (only for maximization 282 var orderedDistinctPairs = (from index in Enumerable.Range(0, qualities.Length) 283 where qualities[index].Value > 0.0 284 select new { Training = qualities[index].Value, Validation = validationQualities[index].Value }) 285 .OrderBy(x => -x.Training) 286 .ToList(); 287 288 int n = (int)Math.Round(PercentileParameter.ActualValue.Value * orderedDistinctPairs.Count); 289 290 double[] validationArr = new double[n]; 291 double[] trainingArr = new double[n]; 292 double[,] qualitiesArr = new double[n, 2]; 293 for (int i = 0; i < n; i++) { 294 validationArr[i] = orderedDistinctPairs[i].Validation; 295 trainingArr[i] = orderedDistinctPairs[i].Training; 296 297 qualitiesArr[i, 0] = trainingArr[i]; 298 qualitiesArr[i, 1] = validationArr[i]; 299 } 300 double r = alglib.correlation.spearmanrankcorrelation(trainingArr, validationArr, n); 301 TrainingValidationQualityCorrelationParameter.ActualValue = new DoubleValue(r); 302 if (InitialTrainingQualityParameter.ActualValue == null) 303 InitialTrainingQualityParameter.ActualValue = new DoubleValue(avgValidationQuality); 304 bool overfitting = 305 avgTrainingQuality > InitialTrainingQualityParameter.ActualValue.Value && // better on training than in initial generation 306 // RelativeValidationQualityParameter.ActualValue.Value < 0.0 && // validation quality is worse than training quality 307 r < correlationLimit; 308 244 bool overfitting = r < correlationThreshold; 309 245 310 246 OverfittingParameter.ActualValue = new BoolValue(overfitting); 311 ItemList<DoubleMatrix> list = TrainingAndValidationQualitiesParameter.ActualValue; 312 if (list == null) { 313 TrainingAndValidationQualitiesParameter.ActualValue = new ItemList<DoubleMatrix>(); 314 } 315 TrainingAndValidationQualitiesParameter.ActualValue.Add(new DoubleMatrix(qualitiesArr)); 247 316 248 return base.Apply(); 317 }318 319 [StorableHook(HookType.AfterDeserialization)]320 private void Initialize() { }321 322 private static void AddValue(DataTable table, double data, string name, string description) {323 DataRow row;324 table.Rows.TryGetValue(name, out row);325 if (row == null) {326 row = new DataRow(name, description);327 row.Values.Add(data);328 table.Rows.Add(row);329 } else {330 row.Values.Add(data);331 }332 249 } 333 250 }
Note: See TracChangeset
for help on using the changeset viewer.