Free cookie consent management tool by TermsFeed Policy Generator

Changeset 12696 for branches


Ignore:
Timestamp:
07/09/15 16:11:52 (9 years ago)
Author:
gkronber
Message:

#2261: killed all weights

Location:
branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees
Files:
7 edited

Legend:

Unmodified
Added
Removed
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/GradientBoostedTreesAlgorithmStatic.cs

    r12661 r12696  
    5555      internal double[] pred;
    5656      internal double[] predTest;
    57       internal double[] w;
    5857      internal double[] y;
    5958      internal int[] activeIdx;
     
    7776
    7877        y = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndices).ToArray();
    79         // weights are all 1 for now (HL doesn't support weights yet)
    80         w = Enumerable.Repeat(1.0, nRows).ToArray();
    8178
    8279        treeBuilder = new RegressionTreeBuilder(problemData, random);
     
    8582
    8683        var zeros = Enumerable.Repeat(0.0, nRows);
    87         var ones = Enumerable.Repeat(1.0, nRows);
    88         double f0 = lossFunction.GetLineSearchFunc(y, zeros, ones)(activeIdx, 0, nRows - 1); // initial constant value (mean for squared errors)
     84        double f0 = lossFunction.GetLineSearchFunc(y, zeros)(activeIdx, 0, nRows - 1); // initial constant value (mean for squared errors)
    8985        pred = Enumerable.Repeat(f0, nRows).ToArray();
    9086        predTest = Enumerable.Repeat(f0, problemData.TestIndices.Count()).ToArray();
     
    107103      public double GetTrainLoss() {
    108104        int nRows = y.Length;
    109         return lossFunction.GetLoss(y, pred, w) / nRows;
     105        return lossFunction.GetLoss(y, pred) / nRows;
    110106      }
    111107      public double GetTestLoss() {
    112108        var yTest = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TestIndices);
    113         var wTest = problemData.TestIndices.Select(_ => 1.0); // ones
    114109        var nRows = problemData.TestIndices.Count();
    115         return lossFunction.GetLoss(yTest, predTest, wTest) / nRows;
     110        return lossFunction.GetLoss(yTest, predTest) / nRows;
    116111      }
    117112    }
     
    160155      var yPred = gbmState.pred;
    161156      var yPredTest = gbmState.predTest;
    162       var w = gbmState.w;
    163157      var treeBuilder = gbmState.treeBuilder;
    164158      var y = gbmState.y;
     
    168162      // copy output of gradient function to pre-allocated rim array (pseudo-residuals)
    169163      int rimIdx = 0;
    170       foreach (var g in lossFunction.GetLossGradient(y, yPred, w)) {
     164      foreach (var g in lossFunction.GetLossGradient(y, yPred)) {
    171165        pseudoRes[rimIdx++] = g;
    172166      }
    173167
    174       var tree = treeBuilder.CreateRegressionTreeForGradientBoosting(pseudoRes, maxSize, activeIdx, lossFunction.GetLineSearchFunc(y, yPred, w), r, m);
     168      var tree = treeBuilder.CreateRegressionTreeForGradientBoosting(pseudoRes, maxSize, activeIdx, lossFunction.GetLineSearchFunc(y, yPred), r, m);
    175169
    176170      int i = 0;
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/AbsoluteErrorLoss.cs

    r12635 r12696  
    3030  // loss function for the weighted absolute error
    3131  public class AbsoluteErrorLoss : ILossFunction {
    32     public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
     32    public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred) {
    3333      var targetEnum = target.GetEnumerator();
    3434      var predEnum = pred.GetEnumerator();
    35       var weightEnum = weight.GetEnumerator();
    3635
    3736      double s = 0;
    38       while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {
     37      while (targetEnum.MoveNext() & predEnum.MoveNext()) {
    3938        double res = targetEnum.Current - predEnum.Current;
    40         s += weightEnum.Current * Math.Abs(res);  // w * |res|
     39        s += Math.Abs(res);  // |res|
    4140      }
    42       if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext())
    43         throw new ArgumentException("target, pred and weight have differing lengths");
     41      if (targetEnum.MoveNext() | predEnum.MoveNext())
     42        throw new ArgumentException("target and pred have differing lengths");
    4443
    4544      return s;
    4645    }
    4746
    48     public IEnumerable<double> GetLossGradient(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
     47    public IEnumerable<double> GetLossGradient(IEnumerable<double> target, IEnumerable<double> pred) {
    4948      var targetEnum = target.GetEnumerator();
    5049      var predEnum = pred.GetEnumerator();
    51       var weightEnum = weight.GetEnumerator();
    5250
    53       while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {
    54         // dL(y, f(x)) / df(x) = weight * sign(res)
     51      while (targetEnum.MoveNext() & predEnum.MoveNext()) {
     52        // dL(y, f(x)) / df(x) = sign(res)
    5553        var res = targetEnum.Current - predEnum.Current;
    56         if (res > 0) yield return weightEnum.Current;
    57         else if (res < 0) yield return -weightEnum.Current;
     54        if (res > 0) yield return 1.0;
     55        else if (res < 0) yield return -1.0;
    5856        else yield return 0.0;
    5957      }
    60       if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext())
    61         throw new ArgumentException("target, pred and weight have differing lengths");
     58      if (targetEnum.MoveNext() | predEnum.MoveNext())
     59        throw new ArgumentException("target and pred have differing lengths");
    6260    }
    6361
    6462    // return median of residuals
    65     public LineSearchFunc GetLineSearchFunc(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
     63    public LineSearchFunc GetLineSearchFunc(IEnumerable<double> target, IEnumerable<double> pred) {
    6664      var targetArr = target.ToArray();
    6765      var predArr = pred.ToArray();
    68       var weightArr = weight.ToArray();
    69       // the optimal constant value that should be added to the predictions is the median of the residuals
    70       // weights are not supported yet (need to calculate a weighted median)
    71       Debug.Assert(weightArr.All(w => w.IsAlmost(1.0)));
    7266
    73       if (targetArr.Length != predArr.Length || predArr.Length != weightArr.Length)
    74         throw new ArgumentException("target, pred and weight have differing lengths");
     67      if (targetArr.Length != predArr.Length)
     68        throw new ArgumentException("target and pred have differing lengths");
    7569
    7670      // line search for abs error
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/ILossFunction.cs

    r12607 r12696  
    3030  // target represents the target vector  (original targets from the problem data, never changed)
    3131  // pred   represents the current vector of predictions (a weighted combination of models learned so far, this vector is updated after each step)
    32   // weight represents a weight vector for rows (this is not supported yet -> all weights are 1)
    3332  public interface ILossFunction {
    34     // returns the weighted loss of the current prediction vector
    35     double GetLoss(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight);
     33    // returns the loss of the current prediction vector
     34    double GetLoss(IEnumerable<double> target, IEnumerable<double> pred);
    3635
    37     // returns an enumerable of the weighted loss gradient for each row
    38     IEnumerable<double> GetLossGradient(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight);
     36    // returns an enumerable of the loss gradient for each row
     37    IEnumerable<double> GetLossGradient(IEnumerable<double> target, IEnumerable<double> pred);
    3938
    4039    // returns a function that returns the optimal prediction value for a subset of rows from target and pred (see LineSearchFunc delegate above)
    41     LineSearchFunc GetLineSearchFunc(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight);
     40    LineSearchFunc GetLineSearchFunc(IEnumerable<double> target, IEnumerable<double> pred);
    4241  }
    4342}
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/LogisticRegressionLoss.cs

    r12607 r12696  
    3030  // Greedy Function Approximation: A Gradient Boosting Machine (page 9)
    3131  public class LogisticRegressionLoss : ILossFunction {
    32     public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
     32    public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred) {
    3333      var targetEnum = target.GetEnumerator();
    3434      var predEnum = pred.GetEnumerator();
    35       var weightEnum = weight.GetEnumerator();
    3635
    3736      double s = 0;
    38       while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {
     37      while (targetEnum.MoveNext() & predEnum.MoveNext()) {
    3938        Debug.Assert(targetEnum.Current.IsAlmost(0.0) || targetEnum.Current.IsAlmost(1.0), "labels must be 0 or 1 for logistic regression loss");
    4039
    4140        var y = targetEnum.Current * 2 - 1; // y in {-1,1}
    42         s += weightEnum.Current * Math.Log(1 + Math.Exp(-2 * y * predEnum.Current));
     41        s += Math.Log(1 + Math.Exp(-2 * y * predEnum.Current));
    4342      }
    44       if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext())
    45         throw new ArgumentException("target, pred and weight have different lengths");
     43      if (targetEnum.MoveNext() | predEnum.MoveNext())
     44        throw new ArgumentException("target and pred have different lengths");
    4645
    4746      return s;
    4847    }
    4948
    50     public IEnumerable<double> GetLossGradient(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
     49    public IEnumerable<double> GetLossGradient(IEnumerable<double> target, IEnumerable<double> pred) {
    5150      var targetEnum = target.GetEnumerator();
    5251      var predEnum = pred.GetEnumerator();
    53       var weightEnum = weight.GetEnumerator();
    5452
    55       while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {
     53      while (targetEnum.MoveNext() & predEnum.MoveNext()) {
    5654        Debug.Assert(targetEnum.Current.IsAlmost(0.0) || targetEnum.Current.IsAlmost(1.0), "labels must be 0 or 1 for logistic regression loss");
    5755        var y = targetEnum.Current * 2 - 1; // y in {-1,1}
    5856
    59         yield return weightEnum.Current * 2 * y / (1 + Math.Exp(2 * y * predEnum.Current));
     57        yield return 2 * y / (1 + Math.Exp(2 * y * predEnum.Current));
    6058
    6159      }
    62       if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext())
    63         throw new ArgumentException("target, pred and weight have different lengths");
     60      if (targetEnum.MoveNext() | predEnum.MoveNext())
     61        throw new ArgumentException("target and pred have different lengths");
    6462    }
    6563
    66     public LineSearchFunc GetLineSearchFunc(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
     64    public LineSearchFunc GetLineSearchFunc(IEnumerable<double> target, IEnumerable<double> pred) {
    6765      var targetArr = target.ToArray();
    6866      var predArr = pred.ToArray();
    69       var weightArr = weight.ToArray();
    70       // weights are not supported yet
    71       // when weights are supported we need to calculate a weighted median
    72       Debug.Assert(weightArr.All(w => w.IsAlmost(1.0)));
    7367
    74       if (targetArr.Length != predArr.Length || predArr.Length != weightArr.Length)
    75         throw new ArgumentException("target, pred and weight have different lengths");
     68      if (targetArr.Length != predArr.Length)
     69        throw new ArgumentException("target and pred have different lengths");
    7670
    7771      // "Simple Newton-Raphson step" of eqn. 23
     
    8276          var row = idx[i];
    8377          var y = targetArr[row] * 2 - 1; // y in {-1,1}
    84           var pseudoResponse = weightArr[row] * 2 * y / (1 + Math.Exp(2 * y * predArr[row]));
     78          var pseudoResponse = 2 * y / (1 + Math.Exp(2 * y * predArr[row]));
    8579
    8680          sumY += pseudoResponse;
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/RelativeErrorLoss.cs

    r12635 r12696  
    3030  // relative error loss is a special case of weighted absolute error loss with weights = (1/target)
    3131  public class RelativeErrorLoss : ILossFunction {
    32     public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
     32    public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred) {
    3333      var targetEnum = target.GetEnumerator();
    3434      var predEnum = pred.GetEnumerator();
    35       var weightEnum = weight.GetEnumerator();
    3635
    3736      double s = 0;
    38       while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {
     37      while (targetEnum.MoveNext() & predEnum.MoveNext()) {
    3938        double res = targetEnum.Current - predEnum.Current;
    40         s += weightEnum.Current * Math.Abs(res) * Math.Abs(1.0 / targetEnum.Current);
     39        s += Math.Abs(res) * Math.Abs(1.0 / targetEnum.Current);
    4140      }
    42       if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext())
    43         throw new ArgumentException("target, pred and weight have different lengths");
     41      if (targetEnum.MoveNext() | predEnum.MoveNext())
     42        throw new ArgumentException("target and pred have different lengths");
    4443
    4544      return s;
    4645    }
    4746
    48     public IEnumerable<double> GetLossGradient(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
     47    public IEnumerable<double> GetLossGradient(IEnumerable<double> target, IEnumerable<double> pred) {
    4948      var targetEnum = target.GetEnumerator();
    5049      var predEnum = pred.GetEnumerator();
    51       var weightEnum = weight.GetEnumerator();
    5250
    53       while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {
    54         // weight * sign(res) * abs(1 / target)
     51      while (targetEnum.MoveNext() & predEnum.MoveNext()) {
     52        // sign(res) * abs(1 / target)
    5553        var res = targetEnum.Current - predEnum.Current;
    56         if (res > 0) yield return weightEnum.Current * 1.0 / Math.Abs(targetEnum.Current);
    57         else if (res < 0) yield return -weightEnum.Current * 1.0 / Math.Abs(targetEnum.Current);
     54        if (res > 0) yield return 1.0 / Math.Abs(targetEnum.Current);
     55        else if (res < 0) yield return -1.0 / Math.Abs(targetEnum.Current);
    5856        else yield return 0.0;
    5957      }
    60       if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext())
    61         throw new ArgumentException("target, pred and weight have different lengths");
     58      if (targetEnum.MoveNext() | predEnum.MoveNext())
     59        throw new ArgumentException("target and pred have different lengths");
    6260    }
    6361
    64     public LineSearchFunc GetLineSearchFunc(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
     62    public LineSearchFunc GetLineSearchFunc(IEnumerable<double> target, IEnumerable<double> pred) {
    6563      var targetArr = target.ToArray();
    6664      var predArr = pred.ToArray();
    67       var weightArr = weight.ToArray();
    68       Debug.Assert(weightArr.All(w => w.IsAlmost(1.0)));
    6965
    70       if (targetArr.Length != predArr.Length || predArr.Length != weightArr.Length)
    71         throw new ArgumentException("target, pred and weight have different lengths");
     66      if (targetArr.Length != predArr.Length)
     67        throw new ArgumentException("target and pred have different lengths");
    7268
    7369      // line search for relative error
     
    7975        else if (nRows == 2) {
    8076          // weighted average of two residuals
    81           var w0 = weightArr[idx[startIdx]] * Math.Abs(1.0 / targetArr[idx[startIdx]]);
    82           var w1 = weightArr[idx[endIdx]] * Math.Abs(1.0 / targetArr[idx[endIdx]]);
     77          var w0 = Math.Abs(1.0 / targetArr[idx[startIdx]]);
     78          var w1 = Math.Abs(1.0 / targetArr[idx[endIdx]]);
    8379          if (w0 > w1) {
    8480            return targetArr[idx[startIdx]] - predArr[idx[startIdx]];
     
    9692            int row = idx[i];
    9793            var res = targetArr[row] - predArr[row];
    98             var w = weightArr[row] * Math.Abs(1.0 / targetArr[row]);
     94            var w = Math.Abs(1.0 / targetArr[row]);
    9995            res_w[i - startIdx] = new KeyValuePair<double, double>(res, w);
    10096            totalWeight += w;
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/SquaredErrorLoss.cs

    r12607 r12696  
    2727namespace HeuristicLab.Algorithms.DataAnalysis {
    2828  public class SquaredErrorLoss : ILossFunction {
    29     public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
     29    public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred) {
    3030      var targetEnum = target.GetEnumerator();
    3131      var predEnum = pred.GetEnumerator();
    32       var weightEnum = weight.GetEnumerator();
    3332
    3433      double s = 0;
    35       while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {
     34      while (targetEnum.MoveNext() & predEnum.MoveNext()) {
    3635        double res = targetEnum.Current - predEnum.Current;
    37         s += weightEnum.Current * res * res; // w * (res)^2
     36        s += res * res; // (res)^2
    3837      }
    39       if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext())
    40         throw new ArgumentException("target, pred and weight have different lengths");
     38      if (targetEnum.MoveNext() | predEnum.MoveNext())
     39        throw new ArgumentException("target and pred have different lengths");
    4140
    4241      return s;
    4342    }
    4443
    45     public IEnumerable<double> GetLossGradient(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
     44    public IEnumerable<double> GetLossGradient(IEnumerable<double> target, IEnumerable<double> pred) {
    4645      var targetEnum = target.GetEnumerator();
    4746      var predEnum = pred.GetEnumerator();
    48       var weightEnum = weight.GetEnumerator();
    4947
    50       while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {
    51         yield return weightEnum.Current * 2.0 * (targetEnum.Current - predEnum.Current); // dL(y, f(x)) / df(x)  = w * 2 * res
     48      while (targetEnum.MoveNext() & predEnum.MoveNext()) {
     49        yield return 2.0 * (targetEnum.Current - predEnum.Current); // dL(y, f(x)) / df(x)  = 2 * res
    5250      }
    53       if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext())
    54         throw new ArgumentException("target, pred and weight have different lengths");
     51      if (targetEnum.MoveNext() | predEnum.MoveNext())
     52        throw new ArgumentException("target and pred have different lengths");
    5553    }
    5654
    57     public LineSearchFunc GetLineSearchFunc(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
     55    public LineSearchFunc GetLineSearchFunc(IEnumerable<double> target, IEnumerable<double> pred) {
    5856      var targetArr = target.ToArray();
    5957      var predArr = pred.ToArray();
    60       var weightArr = weight.ToArray();
    61       if (targetArr.Length != predArr.Length || predArr.Length != weightArr.Length)
    62         throw new ArgumentException("target, pred and weight have different lengths");
     58      if (targetArr.Length != predArr.Length)
     59        throw new ArgumentException("target and pred have different lengths");
    6360
    6461      // line search for squared error loss
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/RegressionTreeBuilder.cs

    r12661 r12696  
    128128      var seLoss = new SquaredErrorLoss();
    129129      var zeros = Enumerable.Repeat(0.0, y.Length);
    130       var ones = Enumerable.Repeat(1.0, y.Length);
    131 
    132       var model = CreateRegressionTreeForGradientBoosting(y, maxSize, problemData.TrainingIndices.ToArray(), seLoss.GetLineSearchFunc(y, zeros, ones), r, m);
     130
     131      var model = CreateRegressionTreeForGradientBoosting(y, maxSize, problemData.TrainingIndices.ToArray(), seLoss.GetLineSearchFunc(y, zeros), r, m);
    133132
    134133      return new GradientBoostedTreesModel(new[] { new ConstantRegressionModel(yAvg), model }, new[] { 1.0, 1.0 });
Note: See TracChangeset for help on using the changeset viewer.