1  #region License Information


2  /* HeuristicLab


3  * Copyright (C) 20022019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)


4  * and the BEACON Center for the Study of Evolution in Action.


5  *


6  * This file is part of HeuristicLab.


7  *


8  * HeuristicLab is free software: you can redistribute it and/or modify


9  * it under the terms of the GNU General Public License as published by


10  * the Free Software Foundation, either version 3 of the License, or


11  * (at your option) any later version.


12  *


13  * HeuristicLab is distributed in the hope that it will be useful,


14  * but WITHOUT ANY WARRANTY; without even the implied warranty of


15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the


16  * GNU General Public License for more details.


17  *


18  * You should have received a copy of the GNU General Public License


19  * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.


20  */


21  #endregion


22 


23  using System;


24  using System.Collections.Generic;


25  using System.Linq;


26  using HeuristicLab.Common;


27  using HeuristicLab.Core;


28  using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;


29 


30  namespace HeuristicLab.Algorithms.DataAnalysis {


31  [StorableClass]


32  [Item("Squared error loss", "")]


33  public sealed class SquaredErrorLoss : Item, ILossFunction {


34  public SquaredErrorLoss() { }


35 


36  public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred) {


37  var targetEnum = target.GetEnumerator();


38  var predEnum = pred.GetEnumerator();


39 


40  double s = 0;


41  while (targetEnum.MoveNext() & predEnum.MoveNext()) {


42  double res = targetEnum.Current  predEnum.Current;


43  s += res * res; // (res)^2


44  }


45  if (targetEnum.MoveNext()  predEnum.MoveNext())


46  throw new ArgumentException("target and pred have different lengths");


47 


48  return s;


49  }


50 


51  public IEnumerable<double> GetLossGradient(IEnumerable<double> target, IEnumerable<double> pred) {


52  var targetEnum = target.GetEnumerator();


53  var predEnum = pred.GetEnumerator();


54 


55  while (targetEnum.MoveNext() & predEnum.MoveNext()) {


56  yield return 2.0 * (targetEnum.Current  predEnum.Current); // dL(y, f(x)) / df(x) = 2 * res


57  }


58  if (targetEnum.MoveNext()  predEnum.MoveNext())


59  throw new ArgumentException("target and pred have different lengths");


60  }


61 


62  // targetArr and predArr are not changed by LineSearch


63  public double LineSearch(double[] targetArr, double[] predArr, int[] idx, int startIdx, int endIdx) {


64  if (targetArr.Length != predArr.Length)


65  throw new ArgumentException("target and pred have different lengths");


66 


67  // line search for squared error loss


68  // for a given partition of rows the optimal constant that should be added to the current prediction values is the average of the residuals


69  double s = 0.0;


70  int n = 0;


71  for (int i = startIdx; i <= endIdx; i++) {


72  int row = idx[i];


73  s += (targetArr[row]  predArr[row]);


74  n++;


75  }


76  return s / n;


77  }


78 


79  #region item implementation


80  [StorableConstructor]


81  private SquaredErrorLoss(bool deserializing) : base(deserializing) { }


82 


83  private SquaredErrorLoss(SquaredErrorLoss original, Cloner cloner) : base(original, cloner) { }


84 


85  public override IDeepCloneable Clone(Cloner cloner) {


86  return new SquaredErrorLoss(this, cloner);


87  }


88  #endregion


89  }


90  }

