1  #region License Information


2  /* HeuristicLab


3  * Copyright (C) 20022015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)


4  * and the BEACON Center for the Study of Evolution in Action.


5  *


6  * This file is part of HeuristicLab.


7  *


8  * HeuristicLab is free software: you can redistribute it and/or modify


9  * it under the terms of the GNU General Public License as published by


10  * the Free Software Foundation, either version 3 of the License, or


11  * (at your option) any later version.


12  *


13  * HeuristicLab is distributed in the hope that it will be useful,


14  * but WITHOUT ANY WARRANTY; without even the implied warranty of


15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the


16  * GNU General Public License for more details.


17  *


18  * You should have received a copy of the GNU General Public License


19  * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.


20  */


21  #endregion


22 


23  using System;


24  using System.Collections.Generic;


25  using System.Diagnostics;


26  using System.Linq;


27  using HeuristicLab.Common;


28 


29  namespace HeuristicLab.Algorithms.DataAnalysis {


30  // relative error loss is a special case of weighted absolute error loss


31  public class RelativeErrorLoss : ILossFunction {


32  public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {


33  var targetEnum = target.GetEnumerator();


34  var predEnum = pred.GetEnumerator();


35  var weightEnum = weight.GetEnumerator();


36 


37  double s = 0;


38  while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {


39  double res = targetEnum.Current  predEnum.Current;


40  s += weightEnum.Current * Math.Abs(res) * Math.Abs(1.0 / targetEnum.Current);


41  }


42  if (targetEnum.MoveNext()  predEnum.MoveNext()  weightEnum.MoveNext())


43  throw new ArgumentException("target, pred and weight have differing lengths");


44 


45  return s;


46  }


47 


48  public IEnumerable<double> GetLossGradient(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {


49  var targetEnum = target.GetEnumerator();


50  var predEnum = pred.GetEnumerator();


51  var weightEnum = weight.GetEnumerator();


52 


53  while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {


54  var res = targetEnum.Current  predEnum.Current;


55  if (res > 0) yield return weightEnum.Current * 1.0 / Math.Abs(targetEnum.Current);


56  else if (res < 0) yield return weightEnum.Current * 1.0 / Math.Abs(targetEnum.Current);


57  else yield return 0.0;


58  }


59  if (targetEnum.MoveNext()  predEnum.MoveNext()  weightEnum.MoveNext())


60  throw new ArgumentException("target, pred and weight have differing lengths");


61  }


62 


63  public LineSearchFunc GetLineSearchFunc(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {


64  var targetArr = target.ToArray();


65  var predArr = pred.ToArray();


66  var weightArr = weight.ToArray();


67  Debug.Assert(weightArr.All(w => w.IsAlmost(1.0)));


68 


69  if (targetArr.Length != predArr.Length  predArr.Length != weightArr.Length)


70  throw new ArgumentException("target, pred and weight have differing lengths");


71 


72  // line search for relative error


73  // TODO: check and improve?


74  LineSearchFunc lineSearch = (idx, startIdx, endIdx) => {


75  // weighted median calculation


76  int nRows = endIdx  startIdx + 1; // startIdx and endIdx are inclusive


77  if (nRows == 1) return targetArr[idx[startIdx]]  predArr[idx[startIdx]];


78  else if (nRows == 2) {


79  var w0 = weightArr[idx[startIdx]] * Math.Abs(1.0 / targetArr[idx[startIdx]]);


80  var w1 = weightArr[idx[endIdx]] * Math.Abs(1.0 / targetArr[idx[endIdx]]);


81  if (w0 > w1) {


82  return targetArr[idx[startIdx]]  predArr[idx[startIdx]];


83  } else if (w0 < w1) {


84  return targetArr[idx[endIdx]]  predArr[idx[endIdx]];


85  } else {


86  // same weight


87  return ((targetArr[idx[startIdx]]  predArr[idx[startIdx]]) + (targetArr[idx[endIdx]]  predArr[idx[endIdx]])) / 2;


88  }


89  } else {


90  var ts = from offset in Enumerable.Range(0, nRows)


91  let i = startIdx + offset


92  select new { res = targetArr[idx[i]]  predArr[idx[i]], weight = weightArr[idx[i]] * Math.Abs(1.0 / targetArr[idx[i]]) };


93  ts = ts.OrderBy(t => t.res);


94  var totalWeight = ts.Sum(t => t.weight);


95  var tsEnumerator = ts.GetEnumerator();


96  tsEnumerator.MoveNext();


97 


98  double aggWeight = tsEnumerator.Current.weight; // weight of first


99 


100  while (aggWeight < totalWeight / 2) {


101  tsEnumerator.MoveNext();


102  aggWeight += tsEnumerator.Current.weight;


103  }


104  return tsEnumerator.Current.res;


105  }


106  };


107  return lineSearch;


108  }


109 


110  public override string ToString() {


111  return "Relative error loss";


112  }


113  }


114  }

