using System; using System.Collections.Generic; using System.Linq; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HeuristicLab.Problems.DataAnalysis; namespace GradientBoostedTrees { [StorableClass] [Item("RegressionTreeModel", "Represents a decision tree for regression.")] // TODO: Implement a view for this public class RegressionTreeModel : NamedItem, IRegressionModel { // trees are represented as a flat array // object-graph-travesal has problems if this is defined as a struct. TODO investigate... //[StorableClass] public struct TreeNode { public readonly static string NO_VARIABLE = string.Empty; //[Storable] public string varName; // name of the variable for splitting or -1 if terminal node //[Storable] public double val; // threshold //[Storable] public int leftIdx; //[Storable] public int rightIdx; //public TreeNode() { // varName = NO_VARIABLE; // leftIdx = -1; // rightIdx = -1; //} //[StorableConstructor] //private TreeNode(bool deserializing) { } public override int GetHashCode() { return (leftIdx * rightIdx) ^ val.GetHashCode(); } } [Storable] public readonly TreeNode[] tree; [StorableConstructor] private RegressionTreeModel(bool serializing) : base(serializing) { } // cloning ctor public RegressionTreeModel(RegressionTreeModel original, Cloner cloner) : base(original, cloner) { this.tree = original.tree; // shallow clone, tree must be readonly } public RegressionTreeModel(TreeNode[] tree) : base("RegressionTreeModel", "Represents a decision tree for regression.") { this.tree = tree; } private static double GetPredictionForRow(TreeNode[] t, int nodeIdx, IDataset ds, int row) { var node = t[nodeIdx]; if (node.varName == TreeNode.NO_VARIABLE) return node.val; else if (ds.GetDoubleValue(node.varName, row) <= node.val) return GetPredictionForRow(t, node.leftIdx, ds, row); else return GetPredictionForRow(t, node.rightIdx, ds, row); } public override IDeepCloneable Clone(Cloner cloner) { return new RegressionTreeModel(this, cloner); } public IEnumerable GetEstimatedValues(IDataset ds, IEnumerable rows) { return rows.Select(r => GetPredictionForRow(tree, 0, ds, r)); } public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { return new RegressionSolution(this, new RegressionProblemData(problemData)); } } }