Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2701_MemPRAlgorithm/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs

Last change on this file was 14345, checked in by gkronber, 8 years ago

#2690: implemented methods to generate symbolic expression tree solutions for decision tree models (random forest and gradient boosted) as well as views which make it possible to inspect each of the individual trees in a GBT and RF solution

File size: 11.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 * and the BEACON Center for the Study of Evolution in Action.
5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21#endregion
22
23using System;
24using System.Collections.Generic;
25using System.Collections.ObjectModel;
26using System.Globalization;
27using System.Linq;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32using HeuristicLab.Problems.DataAnalysis;
33using HeuristicLab.Problems.DataAnalysis.Symbolic;
34using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
35
36namespace HeuristicLab.Algorithms.DataAnalysis {
37  [StorableClass]
38  [Item("RegressionTreeModel", "Represents a decision tree for regression.")]
39  public sealed class RegressionTreeModel : RegressionModel {
40    public override IEnumerable<string> VariablesUsedForPrediction {
41      get { return tree.Select(t => t.VarName).Where(v => v != TreeNode.NO_VARIABLE); }
42    }
43
44    // trees are represented as a flat array   
45    internal struct TreeNode {
46      public readonly static string NO_VARIABLE = null;
47
48      public TreeNode(string varName, double val, int leftIdx = -1, int rightIdx = -1, double weightLeft = -1.0)
49        : this() {
50        VarName = varName;
51        Val = val;
52        LeftIdx = leftIdx;
53        RightIdx = rightIdx;
54        WeightLeft = weightLeft;
55      }
56
57      public string VarName { get; internal set; } // name of the variable for splitting or NO_VARIABLE if terminal node
58      public double Val { get; internal set; } // threshold
59      public int LeftIdx { get; internal set; }
60      public int RightIdx { get; internal set; }
61      public double WeightLeft { get; internal set; } // for partial dependence plots (value in range [0..1] describes the fraction of training samples for the left sub-tree
62
63
64      // necessary because the default implementation of GetHashCode for structs in .NET would only return the hashcode of val here
65      public override int GetHashCode() {
66        return LeftIdx ^ RightIdx ^ Val.GetHashCode();
67      }
68      // necessary because of GetHashCode override
69      public override bool Equals(object obj) {
70        if (obj is TreeNode) {
71          var other = (TreeNode)obj;
72          return Val.Equals(other.Val) &&
73            LeftIdx.Equals(other.LeftIdx) &&
74            RightIdx.Equals(other.RightIdx) &&
75            WeightLeft.Equals(other.WeightLeft) &&
76            EqualStrings(VarName, other.VarName);
77        } else {
78          return false;
79        }
80      }
81
82      private bool EqualStrings(string a, string b) {
83        return (a == null && b == null) ||
84               (a != null && b != null && a.Equals(b));
85      }
86    }
87
88    // not storable!
89    private TreeNode[] tree;
90
91    #region old storable format
92    // remove with HL 3.4
93    [Storable(AllowOneWay = true)]
94    // to prevent storing the references to data caches in nodes
95    // seemingly, it is bad (performance-wise) to persist tuples (tuples are used as keys in a dictionary)
96    private Tuple<string, double, int, int>[] SerializedTree {
97      // get { return tree.Select(t => Tuple.Create(t.VarName, t.Val, t.LeftIdx, t.RightIdx)).ToArray(); }
98      set { this.tree = value.Select(t => new TreeNode(t.Item1, t.Item2, t.Item3, t.Item4, -1.0)).ToArray(); } // use a weight of -1.0 to indicate that partial dependence cannot be calculated for old models
99    }
100    #endregion
101    #region new storable format
102    [Storable]
103    private string[] SerializedTreeVarNames {
104      get { return tree.Select(t => t.VarName).ToArray(); }
105      set {
106        if (tree == null) tree = new TreeNode[value.Length];
107        for (int i = 0; i < value.Length; i++) {
108          tree[i].VarName = value[i];
109        }
110      }
111    }
112    [Storable]
113    private double[] SerializedTreeValues {
114      get { return tree.Select(t => t.Val).ToArray(); }
115      set {
116        if (tree == null) tree = new TreeNode[value.Length];
117        for (int i = 0; i < value.Length; i++) {
118          tree[i].Val = value[i];
119        }
120      }
121    }
122    [Storable]
123    private int[] SerializedTreeLeftIdx {
124      get { return tree.Select(t => t.LeftIdx).ToArray(); }
125      set {
126        if (tree == null) tree = new TreeNode[value.Length];
127        for (int i = 0; i < value.Length; i++) {
128          tree[i].LeftIdx = value[i];
129        }
130      }
131    }
132    [Storable]
133    private int[] SerializedTreeRightIdx {
134      get { return tree.Select(t => t.RightIdx).ToArray(); }
135      set {
136        if (tree == null) tree = new TreeNode[value.Length];
137        for (int i = 0; i < value.Length; i++) {
138          tree[i].RightIdx = value[i];
139        }
140      }
141    }
142    [Storable]
143    private double[] SerializedTreeWeightLeft {
144      get { return tree.Select(t => t.WeightLeft).ToArray(); }
145      set {
146        if (tree == null) tree = new TreeNode[value.Length];
147        for (int i = 0; i < value.Length; i++) {
148          tree[i].WeightLeft = value[i];
149        }
150      }
151    }
152    #endregion
153
154    [StorableConstructor]
155    private RegressionTreeModel(bool serializing) : base(serializing) { }
156    // cloning ctor
157    private RegressionTreeModel(RegressionTreeModel original, Cloner cloner)
158      : base(original, cloner) {
159      if (original.tree != null) {
160        this.tree = new TreeNode[original.tree.Length];
161        Array.Copy(original.tree, this.tree, this.tree.Length);
162      }
163    }
164
165    internal RegressionTreeModel(TreeNode[] tree, string targetVariable)
166      : base(targetVariable, "RegressionTreeModel", "Represents a decision tree for regression.") {
167      this.tree = tree;
168    }
169
170    private static double GetPredictionForRow(TreeNode[] t, ReadOnlyCollection<double>[] columnCache, int nodeIdx, int row) {
171      while (nodeIdx != -1) {
172        var node = t[nodeIdx];
173        if (node.VarName == TreeNode.NO_VARIABLE)
174          return node.Val;
175        if (columnCache[nodeIdx] == null || double.IsNaN(columnCache[nodeIdx][row])) {
176          if (node.WeightLeft.IsAlmost(-1.0)) throw new InvalidOperationException("Cannot calculate partial dependence for trees loaded from older versions of HeuristicLab.");
177          // weighted average for partial dependence plot (recursive here because we need to calculate both sub-trees)
178          return node.WeightLeft * GetPredictionForRow(t, columnCache, node.LeftIdx, row) +
179                 (1.0 - node.WeightLeft) * GetPredictionForRow(t, columnCache, node.RightIdx, row);
180        } else if (columnCache[nodeIdx][row] <= node.Val)
181          nodeIdx = node.LeftIdx;
182        else
183          nodeIdx = node.RightIdx;
184      }
185      throw new InvalidOperationException("Invalid tree in RegressionTreeModel");
186    }
187
188    public override IDeepCloneable Clone(Cloner cloner) {
189      return new RegressionTreeModel(this, cloner);
190    }
191
192    public override IEnumerable<double> GetEstimatedValues(IDataset ds, IEnumerable<int> rows) {
193      // lookup columns for variableNames in one pass over the tree to speed up evaluation later on
194      ReadOnlyCollection<double>[] columnCache = new ReadOnlyCollection<double>[tree.Length];
195
196      for (int i = 0; i < tree.Length; i++) {
197        if (tree[i].VarName != TreeNode.NO_VARIABLE) {
198          // tree models also support calculating estimations if not all variables used for training are available in the dataset
199          if (ds.ColumnNames.Contains(tree[i].VarName))
200            columnCache[i] = ds.GetReadOnlyDoubleValues(tree[i].VarName);
201        }
202      }
203      return rows.Select(r => GetPredictionForRow(tree, columnCache, 0, r));
204    }
205
206    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
207      return new RegressionSolution(this, new RegressionProblemData(problemData));
208    }
209
210    // mainly for debugging
211    public override string ToString() {
212      return TreeToString(0, "");
213    }
214
215    /// <summary>
216    /// Transforms the tree model to a symbolic regression solution
217    /// </summary>
218    /// <param name="problemData"></param>
219    /// <returns>A new symbolic regression solution which matches the tree model</returns>
220    public ISymbolicRegressionSolution CreateSymbolicRegressionSolution(IRegressionProblemData problemData) {
221      var rootSy = new ProgramRootSymbol();
222      var startSy = new StartSymbol();
223      var varCondSy = new VariableCondition() { IgnoreSlope = true };
224      var constSy = new Constant();
225
226      var startNode = startSy.CreateTreeNode();
227      startNode.AddSubtree(CreateSymbolicRegressionTreeRecursive(tree, 0, varCondSy, constSy));
228      var rootNode = rootSy.CreateTreeNode();
229      rootNode.AddSubtree(startNode);
230      var model = new SymbolicRegressionModel(TargetVariable, new SymbolicExpressionTree(rootNode), new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
231      return model.CreateRegressionSolution(problemData);
232    }
233
234    private ISymbolicExpressionTreeNode CreateSymbolicRegressionTreeRecursive(TreeNode[] treeNodes, int nodeIdx, VariableCondition varCondSy, Constant constSy) {
235      var curNode = treeNodes[nodeIdx];
236      if (curNode.VarName == TreeNode.NO_VARIABLE) {
237        var node = (ConstantTreeNode)constSy.CreateTreeNode();
238        node.Value = curNode.Val;
239        return node;
240      } else {
241        var node = (VariableConditionTreeNode)varCondSy.CreateTreeNode();
242        node.VariableName = curNode.VarName;
243        node.Threshold = curNode.Val;
244
245        var left = CreateSymbolicRegressionTreeRecursive(treeNodes, curNode.LeftIdx, varCondSy, constSy);
246        var right = CreateSymbolicRegressionTreeRecursive(treeNodes, curNode.RightIdx, varCondSy, constSy);
247        node.AddSubtree(left);
248        node.AddSubtree(right);
249        return node;
250      }
251    }
252
253
254    private string TreeToString(int idx, string part) {
255      var n = tree[idx];
256      if (n.VarName == TreeNode.NO_VARIABLE) {
257        return string.Format(CultureInfo.InvariantCulture, "{0} -> {1:F}{2}", part, n.Val, Environment.NewLine);
258      } else {
259        return
260          TreeToString(n.LeftIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F} ({4:N3})", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val, n.WeightLeft))
261        + TreeToString(n.RightIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2}  >  {3:F} ({4:N3}))", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val, 1.0 - n.WeightLeft));
262      }
263    }
264
265  }
266}
Note: See TracBrowser for help on using the repository browser.