Changeset 12372 for branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs
- Timestamp:
- 05/01/15 18:30:56 (10 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs
r12349 r12372 33 33 private readonly double[] outx; 34 34 private readonly int[] outSortedIdx; 35 36 private RegressionTreeModel.TreeNode[] tree; // tree is represented as a flat array of nodes 37 private int curTreeNodeIdx; // the index where the next tree node is stored 38 35 39 private readonly IList<RegressionTreeModel.TreeNode> nodeQueue; //TODO 36 40 … … 128 132 } 129 133 } 134 135 // prepare array for the tree nodes (a tree of maxDepth=1 has 1 node, a tree of maxDepth=d has 2^d - 1 nodes) 136 int numNodes = (int)Math.Pow(2, maxDepth) - 1; 137 //this.tree = new RegressionTreeModel.TreeNode[numNodes]; 138 this.tree = Enumerable.Range(0, numNodes).Select(_=>new RegressionTreeModel.TreeNode()).ToArray(); 139 this.curTreeNodeIdx = 0; 140 130 141 // start and end idx are inclusive 131 var tree =CreateRegressionTreeForIdx(maxDepth, 0, effectiveRows - 1, lineSearch);142 CreateRegressionTreeForIdx(maxDepth, 0, effectiveRows - 1, lineSearch); 132 143 return new RegressionTreeModel(tree); 133 144 } 134 145 135 146 // startIdx and endIdx are inclusive 136 private RegressionTreeModel.TreeNodeCreateRegressionTreeForIdx(int maxDepth, int startIdx, int endIdx, LineSearchFunc lineSearch) {147 private void CreateRegressionTreeForIdx(int maxDepth, int startIdx, int endIdx, LineSearchFunc lineSearch) { 137 148 Contract.Assert(endIdx - startIdx >= 0); 138 149 Contract.Assert(startIdx >= 0); 139 150 Contract.Assert(endIdx < internalIdx.Length); 140 151 141 RegressionTreeModel.TreeNode t;142 152 // TODO: stop when y is constant 143 153 // TODO: use priority queue of nodes to be expanded (sorted by improvement) instead of the recursion to maximum depth 144 154 if (maxDepth <= 1 || endIdx - startIdx == 0) { 145 // max depth reached or only one element 146 t = new RegressionTreeModel.TreeNode(RegressionTreeModel.TreeNode.NO_VARIABLE, lineSearch(internalIdx, startIdx, endIdx)); 147 return t; 155 // max depth reached or only one element 156 tree[curTreeNodeIdx].varName = RegressionTreeModel.TreeNode.NO_VARIABLE; 157 tree[curTreeNodeIdx].val = lineSearch(internalIdx, startIdx, endIdx); 158 curTreeNodeIdx++; 148 159 } else { 149 160 int i, j; … … 154 165 // if bestVariableName is NO_VARIABLE then no split was possible anymore 155 166 if (bestVariableName == RegressionTreeModel.TreeNode.NO_VARIABLE) { 156 return new RegressionTreeModel.TreeNode(RegressionTreeModel.TreeNode.NO_VARIABLE, lineSearch(internalIdx, startIdx, endIdx)); 167 // max depth reached or only one element 168 tree[curTreeNodeIdx].varName = RegressionTreeModel.TreeNode.NO_VARIABLE; 169 tree[curTreeNodeIdx].val = lineSearch(internalIdx, startIdx, endIdx); 170 curTreeNodeIdx++; 157 171 } else { 158 172 … … 214 228 Debug.Assert(j <= endIdx); 215 229 216 t = new RegressionTreeModel.TreeNode(bestVariableName, 217 threshold, 218 CreateRegressionTreeForIdx(maxDepth - 1, startIdx, j, lineSearch), 219 CreateRegressionTreeForIdx(maxDepth - 1, i, endIdx, lineSearch)); 220 221 return t; 230 var parentIdx = curTreeNodeIdx; 231 tree[parentIdx].varName = bestVariableName; 232 tree[parentIdx].val = threshold; 233 curTreeNodeIdx++; 234 235 // create left subtree 236 tree[parentIdx].leftIdx = curTreeNodeIdx; 237 CreateRegressionTreeForIdx(maxDepth - 1, startIdx, j, lineSearch); 238 239 // create right subtree 240 tree[parentIdx].rightIdx = curTreeNodeIdx; 241 CreateRegressionTreeForIdx(maxDepth - 1, i, endIdx, lineSearch); 222 242 } 223 243 } … … 272 292 // assumption is that the Average(y) = 0 273 293 private void UpdateVariableRelevance(string bestVar, double sumY, double bestImprovement, int rows) { 294 if (string.IsNullOrEmpty(bestVar)) return; 274 295 // update variable relevance 275 296 double err = sumY * sumY / rows;
Note: See TracChangeset
for help on using the changeset viewer.