Changeset 12632 for branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs
- Timestamp:
- 07/07/15 11:57:37 (9 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs
r12623 r12632 54 54 private readonly int[][] sortedIdx; // random selection from sortedIdxAll (for r < 1.0) 55 55 56 private int calls = 0;57 58 56 // helper arrays which are allocated to maximal necessary size only once in the ctor 59 57 private readonly int[] internalIdx, which, leftTmp, rightTmp; … … 64 62 private int curTreeNodeIdx; // the index where the next tree node is stored 65 63 66 private class Partition { 67 public int ParentNodeIdx { get; set; } 68 public int Depth { get; set; } 69 public int StartIdx { get; set; } 70 public int EndIndex { get; set; } 71 public bool Left { get; set; } 72 } 73 private readonly SortedList<double, Partition> queue; 64 // This class represents information about potential splits. 65 // For each node generated the best splitting variable and threshold as well as 66 // the improvement from the split are stored in a priority queue 67 private class PartitionSplits { 68 public int ParentNodeIdx { get; set; } // the idx of the leaf node representing this partition 69 public int StartIdx { get; set; } // the start idx of the partition 70 public int EndIndex { get; set; } // the end idx of the partition 71 public string SplittingVariable { get; set; } // the best splitting variable 72 public double SplittingThreshold { get; set; } // the best threshold 73 public double SplittingImprovement { get; set; } // the improvement of the split (for priority queue) 74 } 75 76 // this list hold partitions with the information about the best split (organized as a sorted queue) 77 private readonly IList<PartitionSplits> queue; 74 78 75 79 // prepare and allocate buffer variables in ctor … … 95 99 outx = new double[rows]; 96 100 outSortedIdx = new int[rows]; 97 queue = new SortedList<double, Partition>();101 queue = new List<PartitionSplits>(100); 98 102 99 103 x = new double[nCols][]; … … 117 121 // r is fraction of rows to use for training 118 122 // m is fraction of variables to use for training 119 public IRegressionModel CreateRegressionTree(int max Depth, double r = 0.5, double m = 0.5) {123 public IRegressionModel CreateRegressionTree(int maxSize, double r = 0.5, double m = 0.5) { 120 124 // subtract mean of y first 121 125 var yAvg = y.Average(); … … 126 130 var ones = Enumerable.Repeat(1.0, y.Length); 127 131 128 var model = CreateRegressionTreeForGradientBoosting(y, max Depth, problemData.TrainingIndices.ToArray(), seLoss.GetLineSearchFunc(y, zeros, ones), r, m);132 var model = CreateRegressionTreeForGradientBoosting(y, maxSize, problemData.TrainingIndices.ToArray(), seLoss.GetLineSearchFunc(y, zeros, ones), r, m); 129 133 130 134 return new GradientBoostedTreesModel(new[] { new ConstantRegressionModel(yAvg), model }, new[] { 1.0, 1.0 }); … … 132 136 133 137 // specific interface that allows to specify the target labels and the training rows which is necessary when for gradient boosted trees 134 public IRegressionModel CreateRegressionTreeForGradientBoosting(double[] y, int max Depth, int[] idx, LineSearchFunc lineSearch, double r = 0.5, double m = 0.5) {135 Debug.Assert(max Depth> 0);138 public IRegressionModel CreateRegressionTreeForGradientBoosting(double[] y, int maxSize, int[] idx, LineSearchFunc lineSearch, double r = 0.5, double m = 0.5) { 139 Debug.Assert(maxSize > 0); 136 140 Debug.Assert(r > 0); 137 141 Debug.Assert(r <= 1.0); … … 174 178 } 175 179 176 // prepare array for the tree nodes (a tree of maxDepth=1 has 1 node, a tree of maxDepth=d has 2^d - 1 nodes) 177 int numNodes = (int)Math.Pow(2, maxDepth) - 1; 178 this.tree = new RegressionTreeModel.TreeNode[numNodes]; 180 this.tree = new RegressionTreeModel.TreeNode[maxSize]; 181 this.queue.Clear(); 179 182 this.curTreeNodeIdx = 0; 180 183 184 // start out with only one leaf node (constant prediction) 185 // and calculate the best split for this root node and enqueue it into a queue sorted by improvement throught the split 181 186 // start and end idx are inclusive 182 queue.Add(calls++, new Partition() { ParentNodeIdx = -1, Depth = maxDepth, StartIdx = 0, EndIndex = effectiveRows - 1 }); 183 CreateRegressionTreeForIdx(lineSearch); 184 185 return new RegressionTreeModel(tree); 186 } 187 188 private void CreateRegressionTreeForIdx(LineSearchFunc lineSearch) { 189 while (queue.Any()) { 190 var f = queue.First().Value; // actually a stack 191 queue.RemoveAt(0); 192 193 var depth = f.Depth; 187 CreateLeafNode(0, effectiveRows - 1, lineSearch); 188 189 // process the priority queue to complete the tree 190 CreateRegressionTreeFromQueue(maxSize, lineSearch); 191 192 return new RegressionTreeModel(tree.ToArray()); 193 } 194 195 196 // processes potential splits from the queue as long as splits are left and the maximum size of the tree is not reached 197 private void CreateRegressionTreeFromQueue(int maxNodes, LineSearchFunc lineSearch) { 198 while (queue.Any() && curTreeNodeIdx + 1 < maxNodes) { // two nodes are created in each loop 199 var f = queue[queue.Count - 1]; // last element has the largest improvement 200 queue.RemoveAt(queue.Count - 1); 201 194 202 var startIdx = f.StartIdx; 195 203 var endIdx = f.EndIndex; … … 199 207 Debug.Assert(endIdx < internalIdx.Length); 200 208 201 double threshold; 202 string bestVariableName; 203 204 // stop when only one row is left or no split is possible 205 if (depth <= 1 || endIdx - startIdx == 0 || !FindBestVariableAndThreshold(startIdx, endIdx, out threshold, out bestVariableName)) { 206 CreateLeafNode(startIdx, endIdx, lineSearch); 207 if (f.ParentNodeIdx >= 0) if (f.Left) { 208 tree[f.ParentNodeIdx].leftIdx = curTreeNodeIdx; 209 } else { 210 tree[f.ParentNodeIdx].rightIdx = curTreeNodeIdx; 211 } 212 curTreeNodeIdx++; 213 } else { 214 int splitIdx; 215 CreateInternalNode(f.StartIdx, f.EndIndex, bestVariableName, threshold, out splitIdx); 216 217 // connect to parent tree 218 if (f.ParentNodeIdx >= 0) if (f.Left) { 219 tree[f.ParentNodeIdx].leftIdx = curTreeNodeIdx; 220 } else { 221 tree[f.ParentNodeIdx].rightIdx = curTreeNodeIdx; 222 } 223 224 Debug.Assert(splitIdx + 1 <= endIdx); 225 Debug.Assert(startIdx <= splitIdx); 226 227 queue.Add(calls++, new Partition() { ParentNodeIdx = curTreeNodeIdx, Left = true, Depth = depth - 1, StartIdx = startIdx, EndIndex = splitIdx }); // left part before right part (stack organization) 228 queue.Add(calls++, new Partition() { ParentNodeIdx = curTreeNodeIdx, Left = false, Depth = depth - 1, StartIdx = splitIdx + 1, EndIndex = endIdx }); 229 curTreeNodeIdx++; 230 231 } 232 } 233 } 234 235 236 private void CreateLeafNode(int startIdx, int endIdx, LineSearchFunc lineSearch) { 237 // max depth reached or only one element 238 tree[curTreeNodeIdx].varName = RegressionTreeModel.TreeNode.NO_VARIABLE; 239 tree[curTreeNodeIdx].val = lineSearch(internalIdx, startIdx, endIdx); 240 } 241 242 // routine for building the tree for the partition of rows stored in internalIdx between startIdx and endIdx 243 // the lineSearch function calculates the optimal prediction value for tree leaf nodes 244 // (in the case of squared errors it is the average of target values for the rows represented by the node) 209 // transform the leaf node into an internal node 210 tree[f.ParentNodeIdx].VarName = f.SplittingVariable; 211 tree[f.ParentNodeIdx].Val = f.SplittingThreshold; 212 213 // split partition into left and right 214 int splitIdx; 215 SplitPartition(f.StartIdx, f.EndIndex, f.SplittingVariable, f.SplittingThreshold, out splitIdx); 216 Debug.Assert(splitIdx + 1 <= endIdx); 217 Debug.Assert(startIdx <= splitIdx); 218 219 // create two leaf nodes (and enqueue best splits for both) 220 tree[f.ParentNodeIdx].LeftIdx = CreateLeafNode(startIdx, splitIdx, lineSearch); 221 tree[f.ParentNodeIdx].RightIdx = CreateLeafNode(splitIdx + 1, endIdx, lineSearch); 222 } 223 } 224 225 226 // returns the index of the newly created tree node 227 private int CreateLeafNode(int startIdx, int endIdx, LineSearchFunc lineSearch) { 228 tree[curTreeNodeIdx].VarName = RegressionTreeModel.TreeNode.NO_VARIABLE; 229 tree[curTreeNodeIdx].Val = lineSearch(internalIdx, startIdx, endIdx); 230 231 EnqueuePartitionSplit(curTreeNodeIdx, startIdx, endIdx); 232 curTreeNodeIdx++; 233 return curTreeNodeIdx - 1; 234 } 235 236 237 // calculates the optimal split for the partition [startIdx .. endIdx] (inclusive) 238 // which is represented by the leaf node with the specified nodeIdx 239 private void EnqueuePartitionSplit(int nodeIdx, int startIdx, int endIdx) { 240 double threshold, improvement; 241 string bestVariableName; 242 // only enqueue a new split if there are at least 2 rows left and a split is possible 243 if (startIdx < endIdx && 244 FindBestVariableAndThreshold(startIdx, endIdx, out threshold, out bestVariableName, out improvement)) { 245 var split = new PartitionSplits() { 246 ParentNodeIdx = nodeIdx, 247 StartIdx = startIdx, 248 EndIndex = endIdx, 249 SplittingThreshold = threshold, 250 SplittingVariable = bestVariableName 251 }; 252 InsertSortedQueue(split); 253 } 254 } 255 256 257 // routine for splitting a partition of rows stored in internalIdx between startIdx and endIdx into 258 // a left partition and a right partition using the given splittingVariable and threshold 259 // the splitIdx is the last index of the left partition 260 // splitIdx + 1 is the first index of the right partition 245 261 // startIdx and endIdx are inclusive 246 private void CreateInternalNode(int startIdx, int endIdx, string splittingVar, double threshold, out int splitIdx) {262 private void SplitPartition(int startIdx, int endIdx, string splittingVar, double threshold, out int splitIdx) { 247 263 int bestVarIdx = varName2Index[splittingVar]; 248 264 // split - two pass … … 303 319 Debug.Assert(startIdx <= j); 304 320 305 tree[curTreeNodeIdx].varName = splittingVar;306 tree[curTreeNodeIdx].val = threshold;307 321 splitIdx = j; 308 322 } 309 323 310 private bool FindBestVariableAndThreshold(int startIdx, int endIdx, out double threshold, out string bestVar ) {324 private bool FindBestVariableAndThreshold(int startIdx, int endIdx, out double threshold, out string bestVar, out double improvement) { 311 325 Debug.Assert(startIdx < endIdx + 1); // at least 2 elements 312 326 … … 345 359 } 346 360 if (bestVar == RegressionTreeModel.TreeNode.NO_VARIABLE) { 347 threshold = bestThreshold; 361 // not successfull 362 threshold = double.PositiveInfinity; 363 improvement = double.NegativeInfinity; 348 364 return false; 349 365 } else { 350 366 UpdateVariableRelevance(bestVar, sumY, bestImprovement, rows); 367 improvement = bestImprovement; 351 368 threshold = bestThreshold; 352 369 return true; 353 370 } 354 }355 356 private void UpdateVariableRelevance(string bestVar, double sumY, double bestImprovement, int rows) {357 if (string.IsNullOrEmpty(bestVar)) return;358 // update variable relevance359 double baseLine = 1.0 / rows * sumY * sumY; // if best improvement is equal to baseline then the split had no effect360 361 double delta = (bestImprovement - baseLine);362 double v;363 if (!sumImprovements.TryGetValue(bestVar, out v)) {364 sumImprovements[bestVar] = delta;365 }366 sumImprovements[bestVar] = v + delta;367 371 } 368 372 … … 429 433 430 434 435 private void UpdateVariableRelevance(string bestVar, double sumY, double bestImprovement, int rows) { 436 if (string.IsNullOrEmpty(bestVar)) return; 437 // update variable relevance 438 double baseLine = 1.0 / rows * sumY * sumY; // if best improvement is equal to baseline then the split had no effect 439 440 double delta = (bestImprovement - baseLine); 441 double v; 442 if (!sumImprovements.TryGetValue(bestVar, out v)) { 443 sumImprovements[bestVar] = delta; 444 } 445 sumImprovements[bestVar] = v + delta; 446 } 447 431 448 public IEnumerable<KeyValuePair<string, double>> GetVariableRelevance() { 432 449 // values are scaled: the most important variable has relevance = 100 … … 437 454 .OrderByDescending(t => t.Value); 438 455 } 456 457 458 // insert a new parition split (find insertion point and start at first element of the queue) 459 // elements are removed from the queue at the last position 460 // O(n), splits could be organized as a heap to improve runtime (see alglib tsort) 461 private void InsertSortedQueue(PartitionSplits split) { 462 // find insertion position 463 int i = 0; 464 while (i < queue.Count && queue[i].SplittingImprovement < split.SplittingImprovement) { i++; } 465 466 queue.Insert(i, split); 467 } 439 468 } 440 469 }
Note: See TracChangeset
for help on using the changeset viewer.