Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/09/15 16:39:37 (9 years ago)
Author:
gkronber
Message:

#2261 removed line search closure (binding y.ToArray(), and pred.ToArray())

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs

    r12696 r12697  
    4242    private readonly int nCols;
    4343    private readonly double[][] x; // all training data (original order from problemData), x is constant
     44    private double[] originalY; // the original target labels (from problemData), originalY is constant
     45    private double[] curPred; // current predictions for originalY (in case we are using gradient boosting, otherwise = zeros), only necessary for line search
     46
    4447    private double[] y; // training labels (original order from problemData), y can be changed
    4548
     
    102105
    103106      x = new double[nCols][];
    104       y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToArray();
    105 
     107      originalY = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToArray();
     108      y = new double[originalY.Length];
     109      Array.Copy(originalY, y, y.Length); // copy values (originalY is fixed, y is changed in gradient boosting)
     110      curPred = Enumerable.Repeat(0.0, y.Length).ToArray(); // zeros
    106111
    107112      int col = 0;
     
    127132
    128133      var seLoss = new SquaredErrorLoss();
    129       var zeros = Enumerable.Repeat(0.0, y.Length);
    130 
    131       var model = CreateRegressionTreeForGradientBoosting(y, maxSize, problemData.TrainingIndices.ToArray(), seLoss.GetLineSearchFunc(y, zeros), r, m);
     134
     135      var model = CreateRegressionTreeForGradientBoosting(y, curPred, maxSize, problemData.TrainingIndices.ToArray(), seLoss, r, m);
    132136
    133137      return new GradientBoostedTreesModel(new[] { new ConstantRegressionModel(yAvg), model }, new[] { 1.0, 1.0 });
     
    135139
    136140    // specific interface that allows to specify the target labels and the training rows which is necessary when for gradient boosted trees
    137     public IRegressionModel CreateRegressionTreeForGradientBoosting(double[] y, int maxSize, int[] idx, LineSearchFunc lineSearch, double r = 0.5, double m = 0.5) {
     141    public IRegressionModel CreateRegressionTreeForGradientBoosting(double[] y, double[] curPred, int maxSize, int[] idx, ILossFunction lossFunction, double r = 0.5, double m = 0.5) {
    138142      Debug.Assert(maxSize > 0);
    139143      Debug.Assert(r > 0);
     
    143147      Debug.Assert(m <= 1.0);
    144148
    145       this.y = y; // y is changed in gradient boosting
     149      // y and curPred are changed in gradient boosting
     150      this.y = y;
     151      this.curPred = curPred;
    146152
    147153      // shuffle row idx
     
    157163      effectiveVars = (int)Math.Ceiling(nCols * m);
    158164
    159       // the which array is used for partining row idxs
     165      // the which array is used for partitioing row idxs
    160166      Array.Clear(which, 0, which.Length);
    161167
     
    184190      // and calculate the best split for this root node and enqueue it into a queue sorted by improvement throught the split
    185191      // start and end idx are inclusive
    186       CreateLeafNode(0, effectiveRows - 1, lineSearch);
     192      CreateLeafNode(0, effectiveRows - 1, lossFunction);
    187193
    188194      // process the priority queue to complete the tree
    189       CreateRegressionTreeFromQueue(maxSize, lineSearch);
     195      CreateRegressionTreeFromQueue(maxSize, lossFunction);
    190196
    191197      return new RegressionTreeModel(tree.ToArray());
     
    194200
    195201    // processes potential splits from the queue as long as splits are left and the maximum size of the tree is not reached
    196     private void CreateRegressionTreeFromQueue(int maxNodes, LineSearchFunc lineSearch) {
     202    private void CreateRegressionTreeFromQueue(int maxNodes, ILossFunction lossFunction) {
    197203      while (queue.Any() && curTreeNodeIdx + 1 < maxNodes) { // two nodes are created in each loop
    198204        var f = queue[queue.Count - 1]; // last element has the largest improvement
     
    213219
    214220        // create two leaf nodes (and enqueue best splits for both)
    215         var leftTreeIdx = CreateLeafNode(startIdx, splitIdx, lineSearch);
    216         var rightTreeIdx = CreateLeafNode(splitIdx + 1, endIdx, lineSearch);
     221        var leftTreeIdx = CreateLeafNode(startIdx, splitIdx, lossFunction);
     222        var rightTreeIdx = CreateLeafNode(splitIdx + 1, endIdx, lossFunction);
    217223
    218224        // overwrite existing leaf node with an internal node
     
    223229
    224230    // returns the index of the newly created tree node
    225     private int CreateLeafNode(int startIdx, int endIdx, LineSearchFunc lineSearch) {
     231    private int CreateLeafNode(int startIdx, int endIdx, ILossFunction lossFunction) {
    226232      // write a leaf node
    227       var val = lineSearch(internalIdx, startIdx, endIdx);
     233      var val = lossFunction.LineSearch(originalY, curPred, internalIdx, startIdx, endIdx);
    228234      tree[curTreeNodeIdx] = new RegressionTreeModel.TreeNode(RegressionTreeModel.TreeNode.NO_VARIABLE, val);
    229235
Note: See TracChangeset for help on using the changeset viewer.