Free cookie consent management tool by TermsFeed Policy Generator

source: stable/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ThresholdCalculators/NormalDistributionCutPointsThresholdCalculator.cs @ 17541

Last change on this file since 17541 was 17181, checked in by swagner, 5 years ago

#2875: Merged r17180 from trunk to stable

File size: 10.5 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HEAL.Attic;
28
29namespace HeuristicLab.Problems.DataAnalysis {
30  /// <summary>
31  /// Represents a threshold calculator that calculates thresholds as the cutting points between the estimated class distributions (assuming normally distributed class values).
32  /// </summary>
33  [StorableType("D01CB5DC-606B-4CE9-B293-2D4D80A70BB8")]
34  [Item("NormalDistributionCutPointsThresholdCalculator", "Represents a threshold calculator that calculates thresholds as the cutting points between the estimated class distributions (assuming normally distributed class values).")]
35  public class NormalDistributionCutPointsThresholdCalculator : ThresholdCalculator {
36
37    [StorableConstructor]
38    protected NormalDistributionCutPointsThresholdCalculator(StorableConstructorFlag _) : base(_) { }
39    protected NormalDistributionCutPointsThresholdCalculator(NormalDistributionCutPointsThresholdCalculator original, Cloner cloner)
40      : base(original, cloner) {
41    }
42    public NormalDistributionCutPointsThresholdCalculator()
43      : base() {
44    }
45
46    public override IDeepCloneable Clone(Cloner cloner) {
47      return new NormalDistributionCutPointsThresholdCalculator(this, cloner);
48    }
49
50    public override void Calculate(IClassificationProblemData problemData, IEnumerable<double> estimatedValues, IEnumerable<double> targetClassValues, out double[] classValues, out double[] thresholds) {
51      NormalDistributionCutPointsThresholdCalculator.CalculateThresholds(problemData, estimatedValues, targetClassValues, out classValues, out thresholds);
52    }
53
54    public static void CalculateThresholds(IClassificationProblemData problemData, IEnumerable<double> estimatedValues, IEnumerable<double> targetClassValues, out double[] classValues, out double[] thresholds) {
55      var estimatedTargetValues = Enumerable.Zip(estimatedValues, targetClassValues, (e, t) => new { EstimatedValue = e, TargetValue = t }).ToList();
56      double estimatedValuesRange = estimatedValues.Range();
57
58      Dictionary<double, double> classMean = new Dictionary<double, double>();
59      Dictionary<double, double> classStdDev = new Dictionary<double, double>();
60      // calculate moments per class
61      foreach (var group in estimatedTargetValues.GroupBy(p => p.TargetValue)) {
62        IEnumerable<double> estimatedClassValues = group.Select(x => x.EstimatedValue);
63        double classValue = group.Key;
64        double mean, variance;
65        OnlineCalculatorError meanErrorState, varianceErrorState;
66        OnlineMeanAndVarianceCalculator.Calculate(estimatedClassValues, out mean, out variance, out meanErrorState, out varianceErrorState);
67
68        if (meanErrorState == OnlineCalculatorError.None && varianceErrorState == OnlineCalculatorError.None) {
69          classMean[classValue] = mean;
70          classStdDev[classValue] = Math.Sqrt(variance);
71        }
72      }
73
74      double[] originalClasses = classMean.Keys.OrderBy(x => x).ToArray();
75      int nClasses = originalClasses.Length;
76      List<double> thresholdList = new List<double>();
77      for (int i = 0; i < nClasses - 1; i++) {
78        for (int j = i + 1; j < nClasses; j++) {
79          double x1, x2;
80          double class0 = originalClasses[i];
81          double class1 = originalClasses[j];
82          // calculate all thresholds
83          CalculateCutPoints(classMean[class0], classStdDev[class0], classMean[class1], classStdDev[class1], out x1, out x2);
84
85          // if the two cut points are too close (for instance because the stdDev=0)
86          // then move them by 0.1% of the range of estimated values
87          if (x1.IsAlmost(x2)) {
88            x1 -= 0.001 * estimatedValuesRange;
89            x2 += 0.001 * estimatedValuesRange;
90          }
91          if (!double.IsInfinity(x1) && !thresholdList.Any(x => x.IsAlmost(x1))) thresholdList.Add(x1);
92          if (!double.IsInfinity(x2) && !thresholdList.Any(x => x.IsAlmost(x2))) thresholdList.Add(x2);
93        }
94      }
95      thresholdList.Sort();
96
97      // add small value and large value for the calculation of most influential class in each thresholded section
98      thresholdList.Insert(0, double.NegativeInfinity);
99      thresholdList.Add(double.PositiveInfinity);
100
101
102      // find the most likely class for the points between thresholds m
103      List<double> filteredThresholds = new List<double>();
104      List<double> filteredClassValues = new List<double>();
105      for (int i = 0; i < thresholdList.Count - 1; i++) {
106        // determine class with maximal density mass between the thresholds
107        double maxDensity = DensityMass(thresholdList[i], thresholdList[i + 1], classMean[originalClasses[0]], classStdDev[originalClasses[0]]);
108        double maxDensityClassValue = originalClasses[0];
109        foreach (var classValue in originalClasses.Skip(1)) {
110          double density = DensityMass(thresholdList[i], thresholdList[i + 1], classMean[classValue], classStdDev[classValue]);
111          if (density > maxDensity) {
112            maxDensity = density;
113            maxDensityClassValue = classValue;
114          }
115        }
116        if (maxDensity > double.NegativeInfinity &&
117          (filteredClassValues.Count == 0 || !maxDensityClassValue.IsAlmost(filteredClassValues.Last()))) {
118          filteredThresholds.Add(thresholdList[i]);
119          filteredClassValues.Add(maxDensityClassValue);
120        }
121      }
122
123      if (filteredThresholds.Count == 0 || !double.IsNegativeInfinity(filteredThresholds.First())) {
124        // this happens if there are no thresholds (distributions for all classes are exactly the same)
125        // or when the CDF up to the first threshold is zero
126        // -> all samples should be classified as the class with the most observations
127        // group observations by target class and select the class with largest count
128        double mostFrequentClass = targetClassValues.GroupBy(c => c)
129                              .OrderBy(g => g.Count())
130                              .Last().Key;
131        filteredThresholds.Insert(0, double.NegativeInfinity);
132        filteredClassValues.Insert(0, mostFrequentClass);
133      }
134
135      thresholds = filteredThresholds.ToArray();
136      classValues = filteredClassValues.ToArray();
137    }
138
139    private static double sqr2 = Math.Sqrt(2.0);
140    // returns the density function of the standard normal distribution at x
141    private static double NormalCDF(double x) {
142      return 0.5 * (1 + alglib.errorfunction(x / sqr2));
143    }
144
145    // approximation of the log of the normal cummulative distribution from the lightspeed toolbox by Tom Minka
146    // http://research.microsoft.com/en-us/um/people/minka/software/lightspeed/
147    private static double[] c = new double[] { -1, 5 / 2.0, -37 / 3.0, 353 / 4.0, -4081 / 5.0, 55205 / 6.0, -854197 / 7.0 };
148    private static double LogNormalCDF(double x) {
149      if (x >= -6.5)
150        // calculate the log directly if x is large enough
151        return Math.Log(NormalCDF(x));
152      else {
153        double z = Math.Pow(x, -2);
154        // asymptotic series for logcdf
155        double y = z * (c[0] + z * (c[1] + z * (c[2] + z * (c[3] + z * (c[4] + z * (c[5] + z * c[6]))))));
156        return y - 0.5 * Math.Log(2 * Math.PI) - 0.5 * x * x - Math.Log(-x);
157      }
158    }
159
160    // determines the value NormalCDF(mu,sigma, upper)  - NormalCDF(mu, sigma, lower)
161    // = the integral of the PDF of N(mu, sigma) in the range [lower, upper]
162    private static double DensityMass(double lower, double upper, double mu, double sigma) {
163      if (sigma.IsAlmost(0.0)) {
164        if (lower < mu && mu < upper) return 0.0; // all mass is between lower and upper
165        else return double.NegativeInfinity; // no mass is between lower and upper
166      }
167
168      if (lower > mu) {
169        return DensityMass(-upper, -lower, -mu, sigma);
170      }
171
172      upper = (upper - mu) / sigma;
173      lower = (lower - mu) / sigma;
174      if (double.IsNegativeInfinity(lower)) return LogNormalCDF(upper);
175
176      return LogNormalCDF(upper) + Math.Log(1 - Math.Exp(LogNormalCDF(lower) - LogNormalCDF(upper)));
177    }
178
179    // Calculates the points x1 and x2 where the distributions N(m1, s1) == N(m2,s2).
180    // In the general case there should be two cut points. If either s1 or s2 is 0 then x1==x2.
181    // If both s1 and s2 are zero than there are no cut points but we should return something reasonable (e.g. (m1 + m2) / 2) then.
182    private static void CalculateCutPoints(double m1, double s1, double m2, double s2, out double x1, out double x2) {
183      if (s1.IsAlmost(s2)) {
184        if (m1.IsAlmost(m2)) {
185          x1 = double.NegativeInfinity;
186          x2 = double.NegativeInfinity;
187        } else {
188          // s1==s2 and m1 != m2
189          // return something reasonable. cut point should be half way between m1 and m2
190          x1 = (m1 + m2) / 2;
191          x2 = double.NegativeInfinity;
192        }
193      } else if (s1.IsAlmost(0.0)) {
194        // when s1 is 0.0 the cut points are exactly at m1 ...
195        x1 = m1;
196        x2 = m1;
197      } else if (s2.IsAlmost(0.0)) {
198        // ... same for s2
199        x1 = m2;
200        x2 = m2;
201      } else {
202        if (s2 < s1) {
203          // make sure s2 is the larger std.dev.
204          CalculateCutPoints(m2, s2, m1, s1, out x1, out x2);
205        } else {
206          // general case
207          // calculate the solutions x1, x2 where N(m1,s1) == N(m2,s2)
208          double g = Math.Sqrt(2 * s2 * s2 * Math.Log(s2 / s1) - 2 * s1 * s1 * Math.Log(s2 / s1) - 2 * m1 * m2 + m1 * m1 + m2 * m2);
209          double s = (s1 * s1 - s2 * s2);
210          x1 = (m2 * s1 * s1 - m1 * s2 * s2 + s1 * s2 * g) / s;
211          x2 = -(m1 * s2 * s2 - m2 * s1 * s1 + s1 * s2 * g) / s;
212        }
213      }
214    }
215  }
216}
Note: See TracBrowser for help on using the repository browser.