- Timestamp:
- 07/09/15 18:46:20 (10 years ago)
- Location:
- branches/GBT-trunkintegration
- Files:
-
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithmStatic.cs
r12698 r12699 170 170 171 171 int i = 0; 172 // TODO: slow because of multiple calls to GetDoubleValue for each row index173 172 foreach (var pred in tree.GetEstimatedValues(problemData.Dataset, problemData.TrainingIndices)) { 174 173 yPred[i] = yPred[i] + nu * pred; -
branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs
r12663 r12699 25 25 using System.Globalization; 26 26 using System.Linq; 27 using System.Text;28 27 using HeuristicLab.Common; 29 28 using HeuristicLab.Core; … … 36 35 public sealed class RegressionTreeModel : NamedItem, IRegressionModel { 37 36 38 // trees are represented as a flat array 37 // trees are represented as a flat array 39 38 internal struct TreeNode { 40 public readonly static string NO_VARIABLE = string.Empty;39 public readonly static string NO_VARIABLE = null; 41 40 42 41 public TreeNode(string varName, double val, int leftIdx = -1, int rightIdx = -1) … … 52 51 public int LeftIdx { get; private set; } 53 52 public int RightIdx { get; private set; } 53 54 internal IList<double> Data { get; set; } // only necessary to improve efficiency of evaluation 54 55 55 56 // necessary because the default implementation of GetHashCode for structs in .NET would only return the hashcode of val here … … 76 77 } 77 78 79 // not storable! 80 private TreeNode[] tree; 81 78 82 [Storable] 79 private readonly TreeNode[] tree; 83 // to prevent storing the references to data caches in nodes 84 private Tuple<string, double, int, int>[] SerializedTree { 85 get { return tree.Select(t => Tuple.Create(t.VarName, t.Val, t.LeftIdx, t.RightIdx)).ToArray(); } 86 set { this.tree = value.Select(t => new TreeNode(t.Item1, t.Item2, t.Item3, t.Item4)).ToArray(); } 87 } 80 88 81 89 [StorableConstructor] … … 84 92 private RegressionTreeModel(RegressionTreeModel original, Cloner cloner) 85 93 : base(original, cloner) { 86 this.tree = original.tree; // shallow clone, tree must be readonly 94 if (original.tree != null) { 95 this.tree = new TreeNode[original.tree.Length]; 96 Array.Copy(original.tree, this.tree, this.tree.Length); 97 } 87 98 } 88 99 … … 92 103 } 93 104 94 private static double GetPredictionForRow(TreeNode[] t, int nodeIdx, IDataset ds, int row) { 95 var node = t[nodeIdx]; 96 if (node.VarName == TreeNode.NO_VARIABLE) 97 return node.Val; 98 // TODO: many calls to GetDoubleValue are slow because of the dictionary lookup in Dataset (see ticket #2417) 99 else if (ds.GetDoubleValue(node.VarName, row) <= node.Val) 100 return GetPredictionForRow(t, node.LeftIdx, ds, row); 101 else 102 return GetPredictionForRow(t, node.RightIdx, ds, row); 105 private static double GetPredictionForRow(TreeNode[] t, int nodeIdx, int row) { 106 while (nodeIdx != -1) { 107 var node = t[nodeIdx]; 108 if (node.VarName == TreeNode.NO_VARIABLE) 109 return node.Val; 110 111 if (node.Data[row] <= node.Val) 112 nodeIdx = node.LeftIdx; 113 else 114 nodeIdx = node.RightIdx; 115 } 116 throw new InvalidOperationException("Invalid tree in RegressionTreeModel"); 103 117 } 104 118 … … 108 122 109 123 public IEnumerable<double> GetEstimatedValues(IDataset ds, IEnumerable<int> rows) { 110 return rows.Select(r => GetPredictionForRow(tree, 0, ds, r)); 124 // lookup columns for variableNames in one pass over the tree to speed up evaluation later on 125 for (int i = 0; i < tree.Length; i++) { 126 if (tree[i].VarName != TreeNode.NO_VARIABLE) { 127 tree[i].Data = ds.GetReadOnlyDoubleValues(tree[i].VarName); 128 } 129 } 130 return rows.Select(r => GetPredictionForRow(tree, 0, r)); 111 131 } 112 132 … … 130 150 } 131 151 } 132 133 152 } 134 135 153 } -
branches/GBT-trunkintegration/Tests/Test.cs
r12661 r12699 183 183 gbt.Iterations = 5000; 184 184 gbt.MaxSize = 20; 185 gbt.CreateSolution = false; 185 186 #endregion 186 187 187 188 RunAlgorithm(gbt); 188 189 190 Console.WriteLine(gbt.ExecutionTime); 189 191 Assert.AreEqual(267.68704241153921, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6); 190 192 Assert.AreEqual(393.84704062205469, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6); … … 209 211 gbt.Nu = 0.02; 210 212 gbt.LossFunctionParameter.Value = gbt.LossFunctionParameter.ValidValues.First(l => l.ToString().Contains("Absolute")); 213 gbt.CreateSolution = false; 211 214 #endregion 212 215 213 216 RunAlgorithm(gbt); 214 217 218 Console.WriteLine(gbt.ExecutionTime); 215 219 Assert.AreEqual(10.551385044666661, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6); 216 220 Assert.AreEqual(12.918001745581172, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6); … … 235 239 gbt.Nu = 0.005; 236 240 gbt.LossFunctionParameter.Value = gbt.LossFunctionParameter.ValidValues.First(l => l.ToString().Contains("Relative")); 241 gbt.CreateSolution = false; 237 242 #endregion 238 243 239 244 RunAlgorithm(gbt); 240 245 246 Console.WriteLine(gbt.ExecutionTime); 241 247 Assert.AreEqual(0.061954221604374943, ((DoubleValue)gbt.Results["Loss (train)"].Value).Value, 1E-6); 242 248 Assert.AreEqual(0.06316303473499961, ((DoubleValue)gbt.Results["Loss (test)"].Value).Value, 1E-6);
Note: See TracChangeset
for help on using the changeset viewer.