Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/06/15 13:02:19 (9 years ago)
Author:
gkronber
Message:

#2261: comments

Location:
branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/AbsoluteErrorLoss.cs

    r12590 r12597  
    5252
    5353      while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {
     54        // weight * sign(res)
    5455        var res = targetEnum.Current - predEnum.Current;
    5556        if (res > 0) yield return weightEnum.Current;
     
    6162    }
    6263
     64    // return median of residuals
    6365    public LineSearchFunc GetLineSearchFunc(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
    6466      var targetArr = target.ToArray();
  • branches/GBT-trunkintegration/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/RelativeErrorLoss.cs

    r12590 r12597  
    2929namespace HeuristicLab.Algorithms.DataAnalysis {
    3030  // relative error loss is a special case of weighted absolute error loss
     31  // absolute loss is weighted by (1/target)
    3132  public class RelativeErrorLoss : ILossFunction {
    3233    public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
     
    5253
    5354      while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {
     55        // weight * sign(res) * abs(1 / target)
    5456        var res = targetEnum.Current - predEnum.Current;
    5557        if (res > 0) yield return weightEnum.Current * 1.0 / Math.Abs(targetEnum.Current);
     
    7173
    7274      // line search for relative error
    73       // TODO: check and improve?
     75      // weighted median (weight = 1/target)
    7476      LineSearchFunc lineSearch = (idx, startIdx, endIdx) => {
    7577        // weighted median calculation
    7678        int nRows = endIdx - startIdx + 1; // startIdx and endIdx are inclusive
    77         if (nRows == 1) return targetArr[idx[startIdx]] - predArr[idx[startIdx]];
     79        if (nRows == 1) return targetArr[idx[startIdx]] - predArr[idx[startIdx]]; // res
    7880        else if (nRows == 2) {
     81          // weighted average of two residuals
    7982          var w0 = weightArr[idx[startIdx]] * Math.Abs(1.0 / targetArr[idx[startIdx]]);
    8083          var w1 = weightArr[idx[endIdx]] * Math.Abs(1.0 / targetArr[idx[endIdx]]);
    81           if (w0 > w1) {
    82             return targetArr[idx[startIdx]] - predArr[idx[startIdx]];
    83           } else if (w0 < w1) {
    84             return targetArr[idx[endIdx]] - predArr[idx[endIdx]];
    85           } else {
    86             // same weight
    87             return ((targetArr[idx[startIdx]] - predArr[idx[startIdx]]) + (targetArr[idx[endIdx]] - predArr[idx[endIdx]])) / 2;
    88           }
     84          return (w0 * (targetArr[idx[startIdx]] - predArr[idx[startIdx]]) + w1 * (targetArr[idx[endIdx]] - predArr[idx[endIdx]])) / (w0 + w1);
    8985        } else {
    9086          var ts = from offset in Enumerable.Range(0, nRows)
    9187                   let i = startIdx + offset
    92                    select new { res = targetArr[idx[i]] - predArr[idx[i]], weight = weightArr[idx[i]] * Math.Abs(1.0 / targetArr[idx[i]]) };
     88                   let row = idx[i]
     89                   select new { res = targetArr[row] - predArr[row], weight = weightArr[row] * Math.Abs(1.0 / targetArr[row]) };
    9390          ts = ts.OrderBy(t => t.res);
    9491          var totalWeight = ts.Sum(t => t.weight);
Note: See TracChangeset for help on using the changeset viewer.