Free cookie consent management tool by TermsFeed Policy Generator

source: branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs @ 12372

Last change on this file since 12372 was 12372, checked in by gkronber, 9 years ago

#2261: implemented prototype view for gradient boosted trees

File size: 2.6 KB
RevLine 
[12372]1using System;
2using System.Collections.Generic;
[12332]3using System.Linq;
4using HeuristicLab.Common;
5using HeuristicLab.Core;
6using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
7using HeuristicLab.Problems.DataAnalysis;
8
9namespace GradientBoostedTrees {
10  [StorableClass]
11  [Item("RegressionTreeModel", "Represents a decision tree for regression.")]
12  // TODO: Implement a view for this
13  public class RegressionTreeModel : NamedItem, IRegressionModel {
14
[12372]15    // trees are represented as a flat array
16    // object-graph-travesal has problems if this is defined as a struct. TODO investigate...
[12332]17    [StorableClass]
18    public class TreeNode {
19      public readonly static string NO_VARIABLE = string.Empty;
20      [Storable]
[12372]21      public string varName; // name of the variable for splitting or -1 if terminal node
[12332]22      [Storable]
[12372]23      public double val; // threshold
[12332]24      [Storable]
[12372]25      public int leftIdx;
[12332]26      [Storable]
[12372]27      public int rightIdx;
[12332]28
[12372]29      public TreeNode() {
30        varName = NO_VARIABLE;
31        leftIdx = -1;
32        rightIdx = -1;
33      }
[12349]34      [StorableConstructor]
35      private TreeNode(bool deserializing) { }
[12332]36    }
37
38    [Storable]
[12372]39    public readonly TreeNode[] tree;
[12332]40
41    [StorableConstructor]
42    private RegressionTreeModel(bool serializing) : base(serializing) { }
43    // cloning ctor
44    public RegressionTreeModel(RegressionTreeModel original, Cloner cloner)
45      : base(original, cloner) {
[12372]46      this.tree = original.tree; // shallow clone, tree must be readonly
[12332]47    }
48
[12372]49    public RegressionTreeModel(TreeNode[] tree)
50      : base("RegressionTreeModel", "Represents a decision tree for regression.") {
[12332]51      this.tree = tree;
52    }
53
[12372]54    private static double GetPredictionForRow(TreeNode[] t, int nodeIdx, Dataset ds, int row) {
55      var node = t[nodeIdx];
56      if (node.varName == TreeNode.NO_VARIABLE)
57        return node.val;
58      else if (ds.GetDoubleValue(node.varName, row) <= node.val)
59        return GetPredictionForRow(t, node.leftIdx, ds, row);
[12332]60      else
[12372]61        return GetPredictionForRow(t, node.rightIdx, ds, row);
[12332]62    }
63
64    public override IDeepCloneable Clone(Cloner cloner) {
65      return new RegressionTreeModel(this, cloner);
66    }
67
68    public IEnumerable<double> GetEstimatedValues(Dataset ds, IEnumerable<int> rows) {
[12372]69      return rows.Select(r => GetPredictionForRow(tree, 0, ds, r));
[12332]70    }
71
72    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
73      return new RegressionSolution(this, new RegressionProblemData(problemData));
74    }
75  }
76
77}
Note: See TracBrowser for help on using the repository browser.