Changeset 12607


Ignore:
Timestamp:
07/06/15 15:20:28 (4 years ago)
Author:
gkronber
Message:

#2261: also use line search function for the initial estimation f0, changed logistic regression loss function to match description in GBM paper, comments and code improvements

Location:
branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees
Files:
6 edited

Legend:

Unmodified
Added
Removed
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithmStatic.cs

    r12597 r12607  
    8484        activeIdx = Enumerable.Range(0, nRows).ToArray();
    8585
    86         // prepare arrays (allocate only once)
    87         double f0 = y.Average(); // default prediction (constant)
     86        var zeros = Enumerable.Repeat(0.0, nRows);
     87        var ones = Enumerable.Repeat(1.0, nRows);
     88        double f0 = lossFunction.GetLineSearchFunc(y, zeros, ones)(activeIdx, 0, nRows - 1); // initial constant value (mean for squared errors)
    8889        pred = Enumerable.Repeat(f0, nRows).ToArray();
    8990        predTest = Enumerable.Repeat(f0, problemData.TestIndices.Count()).ToArray();
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/AbsoluteErrorLoss.cs

    r12597 r12607  
    3838      while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {
    3939        double res = targetEnum.Current - predEnum.Current;
    40         s += weightEnum.Current * Math.Abs(res);
     40        s += weightEnum.Current * Math.Abs(res);  // w * |res|
    4141      }
    4242      if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext())
     
    5252
    5353      while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {
    54         // weight * sign(res)
     54        // dL(y, f(x)) / df(x) = weight * sign(res)
    5555        var res = targetEnum.Current - predEnum.Current;
    5656        if (res > 0) yield return weightEnum.Current;
     
    6767      var predArr = pred.ToArray();
    6868      var weightArr = weight.ToArray();
    69       // weights are not supported yet
    70       // when weights are supported we need to calculate a weighted median
     69      // the optimal constant value that should be added to the predictions is the median of the residuals
     70      // weights are not supported yet (need to calculate a weighted median)
    7171      Debug.Assert(weightArr.All(w => w.IsAlmost(1.0)));
    7272
     
    8080        int nRows = endIdx - startIdx + 1;
    8181        var res = new double[nRows];
    82         for (int offset = 0; offset < nRows; offset++) {
    83           var i = startIdx + offset;
     82        for (int i = startIdx; i <= endIdx; i++) {
    8483          var row = idx[i];
    85           res[offset] = targetArr[row] - predArr[row];
     84          res[i - startIdx] = targetArr[row] - predArr[row];
    8685        }
    8786        return res.Median();
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/ILossFunction.cs

    r12590 r12607  
    2424
    2525namespace HeuristicLab.Algorithms.DataAnalysis {
     26  // returns the optimal value for the partition of rows stored in idx[startIdx] .. idx[endIdx] inclusive
    2627  public delegate double LineSearchFunc(int[] idx, int startIdx, int endIdx);
    2728
     29  // represents an interface for loss functions used by gradient boosting
     30  // target represents the target vector  (original targets from the problem data, never changed)
     31  // pred   represents the current vector of predictions (a weighted combination of models learned so far, this vector is updated after each step)
     32  // weight represents a weight vector for rows (this is not supported yet -> all weights are 1)
    2833  public interface ILossFunction {
    29     // returns the weighted loss
     34    // returns the weighted loss of the current prediction vector
    3035    double GetLoss(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight);
    3136
     
    3742  }
    3843}
     44
     45
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/LogisticRegressionLoss.cs

    r12590 r12607  
    2828
    2929namespace HeuristicLab.Algorithms.DataAnalysis {
     30  // Greedy Function Approximation: A Gradient Boosting Machine (page 9)
    3031  public class LogisticRegressionLoss : ILossFunction {
    3132    public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
     
    3637      double s = 0;
    3738      while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {
    38         // assert target == 0 or target == 1
    39         if (!targetEnum.Current.IsAlmost(0.0) && !targetEnum.Current.IsAlmost(1.0))
    40           throw new NotSupportedException("labels must be 0 or 1 for logistic regression loss");
    41         double f = Math.Max(-7, Math.Min(7, predEnum.Current)); // threshold for exponent
    42         var probPos = Math.Exp(2 * f) / (1 + Math.Exp(2 * f));
    43         s += weightEnum.Current * (-targetEnum.Current * Math.Log(probPos) - (1 - targetEnum.Current) * Math.Log(1 - probPos));
     39        Debug.Assert(targetEnum.Current.IsAlmost(0.0) || targetEnum.Current.IsAlmost(1.0), "labels must be 0 or 1 for logistic regression loss");
     40
     41        var y = targetEnum.Current * 2 - 1; // y in {-1,1}
     42        s += weightEnum.Current * Math.Log(1 + Math.Exp(-2 * y * predEnum.Current));
    4443      }
    4544      if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext())
    46         throw new ArgumentException("target, pred and weight have differing lengths");
     45        throw new ArgumentException("target, pred and weight have different lengths");
    4746
    4847      return s;
     
    5554
    5655      while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {
    57         // assert target == 0 or target == 1
    58         if (!targetEnum.Current.IsAlmost(0.0) && !targetEnum.Current.IsAlmost(1.0))
    59           throw new NotSupportedException("labels must be 0 or 1 for logistic regression loss");
    60         double f = Math.Max(-7, Math.Min(7, predEnum.Current)); // threshold for exponent
    61         var probPos = Math.Exp(2 * f) / (1 + Math.Exp(2 * f));
    62         yield return weightEnum.Current * (targetEnum.Current - probPos) / (probPos * probPos - probPos);
     56        Debug.Assert(targetEnum.Current.IsAlmost(0.0) || targetEnum.Current.IsAlmost(1.0), "labels must be 0 or 1 for logistic regression loss");
     57        var y = targetEnum.Current * 2 - 1; // y in {-1,1}
     58
     59        yield return weightEnum.Current * 2 * y / (1 + Math.Exp(2 * y * predEnum.Current));
     60
    6361      }
    6462      if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext())
    65         throw new ArgumentException("target, pred and weight have differing lengths");
     63        throw new ArgumentException("target, pred and weight have different lengths");
    6664    }
    6765
     
    7573
    7674      if (targetArr.Length != predArr.Length || predArr.Length != weightArr.Length)
    77         throw new ArgumentException("target, pred and weight have differing lengths");
     75        throw new ArgumentException("target, pred and weight have different lengths");
    7876
    79       // line search for abs error
     77      // "Simple Newton-Raphson step" of eqn. 23
    8078      LineSearchFunc lineSearch = (idx, startIdx, endIdx) => {
    8179        double sumY = 0.0;
    8280        double sumDiff = 0.0;
    8381        for (int i = startIdx; i <= endIdx; i++) {
    84           var yi = (targetArr[idx[i]] - predArr[idx[i]]);
    85           var wi = weightArr[idx[i]];
     82          var row = idx[i];
     83          var y = targetArr[row] * 2 - 1; // y in {-1,1}
     84          var pseudoResponse = weightArr[row] * 2 * y / (1 + Math.Exp(2 * y * predArr[row]));
    8685
    87           sumY += wi * yi;
    88           sumDiff += wi * Math.Abs(yi) * (1 - Math.Abs(yi));
    89 
     86          sumY += pseudoResponse;
     87          sumDiff += Math.Abs(pseudoResponse) * (2 - Math.Abs(pseudoResponse));
    9088        }
    9189        // prevent divByZero
    9290        sumDiff = Math.Max(1E-12, sumDiff);
    93         return 0.5 * sumY / sumDiff;
     91        return sumY / sumDiff;
    9492      };
    9593      return lineSearch;
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/RelativeErrorLoss.cs

    r12597 r12607  
    2828
    2929namespace HeuristicLab.Algorithms.DataAnalysis {
    30   // relative error loss is a special case of weighted absolute error loss
    31   // absolute loss is weighted by (1/target)
     30  // relative error loss is a special case of weighted absolute error loss with weights = (1/target)
    3231  public class RelativeErrorLoss : ILossFunction {
    3332    public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
     
    4241      }
    4342      if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext())
    44         throw new ArgumentException("target, pred and weight have differing lengths");
     43        throw new ArgumentException("target, pred and weight have different lengths");
    4544
    4645      return s;
     
    6059      }
    6160      if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext())
    62         throw new ArgumentException("target, pred and weight have differing lengths");
     61        throw new ArgumentException("target, pred and weight have different lengths");
    6362    }
    6463
     
    7069
    7170      if (targetArr.Length != predArr.Length || predArr.Length != weightArr.Length)
    72         throw new ArgumentException("target, pred and weight have differing lengths");
     71        throw new ArgumentException("target, pred and weight have different lengths");
    7372
    7473      // line search for relative error
     
    8281          var w0 = weightArr[idx[startIdx]] * Math.Abs(1.0 / targetArr[idx[startIdx]]);
    8382          var w1 = weightArr[idx[endIdx]] * Math.Abs(1.0 / targetArr[idx[endIdx]]);
    84           return (w0 * (targetArr[idx[startIdx]] - predArr[idx[startIdx]]) + w1 * (targetArr[idx[endIdx]] - predArr[idx[endIdx]])) / (w0 + w1);
     83          if (w0 > w1) {
     84            return targetArr[idx[startIdx]] - predArr[idx[startIdx]];
     85          } else if (w0 < w1) {
     86            return targetArr[idx[endIdx]] - predArr[idx[endIdx]];
     87          } else {
     88            // same weight -> return average of both residuals
     89            return ((targetArr[idx[startIdx]] - predArr[idx[startIdx]]) + (targetArr[idx[endIdx]] - predArr[idx[endIdx]])) / 2;
     90          }
    8591        } else {
    86           var ts = from offset in Enumerable.Range(0, nRows)
    87                    let i = startIdx + offset
    88                    let row = idx[i]
    89                    select new { res = targetArr[row] - predArr[row], weight = weightArr[row] * Math.Abs(1.0 / targetArr[row]) };
    90           ts = ts.OrderBy(t => t.res);
    91           var totalWeight = ts.Sum(t => t.weight);
    92           var tsEnumerator = ts.GetEnumerator();
    93           tsEnumerator.MoveNext();
     92          // create an array of key-value pairs to be sorted (instead of using Array.Sort(res, weights))
     93          var res_w = new KeyValuePair<double, double>[nRows];
     94          var totalWeight = 0.0;
     95          for (int i = startIdx; i <= endIdx; i++) {
     96            int row = idx[i];
     97            var res = targetArr[row] - predArr[row];
     98            var w = weightArr[row] * Math.Abs(1.0 / targetArr[row]);
     99            res_w[i - startIdx] = new KeyValuePair<double, double>(res, w);
     100            totalWeight += w;
     101          }
     102          res_w.StableSort((a, b) => Math.Sign(a.Key - b.Key));
    94103
    95           double aggWeight = tsEnumerator.Current.weight; // weight of first
    96 
    97           while (aggWeight < totalWeight / 2) {
    98             tsEnumerator.MoveNext();
    99             aggWeight += tsEnumerator.Current.weight;
     104          int k = 0;
     105          double sum = totalWeight - res_w[k].Value; // total - first weight
     106          while (sum > totalWeight / 2) {
     107            k++;
     108            sum -= res_w[k].Value;
    100109          }
    101           return tsEnumerator.Current.res;
     110          return res_w[k].Key;
    102111        }
    103112      };
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/SquaredErrorLoss.cs

    r12590 r12607  
    3535      while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {
    3636        double res = targetEnum.Current - predEnum.Current;
    37         s += weightEnum.Current * res * res;
     37        s += weightEnum.Current * res * res; // w * (res)^2
    3838      }
    3939      if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext())
    40         throw new ArgumentException("target, pred and weight have differing lengths");
     40        throw new ArgumentException("target, pred and weight have different lengths");
    4141
    4242      return s;
     
    4949
    5050      while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {
    51         yield return weightEnum.Current * 2.0 * (targetEnum.Current - predEnum.Current);
     51        yield return weightEnum.Current * 2.0 * (targetEnum.Current - predEnum.Current); // dL(y, f(x)) / df(x)  = w * 2 * res
    5252      }
    5353      if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext())
    54         throw new ArgumentException("target, pred and weight have differing lengths");
     54        throw new ArgumentException("target, pred and weight have different lengths");
    5555    }
    5656
     
    6060      var weightArr = weight.ToArray();
    6161      if (targetArr.Length != predArr.Length || predArr.Length != weightArr.Length)
    62         throw new ArgumentException("target, pred and weight have differing lengths");
     62        throw new ArgumentException("target, pred and weight have different lengths");
    6363
    64       // line search for squared error loss => return the average value
     64      // line search for squared error loss
     65      // for a given partition of rows the optimal constant that should be added to the current prediction values is the average of the residuals
    6566      LineSearchFunc lineSearch = (idx, startIdx, endIdx) => {
    6667        double s = 0.0;
Note: See TracChangeset for help on using the changeset viewer.