Free cookie consent management tool by TermsFeed Policy Generator

source: branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs @ 12589

Last change on this file since 12589 was 12589, checked in by gkronber, 10 years ago

#2261 adapted interface to use IDataset instead of Dataset and added logistic regression loss function

File size: 2.7 KB
RevLine 
[12372]1using System;
2using System.Collections.Generic;
[12332]3using System.Linq;
4using HeuristicLab.Common;
5using HeuristicLab.Core;
6using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
7using HeuristicLab.Problems.DataAnalysis;
8
9namespace GradientBoostedTrees {
10  [StorableClass]
11  [Item("RegressionTreeModel", "Represents a decision tree for regression.")]
12  // TODO: Implement a view for this
13  public class RegressionTreeModel : NamedItem, IRegressionModel {
14
[12372]15    // trees are represented as a flat array
16    // object-graph-travesal has problems if this is defined as a struct. TODO investigate...
[12375]17    //[StorableClass]
18    public struct TreeNode {
[12332]19      public readonly static string NO_VARIABLE = string.Empty;
[12375]20      //[Storable]
[12372]21      public string varName; // name of the variable for splitting or -1 if terminal node
[12375]22      //[Storable]
[12372]23      public double val; // threshold
[12375]24      //[Storable]
[12372]25      public int leftIdx;
[12375]26      //[Storable]
[12372]27      public int rightIdx;
[12332]28
[12375]29      //public TreeNode() {
30      //  varName = NO_VARIABLE;
31      //  leftIdx = -1;
32      //  rightIdx = -1;
33      //}
34      //[StorableConstructor]
35      //private TreeNode(bool deserializing) { }
36      public override int GetHashCode()
37      {
38        return (leftIdx * rightIdx) ^ val.GetHashCode();
[12372]39      }
[12332]40    }
41
42    [Storable]
[12372]43    public readonly TreeNode[] tree;
[12332]44
45    [StorableConstructor]
46    private RegressionTreeModel(bool serializing) : base(serializing) { }
47    // cloning ctor
48    public RegressionTreeModel(RegressionTreeModel original, Cloner cloner)
49      : base(original, cloner) {
[12372]50      this.tree = original.tree; // shallow clone, tree must be readonly
[12332]51    }
52
[12372]53    public RegressionTreeModel(TreeNode[] tree)
54      : base("RegressionTreeModel", "Represents a decision tree for regression.") {
[12332]55      this.tree = tree;
56    }
57
[12589]58    private static double GetPredictionForRow(TreeNode[] t, int nodeIdx, IDataset ds, int row) {
[12372]59      var node = t[nodeIdx];
60      if (node.varName == TreeNode.NO_VARIABLE)
61        return node.val;
62      else if (ds.GetDoubleValue(node.varName, row) <= node.val)
63        return GetPredictionForRow(t, node.leftIdx, ds, row);
[12332]64      else
[12372]65        return GetPredictionForRow(t, node.rightIdx, ds, row);
[12332]66    }
67
68    public override IDeepCloneable Clone(Cloner cloner) {
69      return new RegressionTreeModel(this, cloner);
70    }
71
[12589]72    public IEnumerable<double> GetEstimatedValues(IDataset ds, IEnumerable<int> rows) {
[12372]73      return rows.Select(r => GetPredictionForRow(tree, 0, ds, r));
[12332]74    }
75
76    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
77      return new RegressionSolution(this, new RegressionProblemData(problemData));
78    }
79  }
80
81}
Note: See TracBrowser for help on using the repository browser.