source: branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs @ 12372

Last change on this file since 12372 was 12372, checked in by gkronber, 5 years ago

#2261: implemented prototype view for gradient boosted trees

File size: 2.6 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using HeuristicLab.Common;
5using HeuristicLab.Core;
6using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
7using HeuristicLab.Problems.DataAnalysis;
8
9namespace GradientBoostedTrees {
10  [StorableClass]
11  [Item("RegressionTreeModel", "Represents a decision tree for regression.")]
12  // TODO: Implement a view for this
13  public class RegressionTreeModel : NamedItem, IRegressionModel {
14
15    // trees are represented as a flat array
16    // object-graph-travesal has problems if this is defined as a struct. TODO investigate...
17    [StorableClass]
18    public class TreeNode {
19      public readonly static string NO_VARIABLE = string.Empty;
20      [Storable]
21      public string varName; // name of the variable for splitting or -1 if terminal node
22      [Storable]
23      public double val; // threshold
24      [Storable]
25      public int leftIdx;
26      [Storable]
27      public int rightIdx;
28
29      public TreeNode() {
30        varName = NO_VARIABLE;
31        leftIdx = -1;
32        rightIdx = -1;
33      }
34      [StorableConstructor]
35      private TreeNode(bool deserializing) { }
36    }
37
38    [Storable]
39    public readonly TreeNode[] tree;
40
41    [StorableConstructor]
42    private RegressionTreeModel(bool serializing) : base(serializing) { }
43    // cloning ctor
44    public RegressionTreeModel(RegressionTreeModel original, Cloner cloner)
45      : base(original, cloner) {
46      this.tree = original.tree; // shallow clone, tree must be readonly
47    }
48
49    public RegressionTreeModel(TreeNode[] tree)
50      : base("RegressionTreeModel", "Represents a decision tree for regression.") {
51      this.tree = tree;
52    }
53
54    private static double GetPredictionForRow(TreeNode[] t, int nodeIdx, Dataset ds, int row) {
55      var node = t[nodeIdx];
56      if (node.varName == TreeNode.NO_VARIABLE)
57        return node.val;
58      else if (ds.GetDoubleValue(node.varName, row) <= node.val)
59        return GetPredictionForRow(t, node.leftIdx, ds, row);
60      else
61        return GetPredictionForRow(t, node.rightIdx, ds, row);
62    }
63
64    public override IDeepCloneable Clone(Cloner cloner) {
65      return new RegressionTreeModel(this, cloner);
66    }
67
68    public IEnumerable<double> GetEstimatedValues(Dataset ds, IEnumerable<int> rows) {
69      return rows.Select(r => GetPredictionForRow(tree, 0, ds, r));
70    }
71
72    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
73      return new RegressionSolution(this, new RegressionProblemData(problemData));
74    }
75  }
76
77}
Note: See TracBrowser for help on using the repository browser.