Changeset 13895 for trunk/sources/HeuristicLab.Algorithms.DataAnalysis
- Timestamp:
- 06/15/16 10:02:15 (9 years ago)
- Location:
- trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees
- Files:
-
- 2 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs
r13065 r13895 180 180 181 181 182 // processes potential splits from the queue as long as splits are leftand the maximum size of the tree is not reached182 // processes potential splits from the queue as long as splits are remaining and the maximum size of the tree is not reached 183 183 private void CreateRegressionTreeFromQueue(int maxNodes, ILossFunction lossFunction) { 184 184 while (queue.Any() && curTreeNodeIdx + 1 < maxNodes) { // two nodes are created in each loop … … 204 204 205 205 // overwrite existing leaf node with an internal node 206 tree[f.ParentNodeIdx] = new RegressionTreeModel.TreeNode(f.SplittingVariable, f.SplittingThreshold, leftTreeIdx, rightTreeIdx );206 tree[f.ParentNodeIdx] = new RegressionTreeModel.TreeNode(f.SplittingVariable, f.SplittingThreshold, leftTreeIdx, rightTreeIdx, weightLeft: (splitIdx - startIdx + 1) / (double)(endIdx - startIdx + 1)); 207 207 } 208 208 } -
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs
r13030 r13895 40 40 public readonly static string NO_VARIABLE = null; 41 41 42 public TreeNode(string varName, double val, int leftIdx = -1, int rightIdx = -1 )42 public TreeNode(string varName, double val, int leftIdx = -1, int rightIdx = -1, double weightLeft = -1.0) 43 43 : this() { 44 44 VarName = varName; … … 46 46 LeftIdx = leftIdx; 47 47 RightIdx = rightIdx; 48 } 49 50 public string VarName { get; private set; } // name of the variable for splitting or NO_VARIABLE if terminal node 51 public double Val { get; private set; } // threshold 52 public int LeftIdx { get; private set; } 53 public int RightIdx { get; private set; } 48 WeightLeft = weightLeft; 49 } 50 51 public string VarName { get; internal set; } // name of the variable for splitting or NO_VARIABLE if terminal node 52 public double Val { get; internal set; } // threshold 53 public int LeftIdx { get; internal set; } 54 public int RightIdx { get; internal set; } 55 public double WeightLeft { get; internal set; } // for partial dependence plots (value in range [0..1] describes the fraction of training samples for the left sub-tree 56 54 57 55 58 // necessary because the default implementation of GetHashCode for structs in .NET would only return the hashcode of val here … … 64 67 LeftIdx.Equals(other.LeftIdx) && 65 68 RightIdx.Equals(other.RightIdx) && 69 WeightLeft.Equals(other.WeightLeft) && 66 70 EqualStrings(VarName, other.VarName); 67 71 } else { … … 79 83 private TreeNode[] tree; 80 84 81 [Storable] 85 #region old storable format 86 // remove with HL 3.4 87 [Storable(AllowOneWay = true)] 82 88 // to prevent storing the references to data caches in nodes 83 // TODO seeminglyit is bad (performance-wise) to persist tuples (tuples are used as keys in a dictionary)89 // seemingly, it is bad (performance-wise) to persist tuples (tuples are used as keys in a dictionary) 84 90 private Tuple<string, double, int, int>[] SerializedTree { 85 get { return tree.Select(t => Tuple.Create(t.VarName, t.Val, t.LeftIdx, t.RightIdx)).ToArray(); } 86 set { this.tree = value.Select(t => new TreeNode(t.Item1, t.Item2, t.Item3, t.Item4)).ToArray(); } 87 } 91 // get { return tree.Select(t => Tuple.Create(t.VarName, t.Val, t.LeftIdx, t.RightIdx)).ToArray(); } 92 set { this.tree = value.Select(t => new TreeNode(t.Item1, t.Item2, t.Item3, t.Item4, -1.0)).ToArray(); } // use a weight of -1.0 to indicate that partial dependence cannot be calculated for old models 93 } 94 #endregion 95 #region new storable format 96 [Storable] 97 private string[] SerializedTreeVarNames { 98 get { return tree.Select(t => t.VarName).ToArray(); } 99 set { 100 if (tree == null) tree = new TreeNode[value.Length]; 101 for (int i = 0; i < value.Length; i++) { 102 tree[i].VarName = value[i]; 103 } 104 } 105 } 106 [Storable] 107 private double[] SerializedTreeValues { 108 get { return tree.Select(t => t.Val).ToArray(); } 109 set { 110 if (tree == null) tree = new TreeNode[value.Length]; 111 for (int i = 0; i < value.Length; i++) { 112 tree[i].Val = value[i]; 113 } 114 } 115 } 116 [Storable] 117 private int[] SerializedTreeLeftIdx { 118 get { return tree.Select(t => t.LeftIdx).ToArray(); } 119 set { 120 if (tree == null) tree = new TreeNode[value.Length]; 121 for (int i = 0; i < value.Length; i++) { 122 tree[i].LeftIdx = value[i]; 123 } 124 } 125 } 126 [Storable] 127 private int[] SerializedTreeRightIdx { 128 get { return tree.Select(t => t.RightIdx).ToArray(); } 129 set { 130 if (tree == null) tree = new TreeNode[value.Length]; 131 for (int i = 0; i < value.Length; i++) { 132 tree[i].RightIdx = value[i]; 133 } 134 } 135 } 136 [Storable] 137 private double[] SerializedTreeWeightLeft { 138 get { return tree.Select(t => t.WeightLeft).ToArray(); } 139 set { 140 if (tree == null) tree = new TreeNode[value.Length]; 141 for (int i = 0; i < value.Length; i++) { 142 tree[i].WeightLeft = value[i]; 143 } 144 } 145 } 146 #endregion 147 148 149 88 150 89 151 [StorableConstructor] … … 108 170 if (node.VarName == TreeNode.NO_VARIABLE) 109 171 return node.Val; 110 111 if (columnCache[nodeIdx][row] <= node.Val) 172 if (columnCache[nodeIdx] == null) { 173 if (node.WeightLeft.IsAlmost(-1.0)) throw new InvalidOperationException("Cannot calculate partial dependence for trees loaded from older versions of HeuristicLab."); 174 // weighted average for partial dependence plot (recursive here because we need to calculate both sub-trees) 175 return node.WeightLeft * GetPredictionForRow(t, columnCache, node.LeftIdx, row) + 176 (1.0 - node.WeightLeft) * GetPredictionForRow(t, columnCache, node.RightIdx, row); 177 } else if (columnCache[nodeIdx][row] <= node.Val) 112 178 nodeIdx = node.LeftIdx; 113 179 else … … 127 193 for (int i = 0; i < tree.Length; i++) { 128 194 if (tree[i].VarName != TreeNode.NO_VARIABLE) { 129 columnCache[i] = ds.GetReadOnlyDoubleValues(tree[i].VarName); 195 // tree models also support calculating estimations if not all variables used for training are available in the dataset 196 if (ds.ColumnNames.Contains(tree[i].VarName)) 197 columnCache[i] = ds.GetReadOnlyDoubleValues(tree[i].VarName); 130 198 } 131 199 } … … 148 216 } else { 149 217 return 150 TreeToString(n.LeftIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F} ", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val))151 + TreeToString(n.RightIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} > {3:F}", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val));218 TreeToString(n.LeftIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F} ({4:N3})", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val, n.WeightLeft)) 219 + TreeToString(n.RightIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} > {3:F} ({4:N3}))", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val, 1.0 - n.WeightLeft)); 152 220 } 153 221 }
Note: See TracChangeset
for help on using the changeset viewer.