Changeset 14029 for branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees
- Timestamp:
- 07/08/16 14:40:02 (8 years ago)
- Location:
- branches/crossvalidation-2434
- Files:
-
- 8 edited
- 3 copied
Legend:
- Unmodified
- Added
- Removed
-
branches/crossvalidation-2434
- Property svn:mergeinfo changed
-
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis
- Property svn:mergeinfo changed
-
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithm.cs
r12874 r14029 35 35 36 36 namespace HeuristicLab.Algorithms.DataAnalysis { 37 [Item("Gradient Boosted Trees ", "Gradient boosted trees algorithm. Friedman, J. \"Greedy Function Approximation: A Gradient Boosting Machine\", IMS 1999 Reitz Lecture.")]37 [Item("Gradient Boosted Trees (GBT)", "Gradient boosted trees algorithm. Specific implementation of gradient boosting for regression trees. Friedman, J. \"Greedy Function Approximation: A Gradient Boosting Machine\", IMS 1999 Reitz Lecture.")] 38 38 [StorableClass] 39 39 [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 125)] … … 255 255 // produce solution 256 256 if (CreateSolution) { 257 var surrogateModel = new GradientBoostedTreesModelSurrogate(problemData, (uint)Seed, lossFunction, 258 Iterations, MaxSize, R, M, Nu, state.GetModel()); 257 var model = state.GetModel(); 259 258 260 259 // for logistic regression we produce a classification solution 261 260 if (lossFunction is LogisticRegressionLoss) { 262 var model = new DiscriminantFunctionClassificationModel(surrogateModel,261 var classificationModel = new DiscriminantFunctionClassificationModel(model, 263 262 new AccuracyMaximizationThresholdCalculator()); 264 263 var classificationProblemData = new ClassificationProblemData(problemData.Dataset, 265 264 problemData.AllowedInputVariables, problemData.TargetVariable, problemData.Transformations); 266 model.RecalculateModelParameters(classificationProblemData, classificationProblemData.TrainingIndices);267 268 var classificationSolution = new DiscriminantFunctionClassificationSolution( model, classificationProblemData);265 classificationModel.RecalculateModelParameters(classificationProblemData, classificationProblemData.TrainingIndices); 266 267 var classificationSolution = new DiscriminantFunctionClassificationSolution(classificationModel, classificationProblemData); 269 268 Results.Add(new Result("Solution", classificationSolution)); 270 269 } else { 271 270 // otherwise we produce a regression solution 272 Results.Add(new Result("Solution", new RegressionSolution( surrogateModel, problemData)));271 Results.Add(new Result("Solution", new RegressionSolution(model, problemData))); 273 272 } 274 273 } -
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithmStatic.cs
r12710 r14029 52 52 internal RegressionTreeBuilder treeBuilder { get; private set; } 53 53 54 private readonly uint randSeed; 54 55 private MersenneTwister random { get; set; } 55 56 … … 71 72 this.m = m; 72 73 74 this.randSeed = randSeed; 73 75 random = new MersenneTwister(randSeed); 74 76 this.problemData = problemData; … … 94 96 weights = new List<double>(); 95 97 // add constant model 96 models.Add(new Constant RegressionModel(f0));98 models.Add(new ConstantModel(f0, problemData.TargetVariable)); 97 99 weights.Add(1.0); 98 100 } 99 101 100 102 public IRegressionModel GetModel() { 101 return new GradientBoostedTreesModel(models, weights); 103 #pragma warning disable 618 104 var model = new GradientBoostedTreesModel(models, weights); 105 #pragma warning restore 618 106 // we don't know the number of iterations here but the number of weights is equal 107 // to the number of iterations + 1 (for the constant model) 108 // wrap the actual model in a surrogate that enables persistence and lazy recalculation of the model if necessary 109 return new GradientBoostedTreesModelSurrogate(problemData, randSeed, lossFunction, weights.Count - 1, maxSize, r, m, nu, model); 102 110 } 103 111 public IEnumerable<KeyValuePair<string, double>> GetVariableRelevance() { … … 122 130 123 131 // simple interface 124 public static IRegressionSolution TrainGbm(IRegressionProblemData problemData, ILossFunction lossFunction, int maxSize, double nu, double r, double m, int maxIterations, uint randSeed = 31415) {132 public static GradientBoostedTreesSolution TrainGbm(IRegressionProblemData problemData, ILossFunction lossFunction, int maxSize, double nu, double r, double m, int maxIterations, uint randSeed = 31415) { 125 133 Contract.Assert(r > 0); 126 134 Contract.Assert(r <= 1.0); … … 135 143 136 144 var model = state.GetModel(); 137 return new RegressionSolution(model, (IRegressionProblemData)problemData.Clone());145 return new GradientBoostedTreesSolution(model, (IRegressionProblemData)problemData.Clone()); 138 146 } 139 147 -
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModel.cs
r12869 r14029 33 33 [Item("Gradient boosted tree model", "")] 34 34 // this is essentially a collection of weighted regression models 35 public sealed class GradientBoostedTreesModel : NamedItem, IRegressionModel {35 public sealed class GradientBoostedTreesModel : RegressionModel, IGradientBoostedTreesModel { 36 36 // BackwardsCompatibility3.4 for allowing deserialization & serialization of old models 37 37 #region Backwards compatible code, remove with 3.5 … … 58 58 #endregion 59 59 60 public override IEnumerable<string> VariablesUsedForPrediction { 61 get { return models.SelectMany(x => x.VariablesUsedForPrediction).Distinct().OrderBy(x => x); } 62 } 63 60 64 private readonly IList<IRegressionModel> models; 61 65 public IEnumerable<IRegressionModel> Models { get { return models; } } … … 76 80 this.isCompatibilityLoaded = original.isCompatibilityLoaded; 77 81 } 78 public GradientBoostedTreesModel(IEnumerable<IRegressionModel> models, IEnumerable<double> weights) 79 : base("Gradient boosted tree model", string.Empty) { 82 [Obsolete("The constructor of GBTModel should not be used directly anymore (use GBTModelSurrogate instead)")] 83 internal GradientBoostedTreesModel(IEnumerable<IRegressionModel> models, IEnumerable<double> weights) 84 : base(string.Empty, "Gradient boosted tree model", string.Empty) { 80 85 this.models = new List<IRegressionModel>(models); 81 86 this.weights = new List<double>(weights); … … 88 93 } 89 94 90 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {95 public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 91 96 // allocate target array go over all models and add up weighted estimation for each row 92 97 if (!rows.Any()) return Enumerable.Empty<double>(); // return immediately if rows is empty. This prevents multiple iteration over lazy rows enumerable. … … 104 109 } 105 110 106 public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {111 public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { 107 112 return new RegressionSolution(this, (IRegressionProblemData)problemData.Clone()); 108 113 } 114 109 115 } 110 116 } -
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesModelSurrogate.cs
r12874 r14029 21 21 #endregion 22 22 23 using System;24 23 using System.Collections.Generic; 25 24 using System.Linq; … … 27 26 using HeuristicLab.Core; 28 27 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 29 using HeuristicLab.PluginInfrastructure;30 28 using HeuristicLab.Problems.DataAnalysis; 31 29 … … 36 34 // recalculate the actual GBT model on demand 37 35 [Item("Gradient boosted tree model", "")] 38 public sealed class GradientBoostedTreesModelSurrogate : NamedItem, IRegressionModel {36 public sealed class GradientBoostedTreesModelSurrogate : RegressionModel, IGradientBoostedTreesModel { 39 37 // don't store the actual model! 40 private I RegressionModel actualModel; // the actual model is only recalculated when necessary38 private IGradientBoostedTreesModel actualModel; // the actual model is only recalculated when necessary 41 39 42 40 [Storable] … … 58 56 59 57 58 public override IEnumerable<string> VariablesUsedForPrediction { 59 get { return actualModel.Models.SelectMany(x => x.VariablesUsedForPrediction).Distinct().OrderBy(x => x); } 60 } 61 60 62 [StorableConstructor] 61 63 private GradientBoostedTreesModelSurrogate(bool deserializing) : base(deserializing) { } … … 76 78 77 79 // create only the surrogate model without an actual model 78 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu) 79 : base("Gradient boosted tree model", string.Empty) { 80 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, 81 ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu) 82 : base(trainingProblemData.TargetVariable, "Gradient boosted tree model", string.Empty) { 80 83 this.trainingProblemData = trainingProblemData; 81 84 this.seed = seed; … … 89 92 90 93 // wrap an actual model in a surrograte 91 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu, IRegressionModel model) 94 public GradientBoostedTreesModelSurrogate(IRegressionProblemData trainingProblemData, uint seed, 95 ILossFunction lossFunction, int iterations, int maxSize, double r, double m, double nu, 96 IGradientBoostedTreesModel model) 92 97 : this(trainingProblemData, seed, lossFunction, iterations, maxSize, r, m, nu) { 93 98 this.actualModel = model; … … 99 104 100 105 // forward message to actual model (recalculate model first if necessary) 101 public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {106 public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) { 102 107 if (actualModel == null) actualModel = RecalculateModel(); 103 108 return actualModel.GetEstimatedValues(dataset, rows); 104 109 } 105 110 106 public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {111 public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { 107 112 return new RegressionSolution(this, (IRegressionProblemData)problemData.Clone()); 108 113 } 109 114 115 private IGradientBoostedTreesModel RecalculateModel() { 116 return GradientBoostedTreesAlgorithmStatic.TrainGbm(trainingProblemData, lossFunction, maxSize, nu, r, m, iterations, seed).Model; 117 } 110 118 111 private IRegressionModel RecalculateModel() { 112 return GradientBoostedTreesAlgorithmStatic.TrainGbm(trainingProblemData, lossFunction, maxSize, nu, r, m, iterations, seed).Model; 119 public IEnumerable<IRegressionModel> Models { 120 get { 121 if (actualModel == null) actualModel = RecalculateModel(); 122 return actualModel.Models; 123 } 124 } 125 126 public IEnumerable<double> Weights { 127 get { 128 if (actualModel == null) actualModel = RecalculateModel(); 129 return actualModel.Weights; 130 } 113 131 } 114 132 } -
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs
r12700 r14029 119 119 } 120 120 121 // simple API produces a single regression tree optimizing sum of squared errors122 // this can be used if only a simple regression tree should be produced123 // for a set of trees use the method CreateRegressionTreeForGradientBoosting below124 //125 // r and m work in the same way as for alglib random forest126 // r is fraction of rows to use for training127 // m is fraction of variables to use for training128 public IRegressionModel CreateRegressionTree(int maxSize, double r = 0.5, double m = 0.5) {129 // subtract mean of y first130 var yAvg = y.Average();131 for (int i = 0; i < y.Length; i++) y[i] -= yAvg;132 133 var seLoss = new SquaredErrorLoss();134 135 var model = CreateRegressionTreeForGradientBoosting(y, curPred, maxSize, problemData.TrainingIndices.ToArray(), seLoss, r, m);136 137 return new GradientBoostedTreesModel(new[] { new ConstantRegressionModel(yAvg), model }, new[] { 1.0, 1.0 });138 }139 140 121 // specific interface that allows to specify the target labels and the training rows which is necessary when for gradient boosted trees 141 122 public IRegressionModel CreateRegressionTreeForGradientBoosting(double[] y, double[] curPred, int maxSize, int[] idx, ILossFunction lossFunction, double r = 0.5, double m = 0.5) { … … 156 137 int nRows = idx.Count(); 157 138 158 // shuffle variable idx139 // shuffle variable names 159 140 HeuristicLab.Random.ListExtensions.ShuffleInPlace(allowedVariables, random); 160 141 … … 195 176 CreateRegressionTreeFromQueue(maxSize, lossFunction); 196 177 197 return new RegressionTreeModel(tree.ToArray() );198 } 199 200 201 // processes potential splits from the queue as long as splits are leftand the maximum size of the tree is not reached178 return new RegressionTreeModel(tree.ToArray(), problemData.TargetVariable); 179 } 180 181 182 // processes potential splits from the queue as long as splits are remaining and the maximum size of the tree is not reached 202 183 private void CreateRegressionTreeFromQueue(int maxNodes, ILossFunction lossFunction) { 203 184 while (queue.Any() && curTreeNodeIdx + 1 < maxNodes) { // two nodes are created in each loop … … 223 204 224 205 // overwrite existing leaf node with an internal node 225 tree[f.ParentNodeIdx] = new RegressionTreeModel.TreeNode(f.SplittingVariable, f.SplittingThreshold, leftTreeIdx, rightTreeIdx );206 tree[f.ParentNodeIdx] = new RegressionTreeModel.TreeNode(f.SplittingVariable, f.SplittingThreshold, leftTreeIdx, rightTreeIdx, weightLeft: (splitIdx - startIdx + 1) / (double)(endIdx - startIdx + 1)); 226 207 } 227 208 } -
branches/crossvalidation-2434/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeModel.cs
r12869 r14029 23 23 using System; 24 24 using System.Collections.Generic; 25 using System.Collections.ObjectModel; 25 26 using System.Globalization; 26 27 using System.Linq; … … 33 34 [StorableClass] 34 35 [Item("RegressionTreeModel", "Represents a decision tree for regression.")] 35 public sealed class RegressionTreeModel : NamedItem, IRegressionModel { 36 public sealed class RegressionTreeModel : RegressionModel { 37 public override IEnumerable<string> VariablesUsedForPrediction { 38 get { return tree.Select(t => t.VarName).Where(v => v != TreeNode.NO_VARIABLE); } 39 } 36 40 37 41 // trees are represented as a flat array … … 39 43 public readonly static string NO_VARIABLE = null; 40 44 41 public TreeNode(string varName, double val, int leftIdx = -1, int rightIdx = -1 )45 public TreeNode(string varName, double val, int leftIdx = -1, int rightIdx = -1, double weightLeft = -1.0) 42 46 : this() { 43 47 VarName = varName; … … 45 49 LeftIdx = leftIdx; 46 50 RightIdx = rightIdx; 47 } 48 49 public string VarName { get; private set; } // name of the variable for splitting or NO_VARIABLE if terminal node 50 public double Val { get; private set; } // threshold 51 public int LeftIdx { get; private set; } 52 public int RightIdx { get; private set; } 53 54 internal IList<double> Data { get; set; } // only necessary to improve efficiency of evaluation 51 WeightLeft = weightLeft; 52 } 53 54 public string VarName { get; internal set; } // name of the variable for splitting or NO_VARIABLE if terminal node 55 public double Val { get; internal set; } // threshold 56 public int LeftIdx { get; internal set; } 57 public int RightIdx { get; internal set; } 58 public double WeightLeft { get; internal set; } // for partial dependence plots (value in range [0..1] describes the fraction of training samples for the left sub-tree 59 55 60 56 61 // necessary because the default implementation of GetHashCode for structs in .NET would only return the hashcode of val here … … 65 70 LeftIdx.Equals(other.LeftIdx) && 66 71 RightIdx.Equals(other.RightIdx) && 72 WeightLeft.Equals(other.WeightLeft) && 67 73 EqualStrings(VarName, other.VarName); 68 74 } else { … … 80 86 private TreeNode[] tree; 81 87 82 [Storable] 88 #region old storable format 89 // remove with HL 3.4 90 [Storable(AllowOneWay = true)] 83 91 // to prevent storing the references to data caches in nodes 84 // seemingly it is bad (performance-wise) to persist tuples (tuples are used as keys in a dictionary) TODO92 // seemingly, it is bad (performance-wise) to persist tuples (tuples are used as keys in a dictionary) 85 93 private Tuple<string, double, int, int>[] SerializedTree { 86 get { return tree.Select(t => Tuple.Create(t.VarName, t.Val, t.LeftIdx, t.RightIdx)).ToArray(); } 87 set { this.tree = value.Select(t => new TreeNode(t.Item1, t.Item2, t.Item3, t.Item4)).ToArray(); } 88 } 94 // get { return tree.Select(t => Tuple.Create(t.VarName, t.Val, t.LeftIdx, t.RightIdx)).ToArray(); } 95 set { this.tree = value.Select(t => new TreeNode(t.Item1, t.Item2, t.Item3, t.Item4, -1.0)).ToArray(); } // use a weight of -1.0 to indicate that partial dependence cannot be calculated for old models 96 } 97 #endregion 98 #region new storable format 99 [Storable] 100 private string[] SerializedTreeVarNames { 101 get { return tree.Select(t => t.VarName).ToArray(); } 102 set { 103 if (tree == null) tree = new TreeNode[value.Length]; 104 for (int i = 0; i < value.Length; i++) { 105 tree[i].VarName = value[i]; 106 } 107 } 108 } 109 [Storable] 110 private double[] SerializedTreeValues { 111 get { return tree.Select(t => t.Val).ToArray(); } 112 set { 113 if (tree == null) tree = new TreeNode[value.Length]; 114 for (int i = 0; i < value.Length; i++) { 115 tree[i].Val = value[i]; 116 } 117 } 118 } 119 [Storable] 120 private int[] SerializedTreeLeftIdx { 121 get { return tree.Select(t => t.LeftIdx).ToArray(); } 122 set { 123 if (tree == null) tree = new TreeNode[value.Length]; 124 for (int i = 0; i < value.Length; i++) { 125 tree[i].LeftIdx = value[i]; 126 } 127 } 128 } 129 [Storable] 130 private int[] SerializedTreeRightIdx { 131 get { return tree.Select(t => t.RightIdx).ToArray(); } 132 set { 133 if (tree == null) tree = new TreeNode[value.Length]; 134 for (int i = 0; i < value.Length; i++) { 135 tree[i].RightIdx = value[i]; 136 } 137 } 138 } 139 [Storable] 140 private double[] SerializedTreeWeightLeft { 141 get { return tree.Select(t => t.WeightLeft).ToArray(); } 142 set { 143 if (tree == null) tree = new TreeNode[value.Length]; 144 for (int i = 0; i < value.Length; i++) { 145 tree[i].WeightLeft = value[i]; 146 } 147 } 148 } 149 #endregion 89 150 90 151 [StorableConstructor] … … 99 160 } 100 161 101 internal RegressionTreeModel(TreeNode[] tree )102 : base( "RegressionTreeModel", "Represents a decision tree for regression.") {162 internal RegressionTreeModel(TreeNode[] tree, string targetVariable) 163 : base(targetVariable, "RegressionTreeModel", "Represents a decision tree for regression.") { 103 164 this.tree = tree; 104 165 } 105 166 106 private static double GetPredictionForRow(TreeNode[] t, int nodeIdx, int row) {167 private static double GetPredictionForRow(TreeNode[] t, ReadOnlyCollection<double>[] columnCache, int nodeIdx, int row) { 107 168 while (nodeIdx != -1) { 108 169 var node = t[nodeIdx]; 109 170 if (node.VarName == TreeNode.NO_VARIABLE) 110 171 return node.Val; 111 112 if (node.Data[row] <= node.Val) 172 if (columnCache[nodeIdx] == null || double.IsNaN(columnCache[nodeIdx][row])) { 173 if (node.WeightLeft.IsAlmost(-1.0)) throw new InvalidOperationException("Cannot calculate partial dependence for trees loaded from older versions of HeuristicLab."); 174 // weighted average for partial dependence plot (recursive here because we need to calculate both sub-trees) 175 return node.WeightLeft * GetPredictionForRow(t, columnCache, node.LeftIdx, row) + 176 (1.0 - node.WeightLeft) * GetPredictionForRow(t, columnCache, node.RightIdx, row); 177 } else if (columnCache[nodeIdx][row] <= node.Val) 113 178 nodeIdx = node.LeftIdx; 114 179 else … … 122 187 } 123 188 124 public IEnumerable<double> GetEstimatedValues(IDataset ds, IEnumerable<int> rows) {189 public override IEnumerable<double> GetEstimatedValues(IDataset ds, IEnumerable<int> rows) { 125 190 // lookup columns for variableNames in one pass over the tree to speed up evaluation later on 191 ReadOnlyCollection<double>[] columnCache = new ReadOnlyCollection<double>[tree.Length]; 192 126 193 for (int i = 0; i < tree.Length; i++) { 127 194 if (tree[i].VarName != TreeNode.NO_VARIABLE) { 128 tree[i].Data = ds.GetReadOnlyDoubleValues(tree[i].VarName); 129 } 130 } 131 return rows.Select(r => GetPredictionForRow(tree, 0, r)); 132 } 133 134 public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { 195 // tree models also support calculating estimations if not all variables used for training are available in the dataset 196 if (ds.ColumnNames.Contains(tree[i].VarName)) 197 columnCache[i] = ds.GetReadOnlyDoubleValues(tree[i].VarName); 198 } 199 } 200 return rows.Select(r => GetPredictionForRow(tree, columnCache, 0, r)); 201 } 202 203 public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) { 135 204 return new RegressionSolution(this, new RegressionProblemData(problemData)); 136 205 } … … 147 216 } else { 148 217 return 149 TreeToString(n.LeftIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F}", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val)) 150 + TreeToString(n.RightIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} > {3:F}", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val)); 151 } 152 } 218 TreeToString(n.LeftIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} <= {3:F} ({4:N3})", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val, n.WeightLeft)) 219 + TreeToString(n.RightIdx, string.Format(CultureInfo.InvariantCulture, "{0}{1}{2} > {3:F} ({4:N3}))", part, string.IsNullOrEmpty(part) ? "" : " and ", n.VarName, n.Val, 1.0 - n.WeightLeft)); 220 } 221 } 222 153 223 } 154 224 }
Note: See TracChangeset
for help on using the changeset viewer.