source: branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs @ 12332

Last change on this file since 12332 was 12332, checked in by gkronber, 5 years ago

#2261: initial import of gradient boosted trees for regression

File size: 2.4 KB
Line 
1using System.Collections.Generic;
2using System.Linq;
3using HeuristicLab.Common;
4using HeuristicLab.Core;
5using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
6using HeuristicLab.Problems.DataAnalysis;
7
8namespace GradientBoostedTrees {
9  [StorableClass]
10  [Item("RegressionTreeModel", "Represents a decision tree for regression.")]
11  // TODO: Implement a view for this
12  public class RegressionTreeModel : NamedItem, IRegressionModel {
13
14    [StorableClass]
15    public class TreeNode {
16      public readonly static string NO_VARIABLE = string.Empty;
17      [Storable]
18      public readonly string varName; // name of the variable for splitting or -1 if terminal node
19      [Storable]
20      public readonly double val; // threshold
21      [Storable]
22      public readonly TreeNode left;
23      [Storable]
24      public readonly TreeNode right;
25
26      public TreeNode(string varName, double value, TreeNode left = null, TreeNode right = null) {
27        this.varName = varName;
28        this.val = value;
29        this.left = left;
30        this.right = right;
31      }
32    }
33
34    [Storable]
35    public readonly TreeNode tree;
36
37    [StorableConstructor]
38    private RegressionTreeModel(bool serializing) : base(serializing) { }
39    // cloning ctor
40    public RegressionTreeModel(RegressionTreeModel original, Cloner cloner)
41      : base(original, cloner) {
42      this.tree = original.tree;
43    }
44
45    public RegressionTreeModel(TreeNode tree)
46      : base() {
47      this.name = ItemName;
48      this.description = ItemDescription;
49
50      this.tree = tree;
51    }
52
53    private static double GetPredictionForRow(TreeNode t, Dataset ds, int row) {
54      if (t.varName == TreeNode.NO_VARIABLE)
55        return t.val;
56      else if (ds.GetDoubleValue(t.varName, row) <= t.val)
57        return GetPredictionForRow(t.left, ds, row);
58      else
59        return GetPredictionForRow(t.right, ds, row);
60    }
61
62    public override IDeepCloneable Clone(Cloner cloner) {
63      return new RegressionTreeModel(this, cloner);
64    }
65
66    public IEnumerable<double> GetEstimatedValues(Dataset ds, IEnumerable<int> rows) {
67      return rows.Select(r => GetPredictionForRow(tree, ds, r));
68    }
69
70    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
71      return new RegressionSolution(this, new RegressionProblemData(problemData));
72    }
73  }
74
75}
Note: See TracBrowser for help on using the repository browser.