1  using System;


2  using System.Collections.Generic;


3  using System.Diagnostics;


4  using System.Diagnostics.Contracts;


5  using System.Linq;


6  using HeuristicLab.Common;


7  using HeuristicLab.Core;


8  using HeuristicLab.Problems.DataAnalysis;


9 


10  namespace GradientBoostedTrees {


11  public class RegressionTreeBuilder {


12  private readonly IRandom random;


13  private readonly IRegressionProblemData problemData;


14 


15  private readonly int nCols;


16  private readonly double[][] x; // all training data (original order from problemData)


17  private double[] y; // training labels (original order from problemData)


18 


19  private Dictionary<string, double> sumImprovements; // for variable relevance calculation


20 


21  private readonly string[] allowedVariables; // all variables in shuffled order


22  private Dictionary<string, int> varName2Index; // maps the variable names to column indexes


23  private int effectiveVars; // number of variables that are used from allowedVariables


24 


25  private int effectiveRows; // number of rows that are used from


26  private readonly int[][] sortedIdxAll;


27  private readonly int[][] sortedIdx; // random selection from sortedIdxAll (for r < 1.0)


28 


29 


30 


31  // helper arrays which are allocated to maximal necessary size only once in the ctor


32  private readonly int[] internalIdx, which, leftTmp, rightTmp;


33  private readonly double[] outx;


34  private readonly int[] outSortedIdx;


35 


36  private RegressionTreeModel.TreeNode[] tree; // tree is represented as a flat array of nodes


37  private int curTreeNodeIdx; // the index where the next tree node is stored


38 


39  private readonly IList<RegressionTreeModel.TreeNode> nodeQueue; //TODO


40 


41  // prepare and allocate buffer variables in ctor


42  public RegressionTreeBuilder(IRegressionProblemData problemData, IRandom random) {


43  this.problemData = problemData;


44  this.random = random;


45 


46  var rows = problemData.TrainingIndices.Count();


47 


48  this.nCols = problemData.AllowedInputVariables.Count();


49 


50  allowedVariables = problemData.AllowedInputVariables.ToArray();


51  varName2Index = new Dictionary<string, int>(allowedVariables.Length);


52  for (int i = 0; i < allowedVariables.Length; i++) varName2Index.Add(allowedVariables[i], i);


53 


54  sortedIdxAll = new int[nCols][];


55  sortedIdx = new int[nCols][];


56  sumImprovements = new Dictionary<string, double>();


57  internalIdx = new int[rows];


58  which = new int[rows];


59  leftTmp = new int[rows];


60  rightTmp = new int[rows];


61  outx = new double[rows];


62  outSortedIdx = new int[rows];


63  nodeQueue = new List<RegressionTreeModel.TreeNode>();


64 


65  x = new double[nCols][];


66  y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToArray();


67 


68 


69  int col = 0;


70  foreach (var inputVariable in problemData.AllowedInputVariables) {


71  x[col] = problemData.Dataset.GetDoubleValues(inputVariable, problemData.TrainingIndices).ToArray();


72  sortedIdxAll[col] = Enumerable.Range(0, rows).OrderBy(r => x[col][r]).ToArray();


73  sortedIdx[col] = new int[rows];


74  col++;


75  }


76  }


77 


78  // r and m work in the same way as for alglib random forest


79  // r is fraction of rows to use for training


80  // m is fraction of variables to use for training


81  public IRegressionModel CreateRegressionTree(int maxDepth, double r = 0.5, double m = 0.5) {


82  // subtract mean of y first


83  var yAvg = y.Average();


84  for (int i = 0; i < y.Length; i++) y[i] = yAvg;


85 


86  var seLoss = new SquaredErrorLoss();


87  var zeros = Enumerable.Repeat(0.0, y.Length);


88  var ones = Enumerable.Repeat(1.0, y.Length);


89 


90  var model = CreateRegressionTreeForGradientBoosting(y, maxDepth, problemData.TrainingIndices.ToArray(), seLoss.GetLineSearchFunc(y, zeros, ones), r, m);


91  return new GradientBoostedTreesModel(new[] { new ConstantRegressionModel(yAvg), model }, new[] { 1.0, 1.0 });


92  }


93 


94  // specific interface that allows to specify the target labels and the training rows which is necessary when this functionality is called by the gradient boosting routine


95  public IRegressionModel CreateRegressionTreeForGradientBoosting(double[] y, int maxDepth, int[] idx, LineSearchFunc lineSearch, double r = 0.5, double m = 0.5) {


96  Contract.Assert(maxDepth > 0);


97  Contract.Assert(r > 0);


98  Contract.Assert(r <= 1.0);


99  Contract.Assert(y.Count() == this.y.Length);


100  Contract.Assert(m > 0);


101  Contract.Assert(m <= 1.0);


102 


103  this.y = y; // y is changed in gradient boosting


104 


105  // shuffle row idx


106  HeuristicLab.Random.ListExtensions.ShuffleInPlace(idx, random);


107 


108  int nRows = idx.Count();


109 


110  // shuffle variable idx


111  HeuristicLab.Random.ListExtensions.ShuffleInPlace(allowedVariables, random);


112 


113  effectiveRows = (int)Math.Ceiling(nRows * r);


114  effectiveVars = (int)Math.Ceiling(nCols * m);


115 


116  Array.Clear(which, 0, which.Length);


117 


118  // mark selected rows


119  for (int row = 0; row < effectiveRows; row++) {


120  which[idx[row]] = 1;


121  internalIdx[row] = idx[row];


122  }


123 


124  for (int col = 0; col < nCols; col++) {


125  int i = 0;


126  for (int row = 0; row < nRows; row++) {


127  if (which[sortedIdxAll[col][row]] > 0) {


128  Trace.Assert(i < effectiveRows);


129  sortedIdx[col][i] = sortedIdxAll[col][row];


130  i++;


131  }


132  }


133  }


134 


135  // prepare array for the tree nodes (a tree of maxDepth=1 has 1 node, a tree of maxDepth=d has 2^d  1 nodes)


136  int numNodes = (int)Math.Pow(2, maxDepth)  1;


137  //this.tree = new RegressionTreeModel.TreeNode[numNodes];


138  this.tree = Enumerable.Range(0, numNodes).Select(_=>new RegressionTreeModel.TreeNode()).ToArray();


139  this.curTreeNodeIdx = 0;


140 


141  // start and end idx are inclusive


142  CreateRegressionTreeForIdx(maxDepth, 0, effectiveRows  1, lineSearch);


143  return new RegressionTreeModel(tree);


144  }


145 


146  // startIdx and endIdx are inclusive


147  private void CreateRegressionTreeForIdx(int maxDepth, int startIdx, int endIdx, LineSearchFunc lineSearch) {


148  Contract.Assert(endIdx  startIdx >= 0);


149  Contract.Assert(startIdx >= 0);


150  Contract.Assert(endIdx < internalIdx.Length);


151 


152  // TODO: stop when y is constant


153  // TODO: use priority queue of nodes to be expanded (sorted by improvement) instead of the recursion to maximum depth


154  if (maxDepth <= 1  endIdx  startIdx == 0) {


155  // max depth reached or only one element


156  tree[curTreeNodeIdx].varName = RegressionTreeModel.TreeNode.NO_VARIABLE;


157  tree[curTreeNodeIdx].val = lineSearch(internalIdx, startIdx, endIdx);


158  curTreeNodeIdx++;


159  } else {


160  int i, j;


161  double threshold;


162  string bestVariableName;


163  FindBestVariableAndThreshold(startIdx, endIdx, out threshold, out bestVariableName);


164 


165  // if bestVariableName is NO_VARIABLE then no split was possible anymore


166  if (bestVariableName == RegressionTreeModel.TreeNode.NO_VARIABLE) {


167  // max depth reached or only one element


168  tree[curTreeNodeIdx].varName = RegressionTreeModel.TreeNode.NO_VARIABLE;


169  tree[curTreeNodeIdx].val = lineSearch(internalIdx, startIdx, endIdx);


170  curTreeNodeIdx++;


171  } else {


172 


173 


174  int bestVarIdx = varName2Index[bestVariableName];


175  // split  two pass


176 


177  // store which index goes where


178  for (int k = startIdx; k <= endIdx; k++) {


179  if (x[bestVarIdx][internalIdx[k]] <= threshold)


180  which[internalIdx[k]] = 1; // left partition


181  else


182  which[internalIdx[k]] = 1; // right partition


183  }


184 


185  // partition sortedIdx for each variable


186  for (int col = 0; col < nCols; col++) {


187  i = 0;


188  j = 0;


189  int k;


190  for (k = startIdx; k <= endIdx; k++) {


191  Debug.Assert(Math.Abs(which[sortedIdx[col][k]]) == 1);


192 


193  if (which[sortedIdx[col][k]] < 0) {


194  leftTmp[i++] = sortedIdx[col][k];


195  } else {


196  rightTmp[j++] = sortedIdx[col][k];


197  }


198  }


199  Debug.Assert(i > 0); // at least on element in the left partition


200  Debug.Assert(j > 0); // at least one element in the right partition


201  Debug.Assert(i + j == endIdx  startIdx + 1);


202  k = startIdx;


203  for (int l = 0; l < i; l++) sortedIdx[col][k++] = leftTmp[l];


204  for (int l = 0; l < j; l++) sortedIdx[col][k++] = rightTmp[l];


205  }


206 


207  // partition row indices


208  i = startIdx;


209  j = endIdx;


210  while (i <= j) {


211  Debug.Assert(Math.Abs(which[internalIdx[i]]) == 1);


212  Debug.Assert(Math.Abs(which[internalIdx[j]]) == 1);


213  if (which[internalIdx[i]] < 0) i++;


214  else if (which[internalIdx[j]] > 0) j;


215  else {


216  Trace.Assert(which[internalIdx[i]] > 0);


217  Trace.Assert(which[internalIdx[j]] < 0);


218  // swap


219  int tmp = internalIdx[i];


220  internalIdx[i] = internalIdx[j];


221  internalIdx[j] = tmp;


222  i++;


223  j;


224  }


225  }


226  Debug.Assert(j < i);


227  Debug.Assert(i >= startIdx);


228  Debug.Assert(j <= endIdx);


229 


230  var parentIdx = curTreeNodeIdx;


231  tree[parentIdx].varName = bestVariableName;


232  tree[parentIdx].val = threshold;


233  curTreeNodeIdx++;


234 


235  // create left subtree


236  tree[parentIdx].leftIdx = curTreeNodeIdx;


237  CreateRegressionTreeForIdx(maxDepth  1, startIdx, j, lineSearch);


238 


239  // create right subtree


240  tree[parentIdx].rightIdx = curTreeNodeIdx;


241  CreateRegressionTreeForIdx(maxDepth  1, i, endIdx, lineSearch);


242  }


243  }


244  }


245 


246  private void FindBestVariableAndThreshold(int startIdx, int endIdx, out double threshold, out string bestVar) {


247  Contract.Assert(startIdx < endIdx + 1); // at least 2 elements


248 


249  int rows = endIdx  startIdx + 1;


250  Contract.Assert(rows >= 2);


251 


252  double sumY = 0.0;


253  for (int i = startIdx; i <= endIdx; i++) {


254  sumY += y[internalIdx[i]];


255  }


256 


257  double bestImprovement = 0.0;


258  double bestThreshold = double.PositiveInfinity;


259  bestVar = string.Empty;


260 


261  for (int col = 0; col < effectiveVars; col++) {


262  // sort values for variable to prepare for threshold selection


263  var curVariable = allowedVariables[col];


264  var curVariableIdx = varName2Index[curVariable];


265  for (int i = startIdx; i <= endIdx; i++) {


266  var sortedI = sortedIdx[curVariableIdx][i];


267  outSortedIdx[i  startIdx] = sortedI;


268  outx[i  startIdx] = x[curVariableIdx][sortedI];


269  }


270 


271  double curImprovement;


272  double curThreshold;


273  FindBestThreshold(outx, outSortedIdx, rows, y, sumY, out curThreshold, out curImprovement);


274 


275  if (curImprovement > bestImprovement) {


276  bestImprovement = curImprovement;


277  bestThreshold = curThreshold;


278  bestVar = allowedVariables[col];


279  }


280  }


281 


282  UpdateVariableRelevance(bestVar, sumY, bestImprovement, rows);


283 


284  threshold = bestThreshold;


285 


286  // Contract.Assert(bestImprovement > 0);


287  // Contract.Assert(bestImprovement < double.PositiveInfinity);


288  // Contract.Assert(bestVar != string.Empty);


289  // Contract.Assert(allowedVariables.Contains(bestVar));


290  }


291 


292  // assumption is that the Average(y) = 0


293  private void UpdateVariableRelevance(string bestVar, double sumY, double bestImprovement, int rows) {


294  if (string.IsNullOrEmpty(bestVar)) return;


295  // update variable relevance


296  double err = sumY * sumY / rows;


297  double errAfterSplit = bestImprovement;


298 


299  double delta = (errAfterSplit  err); // relative reduction in squared error


300  double v;


301  if (!sumImprovements.TryGetValue(bestVar, out v)) {


302  sumImprovements[bestVar] = delta;


303  }


304  sumImprovements[bestVar] = v + delta;


305  }


306 


307  // x [0..N1] contains rows sorted values in the range from [0..rows1]


308  // sortedIdx [0..N1] contains the idx of the values in x in the original dataset in the range from [0..rows1]


309  // rows specifies the number of valid entries in x and sortedIdx


310  // y [0..N1] contains the target values in original sorting order


311  // sumY is y.Sum()


312  //


313  // the routine returns the best threshold (x[i] + x[i+1]) / 2 for i = [0 .. rows2] by calculating the reduction in squared error


314  // additionally the reduction in squared error is returned in bestImprovement


315  // if all elements of x are equal the routing fails to produce a threshold


316  private static void FindBestThreshold(double[] x, int[] sortedIdx, int rows, double[] y, double sumY, out double bestThreshold, out double bestImprovement) {


317  Contract.Assert(rows >= 2);


318 


319  double sl = 0.0;


320  double sr = sumY;


321  double nl = 0.0;


322  double nr = rows;


323 


324  bestImprovement = 0.0;


325  bestThreshold = double.NegativeInfinity;


326  // for all thresholds


327  // if we have n rows there are n1 possible splits


328  for (int i = 0; i < rows  1; i++) {


329  sl += y[sortedIdx[i]];


330  sr = y[sortedIdx[i]];


331 


332  nl++;


333  nr;


334  Debug.Assert(nl > 0);


335  Debug.Assert(nr > 0);


336 


337  if (x[i] < x[i + 1]) { // don't try to split when two elements are equal


338  double curQuality = sl * sl / nl + sr * sr / nr;


339  // curQuality = nl*nr / (nl+nr) * Sqr(sl / nl  sr / nr) // greedy function approximation page 12 eqn (35)


340 


341  if (curQuality > bestImprovement) {


342  bestThreshold = (x[i] + x[i + 1]) / 2.0;


343  bestImprovement = curQuality;


344  }


345  }


346  }


347 


348  // if all elements where the same then no split can be found


349  }


350 


351 


352  public IEnumerable<KeyValuePair<string, double>> GetVariableRelevance() {


353  double scaling = 100 / sumImprovements.Max(t => t.Value);


354  return


355  sumImprovements


356  .Select(t => new KeyValuePair<string, double>(t.Key, t.Value * scaling))


357  .OrderByDescending(t => t.Value);


358  }


359  }


360  }


361 


362 


363 


364 


365 


366 


367 


368 


369 


370 


371 


372 


373 


374 


375 


376 


377 


378 


379 

