Changeset 13895 for trunk/sources
- Timestamp:
- 06/15/16 10:02:15 (9 years ago)
- Location:
- trunk/sources
- Files:
-
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs
r13065 r13895 180 180 181 181 182 // processes potential splits from the queue as long as splits are leftand the maximum size of the tree is not reached182 // processes potential splits from the queue as long as splits are remaining and the maximum size of the tree is not reached 183 183 private void CreateRegressionTreeFromQueue(int maxNodes, ILossFunction lossFunction) { 184 184 while (queue.Any() && curTreeNodeIdx + 1 < maxNodes) { // two nodes are created in each loop … … 204 204 205 205 // overwrite existing leaf node with an internal node 206 tree[f.ParentNodeIdx] = new RegressionTreeModel.TreeNode(f.SplittingVariable, f.SplittingThreshold, leftTreeIdx, rightTreeIdx );206 tree[f.ParentNodeIdx] = new RegressionTreeModel.TreeNode(f.SplittingVariable, f.SplittingThreshold, leftTreeIdx, rightTreeIdx, weightLeft: (splitIdx - startIdx + 1) / (double)(endIdx - startIdx + 1)); 207 207 } 208 208 } -
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs
r13030 r13895 40 40 public readonly static string NO_VARIABLE = null; 41 41 42 public TreeNode(string varName, double val, int leftIdx = -1, int rightIdx = -1 )42 public TreeNode(string varName, double val, int leftIdx = -1, int rightIdx = -1, double weightLeft = -1.0) 43 43 : this() { 44 44 VarName = varName; … … 46 46 LeftIdx = leftIdx; 47 47 RightIdx = rightIdx; 48 } 49 50 public string VarName { get; private set; } // name of the variable for splitting or NO_VARIABLE if terminal node 51 public double Val { get; private set; } // threshold 52 public int LeftIdx { get; private set; } 53 public int RightIdx { get; private set; } 48 WeightLeft = weightLeft; 49 } 50 51 public string VarName { get; internal set; } // name of the variable for splitting or NO_VARIABLE if terminal node 52 public double Val { get; internal set; } // threshold 53 public int LeftIdx { get; internal set; } 54 public int RightIdx { get; internal set; } 55 public double WeightLeft { get; internal set; } // for partial dependence plots (value in range [0..1] describes the fraction of training samples for the left sub-tree 56 54 57 55 58 // necessary because the default implementation of GetHashCode for structs in .NET would only return the hashcode of val here … … 64 67 LeftIdx.Equals(other.LeftIdx) && 65 68 RightIdx.Equals(other.RightIdx) && 69 WeightLeft.Equals(other.WeightLeft) && 66 70 EqualStrings(VarName, other.VarName); 67 71 } else { … … 79 83 private TreeNode[] tree; 80 84 81 [Storable] 85 #region old storable format 86 // remove with HL 3.4 87 [Storable(AllowOneWay = true)] 82 88 // to prevent storing the references to data caches in nodes 83 // TODO seeminglyit is bad (performance-wise) to persist tuples (tuples are used as keys in a dictionary)89 // seemingly, it is bad (performance-wise) to persist tuples (tuples are used as keys in a dictionary) 84 90 private Tuple<string, double, int, int>[] SerializedTree { 85 get { return tree.Select(t => Tuple.Create(t.VarName, t.Val, t.LeftIdx, t.RightIdx)).ToArray(); } 86 set { this.tree = value.Select(t => new TreeNode(t.Item1, t.Item2, t.Item3, t.Item4)).ToArray(); } 87 } 91 // get { return tree.Select(t => Tuple.Create(t.VarName, t.Val, t.LeftIdx, t.RightIdx)).ToArray(); } 92 set { this.tree = value.Select(t => new TreeNode(t.Item1, t.Item2, t.Item3, t.Item4, -1.0)).ToArray(); } // use a weight of -1.0 to indicate that partial dependence cannot be calculated for old models 93 } 94 #endregion 95 #region new storable format 96 [Storable] 97 private string[] SerializedTreeVarNames { 98 get { return tree.Select(t => t.VarName).ToArray(); } 99 set { 100 if (tree == null) tree = new TreeNode[value.Length]; 101 for (int i = 0; i < value.Length; i++) { 102 tree[i].VarName = value[i]; 103 } 104 } 105 } 106 [Storable] 107 private double[] SerializedTreeValues { 108 get { return tree.Select(t => t.Val).ToArray(); } 109 set { 110 if (tree == null) tree = new TreeNode[value.Length]; 111 for (int i = 0; i < value.Length; i++) { 112 tree[i].Val = value[i]; 113 } 114 } 115 } 116 [Storable] 117 private int[] SerializedTreeLeftIdx { 118 get { return tree.Select(t => t.LeftIdx).ToArray(); } 119 set { 120 if (tree == null) tree = new TreeNode[value.Length]; 121 for (int i = 0; i < value.Length; i++) { 122 tree[i].LeftIdx = value[i]; 123 } 124 } 125 } 126 [Storable] 127 private int[] SerializedTreeRightIdx { 128 get { return tree.Select(t => t.RightIdx).ToArray(); } 129 set { 130 if (tree == null) tree = new TreeNode[value.Length]; 131 for (int i = 0; i < value.Length; i++) { 132 tree[i].RightIdx = value[i]; 133 } 134 } 135 } 136 [Storable] 137 private double[] SerializedTreeWeightLeft { 138 get { return tree.Select(t => t.WeightLeft).ToArray(); } 139 set { 140 if (tree == null) tree = new TreeNode[value.Length]; 141 for (int i = 0; i < value.Length; i++) { 142 tree[i].WeightLeft = value[i]; 143 } 144 } 145 } 146 #endregion 147 148 149 88 150 89 151 [StorableConstructor] … … 108 170 if (node.VarName == TreeNode.NO_VARIABLE) 109 171 return node.Val; 110 111 if (columnCache[nodeIdx][row] <= node.Val) 172 if (columnCache[nodeIdx] == null) { 173 if (node.WeightLeft.IsAlmost(-1.0)) throw new InvalidOperationException("Cannot calculate partial dependence for trees loaded from older versions of HeuristicLab."); 174 // weighted average for partial dependence plot (recursive here because we need to calculate both sub-trees) 175 return node.WeightLeft * GetPredictionForRow(t, columnCache, node.LeftIdx, row) + 176 (1.0 - node.WeightLeft) * GetPredictionForRow(t, columnCache, node.RightIdx, row); 177 } else if (columnCache[nodeIdx][row] <= node.Val) 112 178 nodeIdx = node.LeftIdx; 113 179 else … … 127 193 for (int i = 0; i < tree.Length; i++) { 128 194 if (tree[i].VarName != TreeNode.NO_VARIABLE) { 129 columnCache[i] = ds.GetReadOnlyDoubleValues(tree[i].VarName); 195 // tree models also support calculating estimations if not all variables used for training are available in the dataset 196 if (ds.ColumnNames.Contains(tree[i].VarName)) 197 columnCache[i] = ds.GetReadOnlyDoubleValues(tree[i].VarName); 130 198 } 131 199 } … … 148 216 } else { 149 217 return 150 TreeToString(n.LeftIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F} ", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val))151 + TreeToString(n.RightIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} > {3:F}", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val));218 TreeToString(n.LeftIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F} ({4:N3})", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val, n.WeightLeft)) 219 + TreeToString(n.RightIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} > {3:F} ({4:N3}))", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val, 1.0 - n.WeightLeft)); 152 220 } 153 221 } -
trunk/sources/HeuristicLab.Tests/HeuristicLab.Algorithms.DataAnalysis-3.4/GradientBoostingTest.cs
r13157 r13895 1 1 using System; 2 using System.Collections; 3 using System.IO; 2 4 using System.Linq; 3 5 using System.Threading; … … 160 162 // x2 > 1.5 AND x1 > 1.5 -> 3.0 161 163 BuildTree(xy, allVariables, 10); 164 } 165 } 166 167 [TestMethod] 168 [TestCategory("Algorithms.DataAnalysis")] 169 [TestProperty("Time", "short")] 170 public void TestDecisionTreePartialDependence() { 171 var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider(); 172 var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower")); 173 var regProblem = new RegressionProblem(); 174 regProblem.Load(provider.LoadData(instance)); 175 var problemData = regProblem.ProblemData; 176 var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, new SquaredErrorLoss(), randSeed: 31415, maxSize: 10, r: 0.5, m: 1, nu: 0.02); 177 for (int i = 0; i < 1000; i++) 178 GradientBoostedTreesAlgorithmStatic.MakeStep(state); 179 180 181 var mostImportantVar = state.GetVariableRelevance().OrderByDescending(kvp => kvp.Value).First(); 182 Console.WriteLine("var: {0} relevance: {1}", mostImportantVar.Key, mostImportantVar.Value); 183 var model = ((IGradientBoostedTreesModel)state.GetModel()); 184 var treeM = model.Models.Skip(1).First(); 185 Console.WriteLine(treeM.ToString()); 186 Console.WriteLine(); 187 188 var mostImportantVarValues = problemData.Dataset.GetDoubleValues(mostImportantVar.Key).OrderBy(x => x).ToArray(); 189 var ds = new ModifiableDataset(new string[] { mostImportantVar.Key }, 190 new IList[] { mostImportantVarValues.ToList<double>() }); 191 192 var estValues = model.GetEstimatedValues(ds, Enumerable.Range(0, mostImportantVarValues.Length)).ToArray(); 193 194 for (int i = 0; i < mostImportantVarValues.Length; i += 10) { 195 Console.WriteLine("{0,-5:N3} {1,-5:N3}", mostImportantVarValues[i], estValues[i]); 196 } 197 } 198 199 [TestMethod] 200 [TestCategory("Algorithms.DataAnalysis")] 201 [TestProperty("Time", "short")] 202 public void TestDecisionTreePersistence() { 203 var provider = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionRealWorldInstanceProvider(); 204 var instance = provider.GetDataDescriptors().Single(x => x.Name.Contains("Tower")); 205 var regProblem = new RegressionProblem(); 206 regProblem.Load(provider.LoadData(instance)); 207 var problemData = regProblem.ProblemData; 208 var state = GradientBoostedTreesAlgorithmStatic.CreateGbmState(problemData, new SquaredErrorLoss(), randSeed: 31415, maxSize: 100, r: 0.5, m: 1, nu: 1); 209 GradientBoostedTreesAlgorithmStatic.MakeStep(state); 210 211 var model = ((IGradientBoostedTreesModel)state.GetModel()); 212 var treeM = model.Models.Skip(1).First(); 213 var origStr = treeM.ToString(); 214 using (var memStream = new MemoryStream()) { 215 Persistence.Default.Xml.XmlGenerator.Serialize(treeM, memStream); 216 var buf = memStream.GetBuffer(); 217 using (var restoreStream = new MemoryStream(buf)) { 218 var restoredTree = Persistence.Default.Xml.XmlParser.Deserialize(restoreStream); 219 var restoredStr = restoredTree.ToString(); 220 Assert.AreEqual(origStr, restoredStr); 221 } 162 222 } 163 223 }
Note: See TracChangeset
for help on using the changeset viewer.