Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/07/15 11:57:37 (9 years ago)
Author:
gkronber
Message:

#2261 implemented node expansion using a priority queue (and changed parameter MaxDepth to MaxSize). Moved unit tests to a separate project.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs

    r12623 r12632  
    5454    private readonly int[][] sortedIdx; // random selection from sortedIdxAll (for r < 1.0)
    5555
    56     private int calls = 0;
    57 
    5856    // helper arrays which are allocated to maximal necessary size only once in the ctor
    5957    private readonly int[] internalIdx, which, leftTmp, rightTmp;
     
    6462    private int curTreeNodeIdx; // the index where the next tree node is stored
    6563
    66     private class Partition {
    67       public int ParentNodeIdx { get; set; }
    68       public int Depth { get; set; }
    69       public int StartIdx { get; set; }
    70       public int EndIndex { get; set; }
    71       public bool Left { get; set; }
    72     }
    73     private readonly SortedList<double, Partition> queue;
     64    // This class represents information about potential splits.
     65    // For each node generated the best splitting variable and threshold as well as
     66    // the improvement from the split are stored in a priority queue
     67    private class PartitionSplits {
     68      public int ParentNodeIdx { get; set; } // the idx of the leaf node representing this partition
     69      public int StartIdx { get; set; } // the start idx of the partition
     70      public int EndIndex { get; set; } // the end idx of the partition
     71      public string SplittingVariable { get; set; } // the best splitting variable
     72      public double SplittingThreshold { get; set; } // the best threshold
     73      public double SplittingImprovement { get; set; } // the improvement of the split (for priority queue)
     74    }
     75
     76    // this list hold partitions with the information about the best split (organized as a sorted queue)
     77    private readonly IList<PartitionSplits> queue;
    7478
    7579    // prepare and allocate buffer variables in ctor
     
    9599      outx = new double[rows];
    96100      outSortedIdx = new int[rows];
    97       queue = new SortedList<double, Partition>();
     101      queue = new List<PartitionSplits>(100);
    98102
    99103      x = new double[nCols][];
     
    117121    // r is fraction of rows to use for training
    118122    // m is fraction of variables to use for training
    119     public IRegressionModel CreateRegressionTree(int maxDepth, double r = 0.5, double m = 0.5) {
     123    public IRegressionModel CreateRegressionTree(int maxSize, double r = 0.5, double m = 0.5) {
    120124      // subtract mean of y first
    121125      var yAvg = y.Average();
     
    126130      var ones = Enumerable.Repeat(1.0, y.Length);
    127131
    128       var model = CreateRegressionTreeForGradientBoosting(y, maxDepth, problemData.TrainingIndices.ToArray(), seLoss.GetLineSearchFunc(y, zeros, ones), r, m);
     132      var model = CreateRegressionTreeForGradientBoosting(y, maxSize, problemData.TrainingIndices.ToArray(), seLoss.GetLineSearchFunc(y, zeros, ones), r, m);
    129133
    130134      return new GradientBoostedTreesModel(new[] { new ConstantRegressionModel(yAvg), model }, new[] { 1.0, 1.0 });
     
    132136
    133137    // specific interface that allows to specify the target labels and the training rows which is necessary when for gradient boosted trees
    134     public IRegressionModel CreateRegressionTreeForGradientBoosting(double[] y, int maxDepth, int[] idx, LineSearchFunc lineSearch, double r = 0.5, double m = 0.5) {
    135       Debug.Assert(maxDepth > 0);
     138    public IRegressionModel CreateRegressionTreeForGradientBoosting(double[] y, int maxSize, int[] idx, LineSearchFunc lineSearch, double r = 0.5, double m = 0.5) {
     139      Debug.Assert(maxSize > 0);
    136140      Debug.Assert(r > 0);
    137141      Debug.Assert(r <= 1.0);
     
    174178      }
    175179
    176       // prepare array for the tree nodes (a tree of maxDepth=1 has 1 node, a tree of maxDepth=d has 2^d - 1 nodes)     
    177       int numNodes = (int)Math.Pow(2, maxDepth) - 1;
    178       this.tree = new RegressionTreeModel.TreeNode[numNodes];
     180      this.tree = new RegressionTreeModel.TreeNode[maxSize];
     181      this.queue.Clear();
    179182      this.curTreeNodeIdx = 0;
    180183
     184      // start out with only one leaf node (constant prediction)
     185      // and calculate the best split for this root node and enqueue it into a queue sorted by improvement throught the split
    181186      // start and end idx are inclusive
    182       queue.Add(calls++, new Partition() { ParentNodeIdx = -1, Depth = maxDepth, StartIdx = 0, EndIndex = effectiveRows - 1 });
    183       CreateRegressionTreeForIdx(lineSearch);
    184 
    185       return new RegressionTreeModel(tree);
    186     }
    187 
    188     private void CreateRegressionTreeForIdx(LineSearchFunc lineSearch) {
    189       while (queue.Any()) {
    190         var f = queue.First().Value; // actually a stack
    191         queue.RemoveAt(0);
    192 
    193         var depth = f.Depth;
     187      CreateLeafNode(0, effectiveRows - 1, lineSearch);
     188
     189      // process the priority queue to complete the tree
     190      CreateRegressionTreeFromQueue(maxSize, lineSearch);
     191
     192      return new RegressionTreeModel(tree.ToArray());
     193    }
     194
     195
     196    // processes potential splits from the queue as long as splits are left and the maximum size of the tree is not reached
     197    private void CreateRegressionTreeFromQueue(int maxNodes, LineSearchFunc lineSearch) {
     198      while (queue.Any() && curTreeNodeIdx + 1 < maxNodes) { // two nodes are created in each loop
     199        var f = queue[queue.Count - 1]; // last element has the largest improvement
     200        queue.RemoveAt(queue.Count - 1);
     201
    194202        var startIdx = f.StartIdx;
    195203        var endIdx = f.EndIndex;
     
    199207        Debug.Assert(endIdx < internalIdx.Length);
    200208
    201         double threshold;
    202         string bestVariableName;
    203 
    204         // stop when only one row is left or no split is possible
    205         if (depth <= 1 || endIdx - startIdx == 0 || !FindBestVariableAndThreshold(startIdx, endIdx, out threshold, out bestVariableName)) {
    206           CreateLeafNode(startIdx, endIdx, lineSearch);
    207           if (f.ParentNodeIdx >= 0) if (f.Left) {
    208               tree[f.ParentNodeIdx].leftIdx = curTreeNodeIdx;
    209             } else {
    210               tree[f.ParentNodeIdx].rightIdx = curTreeNodeIdx;
    211             }
    212           curTreeNodeIdx++;
    213         } else {
    214           int splitIdx;
    215           CreateInternalNode(f.StartIdx, f.EndIndex, bestVariableName, threshold, out splitIdx);
    216 
    217           // connect to parent tree
    218           if (f.ParentNodeIdx >= 0) if (f.Left) {
    219               tree[f.ParentNodeIdx].leftIdx = curTreeNodeIdx;
    220             } else {
    221               tree[f.ParentNodeIdx].rightIdx = curTreeNodeIdx;
    222             }
    223 
    224           Debug.Assert(splitIdx + 1 <= endIdx);
    225           Debug.Assert(startIdx <= splitIdx);
    226 
    227           queue.Add(calls++, new Partition() { ParentNodeIdx = curTreeNodeIdx, Left = true, Depth = depth - 1, StartIdx = startIdx, EndIndex = splitIdx }); // left part before right part (stack organization)
    228           queue.Add(calls++, new Partition() { ParentNodeIdx = curTreeNodeIdx, Left = false, Depth = depth - 1, StartIdx = splitIdx + 1, EndIndex = endIdx });
    229           curTreeNodeIdx++;
    230 
    231         }
    232       }
    233     }
    234 
    235 
    236     private void CreateLeafNode(int startIdx, int endIdx, LineSearchFunc lineSearch) {
    237       // max depth reached or only one element   
    238       tree[curTreeNodeIdx].varName = RegressionTreeModel.TreeNode.NO_VARIABLE;
    239       tree[curTreeNodeIdx].val = lineSearch(internalIdx, startIdx, endIdx);
    240     }
    241 
    242     // routine for building the tree for the partition of rows stored in internalIdx between startIdx and endIdx
    243     // the lineSearch function calculates the optimal prediction value for tree leaf nodes
    244     // (in the case of squared errors it is the average of target values for the rows represented by the node)
     209        // transform the leaf node into an internal node
     210        tree[f.ParentNodeIdx].VarName = f.SplittingVariable;
     211        tree[f.ParentNodeIdx].Val = f.SplittingThreshold;
     212
     213        // split partition into left and right
     214        int splitIdx;
     215        SplitPartition(f.StartIdx, f.EndIndex, f.SplittingVariable, f.SplittingThreshold, out splitIdx);
     216        Debug.Assert(splitIdx + 1 <= endIdx);
     217        Debug.Assert(startIdx <= splitIdx);
     218
     219        // create two leaf nodes (and enqueue best splits for both)
     220        tree[f.ParentNodeIdx].LeftIdx = CreateLeafNode(startIdx, splitIdx, lineSearch);
     221        tree[f.ParentNodeIdx].RightIdx = CreateLeafNode(splitIdx + 1, endIdx, lineSearch);
     222      }
     223    }
     224
     225
     226    // returns the index of the newly created tree node
     227    private int CreateLeafNode(int startIdx, int endIdx, LineSearchFunc lineSearch) {
     228      tree[curTreeNodeIdx].VarName = RegressionTreeModel.TreeNode.NO_VARIABLE;
     229      tree[curTreeNodeIdx].Val = lineSearch(internalIdx, startIdx, endIdx);
     230
     231      EnqueuePartitionSplit(curTreeNodeIdx, startIdx, endIdx);
     232      curTreeNodeIdx++;
     233      return curTreeNodeIdx - 1;
     234    }
     235
     236
     237    // calculates the optimal split for the partition [startIdx .. endIdx] (inclusive)
     238    // which is represented by the leaf node with the specified nodeIdx
     239    private void EnqueuePartitionSplit(int nodeIdx, int startIdx, int endIdx) {
     240      double threshold, improvement;
     241      string bestVariableName;
     242      // only enqueue a new split if there are at least 2 rows left and a split is possible
     243      if (startIdx < endIdx &&
     244        FindBestVariableAndThreshold(startIdx, endIdx, out threshold, out bestVariableName, out improvement)) {
     245        var split = new PartitionSplits() {
     246          ParentNodeIdx = nodeIdx,
     247          StartIdx = startIdx,
     248          EndIndex = endIdx,
     249          SplittingThreshold = threshold,
     250          SplittingVariable = bestVariableName
     251        };
     252        InsertSortedQueue(split);
     253      }
     254    }
     255
     256
     257    // routine for splitting a partition of rows stored in internalIdx between startIdx and endIdx into
     258    // a left partition and a right partition using the given splittingVariable and threshold
     259    // the splitIdx is the last index of the left partition
     260    // splitIdx + 1 is the first index of the right partition
    245261    // startIdx and endIdx are inclusive
    246     private void CreateInternalNode(int startIdx, int endIdx, string splittingVar, double threshold, out int splitIdx) {
     262    private void SplitPartition(int startIdx, int endIdx, string splittingVar, double threshold, out int splitIdx) {
    247263      int bestVarIdx = varName2Index[splittingVar];
    248264      // split - two pass
     
    303319      Debug.Assert(startIdx <= j);
    304320
    305       tree[curTreeNodeIdx].varName = splittingVar;
    306       tree[curTreeNodeIdx].val = threshold;
    307321      splitIdx = j;
    308322    }
    309323
    310     private bool FindBestVariableAndThreshold(int startIdx, int endIdx, out double threshold, out string bestVar) {
     324    private bool FindBestVariableAndThreshold(int startIdx, int endIdx, out double threshold, out string bestVar, out double improvement) {
    311325      Debug.Assert(startIdx < endIdx + 1); // at least 2 elements
    312326
     
    345359      }
    346360      if (bestVar == RegressionTreeModel.TreeNode.NO_VARIABLE) {
    347         threshold = bestThreshold;
     361        // not successfull
     362        threshold = double.PositiveInfinity;
     363        improvement = double.NegativeInfinity;
    348364        return false;
    349365      } else {
    350366        UpdateVariableRelevance(bestVar, sumY, bestImprovement, rows);
     367        improvement = bestImprovement;
    351368        threshold = bestThreshold;
    352369        return true;
    353370      }
    354     }
    355 
    356     private void UpdateVariableRelevance(string bestVar, double sumY, double bestImprovement, int rows) {
    357       if (string.IsNullOrEmpty(bestVar)) return;
    358       // update variable relevance
    359       double baseLine = 1.0 / rows * sumY * sumY; // if best improvement is equal to baseline then the split had no effect
    360 
    361       double delta = (bestImprovement - baseLine);
    362       double v;
    363       if (!sumImprovements.TryGetValue(bestVar, out v)) {
    364         sumImprovements[bestVar] = delta;
    365       }
    366       sumImprovements[bestVar] = v + delta;
    367371    }
    368372
     
    429433
    430434
     435    private void UpdateVariableRelevance(string bestVar, double sumY, double bestImprovement, int rows) {
     436      if (string.IsNullOrEmpty(bestVar)) return;
     437      // update variable relevance
     438      double baseLine = 1.0 / rows * sumY * sumY; // if best improvement is equal to baseline then the split had no effect
     439
     440      double delta = (bestImprovement - baseLine);
     441      double v;
     442      if (!sumImprovements.TryGetValue(bestVar, out v)) {
     443        sumImprovements[bestVar] = delta;
     444      }
     445      sumImprovements[bestVar] = v + delta;
     446    }
     447
    431448    public IEnumerable<KeyValuePair<string, double>> GetVariableRelevance() {
    432449      // values are scaled: the most important variable has relevance = 100
     
    437454        .OrderByDescending(t => t.Value);
    438455    }
     456
     457
     458    // insert a new parition split (find insertion point and start at first element of the queue)
     459    // elements are removed from the queue at the last position
     460    // O(n), splits could be organized as a heap to improve runtime (see alglib tsort)
     461    private void InsertSortedQueue(PartitionSplits split) {
     462      // find insertion position
     463      int i = 0;
     464      while (i < queue.Count && queue[i].SplittingImprovement < split.SplittingImprovement) { i++; }
     465
     466      queue.Insert(i, split);
     467    }
    439468  }
    440469}
Note: See TracChangeset for help on using the changeset viewer.