Free cookie consent management tool by TermsFeed Policy Generator

source: branches/GBT/HeuristicLab.Algorithms.DataAnalysis/3.4/GradientBoostedTrees/LossFunctions/SquaredErrorLoss.cs @ 12332

Last change on this file since 12332 was 12332, checked in by gkronber, 9 years ago

#2261: initial import of gradient boosted trees for regression

File size: 2.4 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using System.Threading.Tasks;
6using HeuristicLab.Common;
7using HeuristicLab.Core;
8
9namespace GradientBoostedTrees {
10  public class SquaredErrorLoss : ILossFunction {
11    public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
12      var targetEnum = target.GetEnumerator();
13      var predEnum = pred.GetEnumerator();
14      var weightEnum = weight.GetEnumerator();
15
16      double s = 0;
17      while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {
18        double res = targetEnum.Current - predEnum.Current;
19        s += weightEnum.Current * res * res;
20      }
21      if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext())
22        throw new ArgumentException("target, pred and weight have differing lengths");
23
24      return s;
25    }
26
27    public IEnumerable<double> GetLossGradient(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
28      var targetEnum = target.GetEnumerator();
29      var predEnum = pred.GetEnumerator();
30      var weightEnum = weight.GetEnumerator();
31
32      while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {
33        yield return weightEnum.Current * 2.0 * (targetEnum.Current - predEnum.Current);
34      }
35      if (targetEnum.MoveNext() | predEnum.MoveNext() | weightEnum.MoveNext())
36        throw new ArgumentException("target, pred and weight have differing lengths");
37    }
38
39    public LineSearchFunc GetLineSearchFunc(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {
40      var targetArr = target.ToArray();
41      var predArr = pred.ToArray();
42      var weightArr = weight.ToArray();
43      if (targetArr.Length != predArr.Length || predArr.Length != weightArr.Length)
44        throw new ArgumentException("target, pred and weight have differing lengths");
45
46      // line search for
47      LineSearchFunc lineSearch = (idx, startIdx, endIdx) => {
48        double s = 0.0;
49        int n = 0;
50        for (int i = startIdx; i <= endIdx; i++) {
51          int row = idx[i];
52          s += (targetArr[row] - predArr[row]);
53          n++;
54        }
55        return s / n;
56      };
57      return lineSearch;
58
59    }
60
61    public override string ToString() {
62      return "Squared error loss";
63    }
64  }
65}
Note: See TracBrowser for help on using the repository browser.