1  #region License Information


2  /* HeuristicLab


3  * Copyright (C) 20022012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)


4  *


5  * This file is part of HeuristicLab.


6  *


7  * HeuristicLab is free software: you can redistribute it and/or modify


8  * it under the terms of the GNU General Public License as published by


9  * the Free Software Foundation, either version 3 of the License, or


10  * (at your option) any later version.


11  *


12  * HeuristicLab is distributed in the hope that it will be useful,


13  * but WITHOUT ANY WARRANTY; without even the implied warranty of


14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the


15  * GNU General Public License for more details.


16  *


17  * You should have received a copy of the GNU General Public License


18  * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.


19  */


20  #endregion


21 


22  using System;


23  using System.Collections.Generic;


24  using System.Linq;


25  using HeuristicLab.Common;


26  using HeuristicLab.Core;


27  using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;


28 


29  namespace HeuristicLab.Problems.DataAnalysis {


30  /// <summary>


31  /// Represents a threshold calculator that calculates thresholds as the cutting points between the estimated class distributions (assuming normally distributed class values).


32  /// </summary>


33  [StorableClass]


34  [Item("NormalDistributionCutPointsThresholdCalculator", "Represents a threshold calculator that calculates thresholds as the cutting points between the estimated class distributions (assuming normally distributed class values).")]


35  public class NormalDistributionCutPointsThresholdCalculator : ThresholdCalculator {


36 


37  [StorableConstructor]


38  protected NormalDistributionCutPointsThresholdCalculator(bool deserializing) : base(deserializing) { }


39  protected NormalDistributionCutPointsThresholdCalculator(NormalDistributionCutPointsThresholdCalculator original, Cloner cloner)


40  : base(original, cloner) {


41  }


42  public NormalDistributionCutPointsThresholdCalculator()


43  : base() {


44  }


45 


46  public override IDeepCloneable Clone(Cloner cloner) {


47  return new NormalDistributionCutPointsThresholdCalculator(this, cloner);


48  }


49 


50  public override void Calculate(IClassificationProblemData problemData, IEnumerable<double> estimatedValues, IEnumerable<double> targetClassValues, out double[] classValues, out double[] thresholds) {


51  NormalDistributionCutPointsThresholdCalculator.CalculateThresholds(problemData, estimatedValues, targetClassValues, out classValues, out thresholds);


52  }


53 


54  public static void CalculateThresholds(IClassificationProblemData problemData, IEnumerable<double> estimatedValues, IEnumerable<double> targetClassValues, out double[] classValues, out double[] thresholds) {


55  var estimatedTargetValues = Enumerable.Zip(estimatedValues, targetClassValues, (e, t) => new { EstimatedValue = e, TargetValue = t }).ToList();


56  double estimatedValuesRange = estimatedValues.Range();


57 


58  Dictionary<double, double> classMean = new Dictionary<double, double>();


59  Dictionary<double, double> classStdDev = new Dictionary<double, double>();


60  // calculate moments per class


61  foreach (var group in estimatedTargetValues.GroupBy(p => p.TargetValue)) {


62  IEnumerable<double> estimatedClassValues = group.Select(x => x.EstimatedValue);


63  double classValue = group.Key;


64  double mean, variance;


65  OnlineCalculatorError meanErrorState, varianceErrorState;


66  OnlineMeanAndVarianceCalculator.Calculate(estimatedClassValues, out mean, out variance, out meanErrorState, out varianceErrorState);


67 


68  if (meanErrorState == OnlineCalculatorError.None && varianceErrorState == OnlineCalculatorError.None) {


69  classMean[classValue] = mean;


70  classStdDev[classValue] = Math.Sqrt(variance);


71  }


72  }


73  double[] originalClasses = classMean.Keys.OrderBy(x => x).ToArray();


74  int nClasses = originalClasses.Length;


75  List<double> thresholdList = new List<double>();


76  for (int i = 0; i < nClasses  1; i++) {


77  for (int j = i + 1; j < nClasses; j++) {


78  double x1, x2;


79  double class0 = originalClasses[i];


80  double class1 = originalClasses[j];


81  // calculate all thresholds


82  CalculateCutPoints(classMean[class0], classStdDev[class0], classMean[class1], classStdDev[class1], out x1, out x2);


83 


84  // if the two cut points are too close (for instance because the stdDev=0)


85  // then move them by 0.1% of the range of estimated values


86  if (x1.IsAlmost(x2)) {


87  x1 = 0.001 * estimatedValuesRange;


88  x2 += 0.001 * estimatedValuesRange;


89  }


90  if (!double.IsInfinity(x1) && !thresholdList.Any(x => x.IsAlmost(x1))) thresholdList.Add(x1);


91  if (!double.IsInfinity(x2) && !thresholdList.Any(x => x.IsAlmost(x2))) thresholdList.Add(x2);


92  }


93  }


94  thresholdList.Sort();


95 


96  // add small value and large value for the calculation of most influential class in each thresholded section


97  thresholdList.Insert(0, double.NegativeInfinity);


98  thresholdList.Add(double.PositiveInfinity);


99 


100  // determine class values for each partition separated by a threshold by calculating the density of all class distributions


101  // all points in the partition are classified as the class with the maximal density in the parition


102  if (thresholdList.Count == 2) {


103  // this happens if there are no thresholds (distributions for all classes are exactly the same)


104  // > all samples should be classified as the class with the most observations


105  // group observations by target class and select the class with largest count


106  double mostFrequentClass = targetClassValues.GroupBy(c => c)


107  .OrderBy(g => g.Count())


108  .Last().Key;


109  thresholds = new double[] { double.NegativeInfinity };


110  classValues = new double[] { mostFrequentClass };


111  } else {


112 


113  // at least one reasonable threshold ...


114  // find the most likely class for the points between thresholds m


115  List<double> filteredThresholds = new List<double>();


116  List<double> filteredClassValues = new List<double>();


117  for (int i = 0; i < thresholdList.Count  1; i++) {


118  // determine class with maximal density mass between the thresholds


119  double maxDensity = DensityMass(thresholdList[i], thresholdList[i + 1], classMean[originalClasses[0]], classStdDev[originalClasses[0]]);


120  double maxDensityClassValue = originalClasses[0];


121  foreach (var classValue in originalClasses.Skip(1)) {


122  double density = DensityMass(thresholdList[i], thresholdList[i + 1], classMean[classValue], classStdDev[classValue]);


123  if (density > maxDensity) {


124  maxDensity = density;


125  maxDensityClassValue = classValue;


126  }


127  }


128  if (maxDensity > double.NegativeInfinity &&


129  (filteredClassValues.Count == 0  !maxDensityClassValue.IsAlmost(filteredClassValues.Last()))) {


130  filteredThresholds.Add(thresholdList[i]);


131  filteredClassValues.Add(maxDensityClassValue);


132  }


133  }


134  thresholds = filteredThresholds.ToArray();


135  classValues = filteredClassValues.ToArray();


136  }


137  }


138 


139  private static double sqr2 = Math.Sqrt(2.0);


140  // returns the density function of the standard normal distribution at x


141  private static double NormalCDF(double x) {


142  return 0.5 * (1 + alglib.errorfunction(x / sqr2));


143  }


144 


145  // approximation of the log of the normal cummulative distribution from the lightspeed toolbox by Tom Minka


146  // http://research.microsoft.com/enus/um/people/minka/software/lightspeed/


147  private static double[] c = new double[] { 1, 5 / 2.0, 37 / 3.0, 353 / 4.0, 4081 / 5.0, 55205 / 6.0, 854197 / 7.0 };


148  private static double LogNormalCDF(double x) {


149  if (x >= 6.5)


150  // calculate the log directly if x is large enough


151  return Math.Log(NormalCDF(x));


152  else {


153  double z = Math.Pow(x, 2);


154  // asymptotic series for logcdf


155  double y = z * (c[0] + z * (c[1] + z * (c[2] + z * (c[3] + z * (c[4] + z * (c[5] + z * c[6]))))));


156  return y  0.5 * Math.Log(2 * Math.PI)  0.5 * x * x  Math.Log(x);


157  }


158  }


159 


160  // determines the value NormalCDF(mu,sigma, upper)  NormalCDF(mu, sigma, lower)


161  // = the integral of the PDF of N(mu, sigma) in the range [lower, upper]


162  private static double DensityMass(double lower, double upper, double mu, double sigma) {


163  if (sigma.IsAlmost(0.0)) {


164  if (lower < mu && mu < upper) return 0.0; // all mass is between lower and upper


165  else return double.NegativeInfinity; // no mass is between lower and upper


166  }


167 


168  if (lower > mu) {


169  return DensityMass(upper, lower, mu, sigma);


170  }


171 


172  upper = (upper  mu) / sigma;


173  lower = (lower  mu) / sigma;


174  if (double.IsNegativeInfinity(lower)) return LogNormalCDF(upper);


175 


176  return LogNormalCDF(upper) + Math.Log(1  Math.Exp(LogNormalCDF(lower)  LogNormalCDF(upper)));


177  }


178 


179  // Calculates the points x1 and x2 where the distributions N(m1, s1) == N(m2,s2).


180  // In the general case there should be two cut points. If either s1 or s2 is 0 then x1==x2.


181  // If both s1 and s2 are zero than there are no cut points but we should return something reasonable (e.g. (m1 + m2) / 2) then.


182  private static void CalculateCutPoints(double m1, double s1, double m2, double s2, out double x1, out double x2) {


183  if (s1.IsAlmost(s2)) {


184  if (m1.IsAlmost(m2)) {


185  x1 = double.NegativeInfinity;


186  x2 = double.NegativeInfinity;


187  } else {


188  // s1==s2 and m1 != m2


189  // return something reasonable. cut point should be half way between m1 and m2


190  x1 = (m1 + m2) / 2;


191  x2 = double.NegativeInfinity;


192  }


193  } else if (s1.IsAlmost(0.0)) {


194  // when s1 is 0.0 the cut points are exactly at m1 ...


195  x1 = m1;


196  x2 = m1;


197  } else if (s2.IsAlmost(0.0)) {


198  // ... same for s2


199  x1 = m2;


200  x2 = m2;


201  } else {


202  if (s2 < s1) {


203  // make sure s2 is the larger std.dev.


204  CalculateCutPoints(m2, s2, m1, s1, out x1, out x2);


205  } else {


206  // general case


207  // calculate the solutions x1, x2 where N(m1,s1) == N(m2,s2)


208  double g = Math.Sqrt(2 * s2 * s2 * Math.Log(s2 / s1)  2 * s1 * s1 * Math.Log(s2 / s1)  2 * m1 * m2 + m1 * m1 + m2 * m2);


209  double s = (s1 * s1  s2 * s2);


210  x1 = (m2 * s1 * s1  m1 * s2 * s2 + s1 * s2 * g) / s;


211  x2 = (m1 * s2 * s2  m2 * s1 * s1 + s1 * s2 * g) / s;


212  }


213  }


214  }


215  }


216  }

