Free cookie consent management tool by TermsFeed Policy Generator

Changeset 13895 for trunk/sources


Ignore:
Timestamp:
06/15/16 10:02:15 (9 years ago)
Author:
gkronber
Message:

#2612: extended GBT to support calculation of partial dependence (as described in the greedy function approximation paper), changed persistence of regression tree models and added two unit tests.

Location:
trunk/sources
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs

    r13065 r13895  
    180180
    181181
    182     // processes potential splits from the queue as long as splits are left and the maximum size of the tree is not reached
     182    // processes potential splits from the queue as long as splits are remaining and the maximum size of the tree is not reached
    183183    private void CreateRegressionTreeFromQueue(int maxNodes, ILossFunction lossFunction) {
    184184      while (queue.Any() && curTreeNodeIdx + 1 < maxNodes) { // two nodes are created in each loop
     
    204204
    205205        // overwrite existing leaf node with an internal node
    206         tree[f.ParentNodeIdx] = new RegressionTreeModel.TreeNode(f.SplittingVariable, f.SplittingThreshold, leftTreeIdx, rightTreeIdx);
     206        tree[f.ParentNodeIdx] = new RegressionTreeModel.TreeNode(f.SplittingVariable, f.SplittingThreshold, leftTreeIdx, rightTreeIdx, weightLeft: (splitIdx - startIdx + 1) / (double)(endIdx - startIdx + 1));
    207207      }
    208208    }
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs

    r13030 r13895  
    4040      public readonly static string NO_VARIABLE = null;
    4141
    42       public TreeNode(string varName, double val, int leftIdx = -1, int rightIdx = -1)
     42      public TreeNode(string varName, double val, int leftIdx = -1, int rightIdx = -1, double weightLeft = -1.0)
    4343        : this() {
    4444        VarName = varName;
     
    4646        LeftIdx = leftIdx;
    4747        RightIdx = rightIdx;
    48       }
    49 
    50       public string VarName { get; private set; } // name of the variable for splitting or NO_VARIABLE if terminal node
    51       public double Val { get; private set; } // threshold
    52       public int LeftIdx { get; private set; }
    53       public int RightIdx { get; private set; }
     48        WeightLeft = weightLeft;
     49      }
     50
     51      public string VarName { get; internal set; } // name of the variable for splitting or NO_VARIABLE if terminal node
     52      public double Val { get; internal set; } // threshold
     53      public int LeftIdx { get; internal set; }
     54      public int RightIdx { get; internal set; }
     55      public double WeightLeft { get; internal set; } // for partial dependence plots (value in range [0..1] describes the fraction of training samples for the left sub-tree
     56
    5457
    5558      // necessary because the default implementation of GetHashCode for structs in .NET would only return the hashcode of val here
     
    6467            LeftIdx.Equals(other.LeftIdx) &&
    6568            RightIdx.Equals(other.RightIdx) &&
     69            WeightLeft.Equals(other.WeightLeft) &&
    6670            EqualStrings(VarName, other.VarName);
    6771        } else {
     
    7983    private TreeNode[] tree;
    8084
    81     [Storable]
     85    #region old storable format
     86    // remove with HL 3.4
     87    [Storable(AllowOneWay = true)]
    8288    // to prevent storing the references to data caches in nodes
    83     // TODO seemingly it is bad (performance-wise) to persist tuples (tuples are used as keys in a dictionary)
     89    // seemingly, it is bad (performance-wise) to persist tuples (tuples are used as keys in a dictionary)
    8490    private Tuple<string, double, int, int>[] SerializedTree {
    85       get { return tree.Select(t => Tuple.Create(t.VarName, t.Val, t.LeftIdx, t.RightIdx)).ToArray(); }
    86       set { this.tree = value.Select(t => new TreeNode(t.Item1, t.Item2, t.Item3, t.Item4)).ToArray(); }
    87     }
     91      // get { return tree.Select(t => Tuple.Create(t.VarName, t.Val, t.LeftIdx, t.RightIdx)).ToArray(); }
     92      set { this.tree = value.Select(t => new TreeNode(t.Item1, t.Item2, t.Item3, t.Item4, -1.0)).ToArray(); } // use a weight of -1.0 to indicate that partial dependence cannot be calculated for old models
     93    }
     94    #endregion
     95    #region new storable format
     96    [Storable]
     97    private string[] SerializedTreeVarNames {
     98      get { return tree.Select(t => t.VarName).ToArray(); }
     99      set {
     100        if (tree == null) tree = new TreeNode[value.Length];
     101        for (int i = 0; i < value.Length; i++) {
     102          tree[i].VarName = value[i];
     103        }
     104      }
     105    }
     106    [Storable]
     107    private double[] SerializedTreeValues {
     108      get { return tree.Select(t => t.Val).ToArray(); }
     109      set {
     110        if (tree == null) tree = new TreeNode[value.Length];
     111        for (int i = 0; i < value.Length; i++) {
     112          tree[i].Val = value[i];
     113        }
     114      }
     115    }
     116    [Storable]
     117    private int[] SerializedTreeLeftIdx {
     118      get { return tree.Select(t => t.LeftIdx).ToArray(); }
     119      set {
     120        if (tree == null) tree = new TreeNode[value.Length];
     121        for (int i = 0; i < value.Length; i++) {
     122          tree[i].LeftIdx = value[i];
     123        }
     124      }
     125    }
     126    [Storable]
     127    private int[] SerializedTreeRightIdx {
     128      get { return tree.Select(t => t.RightIdx).ToArray(); }
     129      set {
     130        if (tree == null) tree = new TreeNode[value.Length];
     131        for (int i = 0; i < value.Length; i++) {
     132          tree[i].RightIdx = value[i];
     133        }
     134      }
     135    }
     136    [Storable]
     137    private double[] SerializedTreeWeightLeft {
     138      get { return tree.Select(t => t.WeightLeft).ToArray(); }
     139      set {
     140        if (tree == null) tree = new TreeNode[value.Length];
     141        for (int i = 0; i < value.Length; i++) {
     142          tree[i].WeightLeft = value[i];
     143        }
     144      }
     145    }
     146    #endregion
     147
     148
     149
    88150
    89151    [StorableConstructor]
     
    108170        if (node.VarName == TreeNode.NO_VARIABLE)
    109171          return node.Val;
    110 
    111         if (columnCache[nodeIdx][row] <= node.Val)
     172        if (columnCache[nodeIdx] == null) {
     173          if (node.WeightLeft.IsAlmost(-1.0)) throw new InvalidOperationException("Cannot calculate partial dependence for trees loaded from older versions of HeuristicLab.");
     174          // weighted average for partial dependence plot (recursive here because we need to calculate both sub-trees)
     175          return node.WeightLeft * GetPredictionForRow(t, columnCache, node.LeftIdx, row) +
     176                 (1.0 - node.WeightLeft) * GetPredictionForRow(t, columnCache, node.RightIdx, row);
     177        } else if (columnCache[nodeIdx][row] <= node.Val)
    112178          nodeIdx = node.LeftIdx;
    113179        else
     
    127193      for (int i = 0; i < tree.Length; i++) {
    128194        if (tree[i].VarName != TreeNode.NO_VARIABLE) {
    129           columnCache[i] = ds.GetReadOnlyDoubleValues(tree[i].VarName);
     195          // tree models also support calculating estimations if not all variables used for training are available in the dataset
     196          if (ds.ColumnNames.Contains(tree[i].VarName))
     197            columnCache[i] = ds.GetReadOnlyDoubleValues(tree[i].VarName);
    130198        }
    131199      }
     
    148216      } else {
    149217        return
    150           TreeToString(n.LeftIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F}", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val))
    151         + TreeToString(n.RightIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} >  {3:F}", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val));
     218          TreeToString(n.LeftIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F} ({4:N3})", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val, n.WeightLeft))
     219        + TreeToString(n.RightIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2}  >  {3:F} ({4:N3}))", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val, 1.0 - n.WeightLeft));
    152220      }
    153221    }
  • trunk/sources/HeuristicLab.Tests/HeuristicLab.Algorithms.DataAnalysis-3.4/GradientBoostingTest.cs

    r13157 r13895  
    11using System;
     2using System.Collections;
     3using System.IO;
    24using System.Linq;
    35using System.Threading;
     
    160162        // x2 >  1.5 AND x1 >  1.5 ->  3.0
    161163        BuildTree(xy, allVariables, 10);
     164      }
     165    }
     166
     167    [TestMethod]
     168    [TestCategory("Algorithms.DataAnalysis")]
     169    [TestProperty("Time", "short")]
     170    public void TestDecisionTreePartialDependence() {
     171      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
     172      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
     173      var regProblem = new RegressionProblem();
     174      regProblem.Load(provider.LoadData(instance));
     175      var problemData = regProblem.ProblemData;
     176      var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, new SquaredErrorLoss(), randSeed: 31415, maxSize: 10, r: 0.5, m: 1, nu: 0.02);
     177      for (int i = 0; i < 1000; i++)
     178        GradientBoostedTreesAlgorithmStatic.MakeStep(state);
     179
     180
     181      var mostImportantVar = state.GetVariableRelevance().OrderByDescending(kvp => kvp.Value).First();
     182      Console.WriteLine("var: {0} relevance: {1}", mostImportantVar.Key, mostImportantVar.Value);
     183      var model = ((IGradientBoostedTreesModel)state.GetModel());
     184      var treeM = model.Models.Skip(1).First();
     185      Console.WriteLine(treeM.ToString());
     186      Console.WriteLine();
     187
     188      var mostImportantVarValues = problemData.Dataset.GetDoubleValues(mostImportantVar.Key).OrderBy(x => x).ToArray();
     189      var ds = new ModifiableDataset(new string[] { mostImportantVar.Key },
     190        new IList[] { mostImportantVarValues.ToList<double>() });
     191
     192      var estValues = model.GetEstimatedValues(ds, Enumerable.Range(0, mostImportantVarValues.Length)).ToArray();
     193
     194      for (int i = 0; i < mostImportantVarValues.Length; i += 10) {
     195        Console.WriteLine("{0,-5:N3} {1,-5:N3}", mostImportantVarValues[i], estValues[i]);
     196      }
     197    }
     198
     199    [TestMethod]
     200    [TestCategory("Algorithms.DataAnalysis")]
     201    [TestProperty("Time", "short")]
     202    public void TestDecisionTreePersistence() {
     203      var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider();
     204      var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower"));
     205      var regProblem = new RegressionProblem();
     206      regProblem.Load(provider.LoadData(instance));
     207      var problemData = regProblem.ProblemData;
     208      var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, new SquaredErrorLoss(), randSeed: 31415, maxSize: 100, r: 0.5, m: 1, nu: 1);
     209      GradientBoostedTreesAlgorithmStatic.MakeStep(state);
     210
     211      var model = ((IGradientBoostedTreesModel)state.GetModel());
     212      var treeM = model.Models.Skip(1).First();
     213      var origStr = treeM.ToString();
     214      using (var memStream = new MemoryStream()) {
     215        Persistence.Default.Xml.XmlGenerator.Serialize(treeM, memStream);
     216        var buf = memStream.GetBuffer();
     217        using (var restoreStream = new MemoryStream(buf)) {
     218          var restoredTree = Persistence.Default.Xml.XmlParser.Deserialize(restoreStream);
     219          var restoredStr = restoredTree.ToString();
     220          Assert.AreEqual(origStr, restoredStr);
     221        }
    162222      }
    163223    }
Note: See TracChangeset for help on using the changeset viewer.