1 | using System;
2 | using System.Collections.Generic;
3 | using System.Diagnostics;
4 | using System.Linq;
5 | using System.Text;
6 | using System.Threading.Tasks;
7 | using HeuristicLab.Common;
8 | using HeuristicLab.Core;
9 |
10 | namespace GradientBoostedTrees {
11 | public class AbsoluteErrorLoss : ILossFunction {
12 | public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
13 | var targetEnum = target.GetEnumerator();
14 | var predEnum = pred.GetEnumerator();
15 | var weightEnum = weight.GetEnumerator();
16 |
17 | double s = 0;
18 | while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {
19 | double res = targetEnum.Current - predEnum.Current;
20 | s += weightEnum.Current * Math.Abs(res);
21 | }
22 | if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext())
23 | throw new ArgumentException("target, pred and weight have differing lengths");
24 |
25 | return s;
26 | }
27 |
28 | public IEnumerable<double> GetLossGradient(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
29 | var targetEnum = target.GetEnumerator();
30 | var predEnum = pred.GetEnumerator();
31 | var weightEnum = weight.GetEnumerator();
32 |
33 | while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {
34 | var res = targetEnum.Current - predEnum.Current;
35 | if (res > 0) yield return weightEnum.Current;
36 | else if (res < 0) yield return -weightEnum.Current;
37 | else yield return 0.0;
38 | }
39 | if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext())
40 | throw new ArgumentException("target, pred and weight have differing lengths");
41 | }
42 |
43 | public LineSearchFunc GetLineSearchFunc(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
44 | var targetArr = target.ToArray();
45 | var predArr = pred.ToArray();
46 | var weightArr = weight.ToArray();
47 | // weights are not supported yet
48 | // when weights are supported we need to calculate a weighted median
49 | Debug.Assert(weightArr.All(w => w.IsAlmost(1.0)));
50 |
51 | if (targetArr.Length != predArr.Length || predArr.Length != weightArr.Length)
52 | throw new ArgumentException("target, pred and weight have differing lengths");
53 |
54 | // line search for abs error
55 | LineSearchFunc lineSearch = (idx, startIdx, endIdx) => {
56 | int nRows = endIdx - startIdx + 1;
57 | var res = from offset in Enumerable.Range(0, nRows)
58 | let i = startIdx + offset
59 | let row = idx[i]
60 | select (targetArr[row] - predArr[row]);
61 | return res.Median();
62 |
63 | // old code for weighted median calculation
64 | // int nRows = endIdx - startIdx + 1; // startIdx and endIdx are inclusive
65 | // if (nRows == 1) return targetArr[idx[startIdx]] - predArr[idx[startIdx]];
66 | // else if (nRows == 2) {
67 | // if (weightArr[idx[startIdx]] > weightArr[idx[endIdx]]) {
68 | // return targetArr[idx[startIdx]] - predArr[idx[startIdx]];
69 | // } else if (weightArr[idx[startIdx]] < weightArr[idx[endIdx]]) {
70 | // return targetArr[idx[endIdx]] - predArr[idx[endIdx]];
71 | // } else {
72 | // // same weight
73 | // return ((targetArr[idx[startIdx]] - predArr[idx[startIdx]]) +
74 | // (targetArr[idx[endIdx]] - predArr[idx[endIdx]])) / 2;
75 | // }
76 | // } else {
77 | // var ts = from offset in Enumerable.Range(0, nRows)
78 | // let i = startIdx + offset
79 | // select new { res = targetArr[idx[i]] - predArr[idx[i]], weight = weightArr[idx[i]] };
80 | // ts = ts.OrderBy(t => t.res);
81 | // var totalWeight = ts.Sum(t => t.weight);
82 | // var tsEnumerator = ts.GetEnumerator();
83 | // tsEnumerator.MoveNext();
84 | //
85 | // double aggWeight = tsEnumerator.Current.weight; // weight of first
86 | //
87 | // while (aggWeight < totalWeight / 2) {
88 | // tsEnumerator.MoveNext();
89 | // aggWeight += tsEnumerator.Current.weight;
90 | // }
91 | // return tsEnumerator.Current.res;
92 | // }
93 | };
94 | return lineSearch;
95 |
96 | }
97 |
98 | public override string ToString() {
99 | return "Absolute error loss";
100 | }
101 | }
102 | }