Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
05/01/15 18:30:56 (9 years ago)
Author:
gkronber
Message:

#2261: implemented prototype view for gradient boosted trees

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs

    r12349 r12372  
    3333    private readonly double[] outx;
    3434    private readonly int[] outSortedIdx;
     35
     36    private RegressionTreeModel.TreeNode[] tree; // tree is represented as a flat array of nodes
     37    private int curTreeNodeIdx; // the index where the next tree node is stored
     38
    3539    private readonly IList<RegressionTreeModel.TreeNode> nodeQueue; //TODO
    3640
     
    128132        }
    129133      }
     134
     135      // prepare array for the tree nodes (a tree of maxDepth=1 has 1 node, a tree of maxDepth=d has 2^d - 1 nodes)     
     136      int numNodes = (int)Math.Pow(2, maxDepth) - 1;
     137      //this.tree = new RegressionTreeModel.TreeNode[numNodes];
     138      this.tree = Enumerable.Range(0, numNodes).Select(_=>new RegressionTreeModel.TreeNode()).ToArray();
     139      this.curTreeNodeIdx = 0;
     140
    130141      // start and end idx are inclusive
    131       var tree = CreateRegressionTreeForIdx(maxDepth, 0, effectiveRows - 1, lineSearch);
     142      CreateRegressionTreeForIdx(maxDepth, 0, effectiveRows - 1, lineSearch);
    132143      return new RegressionTreeModel(tree);
    133144    }
    134145
    135146    // startIdx and endIdx are inclusive
    136     private RegressionTreeModel.TreeNode CreateRegressionTreeForIdx(int maxDepth, int startIdx, int endIdx, LineSearchFunc lineSearch) {
     147    private void CreateRegressionTreeForIdx(int maxDepth, int startIdx, int endIdx, LineSearchFunc lineSearch) {
    137148      Contract.Assert(endIdx - startIdx >= 0);
    138149      Contract.Assert(startIdx >= 0);
    139150      Contract.Assert(endIdx < internalIdx.Length);
    140151
    141       RegressionTreeModel.TreeNode t;
    142152      // TODO: stop when y is constant
    143153      // TODO: use priority queue of nodes to be expanded (sorted by improvement) instead of the recursion to maximum depth
    144154      if (maxDepth <= 1 || endIdx - startIdx == 0) {
    145         // max depth reached or only one element         
    146         t = new RegressionTreeModel.TreeNode(RegressionTreeModel.TreeNode.NO_VARIABLE, lineSearch(internalIdx, startIdx, endIdx));
    147         return t;
     155        // max depth reached or only one element   
     156        tree[curTreeNodeIdx].varName = RegressionTreeModel.TreeNode.NO_VARIABLE;
     157        tree[curTreeNodeIdx].val = lineSearch(internalIdx, startIdx, endIdx);
     158        curTreeNodeIdx++;
    148159      } else {
    149160        int i, j;
     
    154165        // if bestVariableName is NO_VARIABLE then no split was possible anymore
    155166        if (bestVariableName == RegressionTreeModel.TreeNode.NO_VARIABLE) {
    156           return new RegressionTreeModel.TreeNode(RegressionTreeModel.TreeNode.NO_VARIABLE, lineSearch(internalIdx, startIdx, endIdx));
     167          // max depth reached or only one element   
     168          tree[curTreeNodeIdx].varName = RegressionTreeModel.TreeNode.NO_VARIABLE;
     169          tree[curTreeNodeIdx].val = lineSearch(internalIdx, startIdx, endIdx);
     170          curTreeNodeIdx++;
    157171        } else {
    158172
     
    214228          Debug.Assert(j <= endIdx);
    215229
    216           t = new RegressionTreeModel.TreeNode(bestVariableName,
    217             threshold,
    218             CreateRegressionTreeForIdx(maxDepth - 1, startIdx, j, lineSearch),
    219             CreateRegressionTreeForIdx(maxDepth - 1, i, endIdx, lineSearch));
    220 
    221           return t;
     230          var parentIdx = curTreeNodeIdx;
     231          tree[parentIdx].varName = bestVariableName;
     232          tree[parentIdx].val = threshold;
     233          curTreeNodeIdx++;
     234
     235          // create left subtree
     236          tree[parentIdx].leftIdx = curTreeNodeIdx;
     237          CreateRegressionTreeForIdx(maxDepth - 1, startIdx, j, lineSearch);
     238
     239          // create right subtree
     240          tree[parentIdx].rightIdx = curTreeNodeIdx;
     241          CreateRegressionTreeForIdx(maxDepth - 1, i, endIdx, lineSearch);
    222242        }
    223243      }
     
    272292    // assumption is that the Average(y) = 0
    273293    private void UpdateVariableRelevance(string bestVar, double sumY, double bestImprovement, int rows) {
     294      if (string.IsNullOrEmpty(bestVar)) return;
    274295      // update variable relevance
    275296      double err = sumY * sumY / rows;
Note: See TracChangeset for help on using the changeset viewer.