Free cookie consent management tool by TermsFeed Policy Generator

source: branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs @ 12589

Last change on this file since 12589 was 12589, checked in by gkronber, 9 years ago

#2261 adapted interface to use IDataset instead of Dataset and added logistic regression loss function

File size: 2.7 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using HeuristicLab.Common;
5using HeuristicLab.Core;
6using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
7using HeuristicLab.Problems.DataAnalysis;
8
9namespace GradientBoostedTrees {
10  [StorableClass]
11  [Item("RegressionTreeModel", "Represents a decision tree for regression.")]
12  // TODO: Implement a view for this
13  public class RegressionTreeModel : NamedItem, IRegressionModel {
14
15    // trees are represented as a flat array
16    // object-graph-travesal has problems if this is defined as a struct. TODO investigate...
17    //[StorableClass]
18    public struct TreeNode {
19      public readonly static string NO_VARIABLE = string.Empty;
20      //[Storable]
21      public string varName; // name of the variable for splitting or -1 if terminal node
22      //[Storable]
23      public double val; // threshold
24      //[Storable]
25      public int leftIdx;
26      //[Storable]
27      public int rightIdx;
28
29      //public TreeNode() {
30      //  varName = NO_VARIABLE;
31      //  leftIdx = -1;
32      //  rightIdx = -1;
33      //}
34      //[StorableConstructor]
35      //private TreeNode(bool deserializing) { }
36      public override int GetHashCode()
37      {
38        return (leftIdx * rightIdx) ^ val.GetHashCode();
39      }
40    }
41
42    [Storable]
43    public readonly TreeNode[] tree;
44
45    [StorableConstructor]
46    private RegressionTreeModel(bool serializing) : base(serializing) { }
47    // cloning ctor
48    public RegressionTreeModel(RegressionTreeModel original, Cloner cloner)
49      : base(original, cloner) {
50      this.tree = original.tree; // shallow clone, tree must be readonly
51    }
52
53    public RegressionTreeModel(TreeNode[] tree)
54      : base("RegressionTreeModel", "Represents a decision tree for regression.") {
55      this.tree = tree;
56    }
57
58    private static double GetPredictionForRow(TreeNode[] t, int nodeIdx, IDataset ds, int row) {
59      var node = t[nodeIdx];
60      if (node.varName == TreeNode.NO_VARIABLE)
61        return node.val;
62      else if (ds.GetDoubleValue(node.varName, row) <= node.val)
63        return GetPredictionForRow(t, node.leftIdx, ds, row);
64      else
65        return GetPredictionForRow(t, node.rightIdx, ds, row);
66    }
67
68    public override IDeepCloneable Clone(Cloner cloner) {
69      return new RegressionTreeModel(this, cloner);
70    }
71
72    public IEnumerable<double> GetEstimatedValues(IDataset ds, IEnumerable<int> rows) {
73      return rows.Select(r => GetPredictionForRow(tree, 0, ds, r));
74    }
75
76    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
77      return new RegressionSolution(this, new RegressionProblemData(problemData));
78    }
79  }
80
81}
Note: See TracBrowser for help on using the repository browser.