1  #region License Information


2  /* HeuristicLab


3  * Copyright (C) 20022012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)


4  *


5  * This file is part of HeuristicLab.


6  *


7  * HeuristicLab is free software: you can redistribute it and/or modify


8  * it under the terms of the GNU General Public License as published by


9  * the Free Software Foundation, either version 3 of the License, or


10  * (at your option) any later version.


11  *


12  * HeuristicLab is distributed in the hope that it will be useful,


13  * but WITHOUT ANY WARRANTY; without even the implied warranty of


14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the


15  * GNU General Public License for more details.


16  *


17  * You should have received a copy of the GNU General Public License


18  * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.


19  */


20  #endregion


21 


22  using System;


23  using System.Collections.Generic;


24  using System.Linq;


25  using HeuristicLab.Common;


26  using HeuristicLab.Core;


27  using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;


28 


29  namespace HeuristicLab.Problems.DataAnalysis {


30  /// <summary>


31  /// Represents a threshold calculator that calculates thresholds as the cutting points between the estimated class distributions (assuming normally distributed class values).


32  /// </summary>


33  [StorableClass]


34  [Item("NormalDistributionCutPointsThresholdCalculator", "Represents a threshold calculator that calculates thresholds as the cutting points between the estimated class distributions (assuming normally distributed class values).")]


35  public class NormalDistributionCutPointsThresholdCalculator : ThresholdCalculator {


36 


37  [StorableConstructor]


38  protected NormalDistributionCutPointsThresholdCalculator(bool deserializing) : base(deserializing) { }


39  protected NormalDistributionCutPointsThresholdCalculator(NormalDistributionCutPointsThresholdCalculator original, Cloner cloner)


40  : base(original, cloner) {


41  }


42  public NormalDistributionCutPointsThresholdCalculator()


43  : base() {


44  }


45 


46  public override IDeepCloneable Clone(Cloner cloner) {


47  return new NormalDistributionCutPointsThresholdCalculator(this, cloner);


48  }


49 


50  public override void Calculate(IClassificationProblemData problemData, IEnumerable<double> estimatedValues, IEnumerable<double> targetClassValues, out double[] classValues, out double[] thresholds) {


51  NormalDistributionCutPointsThresholdCalculator.CalculateThresholds(problemData, estimatedValues, targetClassValues, out classValues, out thresholds);


52  }


53 


54  public static void CalculateThresholds(IClassificationProblemData problemData, IEnumerable<double> estimatedValues, IEnumerable<double> targetClassValues, out double[] classValues, out double[] thresholds) {


55  var estimatedTargetValues = Enumerable.Zip(estimatedValues, targetClassValues, (e, t) => new { EstimatedValue = e, TargetValue = t }).ToList();


56  double estimatedValuesRange = estimatedValues.Range();


57 


58  Dictionary<double, double> classMean = new Dictionary<double, double>();


59  Dictionary<double, double> classStdDev = new Dictionary<double, double>();


60  // calculate moments per class


61  foreach (var group in estimatedTargetValues.GroupBy(p => p.TargetValue)) {


62  IEnumerable<double> estimatedClassValues = group.Select(x => x.EstimatedValue);


63  double classValue = group.Key;


64  double mean, variance;


65  OnlineCalculatorError meanErrorState, varianceErrorState;


66  OnlineMeanAndVarianceCalculator.Calculate(estimatedClassValues, out mean, out variance, out meanErrorState, out varianceErrorState);


67 


68  if (meanErrorState == OnlineCalculatorError.None && varianceErrorState == OnlineCalculatorError.None) {


69  classMean[classValue] = mean;


70  classStdDev[classValue] = Math.Sqrt(variance);


71  }


72  }


73  double[] originalClasses = classMean.Keys.OrderBy(x => x).ToArray();


74  int nClasses = originalClasses.Length;


75  List<double> thresholdList = new List<double>();


76  for (int i = 0; i < nClasses  1; i++) {


77  for (int j = i + 1; j < nClasses; j++) {


78  double x1, x2;


79  double class0 = originalClasses[i];


80  double class1 = originalClasses[j];


81  // calculate all thresholds


82  CalculateCutPoints(classMean[class0], classStdDev[class0], classMean[class1], classStdDev[class1], out x1, out x2);


83 


84  // if the two cut points are too close (for instance because the stdDev=0)


85  // then move them by 0.1% of the range of estimated values


86  if (x1.IsAlmost(x2)) {


87  x1 = 0.001 * estimatedValuesRange;


88  x2 += 0.001 * estimatedValuesRange;


89  }


90  if (!double.IsInfinity(x1) && !thresholdList.Any(x => x.IsAlmost(x1))) thresholdList.Add(x1);


91  if (!double.IsInfinity(x2) && !thresholdList.Any(x => x.IsAlmost(x2))) thresholdList.Add(x2);


92  }


93  }


94  thresholdList.Sort();


95 


96  // add small value and large value for the calculation of most influential class in each thresholded section


97  thresholdList.Insert(0, estimatedValues.Min()  1);


98  thresholdList.Add(estimatedValues.Max() + 1);


99 


100  // determine class values for each partition separated by a threshold by calculating the density of all class distributions


101  // all points in the partition are classified as the class with the maximal density in the parition


102  List<double> classValuesList = new List<double>();


103  if (thresholdList.Count == 2) {


104  // this happens if there are no thresholds (distributions for all classes are exactly the same)


105  // > all samples should be classified as the first class


106  classValuesList.Add(originalClasses[0]);


107  } else {


108  // at least one reasonable threshold ...


109  // find the most likely class for the points between thresholds m


110  for (int i = 0; i < thresholdList.Count  1; i++) {


111 


112  // determine class with maximal density mass between the thresholds


113  double maxDensity = LogNormalDensityMass(thresholdList[i], thresholdList[i + 1], classMean[originalClasses[0]], classStdDev[originalClasses[0]]);


114  double maxDensityClassValue = originalClasses[0];


115  foreach (var classValue in originalClasses.Skip(1)) {


116  double density = LogNormalDensityMass(thresholdList[i], thresholdList[i + 1], classMean[classValue], classStdDev[classValue]);


117  if (density > maxDensity) {


118  maxDensity = density;


119  maxDensityClassValue = classValue;


120  }


121  }


122  classValuesList.Add(maxDensityClassValue);


123  }


124  }


125 


126  // only keep thresholds at which the class changes


127  // class B overrides threshold s. So only thresholds r and t are relevant and have to be kept


128  //


129  // A B C


130  // /\ /\/\


131  // / r\/ /\t\


132  // / /\/ \ \


133  // / / /\s \ \


134  // /// \\\


135 


136  List<double> filteredThresholds = new List<double>();


137  List<double> filteredClassValues = new List<double>();


138  filteredThresholds.Add(double.NegativeInfinity); // the smallest possible threshold for the first class


139  filteredClassValues.Add(classValuesList[0]);


140  // do not include the last threshold which was just needed for the previous step


141  for (int i = 0; i < classValuesList.Count  1; i++) {


142  if (!classValuesList[i].IsAlmost(classValuesList[i + 1])) {


143  filteredThresholds.Add(thresholdList[i + 1]);


144  filteredClassValues.Add(classValuesList[i + 1]);


145  }


146  }


147  thresholds = filteredThresholds.ToArray();


148  classValues = filteredClassValues.ToArray();


149  }


150 


151  private static double LogNormalDensityMass(double lower, double upper, double mu, double sigma) {


152  if (sigma.IsAlmost(0.0)) {


153  if (lower < mu && mu < upper) return double.PositiveInfinity; // log(1)


154  else return double.NegativeInfinity; // log(0)


155  }


156 


157  Func<double, double> f = (x) =>


158  x * 0.5 * Math.Log(2.0 * Math.PI * sigma * sigma)  Math.Pow(x  mu, 3) / (3 * 2.0 * sigma * sigma);


159 


160  if (double.IsNegativeInfinity(lower)) return f(upper);


161  else return f(upper)  f(lower);


162  }


163 


164  private static void CalculateCutPoints(double m1, double s1, double m2, double s2, out double x1, out double x2) {


165  if (s1.IsAlmost(s2)) {


166  if (m1.IsAlmost(m2)) {


167  x1 = double.NegativeInfinity;


168  x2 = double.NegativeInfinity;


169  } else {


170  x1 = (m1 + m2) / 2;


171  x2 = double.NegativeInfinity;


172  }


173  } else if (s1.IsAlmost(0.0)) {


174  x1 = m1;


175  x2 = m1;


176  } else if (s2.IsAlmost(0.0)) {


177  x1 = m2;


178  x2 = m2;


179  } else {


180  if (s2 < s1) {


181  // make sure s2 is the larger std.dev.


182  CalculateCutPoints(m2, s2, m1, s1, out x1, out x2);


183  } else {


184  double a = (s1 + s2) * (s1  s2);


185  double g = Math.Sqrt(s1 * s1 * s2 * s2 * ((m1  m2) * (m1  m2) + 2.0 * (s1 * s1 + s2 * s2) * Math.Log(s2 / s1)));


186  double m1s2 = m1 * s2 * s2;


187  double m2s1 = m2 * s1 * s1;


188  x1 = (m2s1 + m1s2 + g) / a;


189  x2 = (m2s1  m1s2 + g) / a;


190  }


191  }


192  }


193  }


194  }

