using System.Collections.Generic; using System.Linq; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HeuristicLab.Problems.DataAnalysis; namespace GradientBoostedTrees { [StorableClass] [Item("RegressionTreeModel", "Represents a decision tree for regression.")] // TODO: Implement a view for this public class RegressionTreeModel : NamedItem, IRegressionModel { [StorableClass] public class TreeNode { public readonly static string NO_VARIABLE = string.Empty; [Storable] public readonly string varName; // name of the variable for splitting or -1 if terminal node [Storable] public readonly double val; // threshold [Storable] public readonly TreeNode left; [Storable] public readonly TreeNode right; public TreeNode(string varName, double value, TreeNode left = null, TreeNode right = null) { this.varName = varName; this.val = value; this.left = left; this.right = right; } } [Storable] public readonly TreeNode tree; [StorableConstructor] private RegressionTreeModel(bool serializing) : base(serializing) { } // cloning ctor public RegressionTreeModel(RegressionTreeModel original, Cloner cloner) : base(original, cloner) { this.tree = original.tree; } public RegressionTreeModel(TreeNode tree) : base() { this.name = ItemName; this.description = ItemDescription; this.tree = tree; } private static double GetPredictionForRow(TreeNode t, Dataset ds, int row) { if (t.varName == TreeNode.NO_VARIABLE) return t.val; else if (ds.GetDoubleValue(t.varName, row) <= t.val) return GetPredictionForRow(t.left, ds, row); else return GetPredictionForRow(t.right, ds, row); } public override IDeepCloneable Clone(Cloner cloner) { return new RegressionTreeModel(this, cloner); } public IEnumerable GetEstimatedValues(Dataset ds, IEnumerable rows) { return rows.Select(r => GetPredictionForRow(tree, ds, r)); } public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { return new RegressionSolution(this, new RegressionProblemData(problemData)); } } }