Free cookie consent management tool by TermsFeed Policy Generator

Changeset 12697


Ignore:
Timestamp:
07/09/15 16:39:37 (9 years ago)
Author:
gkronber
Message:

#2261 removed line search closure (binding y.ToArray(), and pred.ToArray())

Location:
branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees
Files:
7 edited

Legend:

Unmodified
Added
Removed
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithmStatic.cs

    r12696 r12697  
    8181        activeIdx = Enumerable.Range(0, nRows).ToArray();
    8282
    83         var zeros = Enumerable.Repeat(0.0, nRows);
    84         double f0 = lossFunction.GetLineSearchFunc(y, zeros)(activeIdx, 0, nRows - 1); // initial constant value (mean for squared errors)
     83        var zeros = Enumerable.Repeat(0.0, nRows).ToArray();
     84        double f0 = lossFunction.LineSearch(y, zeros, activeIdx, 0, nRows - 1); // initial constant value (mean for squared errors)
    8585        pred = Enumerable.Repeat(f0, nRows).ToArray();
    8686        predTest = Enumerable.Repeat(f0, problemData.TestIndices.Count()).ToArray();
     
    166166      }
    167167
    168       var tree = treeBuilder.CreateRegressionTreeForGradientBoosting(pseudoRes, maxSize, activeIdx, lossFunction.GetLineSearchFunc(y, yPred), r, m);
     168      var tree = treeBuilder.CreateRegressionTreeForGradientBoosting(pseudoRes, yPred, maxSize, activeIdx, lossFunction, r, m);
    169169
    170170      int i = 0;
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/AbsoluteErrorLoss.cs

    r12696 r12697  
    6161
    6262    // return median of residuals
    63     public LineSearchFunc GetLineSearchFunc(IEnumerable<double> target, IEnumerable<double> pred) {
    64       var targetArr = target.ToArray();
    65       var predArr = pred.ToArray();
    66 
     63    // targetArr and predArr are not changed by LineSearch
     64    public double LineSearch(double[] targetArr, double[] predArr, int[] idx, int startIdx, int endIdx) {
    6765      if (targetArr.Length != predArr.Length)
    6866        throw new ArgumentException("target and pred have differing lengths");
    6967
    70       // line search for abs error
    71       LineSearchFunc lineSearch = (idx, startIdx, endIdx) => {
    72         // Median() is allocating an array anyway
    73         // It would be possible to pre-allocated an array for the residuals if Median() would allow specification of a sub-range
    74         int nRows = endIdx - startIdx + 1;
    75         var res = new double[nRows];
    76         for (int i = startIdx; i <= endIdx; i++) {
    77           var row = idx[i];
    78           res[i - startIdx] = targetArr[row] - predArr[row];
    79         }
    80         return res.Median(); // TODO: improve efficiency
    81       };
    82       return lineSearch;
    83 
     68      // Median() is allocating an array anyway
     69      // It would be possible to pre-allocated an array for the residuals if Median() would allow specification of a sub-range
     70      int nRows = endIdx - startIdx + 1;
     71      var res = new double[nRows];
     72      for (int i = startIdx; i <= endIdx; i++) {
     73        var row = idx[i];
     74        res[i - startIdx] = targetArr[row] - predArr[row];
     75      }
     76      return res.Median(); // TODO: improve efficiency
    8477    }
    8578
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/ILossFunction.cs

    r12696 r12697  
    2424
    2525namespace HeuristicLab.Algorithms.DataAnalysis {
    26   // returns the optimal value for the partition of rows stored in idx[startIdx] .. idx[endIdx] inclusive
    27   public delegate double LineSearchFunc(int[] idx, int startIdx, int endIdx);
    28 
    2926  // represents an interface for loss functions used by gradient boosting
    3027  // target represents the target vector  (original targets from the problem data, never changed)
     
    3734    IEnumerable<double> GetLossGradient(IEnumerable<double> target, IEnumerable<double> pred);
    3835
    39     // returns a function that returns the optimal prediction value for a subset of rows from target and pred (see LineSearchFunc delegate above)
    40     LineSearchFunc GetLineSearchFunc(IEnumerable<double> target, IEnumerable<double> pred);
     36    // returns the optimal value for the partition of rows stored in idx[startIdx] .. idx[endIdx] inclusive
     37    double LineSearch(double[] targetArr, double[] predArr, int[] idx, int startIdx, int endIdx);
    4138  }
    4239}
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/LogisticRegressionLoss.cs

    r12696 r12697  
    6262    }
    6363
    64     public LineSearchFunc GetLineSearchFunc(IEnumerable<double> target, IEnumerable<double> pred) {
    65       var targetArr = target.ToArray();
    66       var predArr = pred.ToArray();
    67 
     64    // targetArr and predArr are not changed by LineSearch
     65    public double LineSearch(double[] targetArr, double[] predArr, int[] idx, int startIdx, int endIdx) {
    6866      if (targetArr.Length != predArr.Length)
    6967        throw new ArgumentException("target and pred have different lengths");
    7068
    7169      // "Simple Newton-Raphson step" of eqn. 23
    72       LineSearchFunc lineSearch = (idx, startIdx, endIdx) => {
    73         double sumY = 0.0;
    74         double sumDiff = 0.0;
    75         for (int i = startIdx; i <= endIdx; i++) {
    76           var row = idx[i];
    77           var y = targetArr[row] * 2 - 1; // y in {-1,1}
    78           var pseudoResponse = 2 * y / (1 + Math.Exp(2 * y * predArr[row]));
     70      double sumY = 0.0;
     71      double sumDiff = 0.0;
     72      for (int i = startIdx; i <= endIdx; i++) {
     73        var row = idx[i];
     74        var y = targetArr[row] * 2 - 1; // y in {-1,1}
     75        var pseudoResponse = 2 * y / (1 + Math.Exp(2 * y * predArr[row]));
    7976
    80           sumY += pseudoResponse;
    81           sumDiff += Math.Abs(pseudoResponse) * (2 - Math.Abs(pseudoResponse));
    82         }
    83         // prevent divByZero
    84         sumDiff = Math.Max(1E-12, sumDiff);
    85         return sumY / sumDiff;
    86       };
    87       return lineSearch;
    88 
     77        sumY += pseudoResponse;
     78        sumDiff += Math.Abs(pseudoResponse) * (2 - Math.Abs(pseudoResponse));
     79      }
     80      // prevent divByZero
     81      sumDiff = Math.Max(1E-12, sumDiff);
     82      return sumY / sumDiff;
    8983    }
    9084
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/RelativeErrorLoss.cs

    r12696 r12697  
    6060    }
    6161
    62     public LineSearchFunc GetLineSearchFunc(IEnumerable<double> target, IEnumerable<double> pred) {
    63       var targetArr = target.ToArray();
    64       var predArr = pred.ToArray();
    65 
     62    // targetArr and predArr are not changed by LineSearch
     63    public double LineSearch(double[] targetArr, double[] predArr, int[] idx, int startIdx, int endIdx) {
    6664      if (targetArr.Length != predArr.Length)
    6765        throw new ArgumentException("target and pred have different lengths");
     
    6967      // line search for relative error
    7068      // weighted median (weight = 1/target)
    71       LineSearchFunc lineSearch = (idx, startIdx, endIdx) => {
    72         // weighted median calculation
    73         int nRows = endIdx - startIdx + 1; // startIdx and endIdx are inclusive
    74         if (nRows == 1) return targetArr[idx[startIdx]] - predArr[idx[startIdx]]; // res
    75         else if (nRows == 2) {
    76           // weighted average of two residuals
    77           var w0 = Math.Abs(1.0 / targetArr[idx[startIdx]]);
    78           var w1 = Math.Abs(1.0 / targetArr[idx[endIdx]]);
    79           if (w0 > w1) {
    80             return targetArr[idx[startIdx]] - predArr[idx[startIdx]];
    81           } else if (w0 < w1) {
    82             return targetArr[idx[endIdx]] - predArr[idx[endIdx]];
    83           } else {
    84             // same weight -> return average of both residuals
    85             return ((targetArr[idx[startIdx]] - predArr[idx[startIdx]]) + (targetArr[idx[endIdx]] - predArr[idx[endIdx]])) / 2;
    86           }
     69      int nRows = endIdx - startIdx + 1; // startIdx and endIdx are inclusive
     70      if (nRows == 1) return targetArr[idx[startIdx]] - predArr[idx[startIdx]]; // res
     71      else if (nRows == 2) {
     72        // weighted average of two residuals
     73        var w0 = Math.Abs(1.0 / targetArr[idx[startIdx]]);
     74        var w1 = Math.Abs(1.0 / targetArr[idx[endIdx]]);
     75        if (w0 > w1) {
     76          return targetArr[idx[startIdx]] - predArr[idx[startIdx]];
     77        } else if (w0 < w1) {
     78          return targetArr[idx[endIdx]] - predArr[idx[endIdx]];
    8779        } else {
    88           // create an array of key-value pairs to be sorted (instead of using Array.Sort(res, weights))
    89           var res_w = new KeyValuePair<double, double>[nRows];
    90           var totalWeight = 0.0;
    91           for (int i = startIdx; i <= endIdx; i++) {
    92             int row = idx[i];
    93             var res = targetArr[row] - predArr[row];
    94             var w = Math.Abs(1.0 / targetArr[row]);
    95             res_w[i - startIdx] = new KeyValuePair<double, double>(res, w);
    96             totalWeight += w;
    97           }
    98           // TODO: improve efficiency (find median without sort)
    99           res_w.StableSort((a, b) => Math.Sign(a.Key - b.Key));
     80          // same weight -> return average of both residuals
     81          return ((targetArr[idx[startIdx]] - predArr[idx[startIdx]]) + (targetArr[idx[endIdx]] - predArr[idx[endIdx]])) / 2;
     82        }
     83      } else {
     84        // create an array of key-value pairs to be sorted (instead of using Array.Sort(res, weights))
     85        var res_w = new KeyValuePair<double, double>[nRows];
     86        var totalWeight = 0.0;
     87        for (int i = startIdx; i <= endIdx; i++) {
     88          int row = idx[i];
     89          var res = targetArr[row] - predArr[row];
     90          var w = Math.Abs(1.0 / targetArr[row]);
     91          res_w[i - startIdx] = new KeyValuePair<double, double>(res, w);
     92          totalWeight += w;
     93        }
     94        // TODO: improve efficiency (find median without sort)
     95        res_w.StableSort((a, b) => Math.Sign(a.Key - b.Key));
    10096
    101           int k = 0;
    102           double sum = totalWeight - res_w[k].Value; // total - first weight
    103           while (sum > totalWeight / 2) {
    104             k++;
    105             sum -= res_w[k].Value;
    106           }
    107           return res_w[k].Key;
     97        int k = 0;
     98        double sum = totalWeight - res_w[k].Value; // total - first weight
     99        while (sum > totalWeight / 2) {
     100          k++;
     101          sum -= res_w[k].Value;
    108102        }
    109       };
    110       return lineSearch;
     103        return res_w[k].Key;
     104      }
    111105    }
    112106
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/SquaredErrorLoss.cs

    r12696 r12697  
    5353    }
    5454
    55     public LineSearchFunc GetLineSearchFunc(IEnumerable<double> target, IEnumerable<double> pred) {
    56       var targetArr = target.ToArray();
    57       var predArr = pred.ToArray();
     55    // targetArr and predArr are not changed by LineSearch
     56    public double LineSearch(double[] targetArr, double[] predArr, int[] idx, int startIdx, int endIdx) {
    5857      if (targetArr.Length != predArr.Length)
    5958        throw new ArgumentException("target and pred have different lengths");
     
    6160      // line search for squared error loss
    6261      // for a given partition of rows the optimal constant that should be added to the current prediction values is the average of the residuals
    63       LineSearchFunc lineSearch = (idx, startIdx, endIdx) => {
    64         double s = 0.0;
    65         int n = 0;
    66         for (int i = startIdx; i <= endIdx; i++) {
    67           int row = idx[i];
    68           s += (targetArr[row] - predArr[row]);
    69           n++;
    70         }
    71         return s / n;
    72       };
    73       return lineSearch;
    74 
     62      double s = 0.0;
     63      int n = 0;
     64      for (int i = startIdx; i <= endIdx; i++) {
     65        int row = idx[i];
     66        s += (targetArr[row] - predArr[row]);
     67        n++;
     68      }
     69      return s / n;
    7570    }
    7671
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs

    r12696 r12697  
    4242    private readonly int nCols;
    4343    private readonly double[][] x; // all training data (original order from problemData), x is constant
     44    private double[] originalY; // the original target labels (from problemData), originalY is constant
     45    private double[] curPred; // current predictions for originalY (in case we are using gradient boosting, otherwise = zeros), only necessary for line search
     46
    4447    private double[] y; // training labels (original order from problemData), y can be changed
    4548
     
    102105
    103106      x = new double[nCols][];
    104       y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToArray();
    105 
     107      originalY = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToArray();
     108      y = new double[originalY.Length];
     109      Array.Copy(originalY, y, y.Length); // copy values (originalY is fixed, y is changed in gradient boosting)
     110      curPred = Enumerable.Repeat(0.0, y.Length).ToArray(); // zeros
    106111
    107112      int col = 0;
     
    127132
    128133      var seLoss = new SquaredErrorLoss();
    129       var zeros = Enumerable.Repeat(0.0, y.Length);
    130 
    131       var model = CreateRegressionTreeForGradientBoosting(y, maxSize, problemData.TrainingIndices.ToArray(), seLoss.GetLineSearchFunc(y, zeros), r, m);
     134
     135      var model = CreateRegressionTreeForGradientBoosting(y, curPred, maxSize, problemData.TrainingIndices.ToArray(), seLoss, r, m);
    132136
    133137      return new GradientBoostedTreesModel(new[] { new ConstantRegressionModel(yAvg), model }, new[] { 1.0, 1.0 });
     
    135139
    136140    // specific interface that allows to specify the target labels and the training rows which is necessary when for gradient boosted trees
    137     public IRegressionModel CreateRegressionTreeForGradientBoosting(double[] y, int maxSize, int[] idx, LineSearchFunc lineSearch, double r = 0.5, double m = 0.5) {
     141    public IRegressionModel CreateRegressionTreeForGradientBoosting(double[] y, double[] curPred, int maxSize, int[] idx, ILossFunction lossFunction, double r = 0.5, double m = 0.5) {
    138142      Debug.Assert(maxSize > 0);
    139143      Debug.Assert(r > 0);
     
    143147      Debug.Assert(m <= 1.0);
    144148
    145       this.y = y; // y is changed in gradient boosting
     149      // y and curPred are changed in gradient boosting
     150      this.y = y;
     151      this.curPred = curPred;
    146152
    147153      // shuffle row idx
     
    157163      effectiveVars = (int)Math.Ceiling(nCols * m);
    158164
    159       // the which array is used for partining row idxs
     165      // the which array is used for partitioing row idxs
    160166      Array.Clear(which, 0, which.Length);
    161167
     
    184190      // and calculate the best split for this root node and enqueue it into a queue sorted by improvement throught the split
    185191      // start and end idx are inclusive
    186       CreateLeafNode(0, effectiveRows - 1, lineSearch);
     192      CreateLeafNode(0, effectiveRows - 1, lossFunction);
    187193
    188194      // process the priority queue to complete the tree
    189       CreateRegressionTreeFromQueue(maxSize, lineSearch);
     195      CreateRegressionTreeFromQueue(maxSize, lossFunction);
    190196
    191197      return new RegressionTreeModel(tree.ToArray());
     
    194200
    195201    // processes potential splits from the queue as long as splits are left and the maximum size of the tree is not reached
    196     private void CreateRegressionTreeFromQueue(int maxNodes, LineSearchFunc lineSearch) {
     202    private void CreateRegressionTreeFromQueue(int maxNodes, ILossFunction lossFunction) {
    197203      while (queue.Any() && curTreeNodeIdx + 1 < maxNodes) { // two nodes are created in each loop
    198204        var f = queue[queue.Count - 1]; // last element has the largest improvement
     
    213219
    214220        // create two leaf nodes (and enqueue best splits for both)
    215         var leftTreeIdx = CreateLeafNode(startIdx, splitIdx, lineSearch);
    216         var rightTreeIdx = CreateLeafNode(splitIdx + 1, endIdx, lineSearch);
     221        var leftTreeIdx = CreateLeafNode(startIdx, splitIdx, lossFunction);
     222        var rightTreeIdx = CreateLeafNode(splitIdx + 1, endIdx, lossFunction);
    217223
    218224        // overwrite existing leaf node with an internal node
     
    223229
    224230    // returns the index of the newly created tree node
    225     private int CreateLeafNode(int startIdx, int endIdx, LineSearchFunc lineSearch) {
     231    private int CreateLeafNode(int startIdx, int endIdx, ILossFunction lossFunction) {
    226232      // write a leaf node
    227       var val = lineSearch(internalIdx, startIdx, endIdx);
     233      var val = lossFunction.LineSearch(originalY, curPred, internalIdx, startIdx, endIdx);
    228234      tree[curTreeNodeIdx] = new RegressionTreeModel.TreeNode(RegressionTreeModel.TreeNode.NO_VARIABLE, val);
    229235
Note: See TracChangeset for help on using the changeset viewer.