Free cookie consent management tool by TermsFeed Policy Generator

Changeset 8638


Ignore:
Timestamp:
09/12/12 17:28:08 (12 years ago)
Author:
gkronber
Message:

#1925 implemented improvements for the NormalDistributionCutPointsThresholdCalculator

Location:
trunk/sources
Files:
1 added
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/ThresholdCalculators/NormalDistributionCutPointsThresholdCalculator.cs

    r8623 r8638  
    5353
    5454    public static void CalculateThresholds(IClassificationProblemData problemData, IEnumerable<double> estimatedValues, IEnumerable<double> targetClassValues, out double[] classValues, out double[] thresholds) {
    55       double maxEstimatedValue = estimatedValues.Max();
    56       double minEstimatedValue = estimatedValues.Min();
    5755      var estimatedTargetValues = Enumerable.Zip(estimatedValues, targetClassValues, (e, t) => new { EstimatedValue = e, TargetValue = t }).ToList();
     56      double estimatedValuesRange = estimatedValues.Range();
    5857
    5958      Dictionary<double, double> classMean = new Dictionary<double, double>();
     
    8281          // calculate all thresholds
    8382          CalculateCutPoints(classMean[class0], classStdDev[class0], classMean[class1], classStdDev[class1], out x1, out x2);
    84           if (!thresholdList.Any(x => x.IsAlmost(x1))) thresholdList.Add(x1);
    85           if (!thresholdList.Any(x => x.IsAlmost(x2))) thresholdList.Add(x2);
     83
     84          // if the two cut points are too close (for instance because the stdDev=0)
     85          // then move them by 0.1% of the range of estimated values
     86          if (x1.IsAlmost(x2)) {
     87            x1 -= 0.001 * estimatedValuesRange;
     88            x2 += 0.001 * estimatedValuesRange;
     89          }
     90          if (!double.IsInfinity(x1) && !thresholdList.Any(x => x.IsAlmost(x1))) thresholdList.Add(x1);
     91          if (!double.IsInfinity(x2) && !thresholdList.Any(x => x.IsAlmost(x2))) thresholdList.Add(x2);
    8692        }
    8793      }
     
    9298      // all points in the partition are classified as the class with the maximal density in the parition
    9399      List<double> classValuesList = new List<double>();
    94       for (int i = 0; i < thresholdList.Count; i++) {
    95         double m;
    96         if (double.IsNegativeInfinity(thresholdList[i])) {
    97           m = thresholdList[i + 1] - 1.0; // smaller than the smalles non-infinity threshold
    98         } else if (i == thresholdList.Count - 1) {
    99           // last threshold
    100           m = thresholdList[i] + 1.0; // larger than the last threshold
    101         } else {
    102           m = thresholdList[i] + (thresholdList[i + 1] - thresholdList[i]) / 2.0; // middle of partition
     100      if (thresholdList.Count == 1) {
     101        // this happens if there are no thresholds (distributions for all classes are exactly the same)
     102        // -> all samples should be classified as the first class
     103        classValuesList.Add(originalClasses[0]);
     104      } else {
     105        // at least one reasonable threshold ...
     106        // find the most likely class for the points between thresholds m
     107        for (int i = 0; i < thresholdList.Count; i++) {
     108          double m;
     109          if (double.IsNegativeInfinity(thresholdList[i])) {
     110            m = thresholdList[i + 1] - 1.0; // smaller than the smallest non-infinity threshold
     111          } else if (i == thresholdList.Count - 1) {
     112            // last threshold
     113            m = thresholdList[i] + 1.0; // larger than the last threshold
     114          } else {
     115            m = thresholdList[i] + (thresholdList[i + 1] - thresholdList[i]) / 2.0; // middle of partition
     116          }
     117
     118          // determine class with maximal probability density in m
     119          double maxDensity = LogNormalDensity(m, classMean[originalClasses[0]], classStdDev[originalClasses[0]]);
     120          double maxDensityClassValue = originalClasses[0];
     121          foreach (var classValue in originalClasses.Skip(1)) {
     122            double density = LogNormalDensity(m, classMean[classValue], classStdDev[classValue]);
     123            if (density > maxDensity) {
     124              maxDensity = density;
     125              maxDensityClassValue = classValue;
     126            }
     127          }
     128          classValuesList.Add(maxDensityClassValue);
    103129        }
    104 
    105         // determine class with maximal probability density in m
    106         double maxDensity = double.MinValue;
    107         double maxDensityClassValue = -1;
    108         foreach (var classValue in originalClasses) {
    109           double density = LogNormalDensity(m, classMean[classValue], classStdDev[classValue]);
    110           if (density > maxDensity) {
    111             maxDensity = density;
    112             maxDensityClassValue = classValue;
    113           }
    114         }
    115         classValuesList.Add(maxDensityClassValue);
    116130      }
    117131
     
    130144      filteredClassValues.Add(classValuesList[0]);
    131145      for (int i = 0; i < classValuesList.Count - 1; i++) {
    132         if (classValuesList[i] != classValuesList[i + 1]) {
     146        if (!classValuesList[i].IsAlmost(classValuesList[i + 1])) {
    133147          filteredThresholds.Add(thresholdList[i + 1]);
    134148          filteredClassValues.Add(classValuesList[i + 1]);
     
    140154
    141155    private static double LogNormalDensity(double x, double mu, double sigma) {
     156      if (sigma.IsAlmost(0.0))
     157        if (mu.IsAlmost(x)) return 0.0; // (log(1))
     158        else return double.NegativeInfinity;
    142159      return -0.5 * Math.Log(2.0 * Math.PI * sigma * sigma) - ((x - mu) * (x - mu)) / (2.0 * sigma * sigma);
    143160    }
    144161
    145162    private static void CalculateCutPoints(double m1, double s1, double m2, double s2, out double x1, out double x2) {
    146       double a = (s1 * s1 - s2 * s2);
    147       x1 = -(-m2 * s1 * s1 + m1 * s2 * s2 + Math.Sqrt(s1 * s1 * s2 * s2 * ((m1 - m2) * (m1 - m2) + 2.0 * (-s1 * s1 + s2 * s2) * Math.Log(s2 / s1)))) / a;
    148       x2 = (m2 * s1 * s1 - m1 * s2 * s2 + Math.Sqrt(s1 * s1 * s2 * s2 * ((m1 - m2) * (m1 - m2) + 2.0 * (-s1 * s1 + s2 * s2) * Math.Log(s2 / s1)))) / a;
     163      if (s1.IsAlmost(s2)) {
     164        if (m1.IsAlmost(m2)) {
     165          x1 = double.NegativeInfinity;
     166          x2 = double.NegativeInfinity;
     167        } else {
     168          x1 = (m1 + m2) / 2;
     169          x2 = double.NegativeInfinity;
     170        }
     171      } else if (s1.IsAlmost(0.0)) {
     172        x1 = m1;
     173        x2 = m1;
     174      } else if (s2.IsAlmost(0.0)) {
     175        x1 = m2;
     176        x2 = m2;
     177      } else {
     178        // scale s1 and s2 for numeric stability
     179        s2 = s2 / s1;
     180        s1 = 1.0;
     181        double a = (s1 + s2) * (s1 - s2);
     182        double g = Math.Sqrt(s1 * s1 * s2 * s2 * ((m1 - m2) * (m1 - m2) + 2.0 * (s1 * s1 + s2 * s2) * Math.Log(s2 / s1)));
     183        double m1s2 = m1 * s2 * s2;
     184        double m2s1 = m2 * s1 * s1;
     185        x1 = -(-m2s1 + m1s2 + g) / a;
     186        x2 = (m2s1 - m1s2 + g) / a;
     187      }
    149188    }
    150189  }
  • trunk/sources/HeuristicLab.Tests/HeuristicLab.Tests.csproj

    r8611 r8638  
    357357    <Compile Include="HeuristicLab.PluginInfraStructure-3.3\InstallationManagerTest.cs" />
    358358    <Compile Include="HeuristicLab.PluginInfraStructure-3.3\TypeDiscoveryTest.cs" />
     359    <Compile Include="HeuristicLab.Problems.DataAnalysis-3.4\ThresholdCalculatorsTest.cs" />
    359360    <Compile Include="HeuristicLab.Problems.DataAnalysis-3.4\OnlineCalculatorPerformanceTest.cs" />
    360361    <Compile Include="HeuristicLab.Problems.DataAnalysis-3.4\StatisticCalculatorsTest.cs" />
Note: See TracChangeset for help on using the changeset viewer.