Changeset 12590 for branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs
- Timestamp:
- 07/04/15 16:03:36 (9 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs
r12372 r12590 1 using System; 1 #region License Information 2 /* HeuristicLab 3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 * and the BEACON Center for the Study of Evolution in Action. 5 * 6 * This file is part of HeuristicLab. 7 * 8 * HeuristicLab is free software: you can redistribute it and/or modify 9 * it under the terms of the GNU General Public License as published by 10 * the Free Software Foundation, either version 3 of the License, or 11 * (at your option) any later version. 12 * 13 * HeuristicLab is distributed in the hope that it will be useful, 14 * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 * GNU General Public License for more details. 17 * 18 * You should have received a copy of the GNU General Public License 19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>. 20 */ 21 #endregion 22 23 using System; 2 24 using System.Collections.Generic; 3 25 using System.Diagnostics; 4 26 using System.Diagnostics.Contracts; 5 27 using System.Linq; 6 using HeuristicLab.Common;7 28 using HeuristicLab.Core; 8 29 using HeuristicLab.Problems.DataAnalysis; 9 30 10 namespace GradientBoostedTrees { 31 namespace HeuristicLab.Algorithms.DataAnalysis { 32 // This class implements a greedy decision tree learner which selects splits with the maximum reduction in sum of squared errors. 33 // The tree builder also tracks variable relevance metrics based on the splits and improvement after the split. 34 // The implementation is tuned for gradient boosting where multiple trees have to be calculated for the same training data 35 // each time with a different target vector. Vectors of idx to allow iteration of intput variables in sorted order are 36 // pre-calculated so that optimal thresholds for splits can be calculated in O(n) for each input variable. 37 // After each split the row idx are partitioned in a left an right part. 11 38 public class RegressionTreeBuilder { 12 39 private readonly IRandom random; … … 14 41 15 42 private readonly int nCols; 16 private readonly double[][] x; // all training data (original order from problemData) 17 private double[] y; // training labels (original order from problemData) 43 private readonly double[][] x; // all training data (original order from problemData), x is constant 44 private double[] y; // training labels (original order from problemData), y can be changed 18 45 19 46 private Dictionary<string, double> sumImprovements; // for variable relevance calculation … … 76 103 } 77 104 105 // simple API produces a single regression tree optimizing sum of squared errors 106 // this can be used if only a simple regression tree should be produced 107 // for a set of trees use the method CreateRegressionTreeForGradientBoosting below 108 // 78 109 // r and m work in the same way as for alglib random forest 79 110 // r is fraction of rows to use for training … … 92 123 } 93 124 94 // specific interface that allows to specify the target labels and the training rows which is necessary when this functionality is called by the gradient boosting routine125 // specific interface that allows to specify the target labels and the training rows which is necessary when for gradient boosted trees 95 126 public IRegressionModel CreateRegressionTreeForGradientBoosting(double[] y, int maxDepth, int[] idx, LineSearchFunc lineSearch, double r = 0.5, double m = 0.5) { 96 127 Contract.Assert(maxDepth > 0); … … 111 142 HeuristicLab.Random.ListExtensions.ShuffleInPlace(allowedVariables, random); 112 143 144 // only select a part of the rows and columns randomly 113 145 effectiveRows = (int)Math.Ceiling(nRows * r); 114 146 effectiveVars = (int)Math.Ceiling(nCols * m); 115 147 148 // the which array is used for partining row idxs 116 149 Array.Clear(which, 0, which.Length); 117 150 118 151 // mark selected rows 119 152 for (int row = 0; row < effectiveRows; row++) { 120 which[idx[row]] = 1; 153 which[idx[row]] = 1; // we use the which vector as a temporary variable here 121 154 internalIdx[row] = idx[row]; 122 155 } … … 126 159 for (int row = 0; row < nRows; row++) { 127 160 if (which[sortedIdxAll[col][row]] > 0) { 128 Trace.Assert(i < effectiveRows);161 Debug.Assert(i < effectiveRows); 129 162 sortedIdx[col][i] = sortedIdxAll[col][row]; 130 163 i++; … … 135 168 // prepare array for the tree nodes (a tree of maxDepth=1 has 1 node, a tree of maxDepth=d has 2^d - 1 nodes) 136 169 int numNodes = (int)Math.Pow(2, maxDepth) - 1; 137 //this.tree = new RegressionTreeModel.TreeNode[numNodes]; 138 this.tree = Enumerable.Range(0, numNodes).Select(_=>new RegressionTreeModel.TreeNode()).ToArray(); 170 this.tree = new RegressionTreeModel.TreeNode[numNodes]; 139 171 this.curTreeNodeIdx = 0; 140 172 … … 144 176 } 145 177 178 // recursive routine for building the tree for the row idx stored in internalIdx between startIdx and endIdx 179 // the lineSearch function calculates the optimal prediction value for tree leaf nodes 180 // (in the case of squared errors it is the average of target values for the rows represented by the node) 146 181 // startIdx and endIdx are inclusive 147 182 private void CreateRegressionTreeForIdx(int maxDepth, int startIdx, int endIdx, LineSearchFunc lineSearch) { … … 214 249 else if (which[internalIdx[j]] > 0) j--; 215 250 else { 216 Trace.Assert(which[internalIdx[i]] > 0);217 Trace.Assert(which[internalIdx[j]] < 0);251 Debug.Assert(which[internalIdx[i]] > 0); 252 Debug.Assert(which[internalIdx[j]] < 0); 218 253 // swap 219 254 int tmp = internalIdx[i]; … … 283 318 284 319 threshold = bestThreshold; 285 286 // Contract.Assert(bestImprovement > 0);287 // Contract.Assert(bestImprovement < double.PositiveInfinity);288 // Contract.Assert(bestVar != string.Empty);289 // Contract.Assert(allowedVariables.Contains(bestVar));290 320 } 291 321 … … 351 381 352 382 public IEnumerable<KeyValuePair<string, double>> GetVariableRelevance() { 383 // values are scaled: the most important variable has relevance = 100 353 384 double scaling = 100 / sumImprovements.Max(t => t.Value); 354 385 return
Note: See TracChangeset
for help on using the changeset viewer.