[12372] | 1 | using System;
|
---|
| 2 | using System.Collections.Generic;
|
---|
[12332] | 3 | using System.Linq;
|
---|
| 4 | using HeuristicLab.Common;
|
---|
| 5 | using HeuristicLab.Core;
|
---|
| 6 | using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
|
---|
| 7 | using HeuristicLab.Problems.DataAnalysis;
|
---|
| 8 |
|
---|
| 9 | namespace GradientBoostedTrees {
|
---|
| 10 | [StorableClass]
|
---|
| 11 | [Item("RegressionTreeModel", "Represents a decision tree for regression.")]
|
---|
| 12 | // TODO: Implement a view for this
|
---|
| 13 | public class RegressionTreeModel : NamedItem, IRegressionModel {
|
---|
| 14 |
|
---|
[12372] | 15 | // trees are represented as a flat array
|
---|
| 16 | // object-graph-travesal has problems if this is defined as a struct. TODO investigate...
|
---|
[12375] | 17 | //[StorableClass]
|
---|
| 18 | public struct TreeNode {
|
---|
[12332] | 19 | public readonly static string NO_VARIABLE = string.Empty;
|
---|
[12375] | 20 | //[Storable]
|
---|
[12372] | 21 | public string varName; // name of the variable for splitting or -1 if terminal node
|
---|
[12375] | 22 | //[Storable]
|
---|
[12372] | 23 | public double val; // threshold
|
---|
[12375] | 24 | //[Storable]
|
---|
[12372] | 25 | public int leftIdx;
|
---|
[12375] | 26 | //[Storable]
|
---|
[12372] | 27 | public int rightIdx;
|
---|
[12332] | 28 |
|
---|
[12375] | 29 | //public TreeNode() {
|
---|
| 30 | // varName = NO_VARIABLE;
|
---|
| 31 | // leftIdx = -1;
|
---|
| 32 | // rightIdx = -1;
|
---|
| 33 | //}
|
---|
| 34 | //[StorableConstructor]
|
---|
| 35 | //private TreeNode(bool deserializing) { }
|
---|
| 36 | public override int GetHashCode()
|
---|
| 37 | {
|
---|
| 38 | return (leftIdx * rightIdx) ^ val.GetHashCode();
|
---|
[12372] | 39 | }
|
---|
[12332] | 40 | }
|
---|
| 41 |
|
---|
| 42 | [Storable]
|
---|
[12372] | 43 | public readonly TreeNode[] tree;
|
---|
[12332] | 44 |
|
---|
| 45 | [StorableConstructor]
|
---|
| 46 | private RegressionTreeModel(bool serializing) : base(serializing) { }
|
---|
| 47 | // cloning ctor
|
---|
| 48 | public RegressionTreeModel(RegressionTreeModel original, Cloner cloner)
|
---|
| 49 | : base(original, cloner) {
|
---|
[12372] | 50 | this.tree = original.tree; // shallow clone, tree must be readonly
|
---|
[12332] | 51 | }
|
---|
| 52 |
|
---|
[12372] | 53 | public RegressionTreeModel(TreeNode[] tree)
|
---|
| 54 | : base("RegressionTreeModel", "Represents a decision tree for regression.") {
|
---|
[12332] | 55 | this.tree = tree;
|
---|
| 56 | }
|
---|
| 57 |
|
---|
[12589] | 58 | private static double GetPredictionForRow(TreeNode[] t, int nodeIdx, IDataset ds, int row) {
|
---|
[12372] | 59 | var node = t[nodeIdx];
|
---|
| 60 | if (node.varName == TreeNode.NO_VARIABLE)
|
---|
| 61 | return node.val;
|
---|
| 62 | else if (ds.GetDoubleValue(node.varName, row) <= node.val)
|
---|
| 63 | return GetPredictionForRow(t, node.leftIdx, ds, row);
|
---|
[12332] | 64 | else
|
---|
[12372] | 65 | return GetPredictionForRow(t, node.rightIdx, ds, row);
|
---|
[12332] | 66 | }
|
---|
| 67 |
|
---|
| 68 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
| 69 | return new RegressionTreeModel(this, cloner);
|
---|
| 70 | }
|
---|
| 71 |
|
---|
[12589] | 72 | public IEnumerable<double> GetEstimatedValues(IDataset ds, IEnumerable<int> rows) {
|
---|
[12372] | 73 | return rows.Select(r => GetPredictionForRow(tree, 0, ds, r));
|
---|
[12332] | 74 | }
|
---|
| 75 |
|
---|
| 76 | public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
|
---|
| 77 | return new RegressionSolution(this, new RegressionProblemData(problemData));
|
---|
| 78 | }
|
---|
| 79 | }
|
---|
| 80 |
|
---|
| 81 | }
|
---|