1  using System;


2  using System.Collections.Generic;


3  using System.Linq;


4  using System.Text;


5  using System.Threading.Tasks;


6  using HeuristicLab.Common;


7  using HeuristicLab.Core;


8 


9  namespace GradientBoostedTrees {


10  public class SquaredErrorLoss : ILossFunction {


11  public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {


12  var targetEnum = target.GetEnumerator();


13  var predEnum = pred.GetEnumerator();


14  var weightEnum = weight.GetEnumerator();


15 


16  double s = 0;


17  while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {


18  double res = targetEnum.Current  predEnum.Current;


19  s += weightEnum.Current * res * res;


20  }


21  if (targetEnum.MoveNext()  predEnum.MoveNext()  weightEnum.MoveNext())


22  throw new ArgumentException("target, pred and weight have differing lengths");


23 


24  return s;


25  }


26 


27  public IEnumerable<double> GetLossGradient(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {


28  var targetEnum = target.GetEnumerator();


29  var predEnum = pred.GetEnumerator();


30  var weightEnum = weight.GetEnumerator();


31 


32  while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {


33  yield return weightEnum.Current * 2.0 * (targetEnum.Current  predEnum.Current);


34  }


35  if (targetEnum.MoveNext()  predEnum.MoveNext()  weightEnum.MoveNext())


36  throw new ArgumentException("target, pred and weight have differing lengths");


37  }


38 


39  public LineSearchFunc GetLineSearchFunc(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {


40  var targetArr = target.ToArray();


41  var predArr = pred.ToArray();


42  var weightArr = weight.ToArray();


43  if (targetArr.Length != predArr.Length  predArr.Length != weightArr.Length)


44  throw new ArgumentException("target, pred and weight have differing lengths");


45 


46  // line search for


47  LineSearchFunc lineSearch = (idx, startIdx, endIdx) => {


48  double s = 0.0;


49  int n = 0;


50  for (int i = startIdx; i <= endIdx; i++) {


51  int row = idx[i];


52  s += (targetArr[row]  predArr[row]);


53  n++;


54  }


55  return s / n;


56  };


57  return lineSearch;


58 


59  }


60 


61  public override string ToString() {


62  return "Squared error loss";


63  }


64  }


65  }

