[12372] | 1 | using System;
|
---|
| 2 | using System.Collections.Generic;
|
---|
[12332] | 3 | using System.Linq;
|
---|
| 4 | using HeuristicLab.Common;
|
---|
| 5 | using HeuristicLab.Core;
|
---|
| 6 | using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
|
---|
| 7 | using HeuristicLab.Problems.DataAnalysis;
|
---|
| 8 |
|
---|
| 9 | namespace GradientBoostedTrees {
|
---|
| 10 | [StorableClass]
|
---|
| 11 | [Item("RegressionTreeModel", "Represents a decision tree for regression.")]
|
---|
| 12 | // TODO: Implement a view for this
|
---|
| 13 | public class RegressionTreeModel : NamedItem, IRegressionModel {
|
---|
| 14 |
|
---|
[12372] | 15 | // trees are represented as a flat array
|
---|
| 16 | // object-graph-travesal has problems if this is defined as a struct. TODO investigate...
|
---|
[12332] | 17 | [StorableClass]
|
---|
| 18 | public class TreeNode {
|
---|
| 19 | public readonly static string NO_VARIABLE = string.Empty;
|
---|
| 20 | [Storable]
|
---|
[12372] | 21 | public string varName; // name of the variable for splitting or -1 if terminal node
|
---|
[12332] | 22 | [Storable]
|
---|
[12372] | 23 | public double val; // threshold
|
---|
[12332] | 24 | [Storable]
|
---|
[12372] | 25 | public int leftIdx;
|
---|
[12332] | 26 | [Storable]
|
---|
[12372] | 27 | public int rightIdx;
|
---|
[12332] | 28 |
|
---|
[12372] | 29 | public TreeNode() {
|
---|
| 30 | varName = NO_VARIABLE;
|
---|
| 31 | leftIdx = -1;
|
---|
| 32 | rightIdx = -1;
|
---|
| 33 | }
|
---|
[12349] | 34 | [StorableConstructor]
|
---|
| 35 | private TreeNode(bool deserializing) { }
|
---|
[12332] | 36 | }
|
---|
| 37 |
|
---|
| 38 | [Storable]
|
---|
[12372] | 39 | public readonly TreeNode[] tree;
|
---|
[12332] | 40 |
|
---|
| 41 | [StorableConstructor]
|
---|
| 42 | private RegressionTreeModel(bool serializing) : base(serializing) { }
|
---|
| 43 | // cloning ctor
|
---|
| 44 | public RegressionTreeModel(RegressionTreeModel original, Cloner cloner)
|
---|
| 45 | : base(original, cloner) {
|
---|
[12372] | 46 | this.tree = original.tree; // shallow clone, tree must be readonly
|
---|
[12332] | 47 | }
|
---|
| 48 |
|
---|
[12372] | 49 | public RegressionTreeModel(TreeNode[] tree)
|
---|
| 50 | : base("RegressionTreeModel", "Represents a decision tree for regression.") {
|
---|
[12332] | 51 | this.tree = tree;
|
---|
| 52 | }
|
---|
| 53 |
|
---|
[12372] | 54 | private static double GetPredictionForRow(TreeNode[] t, int nodeIdx, Dataset ds, int row) {
|
---|
| 55 | var node = t[nodeIdx];
|
---|
| 56 | if (node.varName == TreeNode.NO_VARIABLE)
|
---|
| 57 | return node.val;
|
---|
| 58 | else if (ds.GetDoubleValue(node.varName, row) <= node.val)
|
---|
| 59 | return GetPredictionForRow(t, node.leftIdx, ds, row);
|
---|
[12332] | 60 | else
|
---|
[12372] | 61 | return GetPredictionForRow(t, node.rightIdx, ds, row);
|
---|
[12332] | 62 | }
|
---|
| 63 |
|
---|
| 64 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
| 65 | return new RegressionTreeModel(this, cloner);
|
---|
| 66 | }
|
---|
| 67 |
|
---|
| 68 | public IEnumerable<double> GetEstimatedValues(Dataset ds, IEnumerable<int> rows) {
|
---|
[12372] | 69 | return rows.Select(r => GetPredictionForRow(tree, 0, ds, r));
|
---|
[12332] | 70 | }
|
---|
| 71 |
|
---|
| 72 | public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
|
---|
| 73 | return new RegressionSolution(this, new RegressionProblemData(problemData));
|
---|
| 74 | }
|
---|
| 75 | }
|
---|
| 76 |
|
---|
| 77 | }
|
---|