Changeset 13977 for stable/HeuristicLab.Algorithms.DataAnalysis/3.4/GBM/GradientBoostingRegressionAlgorithm.cs
 Timestamp:
 07/02/16 09:08:09 (5 years ago)
 Location:
 stable
 Files:

 3 edited
 1 copied
Legend:
 Unmodified
 Added
 Removed

stable
 Property svn:mergeinfo changed
/trunk/sources merged: 13646,13653,13655,13699,13703,13707,13889,13898,13917
 Property svn:mergeinfo changed

stable/HeuristicLab.Algorithms.DataAnalysis
 Property svn:mergeinfo changed
/trunk/sources/HeuristicLab.Algorithms.DataAnalysis merged: 13646,13653,13655,13699,13703,13707,13889,13898,13917
 Property svn:mergeinfo changed

stable/HeuristicLab.Algorithms.DataAnalysis/3.4/GBM/GradientBoostingRegressionAlgorithm.cs
r13646 r13977 37 37 using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression; 38 38 using HeuristicLab.Random; 39 using HeuristicLab.Selection; 39 40 40 41 namespace HeuristicLab.Algorithms.DataAnalysis.MctsSymbolicRegression { … … 63 64 private const string SetSeedRandomlyParameterName = "SetSeedRandomly"; 64 65 private const string CreateSolutionParameterName = "CreateSolution"; 66 private const string StoreRunsParameterName = "StoreRuns"; 65 67 private const string RegressionAlgorithmSolutionResultParameterName = "RegressionAlgorithmResult"; 66 68 … … 105 107 get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; } 106 108 } 109 public IFixedValueParameter<BoolValue> StoreRunsParameter { 110 get { return (IFixedValueParameter<BoolValue>)Parameters[StoreRunsParameterName]; } 111 } 107 112 108 113 #endregion … … 145 150 } 146 151 152 public bool StoreRuns { 153 get { return StoreRunsParameter.Value.Value; } 154 set { StoreRunsParameter.Value.Value = value; } 155 } 156 147 157 public IAlgorithm RegressionAlgorithm { 148 158 get { return RegressionAlgorithmParameter.Value; } … … 172 182 Problem = new RegressionProblem(); // default problem 173 183 var mctsSymbReg = new MctsSymbolicRegressionAlgorithm(); 174 // var sgp = CreateSGP(); 184 mctsSymbReg.Iterations = 10000; 185 mctsSymbReg.StoreAlgorithmInEachRun = false; 186 var sgp = CreateOSGP(); 175 187 var regressionAlgs = new ItemSet<IAlgorithm>(new IAlgorithm[] { 176 new LinearRegression(), new RandomForestRegression(), new NearestNeighbourRegression(),177 // sgp,188 new RandomForestRegression(), 189 sgp, 178 190 mctsSymbReg 179 191 }); … … 203 215 "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true))); 204 216 Parameters[CreateSolutionParameterName].Hidden = true; 217 Parameters.Add(new FixedValueParameter<BoolValue>(StoreRunsParameterName, 218 "Flag that indicates if the results of the individual runs should be stored for detailed analysis", new BoolValue(false))); 219 Parameters[StoreRunsParameterName].Hidden = true; 205 220 } 206 221 … … 215 230 216 231 var table = new DataTable("Qualities"); 217 table.Rows.Add(new DataRow(" Loss(train)"));218 table.Rows.Add(new DataRow(" Loss(test)"));232 table.Rows.Add(new DataRow("R² (train)")); 233 table.Rows.Add(new DataRow("R² (test)")); 219 234 Results.Add(new Result("Qualities", table)); 220 235 var curLoss = new DoubleValue(); 221 236 var curTestLoss = new DoubleValue(); 222 Results.Add(new Result(" Loss(train)", curLoss));223 Results.Add(new Result(" Loss(test)", curTestLoss));237 Results.Add(new Result("R² (train)", curLoss)); 238 Results.Add(new Result("R² (test)", curTestLoss)); 224 239 var runCollection = new RunCollection(); 225 Results.Add(new Result("Runs", runCollection)); 240 if (StoreRuns) 241 Results.Add(new Result("Runs", runCollection)); 226 242 227 243 // init 228 244 var problemData = Problem.ProblemData; 229 var targetVarName = Problem.ProblemData.TargetVariable; 245 var targetVarName = problemData.TargetVariable; 246 var activeVariables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable }); 230 247 var modifiableDataset = new ModifiableDataset( 231 problemData.Dataset.VariableNames,232 problemData.Dataset.VariableNames.Select(v => problemData.Dataset.GetDoubleValues(v).ToList()));248 activeVariables, 249 activeVariables.Select(v => problemData.Dataset.GetDoubleValues(v).ToList())); 233 250 234 251 var trainingRows = problemData.TrainingIndices; … … 269 286 IRegressionModel model; 270 287 IRun run; 288 271 289 // try to find a model. The algorithm might fail to produce a model. In this case we just retry until the iterations are exhausted 272 if (TryExecute(alg, RegressionAlgorithmResult, out model, out run)) {290 if (TryExecute(alg, rand.Next(), RegressionAlgorithmResult, out model, out run)) { 273 291 int row = 0; 274 292 // update predictions for training and test … … 299 317 } 300 318 301 runCollection.Add(run); 302 table.Rows["Loss (train)"].Values.Add(curLoss.Value); 303 table.Rows["Loss (test)"].Values.Add(curTestLoss.Value); 319 if (StoreRuns) 320 runCollection.Add(run); 321 table.Rows["R² (train)"].Values.Add(curLoss.Value); 322 table.Rows["R² (test)"].Values.Add(curTestLoss.Value); 304 323 iterations.Value = i + 1; 305 324 } … … 312 331 } 313 332 // just produce an ensemble solution for now (TODO: correct scaling or linear regression for ensemble model weights) 314 Results.Add(new Result("EnsembleSolution", new RegressionEnsembleSolution(models, (IRegressionProblemData)problemData.Clone()))); 333 334 var ensembleSolution = CreateEnsembleSolution(models, (IRegressionProblemData)problemData.Clone()); 335 Results.Add(new Result("EnsembleSolution", ensembleSolution)); 315 336 } 316 } finally { 337 } 338 finally { 317 339 // reset everything 318 340 alg.Prepare(true); … … 320 342 } 321 343 322 // this is probably slow as hell 344 private static IRegressionEnsembleSolution CreateEnsembleSolution(List<IRegressionModel> models, 345 IRegressionProblemData problemData) { 346 var rows = problemData.TrainingPartition.Size; 347 var features = models.Count; 348 double[,] inputMatrix = new double[rows, features + 1]; 349 //add model estimates 350 for (int m = 0; m < models.Count; m++) { 351 var model = models[m]; 352 var estimates = model.GetEstimatedValues(problemData.Dataset, problemData.TrainingIndices); 353 int estimatesCounter = 0; 354 foreach (var estimate in estimates) { 355 inputMatrix[estimatesCounter, m] = estimate; 356 estimatesCounter++; 357 } 358 } 359 360 //add target 361 var targets = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices); 362 int targetCounter = 0; 363 foreach (var target in targets) { 364 inputMatrix[targetCounter, models.Count] = target; 365 targetCounter++; 366 } 367 368 alglib.linearmodel lm = new alglib.linearmodel(); 369 alglib.lrreport ar = new alglib.lrreport(); 370 double[] coefficients; 371 int retVal = 1; 372 alglib.lrbuildz(inputMatrix, rows, features, out retVal, out lm, out ar); 373 if (retVal != 1) throw new ArgumentException("Error in calculation of linear regression solution"); 374 375 alglib.lrunpack(lm, out coefficients, out features); 376 377 var ensembleModel = new RegressionEnsembleModel(models, coefficients.Take(models.Count)) { AverageModelEstimates = false }; 378 var ensembleSolution = ensembleModel.CreateRegressionSolution(problemData); 379 return ensembleSolution; 380 } 381 382 383 private IAlgorithm CreateOSGP() { 384 // configure strict osgp 385 var alg = new OffspringSelectionGeneticAlgorithm.OffspringSelectionGeneticAlgorithm(); 386 var prob = new SymbolicRegressionSingleObjectiveProblem(); 387 prob.MaximumSymbolicExpressionTreeDepth.Value = 7; 388 prob.MaximumSymbolicExpressionTreeLength.Value = 15; 389 alg.Problem = prob; 390 alg.SuccessRatio.Value = 1.0; 391 alg.ComparisonFactorLowerBound.Value = 1.0; 392 alg.ComparisonFactorUpperBound.Value = 1.0; 393 alg.MutationProbability.Value = 0.15; 394 alg.PopulationSize.Value = 200; 395 alg.MaximumSelectionPressure.Value = 100; 396 alg.MaximumEvaluatedSolutions.Value = 20000; 397 alg.SelectorParameter.Value = alg.SelectorParameter.ValidValues.OfType<GenderSpecificSelector>().First(); 398 alg.MutatorParameter.Value = alg.MutatorParameter.ValidValues.OfType<MultiSymbolicExpressionTreeManipulator>().First(); 399 alg.StoreAlgorithmInEachRun = false; 400 return alg; 401 } 402 323 403 private void SampleTrainingData(MersenneTwister rand, ModifiableDataset ds, int rRows, 324 404 IDataset sourceDs, double[] curTarget, string targetVarName, IEnumerable<int> trainingIndices) { … … 374 454 prob.ProblemDataParameter.Value = problemData; 375 455 return true; 376 } else if (alg.Problem != null) { 377 // a problem is set and it is not compatible 378 return false; 379 } else { 380 try { 381 // we try to use a symbolic regression problem (works for simple regression algs and GP) 382 alg.Problem = new SymbolicRegressionSingleObjectiveProblem(); 383 } catch (Exception) { 384 return false; 385 } 386 return true; 387 } 388 } 389 390 private static bool TryExecute(IAlgorithm alg, string regressionAlgorithmResultName, out IRegressionModel model, out IRun run) { 456 } else return false; 457 } 458 459 private static bool TryExecute(IAlgorithm alg, int seed, string regressionAlgorithmResultName, out IRegressionModel model, out IRun run) { 391 460 model = null; 461 SetSeed(alg, seed); 392 462 using (var wh = new AutoResetEvent(false)) { 393 EventHandler<EventArgs<Exception>> handler = (sender, args) => wh.Set(); 463 Exception ex = null; 464 EventHandler<EventArgs<Exception>> handler = (sender, args) => { 465 ex = args.Value; 466 wh.Set(); 467 }; 394 468 EventHandler handler2 = (sender, args) => wh.Set(); 395 469 alg.ExceptionOccurred += handler; … … 400 474 wh.WaitOne(); 401 475 476 if (ex != null) throw new AggregateException(ex); 402 477 run = alg.Runs.Last(); 478 alg.Runs.Clear(); 403 479 var sols = alg.Results.Select(r => r.Value).OfType<IRegressionSolution>(); 404 480 if (!sols.Any()) return false; … … 419 495 model = sol.Model; 420 496 } 421 } finally { 497 } 498 finally { 422 499 alg.ExceptionOccurred = handler; 423 500 alg.Stopped = handler2; … … 426 503 return model != null; 427 504 } 505 506 private static void SetSeed(IAlgorithm alg, int seed) { 507 // no common interface for algs that use a PRNG > use naming convention to set seed 508 var paramItem = alg as IParameterizedItem; 509 510 if (paramItem.Parameters.ContainsKey("SetSeedRandomly")) { 511 ((BoolValue)paramItem.Parameters["SetSeedRandomly"].ActualValue).Value = false; 512 ((IntValue)paramItem.Parameters["Seed"].ActualValue).Value = seed; 513 } else { 514 throw new ArgumentException("Base learner does not have a seed parameter (algorithm {0})", alg.Name); 515 } 516 517 } 428 518 } 429 519 }
Note: See TracChangeset
for help on using the changeset viewer.