Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
06/15/16 10:02:15 (9 years ago)
Author:
gkronber
Message:

#2612: extended GBT to support calculation of partial dependence (as described in the greedy function approximation paper), changed persistence of regression tree models and added two unit tests.

Location:
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs

    r13065 r13895  
    180180
    181181
    182     // processes potential splits from the queue as long as splits are left and the maximum size of the tree is not reached
     182    // processes potential splits from the queue as long as splits are remaining and the maximum size of the tree is not reached
    183183    private void CreateRegressionTreeFromQueue(int maxNodes, ILossFunction lossFunction) {
    184184      while (queue.Any() && curTreeNodeIdx + 1 < maxNodes) { // two nodes are created in each loop
     
    204204
    205205        // overwrite existing leaf node with an internal node
    206         tree[f.ParentNodeIdx] = new RegressionTreeModel.TreeNode(f.SplittingVariable, f.SplittingThreshold, leftTreeIdx, rightTreeIdx);
     206        tree[f.ParentNodeIdx] = new RegressionTreeModel.TreeNode(f.SplittingVariable, f.SplittingThreshold, leftTreeIdx, rightTreeIdx, weightLeft: (splitIdx - startIdx + 1) / (double)(endIdx - startIdx + 1));
    207207      }
    208208    }
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs

    r13030 r13895  
    4040      public readonly static string NO_VARIABLE = null;
    4141
    42       public TreeNode(string varName, double val, int leftIdx = -1, int rightIdx = -1)
     42      public TreeNode(string varName, double val, int leftIdx = -1, int rightIdx = -1, double weightLeft = -1.0)
    4343        : this() {
    4444        VarName = varName;
     
    4646        LeftIdx = leftIdx;
    4747        RightIdx = rightIdx;
    48       }
    49 
    50       public string VarName { get; private set; } // name of the variable for splitting or NO_VARIABLE if terminal node
    51       public double Val { get; private set; } // threshold
    52       public int LeftIdx { get; private set; }
    53       public int RightIdx { get; private set; }
     48        WeightLeft = weightLeft;
     49      }
     50
     51      public string VarName { get; internal set; } // name of the variable for splitting or NO_VARIABLE if terminal node
     52      public double Val { get; internal set; } // threshold
     53      public int LeftIdx { get; internal set; }
     54      public int RightIdx { get; internal set; }
     55      public double WeightLeft { get; internal set; } // for partial dependence plots (value in range [0..1] describes the fraction of training samples for the left sub-tree
     56
    5457
    5558      // necessary because the default implementation of GetHashCode for structs in .NET would only return the hashcode of val here
     
    6467            LeftIdx.Equals(other.LeftIdx) &&
    6568            RightIdx.Equals(other.RightIdx) &&
     69            WeightLeft.Equals(other.WeightLeft) &&
    6670            EqualStrings(VarName, other.VarName);
    6771        } else {
     
    7983    private TreeNode[] tree;
    8084
    81     [Storable]
     85    #region old storable format
     86    // remove with HL 3.4
     87    [Storable(AllowOneWay = true)]
    8288    // to prevent storing the references to data caches in nodes
    83     // TODO seemingly it is bad (performance-wise) to persist tuples (tuples are used as keys in a dictionary)
     89    // seemingly, it is bad (performance-wise) to persist tuples (tuples are used as keys in a dictionary)
    8490    private Tuple<string, double, int, int>[] SerializedTree {
    85       get { return tree.Select(t => Tuple.Create(t.VarName, t.Val, t.LeftIdx, t.RightIdx)).ToArray(); }
    86       set { this.tree = value.Select(t => new TreeNode(t.Item1, t.Item2, t.Item3, t.Item4)).ToArray(); }
    87     }
     91      // get { return tree.Select(t => Tuple.Create(t.VarName, t.Val, t.LeftIdx, t.RightIdx)).ToArray(); }
     92      set { this.tree = value.Select(t => new TreeNode(t.Item1, t.Item2, t.Item3, t.Item4, -1.0)).ToArray(); } // use a weight of -1.0 to indicate that partial dependence cannot be calculated for old models
     93    }
     94    #endregion
     95    #region new storable format
     96    [Storable]
     97    private string[] SerializedTreeVarNames {
     98      get { return tree.Select(t => t.VarName).ToArray(); }
     99      set {
     100        if (tree == null) tree = new TreeNode[value.Length];
     101        for (int i = 0; i < value.Length; i++) {
     102          tree[i].VarName = value[i];
     103        }
     104      }
     105    }
     106    [Storable]
     107    private double[] SerializedTreeValues {
     108      get { return tree.Select(t => t.Val).ToArray(); }
     109      set {
     110        if (tree == null) tree = new TreeNode[value.Length];
     111        for (int i = 0; i < value.Length; i++) {
     112          tree[i].Val = value[i];
     113        }
     114      }
     115    }
     116    [Storable]
     117    private int[] SerializedTreeLeftIdx {
     118      get { return tree.Select(t => t.LeftIdx).ToArray(); }
     119      set {
     120        if (tree == null) tree = new TreeNode[value.Length];
     121        for (int i = 0; i < value.Length; i++) {
     122          tree[i].LeftIdx = value[i];
     123        }
     124      }
     125    }
     126    [Storable]
     127    private int[] SerializedTreeRightIdx {
     128      get { return tree.Select(t => t.RightIdx).ToArray(); }
     129      set {
     130        if (tree == null) tree = new TreeNode[value.Length];
     131        for (int i = 0; i < value.Length; i++) {
     132          tree[i].RightIdx = value[i];
     133        }
     134      }
     135    }
     136    [Storable]
     137    private double[] SerializedTreeWeightLeft {
     138      get { return tree.Select(t => t.WeightLeft).ToArray(); }
     139      set {
     140        if (tree == null) tree = new TreeNode[value.Length];
     141        for (int i = 0; i < value.Length; i++) {
     142          tree[i].WeightLeft = value[i];
     143        }
     144      }
     145    }
     146    #endregion
     147
     148
     149
    88150
    89151    [StorableConstructor]
     
    108170        if (node.VarName == TreeNode.NO_VARIABLE)
    109171          return node.Val;
    110 
    111         if (columnCache[nodeIdx][row] <= node.Val)
     172        if (columnCache[nodeIdx] == null) {
     173          if (node.WeightLeft.IsAlmost(-1.0)) throw new InvalidOperationException("Cannot calculate partial dependence for trees loaded from older versions of HeuristicLab.");
     174          // weighted average for partial dependence plot (recursive here because we need to calculate both sub-trees)
     175          return node.WeightLeft * GetPredictionForRow(t, columnCache, node.LeftIdx, row) +
     176                 (1.0 - node.WeightLeft) * GetPredictionForRow(t, columnCache, node.RightIdx, row);
     177        } else if (columnCache[nodeIdx][row] <= node.Val)
    112178          nodeIdx = node.LeftIdx;
    113179        else
     
    127193      for (int i = 0; i < tree.Length; i++) {
    128194        if (tree[i].VarName != TreeNode.NO_VARIABLE) {
    129           columnCache[i] = ds.GetReadOnlyDoubleValues(tree[i].VarName);
     195          // tree models also support calculating estimations if not all variables used for training are available in the dataset
     196          if (ds.ColumnNames.Contains(tree[i].VarName))
     197            columnCache[i] = ds.GetReadOnlyDoubleValues(tree[i].VarName);
    130198        }
    131199      }
     
    148216      } else {
    149217        return
    150           TreeToString(n.LeftIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F}", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val))
    151         + TreeToString(n.RightIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} >  {3:F}", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val));
     218          TreeToString(n.LeftIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F} ({4:N3})", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val, n.WeightLeft))
     219        + TreeToString(n.RightIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2}  >  {3:F} ({4:N3}))", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val, 1.0 - n.WeightLeft));
    152220      }
    153221    }
Note: See TracChangeset for help on using the changeset viewer.