Changeset 12697 for branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs
- Timestamp:
- 07/09/15 16:39:37 (9 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs
r12696 r12697 42 42 private readonly int nCols; 43 43 private readonly double[][] x; // all training data (original order from problemData), x is constant 44 private double[] originalY; // the original target labels (from problemData), originalY is constant 45 private double[] curPred; // current predictions for originalY (in case we are using gradient boosting, otherwise = zeros), only necessary for line search 46 44 47 private double[] y; // training labels (original order from problemData), y can be changed 45 48 … … 102 105 103 106 x = new double[nCols][]; 104 y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToArray(); 105 107 originalY = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToArray(); 108 y = new double[originalY.Length]; 109 Array.Copy(originalY, y, y.Length); // copy values (originalY is fixed, y is changed in gradient boosting) 110 curPred = Enumerable.Repeat(0.0, y.Length).ToArray(); // zeros 106 111 107 112 int col = 0; … … 127 132 128 133 var seLoss = new SquaredErrorLoss(); 129 var zeros = Enumerable.Repeat(0.0, y.Length); 130 131 var model = CreateRegressionTreeForGradientBoosting(y, maxSize, problemData.TrainingIndices.ToArray(), seLoss.GetLineSearchFunc(y, zeros), r, m); 134 135 var model = CreateRegressionTreeForGradientBoosting(y, curPred, maxSize, problemData.TrainingIndices.ToArray(), seLoss, r, m); 132 136 133 137 return new GradientBoostedTreesModel(new[] { new ConstantRegressionModel(yAvg), model }, new[] { 1.0, 1.0 }); … … 135 139 136 140 // specific interface that allows to specify the target labels and the training rows which is necessary when for gradient boosted trees 137 public IRegressionModel CreateRegressionTreeForGradientBoosting(double[] y, int maxSize, int[] idx, LineSearchFunc lineSearch, double r = 0.5, double m = 0.5) {141 public IRegressionModel CreateRegressionTreeForGradientBoosting(double[] y, double[] curPred, int maxSize, int[] idx, ILossFunction lossFunction, double r = 0.5, double m = 0.5) { 138 142 Debug.Assert(maxSize > 0); 139 143 Debug.Assert(r > 0); … … 143 147 Debug.Assert(m <= 1.0); 144 148 145 this.y = y; // y is changed in gradient boosting 149 // y and curPred are changed in gradient boosting 150 this.y = y; 151 this.curPred = curPred; 146 152 147 153 // shuffle row idx … … 157 163 effectiveVars = (int)Math.Ceiling(nCols * m); 158 164 159 // the which array is used for parti ning row idxs165 // the which array is used for partitioing row idxs 160 166 Array.Clear(which, 0, which.Length); 161 167 … … 184 190 // and calculate the best split for this root node and enqueue it into a queue sorted by improvement throught the split 185 191 // start and end idx are inclusive 186 CreateLeafNode(0, effectiveRows - 1, l ineSearch);192 CreateLeafNode(0, effectiveRows - 1, lossFunction); 187 193 188 194 // process the priority queue to complete the tree 189 CreateRegressionTreeFromQueue(maxSize, l ineSearch);195 CreateRegressionTreeFromQueue(maxSize, lossFunction); 190 196 191 197 return new RegressionTreeModel(tree.ToArray()); … … 194 200 195 201 // processes potential splits from the queue as long as splits are left and the maximum size of the tree is not reached 196 private void CreateRegressionTreeFromQueue(int maxNodes, LineSearchFunc lineSearch) {202 private void CreateRegressionTreeFromQueue(int maxNodes, ILossFunction lossFunction) { 197 203 while (queue.Any() && curTreeNodeIdx + 1 < maxNodes) { // two nodes are created in each loop 198 204 var f = queue[queue.Count - 1]; // last element has the largest improvement … … 213 219 214 220 // create two leaf nodes (and enqueue best splits for both) 215 var leftTreeIdx = CreateLeafNode(startIdx, splitIdx, l ineSearch);216 var rightTreeIdx = CreateLeafNode(splitIdx + 1, endIdx, l ineSearch);221 var leftTreeIdx = CreateLeafNode(startIdx, splitIdx, lossFunction); 222 var rightTreeIdx = CreateLeafNode(splitIdx + 1, endIdx, lossFunction); 217 223 218 224 // overwrite existing leaf node with an internal node … … 223 229 224 230 // returns the index of the newly created tree node 225 private int CreateLeafNode(int startIdx, int endIdx, LineSearchFunc lineSearch) {231 private int CreateLeafNode(int startIdx, int endIdx, ILossFunction lossFunction) { 226 232 // write a leaf node 227 var val = l ineSearch(internalIdx, startIdx, endIdx);233 var val = lossFunction.LineSearch(originalY, curPred, internalIdx, startIdx, endIdx); 228 234 tree[curTreeNodeIdx] = new RegressionTreeModel.TreeNode(RegressionTreeModel.TreeNode.NO_VARIABLE, val); 229 235
Note: See TracChangeset
for help on using the changeset viewer.