#region License Information /* HeuristicLab * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * and the BEACON Center for the Study of Evolution in Action. * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Collections.ObjectModel; using System.Globalization; using System.Linq; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using HeuristicLab.Problems.DataAnalysis; using HeuristicLab.Problems.DataAnalysis.Symbolic; using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression; namespace HeuristicLab.Algorithms.DataAnalysis { [StorableClass] [Item("RegressionTreeModel", "Represents a decision tree for regression.")] public sealed class RegressionTreeModel : RegressionModel { public override IEnumerable VariablesUsedForPrediction { get { return tree.Select(t => t.VarName).Where(v => v != TreeNode.NO_VARIABLE); } } // trees are represented as a flat array internal struct TreeNode { public readonly static string NO_VARIABLE = null; public TreeNode(string varName, double val, int leftIdx = -1, int rightIdx = -1, double weightLeft = -1.0) : this() { VarName = varName; Val = val; LeftIdx = leftIdx; RightIdx = rightIdx; WeightLeft = weightLeft; } public string VarName { get; internal set; } // name of the variable for splitting or NO_VARIABLE if terminal node public double Val { get; internal set; } // threshold public int LeftIdx { get; internal set; } public int RightIdx { get; internal set; } public double WeightLeft { get; internal set; } // for partial dependence plots (value in range [0..1] describes the fraction of training samples for the left sub-tree // necessary because the default implementation of GetHashCode for structs in .NET would only return the hashcode of val here public override int GetHashCode() { return LeftIdx ^ RightIdx ^ Val.GetHashCode(); } // necessary because of GetHashCode override public override bool Equals(object obj) { if (obj is TreeNode) { var other = (TreeNode)obj; return Val.Equals(other.Val) && LeftIdx.Equals(other.LeftIdx) && RightIdx.Equals(other.RightIdx) && WeightLeft.Equals(other.WeightLeft) && EqualStrings(VarName, other.VarName); } else { return false; } } private bool EqualStrings(string a, string b) { return (a == null && b == null) || (a != null && b != null && a.Equals(b)); } } // not storable! private TreeNode[] tree; #region old storable format // remove with HL 3.4 [Storable(AllowOneWay = true)] // to prevent storing the references to data caches in nodes // seemingly, it is bad (performance-wise) to persist tuples (tuples are used as keys in a dictionary) private Tuple[] SerializedTree { // get { return tree.Select(t => Tuple.Create(t.VarName, t.Val, t.LeftIdx, t.RightIdx)).ToArray(); } set { this.tree = value.Select(t => new TreeNode(t.Item1, t.Item2, t.Item3, t.Item4, -1.0)).ToArray(); } // use a weight of -1.0 to indicate that partial dependence cannot be calculated for old models } #endregion #region new storable format [Storable] private string[] SerializedTreeVarNames { get { return tree.Select(t => t.VarName).ToArray(); } set { if (tree == null) tree = new TreeNode[value.Length]; for (int i = 0; i < value.Length; i++) { tree[i].VarName = value[i]; } } } [Storable] private double[] SerializedTreeValues { get { return tree.Select(t => t.Val).ToArray(); } set { if (tree == null) tree = new TreeNode[value.Length]; for (int i = 0; i < value.Length; i++) { tree[i].Val = value[i]; } } } [Storable] private int[] SerializedTreeLeftIdx { get { return tree.Select(t => t.LeftIdx).ToArray(); } set { if (tree == null) tree = new TreeNode[value.Length]; for (int i = 0; i < value.Length; i++) { tree[i].LeftIdx = value[i]; } } } [Storable] private int[] SerializedTreeRightIdx { get { return tree.Select(t => t.RightIdx).ToArray(); } set { if (tree == null) tree = new TreeNode[value.Length]; for (int i = 0; i < value.Length; i++) { tree[i].RightIdx = value[i]; } } } [Storable] private double[] SerializedTreeWeightLeft { get { return tree.Select(t => t.WeightLeft).ToArray(); } set { if (tree == null) tree = new TreeNode[value.Length]; for (int i = 0; i < value.Length; i++) { tree[i].WeightLeft = value[i]; } } } #endregion [StorableConstructor] private RegressionTreeModel(bool serializing) : base(serializing) { } // cloning ctor private RegressionTreeModel(RegressionTreeModel original, Cloner cloner) : base(original, cloner) { if (original.tree != null) { this.tree = new TreeNode[original.tree.Length]; Array.Copy(original.tree, this.tree, this.tree.Length); } } internal RegressionTreeModel(TreeNode[] tree, string targetVariable) : base(targetVariable, "RegressionTreeModel", "Represents a decision tree for regression.") { this.tree = tree; } private static double GetPredictionForRow(TreeNode[] t, ReadOnlyCollection[] columnCache, int nodeIdx, int row) { while (nodeIdx != -1) { var node = t[nodeIdx]; if (node.VarName == TreeNode.NO_VARIABLE) return node.Val; if (columnCache[nodeIdx] == null || double.IsNaN(columnCache[nodeIdx][row])) { if (node.WeightLeft.IsAlmost(-1.0)) throw new InvalidOperationException("Cannot calculate partial dependence for trees loaded from older versions of HeuristicLab."); // weighted average for partial dependence plot (recursive here because we need to calculate both sub-trees) return node.WeightLeft * GetPredictionForRow(t, columnCache, node.LeftIdx, row) + (1.0 - node.WeightLeft) * GetPredictionForRow(t, columnCache, node.RightIdx, row); } else if (columnCache[nodeIdx][row] <= node.Val) nodeIdx = node.LeftIdx; else nodeIdx = node.RightIdx; } throw new InvalidOperationException("Invalid tree in RegressionTreeModel"); } public override IDeepCloneable Clone(Cloner cloner) { return new RegressionTreeModel(this, cloner); } public override IEnumerable GetEstimatedValues(IDataset ds, IEnumerable rows) { // lookup columns for variableNames in one pass over the tree to speed up evaluation later on ReadOnlyCollection[] columnCache = new ReadOnlyCollection[tree.Length]; for (int i = 0; i < tree.Length; i++) { if (tree[i].VarName != TreeNode.NO_VARIABLE) { // tree models also support calculating estimations if not all variables used for training are available in the dataset if (ds.ColumnNames.Contains(tree[i].VarName)) columnCache[i] = ds.GetReadOnlyDoubleValues(tree[i].VarName); } } return rows.Select(r => GetPredictionForRow(tree, columnCache, 0, r)); } public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { return new RegressionSolution(this, new RegressionProblemData(problemData)); } // mainly for debugging public override string ToString() { return TreeToString(0, ""); } /// /// Transforms the tree model to a symbolic regression solution /// /// /// A new symbolic regression solution which matches the tree model public ISymbolicRegressionSolution CreateSymbolicRegressionSolution(IRegressionProblemData problemData) { var rootSy = new ProgramRootSymbol(); var startSy = new StartSymbol(); var varCondSy = new VariableCondition() { IgnoreSlope = true }; var constSy = new Constant(); var startNode = startSy.CreateTreeNode(); startNode.AddSubtree(CreateSymbolicRegressionTreeRecursive(tree, 0, varCondSy, constSy)); var rootNode = rootSy.CreateTreeNode(); rootNode.AddSubtree(startNode); var model = new SymbolicRegressionModel(TargetVariable, new SymbolicExpressionTree(rootNode), new SymbolicDataAnalysisExpressionTreeLinearInterpreter()); return model.CreateRegressionSolution(problemData); } private ISymbolicExpressionTreeNode CreateSymbolicRegressionTreeRecursive(TreeNode[] treeNodes, int nodeIdx, VariableCondition varCondSy, Constant constSy) { var curNode = treeNodes[nodeIdx]; if (curNode.VarName == TreeNode.NO_VARIABLE) { var node = (ConstantTreeNode)constSy.CreateTreeNode(); node.Value = curNode.Val; return node; } else { var node = (VariableConditionTreeNode)varCondSy.CreateTreeNode(); node.VariableName = curNode.VarName; node.Threshold = curNode.Val; var left = CreateSymbolicRegressionTreeRecursive(treeNodes, curNode.LeftIdx, varCondSy, constSy); var right = CreateSymbolicRegressionTreeRecursive(treeNodes, curNode.RightIdx, varCondSy, constSy); node.AddSubtree(left); node.AddSubtree(right); return node; } } private string TreeToString(int idx, string part) { var n = tree[idx]; if (n.VarName == TreeNode.NO_VARIABLE) { return string.Format(CultureInfo.InvariantCulture, "{0} -> {1:F}{2}", part, n.Val, Environment.NewLine); } else { return TreeToString(n.LeftIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F} ({4:N3})", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val, n.WeightLeft)) + TreeToString(n.RightIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} > {3:F} ({4:N3}))", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val, 1.0 - n.WeightLeft)); } } } }