Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
05/01/15 18:30:56 (9 years ago)
Author:
gkronber
Message:

#2261: implemented prototype view for gradient boosted trees

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs

    r12349 r12372  
    1 using System.Collections.Generic;
     1using System;
     2using System.Collections.Generic;
    23using System.Linq;
    34using HeuristicLab.Common;
     
    1213  public class RegressionTreeModel : NamedItem, IRegressionModel {
    1314
     15    // trees are represented as a flat array
     16    // object-graph-travesal has problems if this is defined as a struct. TODO investigate...
    1417    [StorableClass]
    1518    public class TreeNode {
    1619      public readonly static string NO_VARIABLE = string.Empty;
    1720      [Storable]
    18       public readonly string varName; // name of the variable for splitting or -1 if terminal node
     21      public string varName; // name of the variable for splitting or -1 if terminal node
    1922      [Storable]
    20       public readonly double val; // threshold
     23      public double val; // threshold
    2124      [Storable]
    22       public readonly TreeNode left;
     25      public int leftIdx;
    2326      [Storable]
    24       public readonly TreeNode right;
     27      public int rightIdx;
    2528
     29      public TreeNode() {
     30        varName = NO_VARIABLE;
     31        leftIdx = -1;
     32        rightIdx = -1;
     33      }
    2634      [StorableConstructor]
    2735      private TreeNode(bool deserializing) { }
    28 
    29       public TreeNode(string varName, double value, TreeNode left = null, TreeNode right = null) {
    30         this.varName = varName;
    31         this.val = value;
    32         this.left = left;
    33         this.right = right;
    34       }
    3536    }
    3637
    3738    [Storable]
    38     public readonly TreeNode tree;
     39    public readonly TreeNode[] tree;
    3940
    4041    [StorableConstructor]
     
    4344    public RegressionTreeModel(RegressionTreeModel original, Cloner cloner)
    4445      : base(original, cloner) {
    45       this.tree = original.tree;
     46      this.tree = original.tree; // shallow clone, tree must be readonly
    4647    }
    4748
    48     public RegressionTreeModel(TreeNode tree)
    49       : base() {
    50       this.name = ItemName;
    51       this.description = ItemDescription;
    52 
     49    public RegressionTreeModel(TreeNode[] tree)
     50      : base("RegressionTreeModel", "Represents a decision tree for regression.") {
    5351      this.tree = tree;
    5452    }
    5553
    56     private static double GetPredictionForRow(TreeNode t, Dataset ds, int row) {
    57       if (t.varName == TreeNode.NO_VARIABLE)
    58         return t.val;
    59       else if (ds.GetDoubleValue(t.varName, row) <= t.val)
    60         return GetPredictionForRow(t.left, ds, row);
     54    private static double GetPredictionForRow(TreeNode[] t, int nodeIdx, Dataset ds, int row) {
     55      var node = t[nodeIdx];
     56      if (node.varName == TreeNode.NO_VARIABLE)
     57        return node.val;
     58      else if (ds.GetDoubleValue(node.varName, row) <= node.val)
     59        return GetPredictionForRow(t, node.leftIdx, ds, row);
    6160      else
    62         return GetPredictionForRow(t.right, ds, row);
     61        return GetPredictionForRow(t, node.rightIdx, ds, row);
    6362    }
    6463
     
    6867
    6968    public IEnumerable<double> GetEstimatedValues(Dataset ds, IEnumerable<int> rows) {
    70       return rows.Select(r => GetPredictionForRow(tree, ds, r));
     69      return rows.Select(r => GetPredictionForRow(tree, 0, ds, r));
    7170    }
    7271
Note: See TracChangeset for help on using the changeset viewer.