Changeset 12619
- Timestamp:
- 07/06/15 18:33:24 (9 years ago)
- Location:
- branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees
- Files:
-
- 2 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithm.cs
r12611 r12619 242 242 243 243 var classificationSolution = new DiscriminantFunctionClassificationSolution(model, classificationProblemData); 244 Results.Add(new Result("Solution (classification)", classificationSolution));244 Results.Add(new Result("Solution", classificationSolution)); 245 245 } else { 246 246 // otherwise we produce a regression solution -
branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs
r12597 r12619 24 24 using System.Collections.Generic; 25 25 using System.Diagnostics; 26 using System.Diagnostics.Contracts;27 26 using System.Linq; 28 27 using HeuristicLab.Core; … … 64 63 private int curTreeNodeIdx; // the index where the next tree node is stored 65 64 66 private readonly IList<RegressionTreeModel.TreeNode> nodeQueue; 65 private class Partition { 66 public int TreeNodeIdx { get; set; } 67 public int Depth { get; set; } 68 public int StartIdx { get; set; } 69 public int EndIndex { get; set; } 70 public bool Left { get; set; } 71 } 72 private readonly IList<Partition> queue; 67 73 68 74 // prepare and allocate buffer variables in ctor … … 88 94 outx = new double[rows]; 89 95 outSortedIdx = new int[rows]; 90 nodeQueue = new List<RegressionTreeModel.TreeNode>();96 queue = new List<Partition>(); 91 97 92 98 x = new double[nCols][]; … … 120 126 121 127 var model = CreateRegressionTreeForGradientBoosting(y, maxDepth, problemData.TrainingIndices.ToArray(), seLoss.GetLineSearchFunc(y, zeros, ones), r, m); 128 122 129 return new GradientBoostedTreesModel(new[] { new ConstantRegressionModel(yAvg), model }, new[] { 1.0, 1.0 }); 123 130 } … … 125 132 // specific interface that allows to specify the target labels and the training rows which is necessary when for gradient boosted trees 126 133 public IRegressionModel CreateRegressionTreeForGradientBoosting(double[] y, int maxDepth, int[] idx, LineSearchFunc lineSearch, double r = 0.5, double m = 0.5) { 127 Contract.Assert(maxDepth > 0);128 Contract.Assert(r > 0);129 Contract.Assert(r <= 1.0);130 Contract.Assert(y.Count() == this.y.Length);131 Contract.Assert(m > 0);132 Contract.Assert(m <= 1.0);134 Debug.Assert(maxDepth > 0); 135 Debug.Assert(r > 0); 136 Debug.Assert(r <= 1.0); 137 Debug.Assert(y.Count() == this.y.Length); 138 Debug.Assert(m > 0); 139 Debug.Assert(m <= 1.0); 133 140 134 141 this.y = y; // y is changed in gradient boosting … … 172 179 173 180 // start and end idx are inclusive 174 CreateRegressionTreeForIdx(maxDepth, 0, effectiveRows - 1, lineSearch); 181 queue.Add(new Partition() { TreeNodeIdx = -1, Depth = maxDepth, StartIdx = 0, EndIndex = effectiveRows - 1 }); 182 CreateRegressionTreeForIdx(lineSearch); 183 175 184 return new RegressionTreeModel(tree); 176 185 } 177 186 178 // recursive routine for building the tree for the row idx stored in internalIdx between startIdx and endIdx 187 private void CreateRegressionTreeForIdx(LineSearchFunc lineSearch) { 188 while (queue.Any()) { 189 var f = queue[0]; // actually a stack 190 queue.RemoveAt(0); 191 192 var depth = f.Depth; 193 var startIdx = f.StartIdx; 194 var endIdx = f.EndIndex; 195 196 Debug.Assert(endIdx - startIdx >= 0); 197 Debug.Assert(startIdx >= 0); 198 Debug.Assert(endIdx < internalIdx.Length); 199 200 double threshold; 201 string bestVariableName; 202 203 // stop when only one row is left or no split is possible 204 if (depth <= 1 || endIdx - startIdx == 0 || !FindBestVariableAndThreshold(startIdx, endIdx, out threshold, out bestVariableName)) { 205 CreateLeafNode(startIdx, endIdx, lineSearch); 206 if (f.TreeNodeIdx >= 0) if (f.Left) { 207 tree[f.TreeNodeIdx].leftIdx = curTreeNodeIdx; 208 } else { 209 tree[f.TreeNodeIdx].rightIdx = curTreeNodeIdx; 210 } 211 curTreeNodeIdx++; 212 } else { 213 int splitIdx; 214 CreateInternalNode(f.StartIdx, f.EndIndex, bestVariableName, threshold, out splitIdx); 215 216 // connect to parent tree 217 if (f.TreeNodeIdx >= 0) if (f.Left) { 218 tree[f.TreeNodeIdx].leftIdx = curTreeNodeIdx; 219 } else { 220 tree[f.TreeNodeIdx].rightIdx = curTreeNodeIdx; 221 } 222 223 Debug.Assert(splitIdx + 1 <= endIdx); 224 Debug.Assert(startIdx <= splitIdx); 225 226 queue.Insert(0, new Partition() { TreeNodeIdx = curTreeNodeIdx, Left = false, Depth = depth - 1, StartIdx = splitIdx + 1, EndIndex = endIdx }); 227 queue.Insert(0, new Partition() { TreeNodeIdx = curTreeNodeIdx, Left = true, Depth = depth - 1, StartIdx = startIdx, EndIndex = splitIdx }); // left part before right part (stack organization) 228 curTreeNodeIdx++; 229 230 } 231 } 232 } 233 234 235 private void CreateLeafNode(int startIdx, int endIdx, LineSearchFunc lineSearch) { 236 // max depth reached or only one element 237 tree[curTreeNodeIdx].varName = RegressionTreeModel.TreeNode.NO_VARIABLE; 238 tree[curTreeNodeIdx].val = lineSearch(internalIdx, startIdx, endIdx); 239 } 240 241 // routine for building the tree for the row idx stored in internalIdx between startIdx and endIdx 179 242 // the lineSearch function calculates the optimal prediction value for tree leaf nodes 180 243 // (in the case of squared errors it is the average of target values for the rows represented by the node) 181 244 // startIdx and endIdx are inclusive 182 private void CreateRegressionTreeForIdx(int maxDepth, int startIdx, int endIdx, LineSearchFunc lineSearch) { 183 Contract.Assert(endIdx - startIdx >= 0); 184 Contract.Assert(startIdx >= 0); 185 Contract.Assert(endIdx < internalIdx.Length); 186 187 // TODO: stop when y is constant 188 // TODO: use priority queue of nodes to be expanded (sorted by improvement) instead of the recursion to maximum depth 189 if (maxDepth <= 1 || endIdx - startIdx == 0) { 190 // max depth reached or only one element 191 tree[curTreeNodeIdx].varName = RegressionTreeModel.TreeNode.NO_VARIABLE; 192 tree[curTreeNodeIdx].val = lineSearch(internalIdx, startIdx, endIdx); 193 curTreeNodeIdx++; 194 } else { 195 int i, j; 196 double threshold; 197 string bestVariableName; 198 FindBestVariableAndThreshold(startIdx, endIdx, out threshold, out bestVariableName); 199 200 // if bestVariableName is NO_VARIABLE then no split was possible anymore 201 if (bestVariableName == RegressionTreeModel.TreeNode.NO_VARIABLE) { 202 // max depth reached or only one element 203 tree[curTreeNodeIdx].varName = RegressionTreeModel.TreeNode.NO_VARIABLE; 204 tree[curTreeNodeIdx].val = lineSearch(internalIdx, startIdx, endIdx); 205 curTreeNodeIdx++; 206 } else { 207 208 209 int bestVarIdx = varName2Index[bestVariableName]; 210 // split - two pass 211 212 // store which index goes where 213 for (int k = startIdx; k <= endIdx; k++) { 214 if (x[bestVarIdx][internalIdx[k]] <= threshold) 215 which[internalIdx[k]] = -1; // left partition 216 else 217 which[internalIdx[k]] = 1; // right partition 245 private void CreateInternalNode(int startIdx, int endIdx, string splittingVar, double threshold, out int splitIdx) { 246 int bestVarIdx = varName2Index[splittingVar]; 247 // split - two pass 248 249 // store which index goes where 250 for (int k = startIdx; k <= endIdx; k++) { 251 if (x[bestVarIdx][internalIdx[k]] <= threshold) 252 which[internalIdx[k]] = -1; // left partition 253 else 254 which[internalIdx[k]] = 1; // right partition 255 } 256 257 // partition sortedIdx for each variable 258 int i; 259 int j; 260 for (int col = 0; col < nCols; col++) { 261 i = 0; 262 j = 0; 263 int k; 264 for (k = startIdx; k <= endIdx; k++) { 265 Debug.Assert(Math.Abs(which[sortedIdx[col][k]]) == 1); 266 267 if (which[sortedIdx[col][k]] < 0) { 268 leftTmp[i++] = sortedIdx[col][k]; 269 } else { 270 rightTmp[j++] = sortedIdx[col][k]; 218 271 } 219 220 // partition sortedIdx for each variable 221 for (int col = 0; col < nCols; col++) { 222 i = 0; 223 j = 0; 224 int k; 225 for (k = startIdx; k <= endIdx; k++) { 226 Debug.Assert(Math.Abs(which[sortedIdx[col][k]]) == 1); 227 228 if (which[sortedIdx[col][k]] < 0) { 229 leftTmp[i++] = sortedIdx[col][k]; 230 } else { 231 rightTmp[j++] = sortedIdx[col][k]; 232 } 233 } 234 Debug.Assert(i > 0); // at least on element in the left partition 235 Debug.Assert(j > 0); // at least one element in the right partition 236 Debug.Assert(i + j == endIdx - startIdx + 1); 237 k = startIdx; 238 for (int l = 0; l < i; l++) sortedIdx[col][k++] = leftTmp[l]; 239 for (int l = 0; l < j; l++) sortedIdx[col][k++] = rightTmp[l]; 240 } 241 242 // partition row indices 243 i = startIdx; 244 j = endIdx; 245 while (i <= j) { 246 Debug.Assert(Math.Abs(which[internalIdx[i]]) == 1); 247 Debug.Assert(Math.Abs(which[internalIdx[j]]) == 1); 248 if (which[internalIdx[i]] < 0) i++; 249 else if (which[internalIdx[j]] > 0) j--; 250 else { 251 Debug.Assert(which[internalIdx[i]] > 0); 252 Debug.Assert(which[internalIdx[j]] < 0); 253 // swap 254 int tmp = internalIdx[i]; 255 internalIdx[i] = internalIdx[j]; 256 internalIdx[j] = tmp; 257 i++; 258 j--; 259 } 260 } 261 Debug.Assert(j < i); 262 Debug.Assert(i >= startIdx); 263 Debug.Assert(j <= endIdx); 264 265 var parentIdx = curTreeNodeIdx; 266 tree[parentIdx].varName = bestVariableName; 267 tree[parentIdx].val = threshold; 268 curTreeNodeIdx++; 269 270 // create left subtree 271 tree[parentIdx].leftIdx = curTreeNodeIdx; 272 CreateRegressionTreeForIdx(maxDepth - 1, startIdx, j, lineSearch); 273 274 // create right subtree 275 tree[parentIdx].rightIdx = curTreeNodeIdx; 276 CreateRegressionTreeForIdx(maxDepth - 1, i, endIdx, lineSearch); 277 } 278 } 279 } 280 281 private void FindBestVariableAndThreshold(int startIdx, int endIdx, out double threshold, out string bestVar) { 282 Contract.Assert(startIdx < endIdx + 1); // at least 2 elements 272 } 273 Debug.Assert(i > 0); // at least on element in the left partition 274 Debug.Assert(j > 0); // at least one element in the right partition 275 Debug.Assert(i + j == endIdx - startIdx + 1); 276 k = startIdx; 277 for (int l = 0; l < i; l++) sortedIdx[col][k++] = leftTmp[l]; 278 for (int l = 0; l < j; l++) sortedIdx[col][k++] = rightTmp[l]; 279 } 280 281 // partition row indices 282 i = startIdx; 283 j = endIdx; 284 while (i <= j) { 285 Debug.Assert(Math.Abs(which[internalIdx[i]]) == 1); 286 Debug.Assert(Math.Abs(which[internalIdx[j]]) == 1); 287 if (which[internalIdx[i]] < 0) i++; 288 else if (which[internalIdx[j]] > 0) j--; 289 else { 290 Debug.Assert(which[internalIdx[i]] > 0); 291 Debug.Assert(which[internalIdx[j]] < 0); 292 // swap 293 int tmp = internalIdx[i]; 294 internalIdx[i] = internalIdx[j]; 295 internalIdx[j] = tmp; 296 i++; 297 j--; 298 } 299 } 300 Debug.Assert(j + 1 == i); 301 Debug.Assert(i <= endIdx); 302 Debug.Assert(startIdx <= j); 303 304 tree[curTreeNodeIdx].varName = splittingVar; 305 tree[curTreeNodeIdx].val = threshold; 306 splitIdx = j; 307 } 308 309 private bool FindBestVariableAndThreshold(int startIdx, int endIdx, out double threshold, out string bestVar) { 310 Debug.Assert(startIdx < endIdx + 1); // at least 2 elements 283 311 284 312 int rows = endIdx - startIdx + 1; 285 Contract.Assert(rows >= 2);313 Debug.Assert(rows >= 2); 286 314 287 315 double sumY = 0.0; … … 292 320 double bestImprovement = 0.0; 293 321 double bestThreshold = double.PositiveInfinity; 294 bestVar = string.Empty;322 bestVar = RegressionTreeModel.TreeNode.NO_VARIABLE; 295 323 296 324 for (int col = 0; col < effectiveVars; col++) { … … 314 342 } 315 343 } 316 317 UpdateVariableRelevance(bestVar, sumY, bestImprovement, rows); 318 319 threshold = bestThreshold; 320 } 321 322 // assumption is that the Average(y) = 0 344 if (bestVar == RegressionTreeModel.TreeNode.NO_VARIABLE) { 345 threshold = bestThreshold; 346 return false; 347 } else { 348 UpdateVariableRelevance(bestVar, sumY, bestImprovement, rows); 349 threshold = bestThreshold; 350 return true; 351 } 352 } 353 354 // TODO: assumption is that the Average(y) = 0 323 355 private void UpdateVariableRelevance(string bestVar, double sumY, double bestImprovement, int rows) { 324 356 if (string.IsNullOrEmpty(bestVar)) return; … … 345 377 // if all elements of x are equal the routing fails to produce a threshold 346 378 private static void FindBestThreshold(double[] x, int[] sortedIdx, int rows, double[] y, double sumY, out double bestThreshold, out double bestImprovement) { 347 Contract.Assert(rows >= 2);379 Debug.Assert(rows >= 2); 348 380 349 381 double sl = 0.0;
Note: See TracChangeset
for help on using the changeset viewer.