Free cookie consent management tool by TermsFeed Policy Generator

source: branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/AbsoluteErrorLoss.cs @ 12374

Last change on this file since 12374 was 12374, checked in by gkronber, 9 years ago

#2261: added absolute and relative error loss functions for GBT.

File size: 4.3 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using System.Threading.Tasks;
7using HeuristicLab.Common;
8using HeuristicLab.Core;
9
10namespace GradientBoostedTrees {
11  public class AbsoluteErrorLoss : ILossFunction {
12    public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
13      var targetEnum = target.GetEnumerator();
14      var predEnum = pred.GetEnumerator();
15      var weightEnum = weight.GetEnumerator();
16
17      double s = 0;
18      while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {
19        double res = targetEnum.Current - predEnum.Current;
20        s += weightEnum.Current * Math.Abs(res);
21      }
22      if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext())
23        throw new ArgumentException("target, pred and weight have differing lengths");
24
25      return s;
26    }
27
28    public IEnumerable<double> GetLossGradient(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
29      var targetEnum = target.GetEnumerator();
30      var predEnum = pred.GetEnumerator();
31      var weightEnum = weight.GetEnumerator();
32
33      while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {
34        var res = targetEnum.Current - predEnum.Current;
35        if (res > 0) yield return weightEnum.Current;
36        else if (res < 0) yield return -weightEnum.Current;
37        else yield return 0.0;
38      }
39      if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext())
40        throw new ArgumentException("target, pred and weight have differing lengths");
41    }
42
43    public LineSearchFunc GetLineSearchFunc(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
44      var targetArr = target.ToArray();
45      var predArr = pred.ToArray();
46      var weightArr = weight.ToArray();
47      // weights are not supported yet
48      // when weights are supported we need to calculate a weighted median
49      Debug.Assert(weightArr.All(w => w.IsAlmost(1.0)));
50
51      if (targetArr.Length != predArr.Length || predArr.Length != weightArr.Length)
52        throw new ArgumentException("target, pred and weight have differing lengths");
53
54      // line search for abs error
55      LineSearchFunc lineSearch = (idx, startIdx, endIdx) => {
56        int nRows = endIdx - startIdx + 1;
57        var res = from offset in Enumerable.Range(0, nRows)
58                  let i = startIdx + offset
59                  let row = idx[i]
60                  select (targetArr[row] - predArr[row]);
61        return res.Median();
62
63        // old code for weighted median calculation
64        // int nRows = endIdx - startIdx + 1; // startIdx and endIdx are inclusive
65        // if (nRows == 1) return targetArr[idx[startIdx]] - predArr[idx[startIdx]];
66        // else if (nRows == 2) {
67        //   if (weightArr[idx[startIdx]] > weightArr[idx[endIdx]]) {
68        //     return targetArr[idx[startIdx]] - predArr[idx[startIdx]];
69        //   } else if (weightArr[idx[startIdx]] < weightArr[idx[endIdx]]) {
70        //     return targetArr[idx[endIdx]] - predArr[idx[endIdx]];
71        //   } else {
72        //     // same weight
73        //     return ((targetArr[idx[startIdx]] - predArr[idx[startIdx]]) +
74        //        (targetArr[idx[endIdx]] - predArr[idx[endIdx]])) / 2;
75        //   }
76        // } else {
77        //   var ts = from offset in Enumerable.Range(0, nRows)
78        //            let i = startIdx + offset
79        //            select new { res = targetArr[idx[i]] - predArr[idx[i]], weight = weightArr[idx[i]] };
80        //   ts = ts.OrderBy(t => t.res);
81        //   var totalWeight = ts.Sum(t => t.weight);
82        //   var tsEnumerator = ts.GetEnumerator();
83        //   tsEnumerator.MoveNext();
84        //
85        //   double aggWeight = tsEnumerator.Current.weight; // weight of first
86        //
87        //   while (aggWeight < totalWeight / 2) {
88        //     tsEnumerator.MoveNext();
89        //     aggWeight += tsEnumerator.Current.weight;
90        //   }
91        //   return tsEnumerator.Current.res;
92        // }
93      };
94      return lineSearch;
95
96    }
97
98    public override string ToString() {
99      return "Absolute error loss";
100    }
101  }
102}
Note: See TracBrowser for help on using the repository browser.