1  #region License Information


2  /* HeuristicLab


3  * Copyright (C) 20022015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)


4  * and the BEACON Center for the Study of Evolution in Action.


5  *


6  * This file is part of HeuristicLab.


7  *


8  * HeuristicLab is free software: you can redistribute it and/or modify


9  * it under the terms of the GNU General Public License as published by


10  * the Free Software Foundation, either version 3 of the License, or


11  * (at your option) any later version.


12  *


13  * HeuristicLab is distributed in the hope that it will be useful,


14  * but WITHOUT ANY WARRANTY; without even the implied warranty of


15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the


16  * GNU General Public License for more details.


17  *


18  * You should have received a copy of the GNU General Public License


19  * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.


20  */


21  #endregion


22 


23  using System;


24  using System.Collections.Generic;


25  using System.Diagnostics.Contracts;


26  using System.Linq;


27  using HeuristicLab.Problems.DataAnalysis;


28  using HeuristicLab.Random;


29 


30  namespace HeuristicLab.Algorithms.DataAnalysis {


31  public static class GradientBoostedTreesAlgorithmStatic {


32  #region static API


33 


34  public interface IGbmState {


35  IRegressionModel GetModel();


36  double GetTrainLoss();


37  double GetTestLoss();


38  IEnumerable<KeyValuePair<string, double>> GetVariableRelevance();


39  }


40 


41  // created through factory method


42  // GbmState details are private API users can only use methods from IGbmState


43  private class GbmState : IGbmState {


44  internal IRegressionProblemData problemData { get; private set; }


45  internal ILossFunction lossFunction { get; private set; }


46  internal int maxSize { get; private set; }


47  internal double nu { get; private set; }


48  internal double r { get; private set; }


49  internal double m { get; private set; }


50  internal int[] trainingRows { get; private set; }


51  internal int[] testRows { get; private set; }


52  internal RegressionTreeBuilder treeBuilder { get; private set; }


53 


54  private readonly uint randSeed;


55  private MersenneTwister random { get; set; }


56 


57  // array members (allocate only once)


58  internal double[] pred;


59  internal double[] predTest;


60  internal double[] y;


61  internal int[] activeIdx;


62  internal double[] pseudoRes;


63 


64  private readonly IList<IRegressionModel> models;


65  private readonly IList<double> weights;


66 


67  public GbmState(IRegressionProblemData problemData, ILossFunction lossFunction, uint randSeed, int maxSize, double r, double m, double nu) {


68  // default settings for MaxSize, Nu and R


69  this.maxSize = maxSize;


70  this.nu = nu;


71  this.r = r;


72  this.m = m;


73 


74  this.randSeed = randSeed;


75  random = new MersenneTwister(randSeed);


76  this.problemData = problemData;


77  this.trainingRows = problemData.TrainingIndices.ToArray();


78  this.testRows = problemData.TestIndices.ToArray();


79  this.lossFunction = lossFunction;


80 


81  int nRows = trainingRows.Length;


82 


83  y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, trainingRows).ToArray();


84 


85  treeBuilder = new RegressionTreeBuilder(problemData, random);


86 


87  activeIdx = Enumerable.Range(0, nRows).ToArray();


88 


89  var zeros = Enumerable.Repeat(0.0, nRows).ToArray();


90  double f0 = lossFunction.LineSearch(y, zeros, activeIdx, 0, nRows  1); // initial constant value (mean for squared errors)


91  pred = Enumerable.Repeat(f0, nRows).ToArray();


92  predTest = Enumerable.Repeat(f0, testRows.Length).ToArray();


93  pseudoRes = new double[nRows];


94 


95  models = new List<IRegressionModel>();


96  weights = new List<double>();


97  // add constant model


98  models.Add(new ConstantModel(f0));


99  weights.Add(1.0);


100  }


101 


102  public IRegressionModel GetModel() {


103  #pragma warning disable 618


104  var model = new GradientBoostedTreesModel(models, weights);


105  #pragma warning restore 618


106  // we don't know the number of iterations here but the number of weights is equal


107  // to the number of iterations + 1 (for the constant model)


108  // wrap the actual model in a surrogate that enables persistence and lazy recalculation of the model if necessary


109  return new GradientBoostedTreesModelSurrogate(problemData, randSeed, lossFunction, weights.Count  1, maxSize, r, m, nu, model);


110  }


111  public IEnumerable<KeyValuePair<string, double>> GetVariableRelevance() {


112  return treeBuilder.GetVariableRelevance();


113  }


114 


115  public double GetTrainLoss() {


116  int nRows = y.Length;


117  return lossFunction.GetLoss(y, pred) / nRows;


118  }


119  public double GetTestLoss() {


120  var yTest = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, testRows);


121  var nRows = testRows.Length;


122  return lossFunction.GetLoss(yTest, predTest) / nRows;


123  }


124 


125  internal void AddModel(IRegressionModel m, double weight) {


126  models.Add(m);


127  weights.Add(weight);


128  }


129  }


130 


131  // simple interface


132  public static GradientBoostedTreesSolution TrainGbm(IRegressionProblemData problemData, ILossFunction lossFunction, int maxSize, double nu, double r, double m, int maxIterations, uint randSeed = 31415) {


133  Contract.Assert(r > 0);


134  Contract.Assert(r <= 1.0);


135  Contract.Assert(nu > 0);


136  Contract.Assert(nu <= 1.0);


137 


138  var state = (GbmState)CreateGbmState(problemData, lossFunction, randSeed, maxSize, r, m, nu);


139 


140  for (int iter = 0; iter < maxIterations; iter++) {


141  MakeStep(state);


142  }


143 


144  var model = state.GetModel();


145  return new GradientBoostedTreesSolution(model, (IRegressionProblemData)problemData.Clone());


146  }


147 


148  // for custom stepping & termination


149  public static IGbmState CreateGbmState(IRegressionProblemData problemData, ILossFunction lossFunction, uint randSeed, int maxSize = 3, double r = 0.66, double m = 0.5, double nu = 0.01) {


150  return new GbmState(problemData, lossFunction, randSeed, maxSize, r, m, nu);


151  }


152 


153  // use default settings for maxSize, nu, r from state


154  public static void MakeStep(IGbmState state) {


155  var gbmState = state as GbmState;


156  if (gbmState == null) throw new ArgumentException("state");


157 


158  MakeStep(gbmState, gbmState.maxSize, gbmState.nu, gbmState.r, gbmState.m);


159  }


160 


161  // allow dynamic adaptation of maxSize, nu and r (even though this is not used)


162  public static void MakeStep(IGbmState state, int maxSize, double nu, double r, double m) {


163  var gbmState = state as GbmState;


164  if (gbmState == null) throw new ArgumentException("state");


165 


166  var problemData = gbmState.problemData;


167  var lossFunction = gbmState.lossFunction;


168  var yPred = gbmState.pred;


169  var yPredTest = gbmState.predTest;


170  var treeBuilder = gbmState.treeBuilder;


171  var y = gbmState.y;


172  var activeIdx = gbmState.activeIdx;


173  var pseudoRes = gbmState.pseudoRes;


174  var trainingRows = gbmState.trainingRows;


175  var testRows = gbmState.testRows;


176 


177  // copy output of gradient function to preallocated rim array (pseudoresidual per row and model)


178  int rimIdx = 0;


179  foreach (var g in lossFunction.GetLossGradient(y, yPred)) {


180  pseudoRes[rimIdx++] = g;


181  }


182 


183  var tree = treeBuilder.CreateRegressionTreeForGradientBoosting(pseudoRes, yPred, maxSize, activeIdx, lossFunction, r, m);


184 


185  int i = 0;


186  foreach (var pred in tree.GetEstimatedValues(problemData.Dataset, trainingRows)) {


187  yPred[i] = yPred[i] + nu * pred;


188  i++;


189  }


190  // update predictions for validation set


191  i = 0;


192  foreach (var pred in tree.GetEstimatedValues(problemData.Dataset, testRows)) {


193  yPredTest[i] = yPredTest[i] + nu * pred;


194  i++;


195  }


196 


197  gbmState.AddModel(tree, nu);


198  }


199  #endregion


200  }


201  }

