Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/09/15 18:46:20 (9 years ago)
Author:
gkronber
Message:

#2261: improved performance of evaluation for regression tree models

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs

    r12663 r12699  
    2525using System.Globalization;
    2626using System.Linq;
    27 using System.Text;
    2827using HeuristicLab.Common;
    2928using HeuristicLab.Core;
     
    3635  public sealed class RegressionTreeModel : NamedItem, IRegressionModel {
    3736
    38     // trees are represented as a flat array
     37    // trees are represented as a flat array   
    3938    internal struct TreeNode {
    40       public readonly static string NO_VARIABLE = string.Empty;
     39      public readonly static string NO_VARIABLE = null;
    4140
    4241      public TreeNode(string varName, double val, int leftIdx = -1, int rightIdx = -1)
     
    5251      public int LeftIdx { get; private set; }
    5352      public int RightIdx { get; private set; }
     53
     54      internal IList<double> Data { get; set; } // only necessary to improve efficiency of evaluation
    5455
    5556      // necessary because the default implementation of GetHashCode for structs in .NET would only return the hashcode of val here
     
    7677    }
    7778
     79    // not storable!
     80    private TreeNode[] tree;
     81
    7882    [Storable]
    79     private readonly TreeNode[] tree;
     83    // to prevent storing the references to data caches in nodes
     84    private Tuple<string, double, int, int>[] SerializedTree {
     85      get { return tree.Select(t => Tuple.Create(t.VarName, t.Val, t.LeftIdx, t.RightIdx)).ToArray(); }
     86      set { this.tree = value.Select(t => new TreeNode(t.Item1, t.Item2, t.Item3, t.Item4)).ToArray(); }
     87    }
    8088
    8189    [StorableConstructor]
     
    8492    private RegressionTreeModel(RegressionTreeModel original, Cloner cloner)
    8593      : base(original, cloner) {
    86       this.tree = original.tree; // shallow clone, tree must be readonly
     94      if (original.tree != null) {
     95        this.tree = new TreeNode[original.tree.Length];
     96        Array.Copy(original.tree, this.tree, this.tree.Length);
     97      }
    8798    }
    8899
     
    92103    }
    93104
    94     private static double GetPredictionForRow(TreeNode[] t, int nodeIdx, IDataset ds, int row) {
    95       var node = t[nodeIdx];
    96       if (node.VarName == TreeNode.NO_VARIABLE)
    97         return node.Val;
    98       // TODO: many calls to GetDoubleValue are slow because of the dictionary lookup in Dataset (see ticket #2417)
    99       else if (ds.GetDoubleValue(node.VarName, row) <= node.Val)
    100         return GetPredictionForRow(t, node.LeftIdx, ds, row);
    101       else
    102         return GetPredictionForRow(t, node.RightIdx, ds, row);
     105    private static double GetPredictionForRow(TreeNode[] t, int nodeIdx, int row) {
     106      while (nodeIdx != -1) {
     107        var node = t[nodeIdx];
     108        if (node.VarName == TreeNode.NO_VARIABLE)
     109          return node.Val;
     110
     111        if (node.Data[row] <= node.Val)
     112          nodeIdx = node.LeftIdx;
     113        else
     114          nodeIdx = node.RightIdx;
     115      }
     116      throw new InvalidOperationException("Invalid tree in RegressionTreeModel");
    103117    }
    104118
     
    108122
    109123    public IEnumerable<double> GetEstimatedValues(IDataset ds, IEnumerable<int> rows) {
    110       return rows.Select(r => GetPredictionForRow(tree, 0, ds, r));
     124      // lookup columns for variableNames in one pass over the tree to speed up evaluation later on
     125      for (int i = 0; i < tree.Length; i++) {
     126        if (tree[i].VarName != TreeNode.NO_VARIABLE) {
     127          tree[i].Data = ds.GetReadOnlyDoubleValues(tree[i].VarName);
     128        }
     129      }
     130      return rows.Select(r => GetPredictionForRow(tree, 0, r));
    111131    }
    112132
     
    130150      }
    131151    }
    132 
    133152  }
    134 
    135153}
Note: See TracChangeset for help on using the changeset viewer.