Free cookie consent management tool by TermsFeed Policy Generator

Changeset 12699


Ignore:
Timestamp:
07/09/15 18:46:20 (9 years ago)
Author:
gkronber
Message:

#2261: improved performance of evaluation for regression tree models

Location:
branches/GBT-trunkintegration
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithmStatic.cs

    r12698 r12699  
    170170
    171171      int i = 0;
    172       // TODO: slow because of multiple calls to GetDoubleValue for each row index
    173172      foreach (var pred in tree.GetEstimatedValues(problemData.Dataset, problemData.TrainingIndices)) {
    174173        yPred[i] = yPred[i] + nu * pred;
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs

    r12663 r12699  
    2525using System.Globalization;
    2626using System.Linq;
    27 using System.Text;
    2827using HeuristicLab.Common;
    2928using HeuristicLab.Core;
     
    3635  public sealed class RegressionTreeModel : NamedItem, IRegressionModel {
    3736
    38     // trees are represented as a flat array
     37    // trees are represented as a flat array   
    3938    internal struct TreeNode {
    40       public readonly static string NO_VARIABLE = string.Empty;
     39      public readonly static string NO_VARIABLE = null;
    4140
    4241      public TreeNode(string varName, double val, int leftIdx = -1, int rightIdx = -1)
     
    5251      public int LeftIdx { get; private set; }
    5352      public int RightIdx { get; private set; }
     53
     54      internal IList<double> Data { get; set; } // only necessary to improve efficiency of evaluation
    5455
    5556      // necessary because the default implementation of GetHashCode for structs in .NET would only return the hashcode of val here
     
    7677    }
    7778
     79    // not storable!
     80    private TreeNode[] tree;
     81
    7882    [Storable]
    79     private readonly TreeNode[] tree;
     83    // to prevent storing the references to data caches in nodes
     84    private Tuple<string, double, int, int>[] SerializedTree {
     85      get { return tree.Select(t => Tuple.Create(t.VarName, t.Val, t.LeftIdx, t.RightIdx)).ToArray(); }
     86      set { this.tree = value.Select(t => new TreeNode(t.Item1, t.Item2, t.Item3, t.Item4)).ToArray(); }
     87    }
    8088
    8189    [StorableConstructor]
     
    8492    private RegressionTreeModel(RegressionTreeModel original, Cloner cloner)
    8593      : base(original, cloner) {
    86       this.tree = original.tree; // shallow clone, tree must be readonly
     94      if (original.tree != null) {
     95        this.tree = new TreeNode[original.tree.Length];
     96        Array.Copy(original.tree, this.tree, this.tree.Length);
     97      }
    8798    }
    8899
     
    92103    }
    93104
    94     private static double GetPredictionForRow(TreeNode[] t, int nodeIdx, IDataset ds, int row) {
    95       var node = t[nodeIdx];
    96       if (node.VarName == TreeNode.NO_VARIABLE)
    97         return node.Val;
    98       // TODO: many calls to GetDoubleValue are slow because of the dictionary lookup in Dataset (see ticket #2417)
    99       else if (ds.GetDoubleValue(node.VarName, row) <= node.Val)
    100         return GetPredictionForRow(t, node.LeftIdx, ds, row);
    101       else
    102         return GetPredictionForRow(t, node.RightIdx, ds, row);
     105    private static double GetPredictionForRow(TreeNode[] t, int nodeIdx, int row) {
     106      while (nodeIdx != -1) {
     107        var node = t[nodeIdx];
     108        if (node.VarName == TreeNode.NO_VARIABLE)
     109          return node.Val;
     110
     111        if (node.Data[row] <= node.Val)
     112          nodeIdx = node.LeftIdx;
     113        else
     114          nodeIdx = node.RightIdx;
     115      }
     116      throw new InvalidOperationException("Invalid tree in RegressionTreeModel");
    103117    }
    104118
     
    108122
    109123    public IEnumerable<double> GetEstimatedValues(IDataset ds, IEnumerable<int> rows) {
    110       return rows.Select(r => GetPredictionForRow(tree, 0, ds, r));
     124      // lookup columns for variableNames in one pass over the tree to speed up evaluation later on
     125      for (int i = 0; i < tree.Length; i++) {
     126        if (tree[i].VarName != TreeNode.NO_VARIABLE) {
     127          tree[i].Data = ds.GetReadOnlyDoubleValues(tree[i].VarName);
     128        }
     129      }
     130      return rows.Select(r => GetPredictionForRow(tree, 0, r));
    111131    }
    112132
     
    130150      }
    131151    }
    132 
    133152  }
    134 
    135153}
  • branches/GBT-trunkintegration/Tests/Test.cs

    r12661 r12699  
    183183      gbt.Iterations = 5000;
    184184      gbt.MaxSize = 20;
     185      gbt.CreateSolution = false;
    185186      #endregion
    186187
    187188      RunAlgorithm(gbt);
    188189
     190      Console.WriteLine(gbt.ExecutionTime);
    189191      Assert.AreEqual(267.68704241153921, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
    190192      Assert.AreEqual(393.84704062205469, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
     
    209211      gbt.Nu = 0.02;
    210212      gbt.LossFunctionParameter.Value = gbt.LossFunctionParameter.ValidValues.First(l => l.ToString().Contains("Absolute"));
     213      gbt.CreateSolution = false;
    211214      #endregion
    212215
    213216      RunAlgorithm(gbt);
    214217
     218      Console.WriteLine(gbt.ExecutionTime);
    215219      Assert.AreEqual(10.551385044666661, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
    216220      Assert.AreEqual(12.918001745581172, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
     
    235239      gbt.Nu = 0.005;
    236240      gbt.LossFunctionParameter.Value = gbt.LossFunctionParameter.ValidValues.First(l => l.ToString().Contains("Relative"));
     241      gbt.CreateSolution = false;
    237242      #endregion
    238243
    239244      RunAlgorithm(gbt);
    240245
     246      Console.WriteLine(gbt.ExecutionTime);
    241247      Assert.AreEqual(0.061954221604374943, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6);
    242248      Assert.AreEqual(0.06316303473499961, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
Note: See TracChangeset for help on using the changeset viewer.