Changeset 13948 for branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees
- Timestamp:
- 06/29/16 10:36:52 (8 years ago)
- Location:
- branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis
- Files:
-
- 5 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis
- Property svn:mergeinfo changed
/trunk/sources/HeuristicLab.Algorithms.DataAnalysis (added) merged: 13889,13891,13895,13898,13917,13921-13922,13941
- Property svn:mergeinfo changed
-
branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModel.cs
r13157 r13948 33 33 [Item("Gradient boosted tree model", "")] 34 34 // this is essentially a collection of weighted regression models 35 public sealed class GradientBoostedTreesModel : NamedItem, IGradientBoostedTreesModel {35 public sealed class GradientBoostedTreesModel : RegressionModel, IGradientBoostedTreesModel { 36 36 // BackwardsCompatibility3.4 for allowing deserialization & serialization of old models 37 37 #region Backwards compatible code, remove with 3.5 … … 58 58 #endregion 59 59 60 public override IEnumerable<string> VariablesUsedForPrediction { 61 get { return models.SelectMany(x => x.VariablesUsedForPrediction).Distinct().OrderBy(x => x); } 62 } 63 60 64 private readonly IList<IRegressionModel> models; 61 65 public IEnumerable<IRegressionModel> Models { get { return models; } } … … 77 81 } 78 82 [Obsolete("The constructor of GBTModel should not be used directly anymore (use GBTModelSurrogate instead)")] 79 publicGradientBoostedTreesModel(IEnumerable<IRegressionModel> models, IEnumerable<double> weights)80 : base( "Gradient boosted tree model", string.Empty) {83 internal GradientBoostedTreesModel(IEnumerable<IRegressionModel> models, IEnumerable<double> weights) 84 : base(string.Empty, "Gradient boosted tree model", string.Empty) { 81 85 this.models = new List<IRegressionModel>(models); 82 86 this.weights = new List<double>(weights); … … 89 93 } 90 94 91 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {95 public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 92 96 // allocate target array go over all models and add up weighted estimation for each row 93 97 if (!rows.Any()) return Enumerable.Empty<double>(); // return immediately if rows is empty. This prevents multiple iteration over lazy rows enumerable. … … 105 109 } 106 110 107 public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {111 public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { 108 112 return new RegressionSolution(this, (IRegressionProblemData)problemData.Clone()); 109 113 } 114 110 115 } 111 116 } -
branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModelSurrogate.cs
r13157 r13948 22 22 23 23 using System.Collections.Generic; 24 using System.Linq; 24 25 using HeuristicLab.Common; 25 26 using HeuristicLab.Core; … … 33 34 // recalculate the actual GBT model on demand 34 35 [Item("Gradient boosted tree model", "")] 35 public sealed class GradientBoostedTreesModelSurrogate : NamedItem, IGradientBoostedTreesModel {36 public sealed class GradientBoostedTreesModelSurrogate : RegressionModel, IGradientBoostedTreesModel { 36 37 // don't store the actual model! 37 38 private IGradientBoostedTreesModel actualModel; // the actual model is only recalculated when necessary … … 55 56 56 57 58 public override IEnumerable<string> VariablesUsedForPrediction { 59 get { return actualModel.Models.SelectMany(x => x.VariablesUsedForPrediction).Distinct().OrderBy(x => x); } 60 } 61 57 62 [StorableConstructor] 58 63 private GradientBoostedTreesModelSurrogate(bool deserializing) : base(deserializing) { } … … 73 78 74 79 // create only the surrogate model without an actual model 75 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu) 76 : base("Gradient boosted tree model", string.Empty) { 80 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, 81 ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu) 82 : base(trainingProblemData.TargetVariable, "Gradient boosted tree model", string.Empty) { 77 83 this.trainingProblemData = trainingProblemData; 78 84 this.seed = seed; … … 86 92 87 93 // wrap an actual model in a surrograte 88 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu, IGradientBoostedTreesModel model) 94 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, 95 ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu, 96 IGradientBoostedTreesModel model) 89 97 : this(trainingProblemData, seed, lossFunction, iterations, maxSize, r, m, nu) { 90 98 this.actualModel = model; … … 96 104 97 105 // forward message to actual model (recalculate model first if necessary) 98 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {106 public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 99 107 if (actualModel == null) actualModel = RecalculateModel(); 100 108 return actualModel.GetEstimatedValues(dataset, rows); 101 109 } 102 110 103 public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {111 public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { 104 112 return new RegressionSolution(this, (IRegressionProblemData)problemData.Clone()); 105 113 } 106 107 114 108 115 private IGradientBoostedTreesModel RecalculateModel() { -
branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs
r13065 r13948 180 180 181 181 182 // processes potential splits from the queue as long as splits are leftand the maximum size of the tree is not reached182 // processes potential splits from the queue as long as splits are remaining and the maximum size of the tree is not reached 183 183 private void CreateRegressionTreeFromQueue(int maxNodes, ILossFunction lossFunction) { 184 184 while (queue.Any() && curTreeNodeIdx + 1 < maxNodes) { // two nodes are created in each loop … … 204 204 205 205 // overwrite existing leaf node with an internal node 206 tree[f.ParentNodeIdx] = new RegressionTreeModel.TreeNode(f.SplittingVariable, f.SplittingThreshold, leftTreeIdx, rightTreeIdx );206 tree[f.ParentNodeIdx] = new RegressionTreeModel.TreeNode(f.SplittingVariable, f.SplittingThreshold, leftTreeIdx, rightTreeIdx, weightLeft: (splitIdx - startIdx + 1) / (double)(endIdx - startIdx + 1)); 207 207 } 208 208 } -
branches/HeuristicLab.RegressionSolutionGradientView/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs
r13030 r13948 34 34 [StorableClass] 35 35 [Item("RegressionTreeModel", "Represents a decision tree for regression.")] 36 public sealed class RegressionTreeModel : NamedItem, IRegressionModel { 36 public sealed class RegressionTreeModel : RegressionModel { 37 public override IEnumerable<string> VariablesUsedForPrediction { 38 get { return tree.Select(t => t.VarName).Where(v => v != TreeNode.NO_VARIABLE); } 39 } 37 40 38 41 // trees are represented as a flat array … … 40 43 public readonly static string NO_VARIABLE = null; 41 44 42 public TreeNode(string varName, double val, int leftIdx = -1, int rightIdx = -1 )45 public TreeNode(string varName, double val, int leftIdx = -1, int rightIdx = -1, double weightLeft = -1.0) 43 46 : this() { 44 47 VarName = varName; … … 46 49 LeftIdx = leftIdx; 47 50 RightIdx = rightIdx; 48 } 49 50 public string VarName { get; private set; } // name of the variable for splitting or NO_VARIABLE if terminal node 51 public double Val { get; private set; } // threshold 52 public int LeftIdx { get; private set; } 53 public int RightIdx { get; private set; } 51 WeightLeft = weightLeft; 52 } 53 54 public string VarName { get; internal set; } // name of the variable for splitting or NO_VARIABLE if terminal node 55 public double Val { get; internal set; } // threshold 56 public int LeftIdx { get; internal set; } 57 public int RightIdx { get; internal set; } 58 public double WeightLeft { get; internal set; } // for partial dependence plots (value in range [0..1] describes the fraction of training samples for the left sub-tree 59 54 60 55 61 // necessary because the default implementation of GetHashCode for structs in .NET would only return the hashcode of val here … … 64 70 LeftIdx.Equals(other.LeftIdx) && 65 71 RightIdx.Equals(other.RightIdx) && 72 WeightLeft.Equals(other.WeightLeft) && 66 73 EqualStrings(VarName, other.VarName); 67 74 } else { … … 79 86 private TreeNode[] tree; 80 87 81 [Storable] 88 #region old storable format 89 // remove with HL 3.4 90 [Storable(AllowOneWay = true)] 82 91 // to prevent storing the references to data caches in nodes 83 // TODO seeminglyit is bad (performance-wise) to persist tuples (tuples are used as keys in a dictionary)92 // seemingly, it is bad (performance-wise) to persist tuples (tuples are used as keys in a dictionary) 84 93 private Tuple<string, double, int, int>[] SerializedTree { 85 get { return tree.Select(t => Tuple.Create(t.VarName, t.Val, t.LeftIdx, t.RightIdx)).ToArray(); } 86 set { this.tree = value.Select(t => new TreeNode(t.Item1, t.Item2, t.Item3, t.Item4)).ToArray(); } 87 } 94 // get { return tree.Select(t => Tuple.Create(t.VarName, t.Val, t.LeftIdx, t.RightIdx)).ToArray(); } 95 set { this.tree = value.Select(t => new TreeNode(t.Item1, t.Item2, t.Item3, t.Item4, -1.0)).ToArray(); } // use a weight of -1.0 to indicate that partial dependence cannot be calculated for old models 96 } 97 #endregion 98 #region new storable format 99 [Storable] 100 private string[] SerializedTreeVarNames { 101 get { return tree.Select(t => t.VarName).ToArray(); } 102 set { 103 if (tree == null) tree = new TreeNode[value.Length]; 104 for (int i = 0; i < value.Length; i++) { 105 tree[i].VarName = value[i]; 106 } 107 } 108 } 109 [Storable] 110 private double[] SerializedTreeValues { 111 get { return tree.Select(t => t.Val).ToArray(); } 112 set { 113 if (tree == null) tree = new TreeNode[value.Length]; 114 for (int i = 0; i < value.Length; i++) { 115 tree[i].Val = value[i]; 116 } 117 } 118 } 119 [Storable] 120 private int[] SerializedTreeLeftIdx { 121 get { return tree.Select(t => t.LeftIdx).ToArray(); } 122 set { 123 if (tree == null) tree = new TreeNode[value.Length]; 124 for (int i = 0; i < value.Length; i++) { 125 tree[i].LeftIdx = value[i]; 126 } 127 } 128 } 129 [Storable] 130 private int[] SerializedTreeRightIdx { 131 get { return tree.Select(t => t.RightIdx).ToArray(); } 132 set { 133 if (tree == null) tree = new TreeNode[value.Length]; 134 for (int i = 0; i < value.Length; i++) { 135 tree[i].RightIdx = value[i]; 136 } 137 } 138 } 139 [Storable] 140 private double[] SerializedTreeWeightLeft { 141 get { return tree.Select(t => t.WeightLeft).ToArray(); } 142 set { 143 if (tree == null) tree = new TreeNode[value.Length]; 144 for (int i = 0; i < value.Length; i++) { 145 tree[i].WeightLeft = value[i]; 146 } 147 } 148 } 149 #endregion 150 151 152 88 153 89 154 [StorableConstructor] … … 98 163 } 99 164 100 internal RegressionTreeModel(TreeNode[] tree )101 : base( "RegressionTreeModel", "Represents a decision tree for regression.") {165 internal RegressionTreeModel(TreeNode[] tree, string target = "Target") 166 : base(target, "RegressionTreeModel", "Represents a decision tree for regression.") { 102 167 this.tree = tree; 103 168 } … … 108 173 if (node.VarName == TreeNode.NO_VARIABLE) 109 174 return node.Val; 110 111 if (columnCache[nodeIdx][row] <= node.Val) 175 if (columnCache[nodeIdx] == null) { 176 if (node.WeightLeft.IsAlmost(-1.0)) throw new InvalidOperationException("Cannot calculate partial dependence for trees loaded from older versions of HeuristicLab."); 177 // weighted average for partial dependence plot (recursive here because we need to calculate both sub-trees) 178 return node.WeightLeft * GetPredictionForRow(t, columnCache, node.LeftIdx, row) + 179 (1.0 - node.WeightLeft) * GetPredictionForRow(t, columnCache, node.RightIdx, row); 180 } else if (columnCache[nodeIdx][row] <= node.Val) 112 181 nodeIdx = node.LeftIdx; 113 182 else … … 121 190 } 122 191 123 public IEnumerable<double> GetEstimatedValues(IDataset ds, IEnumerable<int> rows) {192 public override IEnumerable<double> GetEstimatedValues(IDataset ds, IEnumerable<int> rows) { 124 193 // lookup columns for variableNames in one pass over the tree to speed up evaluation later on 125 194 ReadOnlyCollection<double>[] columnCache = new ReadOnlyCollection<double>[tree.Length]; … … 127 196 for (int i = 0; i < tree.Length; i++) { 128 197 if (tree[i].VarName != TreeNode.NO_VARIABLE) { 129 columnCache[i] = ds.GetReadOnlyDoubleValues(tree[i].VarName); 198 // tree models also support calculating estimations if not all variables used for training are available in the dataset 199 if (ds.ColumnNames.Contains(tree[i].VarName)) 200 columnCache[i] = ds.GetReadOnlyDoubleValues(tree[i].VarName); 130 201 } 131 202 } … … 133 204 } 134 205 135 public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {206 public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { 136 207 return new RegressionSolution(this, new RegressionProblemData(problemData)); 137 208 } … … 148 219 } else { 149 220 return 150 TreeToString(n.LeftIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F}", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val)) 151 + TreeToString(n.RightIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} > {3:F}", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val)); 152 } 153 } 221 TreeToString(n.LeftIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F} ({4:N3})", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val, n.WeightLeft)) 222 + TreeToString(n.RightIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} > {3:F} ({4:N3}))", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val, 1.0 - n.WeightLeft)); 223 } 224 } 225 154 226 } 155 227 }
Note: See TracChangeset
for help on using the changeset viewer.