1  #region License Information


2  /* HeuristicLab


3  * Copyright (C) 20022019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)


4  * and the BEACON Center for the Study of Evolution in Action.


5  *


6  * This file is part of HeuristicLab.


7  *


8  * HeuristicLab is free software: you can redistribute it and/or modify


9  * it under the terms of the GNU General Public License as published by


10  * the Free Software Foundation, either version 3 of the License, or


11  * (at your option) any later version.


12  *


13  * HeuristicLab is distributed in the hope that it will be useful,


14  * but WITHOUT ANY WARRANTY; without even the implied warranty of


15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the


16  * GNU General Public License for more details.


17  *


18  * You should have received a copy of the GNU General Public License


19  * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.


20  */


21  #endregion


22 


23  using System;


24  using System.Collections.Generic;


25  using System.Collections.ObjectModel;


26  using System.Globalization;


27  using System.Linq;


28  using HeuristicLab.Common;


29  using HeuristicLab.Core;


30  using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;


31  using HeuristicLab.Problems.DataAnalysis;


32  using HeuristicLab.Problems.DataAnalysis.Symbolic;


33  using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;


34  using HEAL.Attic;


35 


36  namespace HeuristicLab.Algorithms.DataAnalysis {


37  [StorableType("C383410E8707486F98F61DFB708B09B5")]


38  [Item("RegressionTreeModel", "Represents a decision tree for regression.")]


39  public sealed class RegressionTreeModel : RegressionModel {


40  public override IEnumerable<string> VariablesUsedForPrediction {


41  get { return tree.Select(t => t.VarName).Where(v => v != TreeNode.NO_VARIABLE); }


42  }


43 


44  // trees are represented as a flat array


45  internal struct TreeNode {


46  public readonly static string NO_VARIABLE = null;


47 


48  public TreeNode(string varName, double val, int leftIdx = 1, int rightIdx = 1, double weightLeft = 1.0)


49  : this() {


50  VarName = varName;


51  Val = val;


52  LeftIdx = leftIdx;


53  RightIdx = rightIdx;


54  WeightLeft = weightLeft;


55  }


56 


57  public string VarName { get; internal set; } // name of the variable for splitting or NO_VARIABLE if terminal node


58  public double Val { get; internal set; } // threshold


59  public int LeftIdx { get; internal set; }


60  public int RightIdx { get; internal set; }


61  public double WeightLeft { get; internal set; } // for partial dependence plots (value in range [0..1] describes the fraction of training samples for the left subtree


62 


63 


64  // necessary because the default implementation of GetHashCode for structs in .NET would only return the hashcode of val here


65  public override int GetHashCode() {


66  return LeftIdx ^ RightIdx ^ Val.GetHashCode();


67  }


68  // necessary because of GetHashCode override


69  public override bool Equals(object obj) {


70  if (obj is TreeNode) {


71  var other = (TreeNode)obj;


72  return Val.Equals(other.Val) &&


73  LeftIdx.Equals(other.LeftIdx) &&


74  RightIdx.Equals(other.RightIdx) &&


75  WeightLeft.Equals(other.WeightLeft) &&


76  EqualStrings(VarName, other.VarName);


77  } else {


78  return false;


79  }


80  }


81 


82  private bool EqualStrings(string a, string b) {


83  return (a == null && b == null) 


84  (a != null && b != null && a.Equals(b));


85  }


86  }


87 


88  // not storable!


89  private TreeNode[] tree;


90 


91  #region old storable format


92  // remove with HL 3.4


93  [Storable(AllowOneWay = true)]


94  // to prevent storing the references to data caches in nodes


95  // seemingly, it is bad (performancewise) to persist tuples (tuples are used as keys in a dictionary)


96  private Tuple<string, double, int, int>[] SerializedTree {


97  // get { return tree.Select(t => Tuple.Create(t.VarName, t.Val, t.LeftIdx, t.RightIdx)).ToArray(); }


98  set { this.tree = value.Select(t => new TreeNode(t.Item1, t.Item2, t.Item3, t.Item4, 1.0)).ToArray(); } // use a weight of 1.0 to indicate that partial dependence cannot be calculated for old models


99  }


100  #endregion


101  #region new storable format


102  [Storable]


103  private string[] SerializedTreeVarNames {


104  get { return tree.Select(t => t.VarName).ToArray(); }


105  set {


106  if (tree == null) tree = new TreeNode[value.Length];


107  for (int i = 0; i < value.Length; i++) {


108  tree[i].VarName = value[i];


109  }


110  }


111  }


112  [Storable]


113  private double[] SerializedTreeValues {


114  get { return tree.Select(t => t.Val).ToArray(); }


115  set {


116  if (tree == null) tree = new TreeNode[value.Length];


117  for (int i = 0; i < value.Length; i++) {


118  tree[i].Val = value[i];


119  }


120  }


121  }


122  [Storable]


123  private int[] SerializedTreeLeftIdx {


124  get { return tree.Select(t => t.LeftIdx).ToArray(); }


125  set {


126  if (tree == null) tree = new TreeNode[value.Length];


127  for (int i = 0; i < value.Length; i++) {


128  tree[i].LeftIdx = value[i];


129  }


130  }


131  }


132  [Storable]


133  private int[] SerializedTreeRightIdx {


134  get { return tree.Select(t => t.RightIdx).ToArray(); }


135  set {


136  if (tree == null) tree = new TreeNode[value.Length];


137  for (int i = 0; i < value.Length; i++) {


138  tree[i].RightIdx = value[i];


139  }


140  }


141  }


142  [Storable]


143  private double[] SerializedTreeWeightLeft {


144  get { return tree.Select(t => t.WeightLeft).ToArray(); }


145  set {


146  if (tree == null) tree = new TreeNode[value.Length];


147  for (int i = 0; i < value.Length; i++) {


148  tree[i].WeightLeft = value[i];


149  }


150  }


151  }


152  #endregion


153 


154  [StorableConstructor]


155  private RegressionTreeModel(StorableConstructorFlag _) : base(_) { }


156  // cloning ctor


157  private RegressionTreeModel(RegressionTreeModel original, Cloner cloner)


158  : base(original, cloner) {


159  if (original.tree != null) {


160  this.tree = new TreeNode[original.tree.Length];


161  Array.Copy(original.tree, this.tree, this.tree.Length);


162  }


163  }


164 


165  internal RegressionTreeModel(TreeNode[] tree, string targetVariable)


166  : base(targetVariable, "RegressionTreeModel", "Represents a decision tree for regression.") {


167  this.tree = tree;


168  }


169 


170  private static double GetPredictionForRow(TreeNode[] t, ReadOnlyCollection<double>[] columnCache, int nodeIdx, int row) {


171  while (nodeIdx != 1) {


172  var node = t[nodeIdx];


173  if (node.VarName == TreeNode.NO_VARIABLE)


174  return node.Val;


175  if (columnCache[nodeIdx] == null  double.IsNaN(columnCache[nodeIdx][row])) {


176  if (node.WeightLeft.IsAlmost(1.0)) throw new InvalidOperationException("Cannot calculate partial dependence for trees loaded from older versions of HeuristicLab.");


177  // weighted average for partial dependence plot (recursive here because we need to calculate both subtrees)


178  return node.WeightLeft * GetPredictionForRow(t, columnCache, node.LeftIdx, row) +


179  (1.0  node.WeightLeft) * GetPredictionForRow(t, columnCache, node.RightIdx, row);


180  } else if (columnCache[nodeIdx][row] <= node.Val)


181  nodeIdx = node.LeftIdx;


182  else


183  nodeIdx = node.RightIdx;


184  }


185  throw new InvalidOperationException("Invalid tree in RegressionTreeModel");


186  }


187 


188  public override IDeepCloneable Clone(Cloner cloner) {


189  return new RegressionTreeModel(this, cloner);


190  }


191 


192  public override IEnumerable<double> GetEstimatedValues(IDataset ds, IEnumerable<int> rows) {


193  // lookup columns for variableNames in one pass over the tree to speed up evaluation later on


194  ReadOnlyCollection<double>[] columnCache = new ReadOnlyCollection<double>[tree.Length];


195 


196  for (int i = 0; i < tree.Length; i++) {


197  if (tree[i].VarName != TreeNode.NO_VARIABLE) {


198  // tree models also support calculating estimations if not all variables used for training are available in the dataset


199  if (ds.ColumnNames.Contains(tree[i].VarName))


200  columnCache[i] = ds.GetReadOnlyDoubleValues(tree[i].VarName);


201  }


202  }


203  return rows.Select(r => GetPredictionForRow(tree, columnCache, 0, r));


204  }


205 


206  public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {


207  return new RegressionSolution(this, new RegressionProblemData(problemData));


208  }


209 


210  // mainly for debugging


211  public override string ToString() {


212  return TreeToString(0, "");


213  }


214 


215  /// <summary>


216  /// Transforms the tree model to a symbolic regression solution


217  /// </summary>


218  /// <param name="problemData"></param>


219  /// <returns>A new symbolic regression solution which matches the tree model</returns>


220  public ISymbolicRegressionSolution CreateSymbolicRegressionSolution(IRegressionProblemData problemData) {


221  return CreateSymbolicRegressionModel().CreateRegressionSolution(problemData);


222  }


223 


224  /// <summary>


225  /// Transforms the tree model to a symbolic regression model


226  /// </summary>


227  /// <returns>A new symbolic regression model which matches the tree model</returns>


228  public SymbolicRegressionModel CreateSymbolicRegressionModel() {


229  var rootSy = new ProgramRootSymbol();


230  var startSy = new StartSymbol();


231  var varCondSy = new VariableCondition() { IgnoreSlope = true };


232  var constSy = new Constant();


233 


234  var startNode = startSy.CreateTreeNode();


235  startNode.AddSubtree(CreateSymbolicRegressionTreeRecursive(tree, 0, varCondSy, constSy));


236  var rootNode = rootSy.CreateTreeNode();


237  rootNode.AddSubtree(startNode);


238  return new SymbolicRegressionModel(TargetVariable, new SymbolicExpressionTree(rootNode), new SymbolicDataAnalysisExpressionTreeLinearInterpreter());


239  }


240 


241  private ISymbolicExpressionTreeNode CreateSymbolicRegressionTreeRecursive(TreeNode[] treeNodes, int nodeIdx, VariableCondition varCondSy, Constant constSy) {


242  var curNode = treeNodes[nodeIdx];


243  if (curNode.VarName == TreeNode.NO_VARIABLE) {


244  var node = (ConstantTreeNode)constSy.CreateTreeNode();


245  node.Value = curNode.Val;


246  return node;


247  } else {


248  var node = (VariableConditionTreeNode)varCondSy.CreateTreeNode();


249  node.VariableName = curNode.VarName;


250  node.Threshold = curNode.Val;


251 


252  var left = CreateSymbolicRegressionTreeRecursive(treeNodes, curNode.LeftIdx, varCondSy, constSy);


253  var right = CreateSymbolicRegressionTreeRecursive(treeNodes, curNode.RightIdx, varCondSy, constSy);


254  node.AddSubtree(left);


255  node.AddSubtree(right);


256  return node;


257  }


258  }


259 


260 


261  private string TreeToString(int idx, string part) {


262  var n = tree[idx];


263  if (n.VarName == TreeNode.NO_VARIABLE) {


264  return string.Format(CultureInfo.InvariantCulture, "{0} > {1:F}{2}", part, n.Val, Environment.NewLine);


265  } else {


266  return


267  TreeToString(n.LeftIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F} ({4:N3})", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val, n.WeightLeft))


268  + TreeToString(n.RightIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} > {3:F} ({4:N3}))", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val, 1.0  n.WeightLeft));


269  }


270  }


271 


272  }


273  }

