Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/06/15 18:33:24 (9 years ago)
Author:
gkronber
Message:

#2261: replace recursion by a stack to prepare for unbalanced tree expansion

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs

    r12597 r12619  
    2424using System.Collections.Generic;
    2525using System.Diagnostics;
    26 using System.Diagnostics.Contracts;
    2726using System.Linq;
    2827using HeuristicLab.Core;
     
    6463    private int curTreeNodeIdx; // the index where the next tree node is stored
    6564
    66     private readonly IList<RegressionTreeModel.TreeNode> nodeQueue;
     65    private class Partition {
     66      public int TreeNodeIdx { get; set; }
     67      public int Depth { get; set; }
     68      public int StartIdx { get; set; }
     69      public int EndIndex { get; set; }
     70      public bool Left { get; set; }
     71    }
     72    private readonly IList<Partition> queue;
    6773
    6874    // prepare and allocate buffer variables in ctor
     
    8894      outx = new double[rows];
    8995      outSortedIdx = new int[rows];
    90       nodeQueue = new List<RegressionTreeModel.TreeNode>();
     96      queue = new List<Partition>();
    9197
    9298      x = new double[nCols][];
     
    120126
    121127      var model = CreateRegressionTreeForGradientBoosting(y, maxDepth, problemData.TrainingIndices.ToArray(), seLoss.GetLineSearchFunc(y, zeros, ones), r, m);
     128
    122129      return new GradientBoostedTreesModel(new[] { new ConstantRegressionModel(yAvg), model }, new[] { 1.0, 1.0 });
    123130    }
     
    125132    // specific interface that allows to specify the target labels and the training rows which is necessary when for gradient boosted trees
    126133    public IRegressionModel CreateRegressionTreeForGradientBoosting(double[] y, int maxDepth, int[] idx, LineSearchFunc lineSearch, double r = 0.5, double m = 0.5) {
    127       Contract.Assert(maxDepth > 0);
    128       Contract.Assert(r > 0);
    129       Contract.Assert(r <= 1.0);
    130       Contract.Assert(y.Count() == this.y.Length);
    131       Contract.Assert(m > 0);
    132       Contract.Assert(m <= 1.0);
     134      Debug.Assert(maxDepth > 0);
     135      Debug.Assert(r > 0);
     136      Debug.Assert(r <= 1.0);
     137      Debug.Assert(y.Count() == this.y.Length);
     138      Debug.Assert(m > 0);
     139      Debug.Assert(m <= 1.0);
    133140
    134141      this.y = y; // y is changed in gradient boosting
     
    172179
    173180      // start and end idx are inclusive
    174       CreateRegressionTreeForIdx(maxDepth, 0, effectiveRows - 1, lineSearch);
     181      queue.Add(new Partition() { TreeNodeIdx = -1, Depth = maxDepth, StartIdx = 0, EndIndex = effectiveRows - 1 });
     182      CreateRegressionTreeForIdx(lineSearch);
     183
    175184      return new RegressionTreeModel(tree);
    176185    }
    177186
    178     // recursive routine for building the tree for the row idx stored in internalIdx between startIdx and endIdx
     187    private void CreateRegressionTreeForIdx(LineSearchFunc lineSearch) {
     188      while (queue.Any()) {
     189        var f = queue[0]; // actually a stack
     190        queue.RemoveAt(0);
     191
     192        var depth = f.Depth;
     193        var startIdx = f.StartIdx;
     194        var endIdx = f.EndIndex;
     195
     196        Debug.Assert(endIdx - startIdx >= 0);
     197        Debug.Assert(startIdx >= 0);
     198        Debug.Assert(endIdx < internalIdx.Length);
     199
     200        double threshold;
     201        string bestVariableName;
     202
     203        // stop when only one row is left or no split is possible
     204        if (depth <= 1 || endIdx - startIdx == 0 || !FindBestVariableAndThreshold(startIdx, endIdx, out threshold, out bestVariableName)) {
     205          CreateLeafNode(startIdx, endIdx, lineSearch);
     206          if (f.TreeNodeIdx >= 0) if (f.Left) {
     207              tree[f.TreeNodeIdx].leftIdx = curTreeNodeIdx;
     208            } else {
     209              tree[f.TreeNodeIdx].rightIdx = curTreeNodeIdx;
     210            }
     211          curTreeNodeIdx++;
     212        } else {
     213          int splitIdx;
     214          CreateInternalNode(f.StartIdx, f.EndIndex, bestVariableName, threshold, out splitIdx);
     215
     216          // connect to parent tree
     217          if (f.TreeNodeIdx >= 0) if (f.Left) {
     218              tree[f.TreeNodeIdx].leftIdx = curTreeNodeIdx;
     219            } else {
     220              tree[f.TreeNodeIdx].rightIdx = curTreeNodeIdx;
     221            }
     222
     223          Debug.Assert(splitIdx + 1 <= endIdx);
     224          Debug.Assert(startIdx <= splitIdx);
     225
     226          queue.Insert(0, new Partition() { TreeNodeIdx = curTreeNodeIdx, Left = false, Depth = depth - 1, StartIdx = splitIdx + 1, EndIndex = endIdx });
     227          queue.Insert(0, new Partition() { TreeNodeIdx = curTreeNodeIdx, Left = true, Depth = depth - 1, StartIdx = startIdx, EndIndex = splitIdx }); // left part before right part (stack organization)
     228          curTreeNodeIdx++;
     229
     230        }
     231      }
     232    }
     233
     234
     235    private void CreateLeafNode(int startIdx, int endIdx, LineSearchFunc lineSearch) {
     236      // max depth reached or only one element   
     237      tree[curTreeNodeIdx].varName = RegressionTreeModel.TreeNode.NO_VARIABLE;
     238      tree[curTreeNodeIdx].val = lineSearch(internalIdx, startIdx, endIdx);
     239    }
     240
     241    // routine for building the tree for the row idx stored in internalIdx between startIdx and endIdx
    179242    // the lineSearch function calculates the optimal prediction value for tree leaf nodes
    180243    // (in the case of squared errors it is the average of target values for the rows represented by the node)
    181244    // startIdx and endIdx are inclusive
    182     private void CreateRegressionTreeForIdx(int maxDepth, int startIdx, int endIdx, LineSearchFunc lineSearch) {
    183       Contract.Assert(endIdx - startIdx >= 0);
    184       Contract.Assert(startIdx >= 0);
    185       Contract.Assert(endIdx < internalIdx.Length);
    186 
    187       // TODO: stop when y is constant
    188       // TODO: use priority queue of nodes to be expanded (sorted by improvement) instead of the recursion to maximum depth
    189       if (maxDepth <= 1 || endIdx - startIdx == 0) {
    190         // max depth reached or only one element   
    191         tree[curTreeNodeIdx].varName = RegressionTreeModel.TreeNode.NO_VARIABLE;
    192         tree[curTreeNodeIdx].val = lineSearch(internalIdx, startIdx, endIdx);
    193         curTreeNodeIdx++;
    194       } else {
    195         int i, j;
    196         double threshold;
    197         string bestVariableName;
    198         FindBestVariableAndThreshold(startIdx, endIdx, out threshold, out bestVariableName);
    199 
    200         // if bestVariableName is NO_VARIABLE then no split was possible anymore
    201         if (bestVariableName == RegressionTreeModel.TreeNode.NO_VARIABLE) {
    202           // max depth reached or only one element   
    203           tree[curTreeNodeIdx].varName = RegressionTreeModel.TreeNode.NO_VARIABLE;
    204           tree[curTreeNodeIdx].val = lineSearch(internalIdx, startIdx, endIdx);
    205           curTreeNodeIdx++;
    206         } else {
    207 
    208 
    209           int bestVarIdx = varName2Index[bestVariableName];
    210           // split - two pass
    211 
    212           // store which index goes where
    213           for (int k = startIdx; k <= endIdx; k++) {
    214             if (x[bestVarIdx][internalIdx[k]] <= threshold)
    215               which[internalIdx[k]] = -1; // left partition
    216             else
    217               which[internalIdx[k]] = 1; // right partition
     245    private void CreateInternalNode(int startIdx, int endIdx, string splittingVar, double threshold, out int splitIdx) {
     246      int bestVarIdx = varName2Index[splittingVar];
     247      // split - two pass
     248
     249      // store which index goes where
     250      for (int k = startIdx; k <= endIdx; k++) {
     251        if (x[bestVarIdx][internalIdx[k]] <= threshold)
     252          which[internalIdx[k]] = -1; // left partition
     253        else
     254          which[internalIdx[k]] = 1; // right partition
     255      }
     256
     257      // partition sortedIdx for each variable
     258      int i;
     259      int j;
     260      for (int col = 0; col < nCols; col++) {
     261        i = 0;
     262        j = 0;
     263        int k;
     264        for (k = startIdx; k <= endIdx; k++) {
     265          Debug.Assert(Math.Abs(which[sortedIdx[col][k]]) == 1);
     266
     267          if (which[sortedIdx[col][k]] < 0) {
     268            leftTmp[i++] = sortedIdx[col][k];
     269          } else {
     270            rightTmp[j++] = sortedIdx[col][k];
    218271          }
    219 
    220           // partition sortedIdx for each variable
    221           for (int col = 0; col < nCols; col++) {
    222             i = 0;
    223             j = 0;
    224             int k;
    225             for (k = startIdx; k <= endIdx; k++) {
    226               Debug.Assert(Math.Abs(which[sortedIdx[col][k]]) == 1);
    227 
    228               if (which[sortedIdx[col][k]] < 0) {
    229                 leftTmp[i++] = sortedIdx[col][k];
    230               } else {
    231                 rightTmp[j++] = sortedIdx[col][k];
    232               }
    233             }
    234             Debug.Assert(i > 0); // at least on element in the left partition
    235             Debug.Assert(j > 0); // at least one element in the right partition
    236             Debug.Assert(i + j == endIdx - startIdx + 1);
    237             k = startIdx;
    238             for (int l = 0; l < i; l++) sortedIdx[col][k++] = leftTmp[l];
    239             for (int l = 0; l < j; l++) sortedIdx[col][k++] = rightTmp[l];
    240           }
    241 
    242           // partition row indices
    243           i = startIdx;
    244           j = endIdx;
    245           while (i <= j) {
    246             Debug.Assert(Math.Abs(which[internalIdx[i]]) == 1);
    247             Debug.Assert(Math.Abs(which[internalIdx[j]]) == 1);
    248             if (which[internalIdx[i]] < 0) i++;
    249             else if (which[internalIdx[j]] > 0) j--;
    250             else {
    251               Debug.Assert(which[internalIdx[i]] > 0);
    252               Debug.Assert(which[internalIdx[j]] < 0);
    253               // swap
    254               int tmp = internalIdx[i];
    255               internalIdx[i] = internalIdx[j];
    256               internalIdx[j] = tmp;
    257               i++;
    258               j--;
    259             }
    260           }
    261           Debug.Assert(j < i);
    262           Debug.Assert(i >= startIdx);
    263           Debug.Assert(j <= endIdx);
    264 
    265           var parentIdx = curTreeNodeIdx;
    266           tree[parentIdx].varName = bestVariableName;
    267           tree[parentIdx].val = threshold;
    268           curTreeNodeIdx++;
    269 
    270           // create left subtree
    271           tree[parentIdx].leftIdx = curTreeNodeIdx;
    272           CreateRegressionTreeForIdx(maxDepth - 1, startIdx, j, lineSearch);
    273 
    274           // create right subtree
    275           tree[parentIdx].rightIdx = curTreeNodeIdx;
    276           CreateRegressionTreeForIdx(maxDepth - 1, i, endIdx, lineSearch);
    277         }
    278       }
    279     }
    280 
    281     private void FindBestVariableAndThreshold(int startIdx, int endIdx, out double threshold, out string bestVar) {
    282       Contract.Assert(startIdx < endIdx + 1); // at least 2 elements
     272        }
     273        Debug.Assert(i > 0); // at least on element in the left partition
     274        Debug.Assert(j > 0); // at least one element in the right partition
     275        Debug.Assert(i + j == endIdx - startIdx + 1);
     276        k = startIdx;
     277        for (int l = 0; l < i; l++) sortedIdx[col][k++] = leftTmp[l];
     278        for (int l = 0; l < j; l++) sortedIdx[col][k++] = rightTmp[l];
     279      }
     280
     281      // partition row indices
     282      i = startIdx;
     283      j = endIdx;
     284      while (i <= j) {
     285        Debug.Assert(Math.Abs(which[internalIdx[i]]) == 1);
     286        Debug.Assert(Math.Abs(which[internalIdx[j]]) == 1);
     287        if (which[internalIdx[i]] < 0) i++;
     288        else if (which[internalIdx[j]] > 0) j--;
     289        else {
     290          Debug.Assert(which[internalIdx[i]] > 0);
     291          Debug.Assert(which[internalIdx[j]] < 0);
     292          // swap
     293          int tmp = internalIdx[i];
     294          internalIdx[i] = internalIdx[j];
     295          internalIdx[j] = tmp;
     296          i++;
     297          j--;
     298        }
     299      }
     300      Debug.Assert(j + 1 == i);
     301      Debug.Assert(i <= endIdx);
     302      Debug.Assert(startIdx <= j);
     303
     304      tree[curTreeNodeIdx].varName = splittingVar;
     305      tree[curTreeNodeIdx].val = threshold;
     306      splitIdx = j;
     307    }
     308
     309    private bool FindBestVariableAndThreshold(int startIdx, int endIdx, out double threshold, out string bestVar) {
     310      Debug.Assert(startIdx < endIdx + 1); // at least 2 elements
    283311
    284312      int rows = endIdx - startIdx + 1;
    285       Contract.Assert(rows >= 2);
     313      Debug.Assert(rows >= 2);
    286314
    287315      double sumY = 0.0;
     
    292320      double bestImprovement = 0.0;
    293321      double bestThreshold = double.PositiveInfinity;
    294       bestVar = string.Empty;
     322      bestVar = RegressionTreeModel.TreeNode.NO_VARIABLE;
    295323
    296324      for (int col = 0; col < effectiveVars; col++) {
     
    314342        }
    315343      }
    316 
    317       UpdateVariableRelevance(bestVar, sumY, bestImprovement, rows);
    318 
    319       threshold = bestThreshold;
    320     }
    321 
    322     // assumption is that the Average(y) = 0
     344      if (bestVar == RegressionTreeModel.TreeNode.NO_VARIABLE) {
     345        threshold = bestThreshold;
     346        return false;
     347      } else {
     348        UpdateVariableRelevance(bestVar, sumY, bestImprovement, rows);
     349        threshold = bestThreshold;
     350        return true;
     351      }
     352    }
     353
     354    // TODO: assumption is that the Average(y) = 0
    323355    private void UpdateVariableRelevance(string bestVar, double sumY, double bestImprovement, int rows) {
    324356      if (string.IsNullOrEmpty(bestVar)) return;
     
    345377    // if all elements of x are equal the routing fails to produce a threshold
    346378    private static void FindBestThreshold(double[] x, int[] sortedIdx, int rows, double[] y, double sumY, out double bestThreshold, out double bestImprovement) {
    347       Contract.Assert(rows >= 2);
     379      Debug.Assert(rows >= 2);
    348380
    349381      double sl = 0.0;
Note: See TracChangeset for help on using the changeset viewer.