source: branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs @ 12620

Last change on this file since 12620 was 12620, checked in by gkronber, 4 years ago

#2261: corrected check if a split is useful, added a unit test class and added an elaborate comment on split quality calculation

File size: 18.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 * and the BEACON Center for the Study of Evolution in Action.
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21#endregion
22
23using System;
24using System.Collections;
25using System.Collections.Generic;
26using System.Diagnostics;
27using System.Linq;
28using HeuristicLab.Core;
29using HeuristicLab.Problems.DataAnalysis;
30
31namespace HeuristicLab.Algorithms.DataAnalysis {
32  // This class implements a greedy decision tree learner which selects splits with the maximum reduction in sum of squared errors.
33  // The tree builder also tracks variable relevance metrics based on the splits and improvement after the split.
34  // The implementation is tuned for gradient boosting where multiple trees have to be calculated for the same training data
35  // each time with a different target vector. Vectors of idx to allow iteration of intput variables in sorted order are
36  // pre-calculated so that optimal thresholds for splits can be calculated in O(n) for each input variable.
37  // After each split the row idx are partitioned in a left an right part.
38  public class RegressionTreeBuilder {
39    private readonly IRandom random;
40    private readonly IRegressionProblemData problemData;
41
42    private readonly int nCols;
43    private readonly double[][] x; // all training data (original order from problemData), x is constant
44    private double[] y; // training labels (original order from problemData), y can be changed
45
46    private Dictionary<string, double> sumImprovements; // for variable relevance calculation
47
48    private readonly string[] allowedVariables; // all variables in shuffled order
49    private Dictionary<string, int> varName2Index; // maps the variable names to column indexes
50    private int effectiveVars; // number of variables that are used from allowedVariables
51
52    private int effectiveRows; // number of rows that are used from
53    private readonly int[][] sortedIdxAll;
54    private readonly int[][] sortedIdx; // random selection from sortedIdxAll (for r < 1.0)
55
56    private int calls = 0;
57
58    // helper arrays which are allocated to maximal necessary size only once in the ctor
59    private readonly int[] internalIdx, which, leftTmp, rightTmp;
60    private readonly double[] outx;
61    private readonly int[] outSortedIdx;
62
63    private RegressionTreeModel.TreeNode[] tree; // tree is represented as a flat array of nodes
64    private int curTreeNodeIdx; // the index where the next tree node is stored
65
66    private class Partition {
67      public int ParentNodeIdx { get; set; }
68      public int Depth { get; set; }
69      public int StartIdx { get; set; }
70      public int EndIndex { get; set; }
71      public bool Left { get; set; }
72    }
73    private readonly SortedList<double, Partition> queue;
74
75    // prepare and allocate buffer variables in ctor
76    public RegressionTreeBuilder(IRegressionProblemData problemData, IRandom random) {
77      this.problemData = problemData;
78      this.random = random;
79
80      var rows = problemData.TrainingIndices.Count();
81
82      this.nCols = problemData.AllowedInputVariables.Count();
83
84      allowedVariables = problemData.AllowedInputVariables.ToArray();
85      varName2Index = new Dictionary<string, int>(allowedVariables.Length);
86      for (int i = 0; i < allowedVariables.Length; i++) varName2Index.Add(allowedVariables[i], i);
87
88      sortedIdxAll = new int[nCols][];
89      sortedIdx = new int[nCols][];
90      sumImprovements = new Dictionary<string, double>();
91      internalIdx = new int[rows];
92      which = new int[rows];
93      leftTmp = new int[rows];
94      rightTmp = new int[rows];
95      outx = new double[rows];
96      outSortedIdx = new int[rows];
97      queue = new SortedList<double, Partition>();
98
99      x = new double[nCols][];
100      y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToArray();
101
102
103      int col = 0;
104      foreach (var inputVariable in problemData.AllowedInputVariables) {
105        x[col] = problemData.Dataset.GetDoubleValues(inputVariable, problemData.TrainingIndices).ToArray();
106        sortedIdxAll[col] = Enumerable.Range(0, rows).OrderBy(r => x[col][r]).ToArray();
107        sortedIdx[col] = new int[rows];
108        col++;
109      }
110    }
111
112    // simple API produces a single regression tree optimizing sum of squared errors
113    // this can be used if only a simple regression tree should be produced
114    // for a set of trees use the method CreateRegressionTreeForGradientBoosting below
115    //
116    // r and m work in the same way as for alglib random forest
117    // r is fraction of rows to use for training
118    // m is fraction of variables to use for training
119    public IRegressionModel CreateRegressionTree(int maxDepth, double r = 0.5, double m = 0.5) {
120      // subtract mean of y first
121      var yAvg = y.Average();
122      for (int i = 0; i < y.Length; i++) y[i] -= yAvg;
123
124      var seLoss = new SquaredErrorLoss();
125      var zeros = Enumerable.Repeat(0.0, y.Length);
126      var ones = Enumerable.Repeat(1.0, y.Length);
127
128      var model = CreateRegressionTreeForGradientBoosting(y, maxDepth, problemData.TrainingIndices.ToArray(), seLoss.GetLineSearchFunc(y, zeros, ones), r, m);
129
130      return new GradientBoostedTreesModel(new[] { new ConstantRegressionModel(yAvg), model }, new[] { 1.0, 1.0 });
131    }
132
133    // specific interface that allows to specify the target labels and the training rows which is necessary when for gradient boosted trees
134    public IRegressionModel CreateRegressionTreeForGradientBoosting(double[] y, int maxDepth, int[] idx, LineSearchFunc lineSearch, double r = 0.5, double m = 0.5) {
135      Debug.Assert(maxDepth > 0);
136      Debug.Assert(r > 0);
137      Debug.Assert(r <= 1.0);
138      Debug.Assert(y.Count() == this.y.Length);
139      Debug.Assert(m > 0);
140      Debug.Assert(m <= 1.0);
141
142      this.y = y; // y is changed in gradient boosting
143
144      // shuffle row idx
145      HeuristicLab.Random.ListExtensions.ShuffleInPlace(idx, random);
146
147      int nRows = idx.Count();
148
149      // shuffle variable idx
150      HeuristicLab.Random.ListExtensions.ShuffleInPlace(allowedVariables, random);
151
152      // only select a part of the rows and columns randomly
153      effectiveRows = (int)Math.Ceiling(nRows * r);
154      effectiveVars = (int)Math.Ceiling(nCols * m);
155
156      // the which array is used for partining row idxs
157      Array.Clear(which, 0, which.Length);
158
159      // mark selected rows
160      for (int row = 0; row < effectiveRows; row++) {
161        which[idx[row]] = 1; // we use the which vector as a temporary variable here
162        internalIdx[row] = idx[row];
163      }
164
165      for (int col = 0; col < nCols; col++) {
166        int i = 0;
167        for (int row = 0; row < nRows; row++) {
168          if (which[sortedIdxAll[col][row]] > 0) {
169            Debug.Assert(i < effectiveRows);
170            sortedIdx[col][i] = sortedIdxAll[col][row];
171            i++;
172          }
173        }
174      }
175
176      // prepare array for the tree nodes (a tree of maxDepth=1 has 1 node, a tree of maxDepth=d has 2^d - 1 nodes)     
177      int numNodes = (int)Math.Pow(2, maxDepth) - 1;
178      this.tree = new RegressionTreeModel.TreeNode[numNodes];
179      this.curTreeNodeIdx = 0;
180
181      // start and end idx are inclusive
182      queue.Add(calls++, new Partition() { ParentNodeIdx = -1, Depth = maxDepth, StartIdx = 0, EndIndex = effectiveRows - 1 });
183      CreateRegressionTreeForIdx(lineSearch);
184
185      return new RegressionTreeModel(tree);
186    }
187
188    private void CreateRegressionTreeForIdx(LineSearchFunc lineSearch) {
189      while (queue.Any()) {
190        var f = queue.First().Value; // actually a stack
191        queue.RemoveAt(0);
192
193        var depth = f.Depth;
194        var startIdx = f.StartIdx;
195        var endIdx = f.EndIndex;
196
197        Debug.Assert(endIdx - startIdx >= 0);
198        Debug.Assert(startIdx >= 0);
199        Debug.Assert(endIdx < internalIdx.Length);
200
201        double threshold;
202        string bestVariableName;
203
204        // stop when only one row is left or no split is possible
205        if (depth <= 1 || endIdx - startIdx == 0 || !FindBestVariableAndThreshold(startIdx, endIdx, out threshold, out bestVariableName)) {
206          CreateLeafNode(startIdx, endIdx, lineSearch);
207          if (f.ParentNodeIdx >= 0) if (f.Left) {
208              tree[f.ParentNodeIdx].leftIdx = curTreeNodeIdx;
209            } else {
210              tree[f.ParentNodeIdx].rightIdx = curTreeNodeIdx;
211            }
212          curTreeNodeIdx++;
213        } else {
214          int splitIdx;
215          CreateInternalNode(f.StartIdx, f.EndIndex, bestVariableName, threshold, out splitIdx);
216
217          // connect to parent tree
218          if (f.ParentNodeIdx >= 0) if (f.Left) {
219              tree[f.ParentNodeIdx].leftIdx = curTreeNodeIdx;
220            } else {
221              tree[f.ParentNodeIdx].rightIdx = curTreeNodeIdx;
222            }
223
224          Debug.Assert(splitIdx + 1 <= endIdx);
225          Debug.Assert(startIdx <= splitIdx);
226
227          queue.Add(calls++, new Partition() { ParentNodeIdx = curTreeNodeIdx, Left = true, Depth = depth - 1, StartIdx = startIdx, EndIndex = splitIdx }); // left part before right part (stack organization)
228          queue.Add(calls++, new Partition() { ParentNodeIdx = curTreeNodeIdx, Left = false, Depth = depth - 1, StartIdx = splitIdx + 1, EndIndex = endIdx });
229          curTreeNodeIdx++;
230
231        }
232      }
233    }
234
235
236    private void CreateLeafNode(int startIdx, int endIdx, LineSearchFunc lineSearch) {
237      // max depth reached or only one element   
238      tree[curTreeNodeIdx].varName = RegressionTreeModel.TreeNode.NO_VARIABLE;
239      tree[curTreeNodeIdx].val = lineSearch(internalIdx, startIdx, endIdx);
240    }
241
242    // routine for building the tree for the row idx stored in internalIdx between startIdx and endIdx
243    // the lineSearch function calculates the optimal prediction value for tree leaf nodes
244    // (in the case of squared errors it is the average of target values for the rows represented by the node)
245    // startIdx and endIdx are inclusive
246    private void CreateInternalNode(int startIdx, int endIdx, string splittingVar, double threshold, out int splitIdx) {
247      int bestVarIdx = varName2Index[splittingVar];
248      // split - two pass
249
250      // store which index goes where
251      for (int k = startIdx; k <= endIdx; k++) {
252        if (x[bestVarIdx][internalIdx[k]] <= threshold)
253          which[internalIdx[k]] = -1; // left partition
254        else
255          which[internalIdx[k]] = 1; // right partition
256      }
257
258      // partition sortedIdx for each variable
259      int i;
260      int j;
261      for (int col = 0; col < nCols; col++) {
262        i = 0;
263        j = 0;
264        int k;
265        for (k = startIdx; k <= endIdx; k++) {
266          Debug.Assert(Math.Abs(which[sortedIdx[col][k]]) == 1);
267
268          if (which[sortedIdx[col][k]] < 0) {
269            leftTmp[i++] = sortedIdx[col][k];
270          } else {
271            rightTmp[j++] = sortedIdx[col][k];
272          }
273        }
274        Debug.Assert(i > 0); // at least on element in the left partition
275        Debug.Assert(j > 0); // at least one element in the right partition
276        Debug.Assert(i + j == endIdx - startIdx + 1);
277        k = startIdx;
278        for (int l = 0; l < i; l++) sortedIdx[col][k++] = leftTmp[l];
279        for (int l = 0; l < j; l++) sortedIdx[col][k++] = rightTmp[l];
280      }
281
282      // partition row indices
283      i = startIdx;
284      j = endIdx;
285      while (i <= j) {
286        Debug.Assert(Math.Abs(which[internalIdx[i]]) == 1);
287        Debug.Assert(Math.Abs(which[internalIdx[j]]) == 1);
288        if (which[internalIdx[i]] < 0) i++;
289        else if (which[internalIdx[j]] > 0) j--;
290        else {
291          Debug.Assert(which[internalIdx[i]] > 0);
292          Debug.Assert(which[internalIdx[j]] < 0);
293          // swap
294          int tmp = internalIdx[i];
295          internalIdx[i] = internalIdx[j];
296          internalIdx[j] = tmp;
297          i++;
298          j--;
299        }
300      }
301      Debug.Assert(j + 1 == i);
302      Debug.Assert(i <= endIdx);
303      Debug.Assert(startIdx <= j);
304
305      tree[curTreeNodeIdx].varName = splittingVar;
306      tree[curTreeNodeIdx].val = threshold;
307      splitIdx = j;
308    }
309
310    private bool FindBestVariableAndThreshold(int startIdx, int endIdx, out double threshold, out string bestVar) {
311      Debug.Assert(startIdx < endIdx + 1); // at least 2 elements
312
313      int rows = endIdx - startIdx + 1;
314      Debug.Assert(rows >= 2);
315
316      double sumY = 0.0;
317      for (int i = startIdx; i <= endIdx; i++) {
318        sumY += y[internalIdx[i]];
319      }
320
321      double bestImprovement = 1.0 / rows * sumY * sumY;
322      double bestThreshold = double.PositiveInfinity;
323      bestVar = RegressionTreeModel.TreeNode.NO_VARIABLE;
324
325      for (int col = 0; col < effectiveVars; col++) {
326        // sort values for variable to prepare for threshold selection
327        var curVariable = allowedVariables[col];
328        var curVariableIdx = varName2Index[curVariable];
329        for (int i = startIdx; i <= endIdx; i++) {
330          var sortedI = sortedIdx[curVariableIdx][i];
331          outSortedIdx[i - startIdx] = sortedI;
332          outx[i - startIdx] = x[curVariableIdx][sortedI];
333        }
334
335        double curImprovement;
336        double curThreshold;
337        FindBestThreshold(outx, outSortedIdx, rows, y, sumY, out curThreshold, out curImprovement);
338
339        if (curImprovement > bestImprovement) {
340          bestImprovement = curImprovement;
341          bestThreshold = curThreshold;
342          bestVar = allowedVariables[col];
343        }
344      }
345      if (bestVar == RegressionTreeModel.TreeNode.NO_VARIABLE) {
346        threshold = bestThreshold;
347        return false;
348      } else {
349        UpdateVariableRelevance(bestVar, sumY, bestImprovement, rows);
350        threshold = bestThreshold;
351        return true;
352      }
353    }
354
355    // TODO: assumption is that the Average(y) = 0
356    private void UpdateVariableRelevance(string bestVar, double sumY, double bestImprovement, int rows) {
357      if (string.IsNullOrEmpty(bestVar)) return;
358      // update variable relevance
359      double err = sumY * sumY / rows;
360      double errAfterSplit = bestImprovement;
361
362      double delta = (errAfterSplit - err); // relative reduction in squared error
363      double v;
364      if (!sumImprovements.TryGetValue(bestVar, out v)) {
365        sumImprovements[bestVar] = delta;
366      }
367      sumImprovements[bestVar] = v + delta;
368    }
369
370    // x [0..N-1] contains rows sorted values in the range from [0..rows-1]
371    // sortedIdx [0..N-1] contains the idx of the values in x in the original dataset in the range from [0..rows-1]
372    // rows specifies the number of valid entries in x and sortedIdx
373    // y [0..N-1] contains the target values in original sorting order
374    // sumY is y.Sum()
375    //
376    // the routine returns the best threshold (x[i] + x[i+1]) / 2 for i = [0 .. rows-2] by calculating the reduction in squared error
377    // additionally the reduction in squared error is returned in bestImprovement
378    // if all elements of x are equal the routing fails to produce a threshold
379    private static void FindBestThreshold(double[] x, int[] sortedIdx, int rows, double[] y, double sumY, out double bestThreshold, out double bestImprovement) {
380      Debug.Assert(rows >= 2);
381
382      double sl = 0.0;
383      double sr = sumY;
384      double nl = 0.0;
385      double nr = rows;
386
387      bestImprovement = 1.0 / rows * sumY * sumY;
388      bestThreshold = double.NegativeInfinity;
389      // for all thresholds
390      // if we have n rows there are n-1 possible splits
391      for (int i = 0; i < rows - 1; i++) {
392        sl += y[sortedIdx[i]];
393        sr -= y[sortedIdx[i]];
394
395        nl++;
396        nr--;
397        Debug.Assert(nl > 0);
398        Debug.Assert(nr > 0);
399
400        if (x[i] < x[i + 1]) { // don't try to split when two elements are equal
401
402          // goal is to find the split with leading to minimal total variance of left and right parts
403          // without partitioning the variance is var(y) = E(y²) - E(y)² 
404          //    = 1/n * sum(y²) - (1/n * sum(y))²
405          //      -------------
406          // if we split into right and left part the overall variance is the weigthed combination nl/n * var(y_l) + nr/n * var(y_r) 
407          //    = nl/n * (1/nl * sum(y_l²) - (1/nl * sum(y_l))²) + nr/n * (1/nr * sum(y_r²) - (1/nr * sum(y_r))²)
408          //    = 1/n * sum(y_l²) - 1/nl * 1/n * sum(y_l)² + 1/n * sum(y_r²) - 1/nr * 1/n * sum(y_r)²
409          //    = 1/n * (sum(y_l²) + sum(y_r²)) - 1/n * (sum(y_l)² / nl + sum(y_r)² / nr)
410          //    = 1/n * sum(y²) - 1/n * (sum(y_l)² / nl + sum(y_r)² / nr)
411          //      -------------
412          //       not changed by split (and the same for total variance without partitioning)
413          //
414          //   therefore we need to find the maximum value (sum(y_l)² / nl + sum(y_r)² / nr) (ignoring the factor 1/n)
415          //   and this value must be larger than 1/n * sum(y)² to be an improvement over no split
416
417          double curQuality = sl * sl / nl + sr * sr / nr;
418
419          if (curQuality > bestImprovement) {
420            bestThreshold = (x[i] + x[i + 1]) / 2.0;
421            bestImprovement = curQuality;
422          }
423        }
424      }
425
426      // if all elements where the same then no split can be found
427    }
428
429
430    public IEnumerable<KeyValuePair<string, double>> GetVariableRelevance() {
431      // values are scaled: the most important variable has relevance = 100
432      double scaling = 100 / sumImprovements.Max(t => t.Value);
433      return
434        sumImprovements
435        .Select(t => new KeyValuePair<string, double>(t.Key, t.Value * scaling))
436        .OrderByDescending(t => t.Value);
437    }
438  }
439}
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
Note: See TracBrowser for help on using the repository browser.