Changeset 12372 for branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs
- Timestamp:
- 05/01/15 18:30:56 (9 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs
r12349 r12372 1 using System.Collections.Generic; 1 using System; 2 using System.Collections.Generic; 2 3 using System.Linq; 3 4 using HeuristicLab.Common; … … 12 13 public class RegressionTreeModel : NamedItem, IRegressionModel { 13 14 15 // trees are represented as a flat array 16 // object-graph-travesal has problems if this is defined as a struct. TODO investigate... 14 17 [StorableClass] 15 18 public class TreeNode { 16 19 public readonly static string NO_VARIABLE = string.Empty; 17 20 [Storable] 18 public readonlystring varName; // name of the variable for splitting or -1 if terminal node21 public string varName; // name of the variable for splitting or -1 if terminal node 19 22 [Storable] 20 public readonlydouble val; // threshold23 public double val; // threshold 21 24 [Storable] 22 public readonly TreeNode left;25 public int leftIdx; 23 26 [Storable] 24 public readonly TreeNode right;27 public int rightIdx; 25 28 29 public TreeNode() { 30 varName = NO_VARIABLE; 31 leftIdx = -1; 32 rightIdx = -1; 33 } 26 34 [StorableConstructor] 27 35 private TreeNode(bool deserializing) { } 28 29 public TreeNode(string varName, double value, TreeNode left = null, TreeNode right = null) {30 this.varName = varName;31 this.val = value;32 this.left = left;33 this.right = right;34 }35 36 } 36 37 37 38 [Storable] 38 public readonly TreeNode tree;39 public readonly TreeNode[] tree; 39 40 40 41 [StorableConstructor] … … 43 44 public RegressionTreeModel(RegressionTreeModel original, Cloner cloner) 44 45 : base(original, cloner) { 45 this.tree = original.tree; 46 this.tree = original.tree; // shallow clone, tree must be readonly 46 47 } 47 48 48 public RegressionTreeModel(TreeNode tree) 49 : base() { 50 this.name = ItemName; 51 this.description = ItemDescription; 52 49 public RegressionTreeModel(TreeNode[] tree) 50 : base("RegressionTreeModel", "Represents a decision tree for regression.") { 53 51 this.tree = tree; 54 52 } 55 53 56 private static double GetPredictionForRow(TreeNode t, Dataset ds, int row) { 57 if (t.varName == TreeNode.NO_VARIABLE) 58 return t.val; 59 else if (ds.GetDoubleValue(t.varName, row) <= t.val) 60 return GetPredictionForRow(t.left, ds, row); 54 private static double GetPredictionForRow(TreeNode[] t, int nodeIdx, Dataset ds, int row) { 55 var node = t[nodeIdx]; 56 if (node.varName == TreeNode.NO_VARIABLE) 57 return node.val; 58 else if (ds.GetDoubleValue(node.varName, row) <= node.val) 59 return GetPredictionForRow(t, node.leftIdx, ds, row); 61 60 else 62 return GetPredictionForRow(t .right, ds, row);61 return GetPredictionForRow(t, node.rightIdx, ds, row); 63 62 } 64 63 … … 68 67 69 68 public IEnumerable<double> GetEstimatedValues(Dataset ds, IEnumerable<int> rows) { 70 return rows.Select(r => GetPredictionForRow(tree, ds, r));69 return rows.Select(r => GetPredictionForRow(tree, 0, ds, r)); 71 70 } 72 71
Note: See TracChangeset
for help on using the changeset viewer.