1  #region License Information


2  /* HeuristicLab


3  * Copyright (C) 20022015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)


4  * and the BEACON Center for the Study of Evolution in Action.


5  *


6  * This file is part of HeuristicLab.


7  *


8  * HeuristicLab is free software: you can redistribute it and/or modify


9  * it under the terms of the GNU General Public License as published by


10  * the Free Software Foundation, either version 3 of the License, or


11  * (at your option) any later version.


12  *


13  * HeuristicLab is distributed in the hope that it will be useful,


14  * but WITHOUT ANY WARRANTY; without even the implied warranty of


15  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the


16  * GNU General Public License for more details.


17  *


18  * You should have received a copy of the GNU General Public License


19  * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.


20  */


21  #endregion


22 


23  using System;


24  using System.Collections.Generic;


25  using System.Diagnostics;


26  using System.Linq;


27  using HeuristicLab.Common;


28 


29  namespace HeuristicLab.Algorithms.DataAnalysis {


30  public class LogisticRegressionLoss : ILossFunction {


31  public double GetLoss(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {


32  var targetEnum = target.GetEnumerator();


33  var predEnum = pred.GetEnumerator();


34  var weightEnum = weight.GetEnumerator();


35 


36  double s = 0;


37  while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {


38  // assert target == 0 or target == 1


39  if (!targetEnum.Current.IsAlmost(0.0) && !targetEnum.Current.IsAlmost(1.0))


40  throw new NotSupportedException("labels must be 0 or 1 for logistic regression loss");


41  double f = Math.Max(7, Math.Min(7, predEnum.Current)); // threshold for exponent


42  var probPos = Math.Exp(2 * f) / (1 + Math.Exp(2 * f));


43  s += weightEnum.Current * (targetEnum.Current * Math.Log(probPos)  (1  targetEnum.Current) * Math.Log(1  probPos));


44  }


45  if (targetEnum.MoveNext()  predEnum.MoveNext()  weightEnum.MoveNext())


46  throw new ArgumentException("target, pred and weight have differing lengths");


47 


48  return s;


49  }


50 


51  public IEnumerable<double> GetLossGradient(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {


52  var targetEnum = target.GetEnumerator();


53  var predEnum = pred.GetEnumerator();


54  var weightEnum = weight.GetEnumerator();


55 


56  while (targetEnum.MoveNext() & predEnum.MoveNext() & weightEnum.MoveNext()) {


57  // assert target == 0 or target == 1


58  if (!targetEnum.Current.IsAlmost(0.0) && !targetEnum.Current.IsAlmost(1.0))


59  throw new NotSupportedException("labels must be 0 or 1 for logistic regression loss");


60  double f = Math.Max(7, Math.Min(7, predEnum.Current)); // threshold for exponent


61  var probPos = Math.Exp(2 * f) / (1 + Math.Exp(2 * f));


62  yield return weightEnum.Current * (targetEnum.Current  probPos) / (probPos * probPos  probPos);


63  }


64  if (targetEnum.MoveNext()  predEnum.MoveNext()  weightEnum.MoveNext())


65  throw new ArgumentException("target, pred and weight have differing lengths");


66  }


67 


68  public LineSearchFunc GetLineSearchFunc(IEnumerable<double> target, IEnumerable<double> pred, IEnumerable<double> weight) {


69  var targetArr = target.ToArray();


70  var predArr = pred.ToArray();


71  var weightArr = weight.ToArray();


72  // weights are not supported yet


73  // when weights are supported we need to calculate a weighted median


74  Debug.Assert(weightArr.All(w => w.IsAlmost(1.0)));


75 


76  if (targetArr.Length != predArr.Length  predArr.Length != weightArr.Length)


77  throw new ArgumentException("target, pred and weight have differing lengths");


78 


79  // line search for abs error


80  LineSearchFunc lineSearch = (idx, startIdx, endIdx) => {


81  double sumY = 0.0;


82  double sumDiff = 0.0;


83  for (int i = startIdx; i <= endIdx; i++) {


84  var yi = (targetArr[idx[i]]  predArr[idx[i]]);


85  var wi = weightArr[idx[i]];


86 


87  sumY += wi * yi;


88  sumDiff += wi * Math.Abs(yi) * (1  Math.Abs(yi));


89 


90  }


91  // prevent divByZero


92  sumDiff = Math.Max(1E12, sumDiff);


93  return 0.5 * sumY / sumDiff;


94  };


95  return lineSearch;


96 


97  }


98 


99  public override string ToString() {


100  return "Logistic regression loss";


101  }


102  }


103  }

