Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
05/01/15 18:30:56 (10 years ago)
Author:
gkronber
Message:

#2261: implemented prototype view for gradient boosted trees

Location:
branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModel.cs

    r12332 r12372  
    88
    99namespace GradientBoostedTrees {
    10   [Item("GradientBoostedTreesSolution", "")]
    1110  [StorableClass]
     11  [Item("Gradient boosted tree model", "")]
    1212  public sealed class GradientBoostedTreesModel : NamedItem, IRegressionModel {
    1313
    1414    [Storable]
    1515    private readonly IList<IRegressionModel> models;
     16    public IEnumerable<IRegressionModel> Models { get { return models; } }
     17
    1618    [Storable]
    1719    private readonly IList<double> weights;
     20    public IEnumerable<double> Weights { get { return weights; } }
    1821
    1922    [StorableConstructor]
     
    2528    }
    2629    public GradientBoostedTreesModel(IEnumerable<IRegressionModel> models, IEnumerable<double> weights)
    27       : base() {
     30      : base("Gradient boosted tree model", string.Empty) {
    2831      this.models = new List<IRegressionModel>(models);
    2932      this.weights = new List<double>(weights);
  • branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs

    r12349 r12372  
    3333    private readonly double[] outx;
    3434    private readonly int[] outSortedIdx;
     35
     36    private RegressionTreeModel.TreeNode[] tree; // tree is represented as a flat array of nodes
     37    private int curTreeNodeIdx; // the index where the next tree node is stored
     38
    3539    private readonly IList<RegressionTreeModel.TreeNode> nodeQueue; //TODO
    3640
     
    128132        }
    129133      }
     134
     135      // prepare array for the tree nodes (a tree of maxDepth=1 has 1 node, a tree of maxDepth=d has 2^d - 1 nodes)     
     136      int numNodes = (int)Math.Pow(2, maxDepth) - 1;
     137      //this.tree = new RegressionTreeModel.TreeNode[numNodes];
     138      this.tree = Enumerable.Range(0, numNodes).Select(_=>new RegressionTreeModel.TreeNode()).ToArray();
     139      this.curTreeNodeIdx = 0;
     140
    130141      // start and end idx are inclusive
    131       var tree = CreateRegressionTreeForIdx(maxDepth, 0, effectiveRows - 1, lineSearch);
     142      CreateRegressionTreeForIdx(maxDepth, 0, effectiveRows - 1, lineSearch);
    132143      return new RegressionTreeModel(tree);
    133144    }
    134145
    135146    // startIdx and endIdx are inclusive
    136     private RegressionTreeModel.TreeNode CreateRegressionTreeForIdx(int maxDepth, int startIdx, int endIdx, LineSearchFunc lineSearch) {
     147    private void CreateRegressionTreeForIdx(int maxDepth, int startIdx, int endIdx, LineSearchFunc lineSearch) {
    137148      Contract.Assert(endIdx - startIdx >= 0);
    138149      Contract.Assert(startIdx >= 0);
    139150      Contract.Assert(endIdx < internalIdx.Length);
    140151
    141       RegressionTreeModel.TreeNode t;
    142152      // TODO: stop when y is constant
    143153      // TODO: use priority queue of nodes to be expanded (sorted by improvement) instead of the recursion to maximum depth
    144154      if (maxDepth <= 1 || endIdx - startIdx == 0) {
    145         // max depth reached or only one element         
    146         t = new RegressionTreeModel.TreeNode(RegressionTreeModel.TreeNode.NO_VARIABLE, lineSearch(internalIdx, startIdx, endIdx));
    147         return t;
     155        // max depth reached or only one element   
     156        tree[curTreeNodeIdx].varName = RegressionTreeModel.TreeNode.NO_VARIABLE;
     157        tree[curTreeNodeIdx].val = lineSearch(internalIdx, startIdx, endIdx);
     158        curTreeNodeIdx++;
    148159      } else {
    149160        int i, j;
     
    154165        // if bestVariableName is NO_VARIABLE then no split was possible anymore
    155166        if (bestVariableName == RegressionTreeModel.TreeNode.NO_VARIABLE) {
    156           return new RegressionTreeModel.TreeNode(RegressionTreeModel.TreeNode.NO_VARIABLE, lineSearch(internalIdx, startIdx, endIdx));
     167          // max depth reached or only one element   
     168          tree[curTreeNodeIdx].varName = RegressionTreeModel.TreeNode.NO_VARIABLE;
     169          tree[curTreeNodeIdx].val = lineSearch(internalIdx, startIdx, endIdx);
     170          curTreeNodeIdx++;
    157171        } else {
    158172
     
    214228          Debug.Assert(j <= endIdx);
    215229
    216           t = new RegressionTreeModel.TreeNode(bestVariableName,
    217             threshold,
    218             CreateRegressionTreeForIdx(maxDepth - 1, startIdx, j, lineSearch),
    219             CreateRegressionTreeForIdx(maxDepth - 1, i, endIdx, lineSearch));
    220 
    221           return t;
     230          var parentIdx = curTreeNodeIdx;
     231          tree[parentIdx].varName = bestVariableName;
     232          tree[parentIdx].val = threshold;
     233          curTreeNodeIdx++;
     234
     235          // create left subtree
     236          tree[parentIdx].leftIdx = curTreeNodeIdx;
     237          CreateRegressionTreeForIdx(maxDepth - 1, startIdx, j, lineSearch);
     238
     239          // create right subtree
     240          tree[parentIdx].rightIdx = curTreeNodeIdx;
     241          CreateRegressionTreeForIdx(maxDepth - 1, i, endIdx, lineSearch);
    222242        }
    223243      }
     
    272292    // assumption is that the Average(y) = 0
    273293    private void UpdateVariableRelevance(string bestVar, double sumY, double bestImprovement, int rows) {
     294      if (string.IsNullOrEmpty(bestVar)) return;
    274295      // update variable relevance
    275296      double err = sumY * sumY / rows;
  • branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs

    r12349 r12372  
    1 using System.Collections.Generic;
     1using System;
     2using System.Collections.Generic;
    23using System.Linq;
    34using HeuristicLab.Common;
     
    1213  public class RegressionTreeModel : NamedItem, IRegressionModel {
    1314
     15    // trees are represented as a flat array
     16    // object-graph-travesal has problems if this is defined as a struct. TODO investigate...
    1417    [StorableClass]
    1518    public class TreeNode {
    1619      public readonly static string NO_VARIABLE = string.Empty;
    1720      [Storable]
    18       public readonly string varName; // name of the variable for splitting or -1 if terminal node
     21      public string varName; // name of the variable for splitting or -1 if terminal node
    1922      [Storable]
    20       public readonly double val; // threshold
     23      public double val; // threshold
    2124      [Storable]
    22       public readonly TreeNode left;
     25      public int leftIdx;
    2326      [Storable]
    24       public readonly TreeNode right;
     27      public int rightIdx;
    2528
     29      public TreeNode() {
     30        varName = NO_VARIABLE;
     31        leftIdx = -1;
     32        rightIdx = -1;
     33      }
    2634      [StorableConstructor]
    2735      private TreeNode(bool deserializing) { }
    28 
    29       public TreeNode(string varName, double value, TreeNode left = null, TreeNode right = null) {
    30         this.varName = varName;
    31         this.val = value;
    32         this.left = left;
    33         this.right = right;
    34       }
    3536    }
    3637
    3738    [Storable]
    38     public readonly TreeNode tree;
     39    public readonly TreeNode[] tree;
    3940
    4041    [StorableConstructor]
     
    4344    public RegressionTreeModel(RegressionTreeModel original, Cloner cloner)
    4445      : base(original, cloner) {
    45       this.tree = original.tree;
     46      this.tree = original.tree; // shallow clone, tree must be readonly
    4647    }
    4748
    48     public RegressionTreeModel(TreeNode tree)
    49       : base() {
    50       this.name = ItemName;
    51       this.description = ItemDescription;
    52 
     49    public RegressionTreeModel(TreeNode[] tree)
     50      : base("RegressionTreeModel", "Represents a decision tree for regression.") {
    5351      this.tree = tree;
    5452    }
    5553
    56     private static double GetPredictionForRow(TreeNode t, Dataset ds, int row) {
    57       if (t.varName == TreeNode.NO_VARIABLE)
    58         return t.val;
    59       else if (ds.GetDoubleValue(t.varName, row) <= t.val)
    60         return GetPredictionForRow(t.left, ds, row);
     54    private static double GetPredictionForRow(TreeNode[] t, int nodeIdx, Dataset ds, int row) {
     55      var node = t[nodeIdx];
     56      if (node.varName == TreeNode.NO_VARIABLE)
     57        return node.val;
     58      else if (ds.GetDoubleValue(node.varName, row) <= node.val)
     59        return GetPredictionForRow(t, node.leftIdx, ds, row);
    6160      else
    62         return GetPredictionForRow(t.right, ds, row);
     61        return GetPredictionForRow(t, node.rightIdx, ds, row);
    6362    }
    6463
     
    6867
    6968    public IEnumerable<double> GetEstimatedValues(Dataset ds, IEnumerable<int> rows) {
    70       return rows.Select(r => GetPredictionForRow(tree, ds, r));
     69      return rows.Select(r => GetPredictionForRow(tree, 0, ds, r));
    7170    }
    7271
Note: See TracChangeset for help on using the changeset viewer.