Changeset 8658


Ignore:
Timestamp:
09/14/12 18:48:07 (8 years ago)
Author:
gkronber
Message:

#1925 fixed some bugs in NormalDistributionCutPointsThresholdCalculator (probably more remaining)

Location:
trunk/sources
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ThresholdCalculators/NormalDistributionCutPointsThresholdCalculator.cs

    r8638 r8658  
    9393      }
    9494      thresholdList.Sort();
    95       thresholdList.Insert(0, double.NegativeInfinity);
     95
     96      // add small value and large value for the calculation of most influential class in each thresholded section
     97      thresholdList.Insert(0, estimatedValues.Min() - 1);
     98      thresholdList.Add(estimatedValues.Max() + 1);
    9699
    97100      // determine class values for each partition separated by a threshold by calculating the density of all class distributions
    98101      // all points in the partition are classified as the class with the maximal density in the parition
    99102      List<double> classValuesList = new List<double>();
    100       if (thresholdList.Count == 1) {
     103      if (thresholdList.Count == 2) {
    101104        // this happens if there are no thresholds (distributions for all classes are exactly the same)
    102105        // -> all samples should be classified as the first class
     
    105108        // at least one reasonable threshold ...
    106109        // find the most likely class for the points between thresholds m
    107         for (int i = 0; i < thresholdList.Count; i++) {
    108           double m;
    109           if (double.IsNegativeInfinity(thresholdList[i])) {
    110             m = thresholdList[i + 1] - 1.0; // smaller than the smallest non-infinity threshold
    111           } else if (i == thresholdList.Count - 1) {
    112             // last threshold
    113             m = thresholdList[i] + 1.0; // larger than the last threshold
    114           } else {
    115             m = thresholdList[i] + (thresholdList[i + 1] - thresholdList[i]) / 2.0; // middle of partition
    116           }
    117 
    118           // determine class with maximal probability density in m
    119           double maxDensity = LogNormalDensity(m, classMean[originalClasses[0]], classStdDev[originalClasses[0]]);
     110        for (int i = 0; i < thresholdList.Count - 1; i++) {
     111
     112          // determine class with maximal density mass between the thresholds
     113          double maxDensity = LogNormalDensityMass(thresholdList[i], thresholdList[i + 1], classMean[originalClasses[0]], classStdDev[originalClasses[0]]);
    120114          double maxDensityClassValue = originalClasses[0];
    121115          foreach (var classValue in originalClasses.Skip(1)) {
    122             double density = LogNormalDensity(m, classMean[classValue], classStdDev[classValue]);
     116            double density = LogNormalDensityMass(thresholdList[i], thresholdList[i + 1], classMean[classValue], classStdDev[classValue]);
    123117            if (density > maxDensity) {
    124118              maxDensity = density;
     
    139133      //    /   / /\s  \ \     
    140134      //  -/---/-/ -\---\-\----
     135
    141136      List<double> filteredThresholds = new List<double>();
    142137      List<double> filteredClassValues = new List<double>();
    143       filteredThresholds.Add(thresholdList[0]);
     138      filteredThresholds.Add(double.NegativeInfinity); // the smallest possible threshold for the first class
    144139      filteredClassValues.Add(classValuesList[0]);
     140      // do not include the last threshold which was just needed for the previous step
    145141      for (int i = 0; i < classValuesList.Count - 1; i++) {
    146142        if (!classValuesList[i].IsAlmost(classValuesList[i + 1])) {
     
    153149    }
    154150
     151    private static double LogNormalDensityMass(double lower, double upper, double mu, double sigma) {
     152      if (sigma.IsAlmost(0.0)) {
     153        if (lower < mu && mu < upper) return double.PositiveInfinity; // log(1)
     154        else return double.NegativeInfinity; // log(0)
     155      }
     156
     157      Func<double, double> f = (x) =>
     158        x * -0.5 * Math.Log(2.0 * Math.PI * sigma * sigma) - Math.Pow(x - mu, 3) / (3 * 2.0 * sigma * sigma);
     159
     160      if (double.IsNegativeInfinity(lower)) return f(upper);
     161      else return f(upper) - f(lower);
     162    }
     163
    155164    private static double LogNormalDensity(double x, double mu, double sigma) {
    156       if (sigma.IsAlmost(0.0))
    157         if (mu.IsAlmost(x)) return 0.0; // (log(1))
     165      if (sigma.IsAlmost(0.0)) {
     166        if (x.IsAlmost(mu)) return 0.0; // log(1);
    158167        else return double.NegativeInfinity;
    159       return -0.5 * Math.Log(2.0 * Math.PI * sigma * sigma) - ((x - mu) * (x - mu)) / (2.0 * sigma * sigma);
     168      }
     169
     170      return -0.5 * Math.Log(2.0 * Math.PI * sigma * sigma) - Math.Pow(x - mu, 2) / (2.0 * sigma * sigma);
    160171    }
    161172
     
    176187        x2 = m2;
    177188      } else {
    178         // scale s1 and s2 for numeric stability
    179         s2 = s2 / s1;
    180         s1 = 1.0;
    181         double a = (s1 + s2) * (s1 - s2);
    182         double g = Math.Sqrt(s1 * s1 * s2 * s2 * ((m1 - m2) * (m1 - m2) + 2.0 * (s1 * s1 + s2 * s2) * Math.Log(s2 / s1)));
    183         double m1s2 = m1 * s2 * s2;
    184         double m2s1 = m2 * s1 * s1;
    185         x1 = -(-m2s1 + m1s2 + g) / a;
    186         x2 = (m2s1 - m1s2 + g) / a;
     189        if (s2 < s1) {
     190          // make sure s2 is the larger std.dev.
     191          CalculateCutPoints(m2, s2, m1, s1, out x1, out x2);
     192        } else {
     193          // scale s1 and s2 for numeric stability
     194          //s2 = s2 / s1;
     195          //s1 = 1.0;
     196          double a = (s1 + s2) * (s1 - s2);
     197          double g = Math.Sqrt(s1 * s1 * s2 * s2 * ((m1 - m2) * (m1 - m2) + 2.0 * (s1 * s1 + s2 * s2) * Math.Log(s2 / s1)));
     198          double m1s2 = m1 * s2 * s2;
     199          double m2s1 = m2 * s1 * s1;
     200          x1 = -(-m2s1 + m1s2 + g) / a;
     201          x2 = (m2s1 - m1s2 + g) / a;
     202        }
    187203      }
    188204    }
  • trunk/sources/HeuristicLab.Tests/HeuristicLab.Problems.DataAnalysis-3.4/ThresholdCalculatorsTest.cs

    r8638 r8658  
    111111      }
    112112
     113
     114      {
     115        // normal operation
     116        double[] estimatedValues = new double[]
     117                                     {
     118                                       2.9937,
     119                                       2.9861,
     120                                       1.0202,
     121                                       0.9844,
     122                                       1.9912,
     123                                       1.9970,
     124                                       0.9776,
     125                                       0.9611,
     126                                       1.9882,
     127                                       1.9953,
     128                                       2.0147,
     129                                       2.0106,
     130                                       2.9949,
     131                                       0.9925,
     132                                       3.0050,
     133                                       1.9987,
     134                                       2.9973,
     135                                       1.0110,
     136                                       2.0160,
     137                                       2.9559,
     138                                       1.9943,
     139                                       2.9477,
     140                                       2.0158,
     141                                       2.0026,
     142                                       1.9837,
     143                                       3.0185,
     144                                     };
     145        double[] targetClassValues = new double[]
     146                                       {
     147                                          3,
     148                                          3,
     149                                          1,
     150                                          1,
     151                                          2,
     152                                          2,
     153                                          1,
     154                                          1,
     155                                          2,
     156                                          2,
     157                                          2,
     158                                          2,
     159                                          3,
     160                                          1,
     161                                          3,
     162                                          2,
     163                                          3,
     164                                          1,
     165                                          2,
     166                                          3,
     167                                          2,
     168                                          3,
     169                                          2,
     170                                          2,
     171                                          2,
     172                                          3,
     173                                       };
     174
     175        double[] classValues;
     176        double[] thresholds;
     177        NormalDistributionCutPointsThresholdCalculator.CalculateThresholds(null, estimatedValues, targetClassValues,
     178                                                                           out classValues, out thresholds);
     179
     180
     181        var expectedClassValues = new double[] { 2.0, 1.0, 2.0, 3.0 };
     182        var expectedTresholds = new double[] { double.NegativeInfinity, -18.365068542315438, 1.6573010498191565, 2.314962133866949 };
     183
     184        AssertEqual(expectedClassValues, classValues);
     185        AssertEqual(expectedTresholds, thresholds);
     186      }
    113187    }
    114188
Note: See TracChangeset for help on using the changeset viewer.