Changeset 13948 for branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis/3.4/GBM/GradientBoostingRegressionAlgorithm.cs
- Timestamp:
- 06/29/16 10:36:52 (8 years ago)
- Location:
- branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis
- Files:
-
- 2 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis
- Property svn:mergeinfo changed
/trunk/sources/HeuristicLab.Algorithms.DataAnalysis (added) merged: 13889,13891,13895,13898,13917,13921-13922,13941
- Property svn:mergeinfo changed
-
branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis/3.4/GBM/GradientBoostingRegressionAlgorithm.cs
r13724 r13948 64 64 private const string SetSeedRandomlyParameterName = "SetSeedRandomly"; 65 65 private const string CreateSolutionParameterName = "CreateSolution"; 66 private const string StoreRunsParameterName = "StoreRuns"; 66 67 private const string RegressionAlgorithmSolutionResultParameterName = "RegressionAlgorithmResult"; 67 68 … … 106 107 get { return (IFixedValueParameter<BoolValue>)Parameters[CreateSolutionParameterName]; } 107 108 } 109 public IFixedValueParameter<BoolValue> StoreRunsParameter { 110 get { return (IFixedValueParameter<BoolValue>)Parameters[StoreRunsParameterName]; } 111 } 108 112 109 113 #endregion … … 144 148 get { return CreateSolutionParameter.Value.Value; } 145 149 set { CreateSolutionParameter.Value.Value = value; } 150 } 151 152 public bool StoreRuns { 153 get { return StoreRunsParameter.Value.Value; } 154 set { StoreRunsParameter.Value.Value = value; } 146 155 } 147 156 … … 178 187 var regressionAlgs = new ItemSet<IAlgorithm>(new IAlgorithm[] { 179 188 new RandomForestRegression(), 180 sgp, 189 sgp, 181 190 mctsSymbReg 182 191 }); … … 206 215 "Flag that indicates if a solution should be produced at the end of the run", new BoolValue(true))); 207 216 Parameters[CreateSolutionParameterName].Hidden = true; 217 Parameters.Add(new FixedValueParameter<BoolValue>(StoreRunsParameterName, 218 "Flag that indicates if the results of the individual runs should be stored for detailed analysis", new BoolValue(false))); 219 Parameters[StoreRunsParameterName].Hidden = true; 208 220 } 209 221 … … 218 230 219 231 var table = new DataTable("Qualities"); 220 table.Rows.Add(new DataRow(" Loss(train)"));221 table.Rows.Add(new DataRow(" Loss(test)"));232 table.Rows.Add(new DataRow("R² (train)")); 233 table.Rows.Add(new DataRow("R² (test)")); 222 234 Results.Add(new Result("Qualities", table)); 223 235 var curLoss = new DoubleValue(); 224 236 var curTestLoss = new DoubleValue(); 225 Results.Add(new Result(" Loss(train)", curLoss));226 Results.Add(new Result(" Loss(test)", curTestLoss));237 Results.Add(new Result("R² (train)", curLoss)); 238 Results.Add(new Result("R² (test)", curTestLoss)); 227 239 var runCollection = new RunCollection(); 228 Results.Add(new Result("Runs", runCollection)); 240 if (StoreRuns) 241 Results.Add(new Result("Runs", runCollection)); 229 242 230 243 // init 231 244 var problemData = Problem.ProblemData; 232 var targetVarName = Problem.ProblemData.TargetVariable;245 var targetVarName = problemData.TargetVariable; 233 246 var activeVariables = problemData.AllowedInputVariables.Concat(new string[] { problemData.TargetVariable }); 234 247 var modifiableDataset = new ModifiableDataset( … … 273 286 IRegressionModel model; 274 287 IRun run; 288 275 289 // try to find a model. The algorithm might fail to produce a model. In this case we just retry until the iterations are exhausted 276 if (TryExecute(alg, RegressionAlgorithmResult, out model, out run)) {290 if (TryExecute(alg, rand.Next(), RegressionAlgorithmResult, out model, out run)) { 277 291 int row = 0; 278 292 // update predictions for training and test … … 303 317 } 304 318 305 runCollection.Add(run); 306 table.Rows["Loss (train)"].Values.Add(curLoss.Value); 307 table.Rows["Loss (test)"].Values.Add(curTestLoss.Value); 319 if (StoreRuns) 320 runCollection.Add(run); 321 table.Rows["R² (train)"].Values.Add(curLoss.Value); 322 table.Rows["R² (test)"].Values.Add(curTestLoss.Value); 308 323 iterations.Value = i + 1; 309 324 } … … 317 332 // just produce an ensemble solution for now (TODO: correct scaling or linear regression for ensemble model weights) 318 333 319 var ensembleModel = new RegressionEnsembleModel(models) { AverageModelEstimates = false }; 320 var ensembleSolution = ensembleModel.CreateRegressionSolution((IRegressionProblemData)problemData.Clone()); 334 var ensembleSolution = CreateEnsembleSolution(models, (IRegressionProblemData)problemData.Clone()); 321 335 Results.Add(new Result("EnsembleSolution", ensembleSolution)); 322 336 } … … 326 340 alg.Prepare(true); 327 341 } 342 } 343 344 private static IRegressionEnsembleSolution CreateEnsembleSolution(List<IRegressionModel> models, 345 IRegressionProblemData problemData) { 346 var rows = problemData.TrainingPartition.Size; 347 var features = models.Count; 348 double[,] inputMatrix = new double[rows, features + 1]; 349 //add model estimates 350 for (int m = 0; m < models.Count; m++) { 351 var model = models[m]; 352 var estimates = model.GetEstimatedValues(problemData.Dataset, problemData.TrainingIndices); 353 int estimatesCounter = 0; 354 foreach (var estimate in estimates) { 355 inputMatrix[estimatesCounter, m] = estimate; 356 estimatesCounter++; 357 } 358 } 359 360 //add target 361 var targets = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices); 362 int targetCounter = 0; 363 foreach (var target in targets) { 364 inputMatrix[targetCounter, models.Count] = target; 365 targetCounter++; 366 } 367 368 alglib.linearmodel lm = new alglib.linearmodel(); 369 alglib.lrreport ar = new alglib.lrreport(); 370 double[] coefficients; 371 int retVal = 1; 372 alglib.lrbuildz(inputMatrix, rows, features, out retVal, out lm, out ar); 373 if (retVal != 1) throw new ArgumentException("Error in calculation of linear regression solution"); 374 375 alglib.lrunpack(lm, out coefficients, out features); 376 377 var ensembleModel = new RegressionEnsembleModel(models, coefficients.Take(models.Count)) { AverageModelEstimates = false }; 378 var ensembleSolution = (IRegressionEnsembleSolution)ensembleModel.CreateRegressionSolution(problemData); 379 return ensembleSolution; 328 380 } 329 381 … … 391 443 progRootNode.AddSubtree(startNode); 392 444 var t = new SymbolicExpressionTree(progRootNode); 393 var combinedModel = new SymbolicRegressionModel( t, interpreter, lowerLimit, upperLimit);445 var combinedModel = new SymbolicRegressionModel(problemData.TargetVariable, t, interpreter, lowerLimit, upperLimit); 394 446 var sol = new SymbolicRegressionSolution(combinedModel, problemData); 395 447 return sol; … … 405 457 } 406 458 407 private static bool TryExecute(IAlgorithm alg, string regressionAlgorithmResultName, out IRegressionModel model, out IRun run) {459 private static bool TryExecute(IAlgorithm alg, int seed, string regressionAlgorithmResultName, out IRegressionModel model, out IRun run) { 408 460 model = null; 461 SetSeed(alg, seed); 409 462 using (var wh = new AutoResetEvent(false)) { 410 EventHandler<EventArgs<Exception>> handler = (sender, args) => wh.Set(); 463 Exception ex = null; 464 EventHandler<EventArgs<Exception>> handler = (sender, args) => { 465 ex = args.Value; 466 wh.Set(); 467 }; 411 468 EventHandler handler2 = (sender, args) => wh.Set(); 412 469 alg.ExceptionOccurred += handler; … … 417 474 wh.WaitOne(); 418 475 476 if (ex != null) throw new AggregateException(ex); 419 477 run = alg.Runs.Last(); 478 alg.Runs.Clear(); 420 479 var sols = alg.Results.Select(r => r.Value).OfType<IRegressionSolution>(); 421 480 if (!sols.Any()) return false; … … 444 503 return model != null; 445 504 } 505 506 private static void SetSeed(IAlgorithm alg, int seed) { 507 // no common interface for algs that use a PRNG -> use naming convention to set seed 508 var paramItem = alg as IParameterizedItem; 509 510 if (paramItem.Parameters.ContainsKey("SetSeedRandomly")) { 511 ((BoolValue)paramItem.Parameters["SetSeedRandomly"].ActualValue).Value = false; 512 ((IntValue)paramItem.Parameters["Seed"].ActualValue).Value = seed; 513 } else { 514 throw new ArgumentException("Base learner does not have a seed parameter (algorithm {0})", alg.Name); 515 } 516 517 } 446 518 } 447 519 }
Note: See TracChangeset
for help on using the changeset viewer.