1 | using System;
|
---|
2 | using System.Collections.Generic;
|
---|
3 | using System.Diagnostics;
|
---|
4 | using System.Linq;
|
---|
5 | using System.Text;
|
---|
6 | using System.Threading.Tasks;
|
---|
7 | using HeuristicLab.Common;
|
---|
8 | using HeuristicLab.Core;
|
---|
9 |
|
---|
10 | namespace GradientBoostedTrees {
|
---|
11 | public class AbsoluteErrorLoss : ILossFunction {
|
---|
12 | public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
|
---|
13 | var targetEnum = target.GetEnumerator();
|
---|
14 | var predEnum = pred.GetEnumerator();
|
---|
15 | var weightEnum = weight.GetEnumerator();
|
---|
16 |
|
---|
17 | double s = 0;
|
---|
18 | while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {
|
---|
19 | double res = targetEnum.Current - predEnum.Current;
|
---|
20 | s += weightEnum.Current * Math.Abs(res);
|
---|
21 | }
|
---|
22 | if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext())
|
---|
23 | throw new ArgumentException("target, pred and weight have differing lengths");
|
---|
24 |
|
---|
25 | return s;
|
---|
26 | }
|
---|
27 |
|
---|
28 | public IEnumerable<double> GetLossGradient(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
|
---|
29 | var targetEnum = target.GetEnumerator();
|
---|
30 | var predEnum = pred.GetEnumerator();
|
---|
31 | var weightEnum = weight.GetEnumerator();
|
---|
32 |
|
---|
33 | while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {
|
---|
34 | var res = targetEnum.Current - predEnum.Current;
|
---|
35 | if (res > 0) yield return weightEnum.Current;
|
---|
36 | else if (res < 0) yield return -weightEnum.Current;
|
---|
37 | else yield return 0.0;
|
---|
38 | }
|
---|
39 | if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext())
|
---|
40 | throw new ArgumentException("target, pred and weight have differing lengths");
|
---|
41 | }
|
---|
42 |
|
---|
43 | public LineSearchFunc GetLineSearchFunc(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
|
---|
44 | var targetArr = target.ToArray();
|
---|
45 | var predArr = pred.ToArray();
|
---|
46 | var weightArr = weight.ToArray();
|
---|
47 | // weights are not supported yet
|
---|
48 | // when weights are supported we need to calculate a weighted median
|
---|
49 | Debug.Assert(weightArr.All(w => w.IsAlmost(1.0)));
|
---|
50 |
|
---|
51 | if (targetArr.Length != predArr.Length || predArr.Length != weightArr.Length)
|
---|
52 | throw new ArgumentException("target, pred and weight have differing lengths");
|
---|
53 |
|
---|
54 | // line search for abs error
|
---|
55 | LineSearchFunc lineSearch = (idx, startIdx, endIdx) => {
|
---|
56 | int nRows = endIdx - startIdx + 1;
|
---|
57 | var res = from offset in Enumerable.Range(0, nRows)
|
---|
58 | let i = startIdx + offset
|
---|
59 | let row = idx[i]
|
---|
60 | select (targetArr[row] - predArr[row]);
|
---|
61 | return res.Median();
|
---|
62 |
|
---|
63 | // old code for weighted median calculation
|
---|
64 | // int nRows = endIdx - startIdx + 1; // startIdx and endIdx are inclusive
|
---|
65 | // if (nRows == 1) return targetArr[idx[startIdx]] - predArr[idx[startIdx]];
|
---|
66 | // else if (nRows == 2) {
|
---|
67 | // if (weightArr[idx[startIdx]] > weightArr[idx[endIdx]]) {
|
---|
68 | // return targetArr[idx[startIdx]] - predArr[idx[startIdx]];
|
---|
69 | // } else if (weightArr[idx[startIdx]] < weightArr[idx[endIdx]]) {
|
---|
70 | // return targetArr[idx[endIdx]] - predArr[idx[endIdx]];
|
---|
71 | // } else {
|
---|
72 | // // same weight
|
---|
73 | // return ((targetArr[idx[startIdx]] - predArr[idx[startIdx]]) +
|
---|
74 | // (targetArr[idx[endIdx]] - predArr[idx[endIdx]])) / 2;
|
---|
75 | // }
|
---|
76 | // } else {
|
---|
77 | // var ts = from offset in Enumerable.Range(0, nRows)
|
---|
78 | // let i = startIdx + offset
|
---|
79 | // select new { res = targetArr[idx[i]] - predArr[idx[i]], weight = weightArr[idx[i]] };
|
---|
80 | // ts = ts.OrderBy(t => t.res);
|
---|
81 | // var totalWeight = ts.Sum(t => t.weight);
|
---|
82 | // var tsEnumerator = ts.GetEnumerator();
|
---|
83 | // tsEnumerator.MoveNext();
|
---|
84 | //
|
---|
85 | // double aggWeight = tsEnumerator.Current.weight; // weight of first
|
---|
86 | //
|
---|
87 | // while (aggWeight < totalWeight / 2) {
|
---|
88 | // tsEnumerator.MoveNext();
|
---|
89 | // aggWeight += tsEnumerator.Current.weight;
|
---|
90 | // }
|
---|
91 | // return tsEnumerator.Current.res;
|
---|
92 | // }
|
---|
93 | };
|
---|
94 | return lineSearch;
|
---|
95 |
|
---|
96 | }
|
---|
97 |
|
---|
98 | public override string ToString() {
|
---|
99 | return "Absolute error loss";
|
---|
100 | }
|
---|
101 | }
|
---|
102 | }
|
---|