Changeset 12658
- Timestamp:
- 07/07/15 17:42:24 (9 years ago)
- Location:
- branches/GBT-trunkintegration
- Files:
-
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs
r12632 r12658 207 207 Debug.Assert(endIdx < internalIdx.Length); 208 208 209 // transform the leaf node into an internal node210 tree[f.ParentNodeIdx].VarName = f.SplittingVariable;211 tree[f.ParentNodeIdx].Val = f.SplittingThreshold;212 213 209 // split partition into left and right 214 210 int splitIdx; … … 218 214 219 215 // create two leaf nodes (and enqueue best splits for both) 220 tree[f.ParentNodeIdx].LeftIdx = CreateLeafNode(startIdx, splitIdx, lineSearch); 221 tree[f.ParentNodeIdx].RightIdx = CreateLeafNode(splitIdx + 1, endIdx, lineSearch); 216 var leftTreeIdx = CreateLeafNode(startIdx, splitIdx, lineSearch); 217 var rightTreeIdx = CreateLeafNode(splitIdx + 1, endIdx, lineSearch); 218 219 // overwrite existing leaf node with an internal node 220 tree[f.ParentNodeIdx] = new RegressionTreeModel.TreeNode(f.SplittingVariable, f.SplittingThreshold, leftTreeIdx, rightTreeIdx); 222 221 } 223 222 } … … 226 225 // returns the index of the newly created tree node 227 226 private int CreateLeafNode(int startIdx, int endIdx, LineSearchFunc lineSearch) { 228 tree[curTreeNodeIdx].VarName = RegressionTreeModel.TreeNode.NO_VARIABLE; 229 tree[curTreeNodeIdx].Val = lineSearch(internalIdx, startIdx, endIdx); 227 // write a leaf node 228 var val = lineSearch(internalIdx, startIdx, endIdx); 229 tree[curTreeNodeIdx] = new RegressionTreeModel.TreeNode(RegressionTreeModel.TreeNode.NO_VARIABLE, val); 230 230 231 231 EnqueuePartitionSplit(curTreeNodeIdx, startIdx, endIdx); -
branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs
r12635 r12658 21 21 #endregion 22 22 23 using System; 23 24 using System.Collections.Generic; 25 using System.Globalization; 24 26 using System.Linq; 27 using System.Text; 25 28 using HeuristicLab.Common; 26 29 using HeuristicLab.Core; … … 31 34 [StorableClass] 32 35 [Item("RegressionTreeModel", "Represents a decision tree for regression.")] 33 public class RegressionTreeModel : NamedItem, IRegressionModel {36 public sealed class RegressionTreeModel : NamedItem, IRegressionModel { 34 37 35 38 // trees are represented as a flat array 36 publicstruct TreeNode {39 internal struct TreeNode { 37 40 public readonly static string NO_VARIABLE = string.Empty; 38 public string VarName { get; set; } // name of the variable for splitting or NO_VARIABLE if terminal node39 public double Val { get; set; } // threshold40 public int LeftIdx { get; set; }41 public int RightIdx { get; set; }42 41 42 public TreeNode(string varName, double val, int leftIdx = -1, int rightIdx = -1) 43 : this() { 44 VarName = varName; 45 Val = val; 46 LeftIdx = leftIdx; 47 RightIdx = rightIdx; 48 } 49 50 public string VarName { get; private set; } // name of the variable for splitting or NO_VARIABLE if terminal node 51 public double Val { get; private set; } // threshold 52 public int LeftIdx { get; private set; } 53 public int RightIdx { get; private set; } 54 55 // necessary because the default implementation of GetHashCode for structs in .NET would only return the hashcode of val here 43 56 public override int GetHashCode() { 44 57 return LeftIdx ^ RightIdx ^ Val.GetHashCode(); 58 } 59 // necessary because of GetHashCode override 60 public override bool Equals(object obj) { 61 if (obj is TreeNode) { 62 var other = (TreeNode)obj; 63 return Val.Equals(other.Val) && 64 VarName.Equals(other.VarName) && 65 LeftIdx.Equals(other.LeftIdx) && 66 RightIdx.Equals(other.RightIdx); 67 } else { 68 return false; 69 } 45 70 } 46 71 } 47 72 48 73 [Storable] 49 p ublicreadonly TreeNode[] tree;74 private readonly TreeNode[] tree; 50 75 51 76 [StorableConstructor] 52 77 private RegressionTreeModel(bool serializing) : base(serializing) { } 53 78 // cloning ctor 54 p ublicRegressionTreeModel(RegressionTreeModel original, Cloner cloner)79 private RegressionTreeModel(RegressionTreeModel original, Cloner cloner) 55 80 : base(original, cloner) { 56 81 this.tree = original.tree; // shallow clone, tree must be readonly 57 82 } 58 83 59 publicRegressionTreeModel(TreeNode[] tree)84 internal RegressionTreeModel(TreeNode[] tree) 60 85 : base("RegressionTreeModel", "Represents a decision tree for regression.") { 61 86 this.tree = tree; … … 84 109 return new RegressionSolution(this, new RegressionProblemData(problemData)); 85 110 } 111 112 // mainly for debugging 113 public override string ToString() { 114 return TreeToString(0, ""); 115 } 116 117 private string TreeToString(int idx, string part) { 118 var n = tree[idx]; 119 if (n.VarName == TreeNode.NO_VARIABLE) { 120 return string.Format(CultureInfo.InvariantCulture, "{0} -> {1:F}{2}", part, n.Val, Environment.NewLine); 121 } else { 122 return 123 TreeToString(n.LeftIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F}", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val)) 124 + TreeToString(n.RightIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} > {3:F}", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val)); 125 } 126 } 127 86 128 } 87 129 -
branches/GBT-trunkintegration/Tests/Test.cs
r12632 r12658 11 11 using Microsoft.VisualStudio.TestTools.UnitTesting; 12 12 13 namespace HeuristicLab.Algorithms.DataAnalysis .GradientBoostedTrees{13 namespace HeuristicLab.Algorithms.DataAnalysis { 14 14 [TestClass()] 15 15 public class Test { … … 21 21 var xy = new double[,] 22 22 { 23 { 1, 20, 0},24 { 1, 20, 0},25 { 2, 10, 0},26 { 2, 10, 0},27 }; 28 var allVariables = new string[] { "y", "x1", "x2" }; 29 30 // x1 <= 15 -> 231 // x1 > 15 -> 132 BuildTree(xy, allVariables, 10); 33 } 34 35 36 { 37 var xy = new double[,] 38 { 39 { 1, 20, 1},40 { 1, 20, -1},41 { 2, 10, -1},42 { 2, 10, 1},23 {-1, 20, 0}, 24 {-1, 20, 0}, 25 { 1, 10, 0}, 26 { 1, 10, 0}, 27 }; 28 var allVariables = new string[] { "y", "x1", "x2" }; 29 30 // x1 <= 15 -> 1 31 // x1 > 15 -> -1 32 BuildTree(xy, allVariables, 10); 33 } 34 35 36 { 37 var xy = new double[,] 38 { 39 {-1, 20, 1}, 40 {-1, 20, -1}, 41 { 1, 10, -1}, 42 { 1, 10, 1}, 43 43 }; 44 44 var allVariables = new string[] { "y", "x1", "x2" }; 45 45 46 46 // ignore irrelevant variables 47 // x1 <= 15 -> 248 // x1 > 15 -> 147 // x1 <= 15 -> 1 48 // x1 > 15 -> -1 49 49 BuildTree(xy, allVariables, 10); 50 50 } … … 54 54 var xy = new double[,] 55 55 { 56 { 1, 20, 1},57 { 2, 20, -1},58 { 3, 10, -1},59 { 4, 10, 1},60 }; 61 62 var allVariables = new string[] { "y", "x1", "x2" }; 63 64 // x1 <= 15 AND x2 <= 0 -> 365 // x1 <= 15 AND x2 > 0 -> 466 // x1 > 15 AND x2 <= 0 -> 267 // x1 > 15 AND x2 > 0 -> 156 {-2, 20, 1}, 57 {-1, 20, -1}, 58 { 1, 10, -1}, 59 { 2, 10, 1}, 60 }; 61 62 var allVariables = new string[] { "y", "x1", "x2" }; 63 64 // x1 <= 15 AND x2 <= 0 -> 1 65 // x1 <= 15 AND x2 > 0 -> 2 66 // x1 > 15 AND x2 <= 0 -> -1 67 // x1 > 15 AND x2 > 0 -> -2 68 68 BuildTree(xy, allVariables, 10); 69 69 } … … 73 73 var xy = new double[,] 74 74 { 75 { 0.5, 20, 1},76 { 1.5, 20, 1},77 { 1.5, 20, -1},78 { 2.5, 20, -1},79 { 2.5, 10, -1},80 { 3.5, 10, -1},81 { 3.5, 10, 1},82 { 4.5, 10, 1},83 }; 84 85 var allVariables = new string[] { "y", "x1", "x2" }; 86 87 // x1 <= 15 AND x2 <= 0 -> 388 // x1 <= 15 AND x2 > 0 -> 489 // x1 > 15 AND x2 <= 0 -> 290 // x1 > 15 AND x2 > 0 -> 175 {-2.5, 20, 1}, 76 {-1.5, 20, 1}, 77 {-1.5, 20, -1}, 78 {0.5, 20, -1}, 79 {0.5, 10, -1}, 80 {1.5, 10, -1}, 81 {1.5, 10, 1}, 82 {2.5, 10, 1}, 83 }; 84 85 var allVariables = new string[] { "y", "x1", "x2" }; 86 87 // x1 <= 15 AND x2 <= 0 -> 1 88 // x1 <= 15 AND x2 > 0 -> 2 89 // x1 > 15 AND x2 <= 0 -> -1 90 // x1 > 15 AND x2 > 0 -> -2 91 91 BuildTree(xy, allVariables, 10); 92 92 } … … 97 97 var xy = new double[,] 98 98 { 99 { 10, 1, 1},100 { 1, 1, 2},101 { 1, 2, 1},102 { 10, 2, 2},99 { 1, 1, 1}, 100 {-1, 1, 2}, 101 {-1, 2, 1}, 102 { 1, 2, 2}, 103 103 }; 104 104 … … 106 106 107 107 // split cannot be found 108 // -> 5.50108 // -> 0.0 109 109 BuildTree(xy, allVariables, 3); 110 110 } … … 113 113 var xy = new double[,] 114 114 { 115 { 10, 1, 1},116 { 1, 1, 2},117 { 1, 2, 1},118 { 10.1, 2, 2},115 { 1, 1, 1}, 116 {-1, 1, 2}, 117 {-1, 2, 1}, 118 { 1.0001, 2, 2}, 119 119 }; 120 120 121 121 var allVariables = new string[] { "y", "x1", "x2" }; 122 122 // (two possible solutions) 123 // x2 <= 1.5 -> 5.50124 // x2 > 1.5 -> 5.55123 // x2 <= 1.5 -> 0 124 // x2 > 1.5 -> 0 (not quite) 125 125 BuildTree(xy, allVariables, 3); 126 126 127 // x1 <= 1.5 AND x2 <= 1.5 -> 1 0128 // x1 <= 1.5 AND x2 > 1.5 -> 1129 // x1 > 1.5 AND x2 <= 1.5 -> 1130 // x1 > 1.5 AND x2 > 1.5 -> 1 0.1127 // x1 <= 1.5 AND x2 <= 1.5 -> 1 128 // x1 <= 1.5 AND x2 > 1.5 -> -1 129 // x1 > 1.5 AND x2 <= 1.5 -> -1 130 // x1 > 1.5 AND x2 > 1.5 -> 1 (not quite) 131 131 BuildTree(xy, allVariables, 7); 132 132 } … … 155 155 {-1, 1, 2}, 156 156 {-1, 2, 1}, 157 { 1, 2, 2},157 { 3, 2, 2}, 158 158 }; 159 159 … … 162 162 // x2 <= 1.5 -> -1.0 163 163 // x2 > 1.5 AND x1 <= 1.5 -> -1.0 164 // x2 > 1.5 AND x1 > 1.5 -> 1.0164 // x2 > 1.5 AND x1 > 1.5 -> 3.0 165 165 BuildTree(xy, allVariables, 10); 166 166 } … … 269 269 var builder = new RegressionTreeBuilder(problemData, rand); 270 270 var model = (GradientBoostedTreesModel)builder.CreateRegressionTree(maxDepth, 1, 1); // maximal depth and use all rows and cols 271 var constM = model.Models.First() as ConstantRegressionModel;272 271 var treeM = model.Models.Skip(1).First() as RegressionTreeModel; 273 WriteTree(treeM.tree, 0, "", constM.Constant); 272 273 Console.WriteLine(treeM.ToString()); 274 274 Console.WriteLine(); 275 }276 277 private void WriteTree(RegressionTreeModel.TreeNode[] tree, int idx, string partialRule, double offset) {278 var n = tree[idx];279 if (n.VarName == RegressionTreeModel.TreeNode.NO_VARIABLE) {280 Console.WriteLine("{0} -> {1:F}", partialRule, n.Val + offset);281 } else {282 WriteTree(tree, n.LeftIdx,283 string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F}",284 partialRule,285 string.IsNullOrEmpty(partialRule) ? "" : " and ",286 n.VarName,287 n.Val), offset);288 WriteTree(tree, n.RightIdx,289 string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} > {3:F}",290 partialRule,291 string.IsNullOrEmpty(partialRule) ? "" : " and ",292 n.VarName,293 n.Val), offset);294 }295 275 } 296 276 #endregion
Note: See TracChangeset
for help on using the changeset viewer.