Free cookie consent management tool by TermsFeed Policy Generator

source: branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs @ 12349

Last change on this file since 12349 was 12349, checked in by gkronber, 9 years ago

#2261: added serialization constructor for RegressionTreeModel

File size: 2.5 KB
Line 
1using System.Collections.Generic;
2using System.Linq;
3using HeuristicLab.Common;
4using HeuristicLab.Core;
5using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
6using HeuristicLab.Problems.DataAnalysis;
7
8namespace GradientBoostedTrees {
9  [StorableClass]
10  [Item("RegressionTreeModel", "Represents a decision tree for regression.")]
11  // TODO: Implement a view for this
12  public class RegressionTreeModel : NamedItem, IRegressionModel {
13
14    [StorableClass]
15    public class TreeNode {
16      public readonly static string NO_VARIABLE = string.Empty;
17      [Storable]
18      public readonly string varName; // name of the variable for splitting or -1 if terminal node
19      [Storable]
20      public readonly double val; // threshold
21      [Storable]
22      public readonly TreeNode left;
23      [Storable]
24      public readonly TreeNode right;
25
26      [StorableConstructor]
27      private TreeNode(bool deserializing) { }
28
29      public TreeNode(string varName, double value, TreeNode left = null, TreeNode right = null) {
30        this.varName = varName;
31        this.val = value;
32        this.left = left;
33        this.right = right;
34      }
35    }
36
37    [Storable]
38    public readonly TreeNode tree;
39
40    [StorableConstructor]
41    private RegressionTreeModel(bool serializing) : base(serializing) { }
42    // cloning ctor
43    public RegressionTreeModel(RegressionTreeModel original, Cloner cloner)
44      : base(original, cloner) {
45      this.tree = original.tree;
46    }
47
48    public RegressionTreeModel(TreeNode tree)
49      : base() {
50      this.name = ItemName;
51      this.description = ItemDescription;
52
53      this.tree = tree;
54    }
55
56    private static double GetPredictionForRow(TreeNode t, Dataset ds, int row) {
57      if (t.varName == TreeNode.NO_VARIABLE)
58        return t.val;
59      else if (ds.GetDoubleValue(t.varName, row) <= t.val)
60        return GetPredictionForRow(t.left, ds, row);
61      else
62        return GetPredictionForRow(t.right, ds, row);
63    }
64
65    public override IDeepCloneable Clone(Cloner cloner) {
66      return new RegressionTreeModel(this, cloner);
67    }
68
69    public IEnumerable<double> GetEstimatedValues(Dataset ds, IEnumerable<int> rows) {
70      return rows.Select(r => GetPredictionForRow(tree, ds, r));
71    }
72
73    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
74      return new RegressionSolution(this, new RegressionProblemData(problemData));
75    }
76  }
77
78}
Note: See TracBrowser for help on using the repository browser.