1  #region License Information


2  /* HeuristicLab


3  * Copyright (C) 20022019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)


4  * and the BEACON Center for the Study of Evolution in Action.


5  *


6  * This file is part of HeuristicLab.


7  *


8  * HeuristicLab is free software: you can redistribute it and/or modify


9  * it under the terms of the GNU General Public License as published by


10  * the Free Software Foundation, either version 3 of the License, or


11  * (at your option) any later version.


12  *


13  * HeuristicLab is distributed in the hope that it will be useful,


14  * but WITHOUT ANY WARRANTY; without even the implied warranty of


15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the


16  * GNU General Public License for more details.


17  *


18  * You should have received a copy of the GNU General Public License


19  * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.


20  */


21  #endregion


22 


23  using System;


24  using System.Collections.Generic;


25  using HeuristicLab.Common;


26  using HeuristicLab.Core;


27  using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;


28 


29  namespace HeuristicLab.Algorithms.DataAnalysis {


30  // relative error loss is a special case of weighted absolute error loss with weights = (1/target)


31  [StorableClass]


32  [Item("Relative error loss", "")]


33  public sealed class RelativeErrorLoss : Item, ILossFunction {


34  public RelativeErrorLoss() { }


35 


36  public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred) {


37  var targetEnum = target.GetEnumerator();


38  var predEnum = pred.GetEnumerator();


39 


40  double s = 0;


41  while (targetEnum.MoveNext() & predEnum.MoveNext()) {


42  double res = targetEnum.Current  predEnum.Current;


43  s += Math.Abs(res) * Math.Abs(1.0 / targetEnum.Current);


44  }


45  if (targetEnum.MoveNext()  predEnum.MoveNext())


46  throw new ArgumentException("target and pred have different lengths");


47 


48  return s;


49  }


50 


51  public IEnumerable<double> GetLossGradient(IEnumerable<double> target, IEnumerable<double> pred) {


52  var targetEnum = target.GetEnumerator();


53  var predEnum = pred.GetEnumerator();


54 


55  while (targetEnum.MoveNext() & predEnum.MoveNext()) {


56  // sign(res) * abs(1 / target)


57  var res = targetEnum.Current  predEnum.Current;


58  if (res > 0) yield return 1.0 / Math.Abs(targetEnum.Current);


59  else if (res < 0) yield return 1.0 / Math.Abs(targetEnum.Current);


60  else yield return 0.0;


61  }


62  if (targetEnum.MoveNext()  predEnum.MoveNext())


63  throw new ArgumentException("target and pred have different lengths");


64  }


65 


66  // targetArr and predArr are not changed by LineSearch


67  public double LineSearch(double[] targetArr, double[] predArr, int[] idx, int startIdx, int endIdx) {


68  if (targetArr.Length != predArr.Length)


69  throw new ArgumentException("target and pred have different lengths");


70 


71  // line search for relative error


72  // weighted median (weight = 1/target)


73  int nRows = endIdx  startIdx + 1; // startIdx and endIdx are inclusive


74  if (nRows == 1) return targetArr[idx[startIdx]]  predArr[idx[startIdx]]; // res


75  else if (nRows == 2) {


76  // weighted average of two residuals


77  var w0 = Math.Abs(1.0 / targetArr[idx[startIdx]]);


78  var w1 = Math.Abs(1.0 / targetArr[idx[endIdx]]);


79  if (w0 > w1) {


80  return targetArr[idx[startIdx]]  predArr[idx[startIdx]];


81  } else if (w0 < w1) {


82  return targetArr[idx[endIdx]]  predArr[idx[endIdx]];


83  } else {


84  // same weight > return average of both residuals


85  return ((targetArr[idx[startIdx]]  predArr[idx[startIdx]]) + (targetArr[idx[endIdx]]  predArr[idx[endIdx]])) / 2;


86  }


87  } else {


88  // create an array of keyvalue pairs to be sorted (instead of using Array.Sort(res, weights))


89  var res_w = new KeyValuePair<double, double>[nRows];


90  var totalWeight = 0.0;


91  for (int i = startIdx; i <= endIdx; i++) {


92  int row = idx[i];


93  var res = targetArr[row]  predArr[row];


94  var w = Math.Abs(1.0 / targetArr[row]);


95  res_w[i  startIdx] = new KeyValuePair<double, double>(res, w);


96  totalWeight += w;


97  }


98  // TODO: improve efficiency (find median without sort)


99  res_w.StableSort((a, b) => Math.Sign(a.Key  b.Key));


100 


101  int k = 0;


102  double sum = totalWeight  res_w[k].Value; // total  first weight


103  while (sum > totalWeight / 2) {


104  k++;


105  sum = res_w[k].Value;


106  }


107  return res_w[k].Key;


108  }


109  }


110 


111  #region item implementation


112  [StorableConstructor]


113  private RelativeErrorLoss(bool deserializing) : base(deserializing) { }


114 


115  private RelativeErrorLoss(RelativeErrorLoss original, Cloner cloner) : base(original, cloner) { }


116 


117  public override IDeepCloneable Clone(Cloner cloner) {


118  return new RelativeErrorLoss(this, cloner);


119  }


120  #endregion


121  }


122  }

