- Timestamp:
- 05/01/15 18:30:56 (10 years ago)
- Location:
- branches/GBT
- Files:
-
- 4 added
- 4 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/GBT/HeuristicLab.Algorithms.DataAnalysis.Views/3.4/HeuristicLab.Algorithms.DataAnalysis.Views-3.4.csproj
r11623 r12372 125 125 </ItemGroup> 126 126 <ItemGroup> 127 <Compile Include="RegressionTreeModelView.cs"> 128 <SubType>UserControl</SubType> 129 </Compile> 130 <Compile Include="RegressionTreeModelView.Designer.cs"> 131 <DependentUpon>RegressionTreeModelView.cs</DependentUpon> 132 </Compile> 133 <Compile Include="GradientBoostedTreesModelView.cs"> 134 <SubType>UserControl</SubType> 135 </Compile> 136 <Compile Include="GradientBoostedTreesModelView.Designer.cs"> 137 <DependentUpon>GradientBoostedTreesModelView.cs</DependentUpon> 138 </Compile> 127 139 <Compile Include="MeanProdView.cs"> 128 140 <SubType>UserControl</SubType> … … 244 256 <Name>HeuristicLab.Data-3.3</Name> 245 257 <Private>False</Private> 258 </ProjectReference> 259 <ProjectReference Include="..\..\HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Views\3.4\HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Views-3.4.csproj"> 260 <Project>{423bd94f-963a-438e-ba45-3bb3d61cd03b}</Project> 261 <Name>HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Views-3.4</Name> 262 </ProjectReference> 263 <ProjectReference Include="..\..\HeuristicLab.Encodings.SymbolicExpressionTreeEncoding\3.4\HeuristicLab.Encodings.SymbolicExpressionTreeEncoding-3.4.csproj"> 264 <Project>{06D4A186-9319-48A0-BADE-A2058D462EEA}</Project> 265 <Name>HeuristicLab.Encodings.SymbolicExpressionTreeEncoding-3.4</Name> 246 266 </ProjectReference> 247 267 <ProjectReference Include="..\..\HeuristicLab.MainForm.WindowsForms\3.3\HeuristicLab.MainForm.WindowsForms-3.3.csproj"> -
branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModel.cs
r12332 r12372 8 8 9 9 namespace GradientBoostedTrees { 10 [Item("GradientBoostedTreesSolution", "")]11 10 [StorableClass] 11 [Item("Gradient boosted tree model", "")] 12 12 public sealed class GradientBoostedTreesModel : NamedItem, IRegressionModel { 13 13 14 14 [Storable] 15 15 private readonly IList<IRegressionModel> models; 16 public IEnumerable<IRegressionModel> Models { get { return models; } } 17 16 18 [Storable] 17 19 private readonly IList<double> weights; 20 public IEnumerable<double> Weights { get { return weights; } } 18 21 19 22 [StorableConstructor] … … 25 28 } 26 29 public GradientBoostedTreesModel(IEnumerable<IRegressionModel> models, IEnumerable<double> weights) 27 : base( ) {30 : base("Gradient boosted tree model", string.Empty) { 28 31 this.models = new List<IRegressionModel>(models); 29 32 this.weights = new List<double>(weights); -
branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs
r12349 r12372 33 33 private readonly double[] outx; 34 34 private readonly int[] outSortedIdx; 35 36 private RegressionTreeModel.TreeNode[] tree; // tree is represented as a flat array of nodes 37 private int curTreeNodeIdx; // the index where the next tree node is stored 38 35 39 private readonly IList<RegressionTreeModel.TreeNode> nodeQueue; //TODO 36 40 … … 128 132 } 129 133 } 134 135 // prepare array for the tree nodes (a tree of maxDepth=1 has 1 node, a tree of maxDepth=d has 2^d - 1 nodes) 136 int numNodes = (int)Math.Pow(2, maxDepth) - 1; 137 //this.tree = new RegressionTreeModel.TreeNode[numNodes]; 138 this.tree = Enumerable.Range(0, numNodes).Select(_=>new RegressionTreeModel.TreeNode()).ToArray(); 139 this.curTreeNodeIdx = 0; 140 130 141 // start and end idx are inclusive 131 var tree =CreateRegressionTreeForIdx(maxDepth, 0, effectiveRows - 1, lineSearch);142 CreateRegressionTreeForIdx(maxDepth, 0, effectiveRows - 1, lineSearch); 132 143 return new RegressionTreeModel(tree); 133 144 } 134 145 135 146 // startIdx and endIdx are inclusive 136 private RegressionTreeModel.TreeNodeCreateRegressionTreeForIdx(int maxDepth, int startIdx, int endIdx, LineSearchFunc lineSearch) {147 private void CreateRegressionTreeForIdx(int maxDepth, int startIdx, int endIdx, LineSearchFunc lineSearch) { 137 148 Contract.Assert(endIdx - startIdx >= 0); 138 149 Contract.Assert(startIdx >= 0); 139 150 Contract.Assert(endIdx < internalIdx.Length); 140 151 141 RegressionTreeModel.TreeNode t;142 152 // TODO: stop when y is constant 143 153 // TODO: use priority queue of nodes to be expanded (sorted by improvement) instead of the recursion to maximum depth 144 154 if (maxDepth <= 1 || endIdx - startIdx == 0) { 145 // max depth reached or only one element 146 t = new RegressionTreeModel.TreeNode(RegressionTreeModel.TreeNode.NO_VARIABLE, lineSearch(internalIdx, startIdx, endIdx)); 147 return t; 155 // max depth reached or only one element 156 tree[curTreeNodeIdx].varName = RegressionTreeModel.TreeNode.NO_VARIABLE; 157 tree[curTreeNodeIdx].val = lineSearch(internalIdx, startIdx, endIdx); 158 curTreeNodeIdx++; 148 159 } else { 149 160 int i, j; … … 154 165 // if bestVariableName is NO_VARIABLE then no split was possible anymore 155 166 if (bestVariableName == RegressionTreeModel.TreeNode.NO_VARIABLE) { 156 return new RegressionTreeModel.TreeNode(RegressionTreeModel.TreeNode.NO_VARIABLE, lineSearch(internalIdx, startIdx, endIdx)); 167 // max depth reached or only one element 168 tree[curTreeNodeIdx].varName = RegressionTreeModel.TreeNode.NO_VARIABLE; 169 tree[curTreeNodeIdx].val = lineSearch(internalIdx, startIdx, endIdx); 170 curTreeNodeIdx++; 157 171 } else { 158 172 … … 214 228 Debug.Assert(j <= endIdx); 215 229 216 t = new RegressionTreeModel.TreeNode(bestVariableName, 217 threshold, 218 CreateRegressionTreeForIdx(maxDepth - 1, startIdx, j, lineSearch), 219 CreateRegressionTreeForIdx(maxDepth - 1, i, endIdx, lineSearch)); 220 221 return t; 230 var parentIdx = curTreeNodeIdx; 231 tree[parentIdx].varName = bestVariableName; 232 tree[parentIdx].val = threshold; 233 curTreeNodeIdx++; 234 235 // create left subtree 236 tree[parentIdx].leftIdx = curTreeNodeIdx; 237 CreateRegressionTreeForIdx(maxDepth - 1, startIdx, j, lineSearch); 238 239 // create right subtree 240 tree[parentIdx].rightIdx = curTreeNodeIdx; 241 CreateRegressionTreeForIdx(maxDepth - 1, i, endIdx, lineSearch); 222 242 } 223 243 } … … 272 292 // assumption is that the Average(y) = 0 273 293 private void UpdateVariableRelevance(string bestVar, double sumY, double bestImprovement, int rows) { 294 if (string.IsNullOrEmpty(bestVar)) return; 274 295 // update variable relevance 275 296 double err = sumY * sumY / rows; -
branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs
r12349 r12372 1 using System.Collections.Generic; 1 using System; 2 using System.Collections.Generic; 2 3 using System.Linq; 3 4 using HeuristicLab.Common; … … 12 13 public class RegressionTreeModel : NamedItem, IRegressionModel { 13 14 15 // trees are represented as a flat array 16 // object-graph-travesal has problems if this is defined as a struct. TODO investigate... 14 17 [StorableClass] 15 18 public class TreeNode { 16 19 public readonly static string NO_VARIABLE = string.Empty; 17 20 [Storable] 18 public readonlystring varName; // name of the variable for splitting or -1 if terminal node21 public string varName; // name of the variable for splitting or -1 if terminal node 19 22 [Storable] 20 public readonlydouble val; // threshold23 public double val; // threshold 21 24 [Storable] 22 public readonly TreeNode left;25 public int leftIdx; 23 26 [Storable] 24 public readonly TreeNode right;27 public int rightIdx; 25 28 29 public TreeNode() { 30 varName = NO_VARIABLE; 31 leftIdx = -1; 32 rightIdx = -1; 33 } 26 34 [StorableConstructor] 27 35 private TreeNode(bool deserializing) { } 28 29 public TreeNode(string varName, double value, TreeNode left = null, TreeNode right = null) {30 this.varName = varName;31 this.val = value;32 this.left = left;33 this.right = right;34 }35 36 } 36 37 37 38 [Storable] 38 public readonly TreeNode tree;39 public readonly TreeNode[] tree; 39 40 40 41 [StorableConstructor] … … 43 44 public RegressionTreeModel(RegressionTreeModel original, Cloner cloner) 44 45 : base(original, cloner) { 45 this.tree = original.tree; 46 this.tree = original.tree; // shallow clone, tree must be readonly 46 47 } 47 48 48 public RegressionTreeModel(TreeNode tree) 49 : base() { 50 this.name = ItemName; 51 this.description = ItemDescription; 52 49 public RegressionTreeModel(TreeNode[] tree) 50 : base("RegressionTreeModel", "Represents a decision tree for regression.") { 53 51 this.tree = tree; 54 52 } 55 53 56 private static double GetPredictionForRow(TreeNode t, Dataset ds, int row) { 57 if (t.varName == TreeNode.NO_VARIABLE) 58 return t.val; 59 else if (ds.GetDoubleValue(t.varName, row) <= t.val) 60 return GetPredictionForRow(t.left, ds, row); 54 private static double GetPredictionForRow(TreeNode[] t, int nodeIdx, Dataset ds, int row) { 55 var node = t[nodeIdx]; 56 if (node.varName == TreeNode.NO_VARIABLE) 57 return node.val; 58 else if (ds.GetDoubleValue(node.varName, row) <= node.val) 59 return GetPredictionForRow(t, node.leftIdx, ds, row); 61 60 else 62 return GetPredictionForRow(t .right, ds, row);61 return GetPredictionForRow(t, node.rightIdx, ds, row); 63 62 } 64 63 … … 68 67 69 68 public IEnumerable<double> GetEstimatedValues(Dataset ds, IEnumerable<int> rows) { 70 return rows.Select(r => GetPredictionForRow(tree, ds, r));69 return rows.Select(r => GetPredictionForRow(tree, 0, ds, r)); 71 70 } 72 71
Note: See TracChangeset
for help on using the changeset viewer.