Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
03/15/11 08:25:27 (13 years ago)
Author:
gkronber
Message:

#1418 refactored threshold calculators.

Location:
branches/DataAnalysis Refactoring/HeuristicLab.Algorithms.DataAnalysis/3.4
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • branches/DataAnalysis Refactoring/HeuristicLab.Algorithms.DataAnalysis/3.4

    • Property svn:ignore
      •  

        old new  
        33HeuristicLabAlgorithmsDataAnalysisPlugin.cs
        44HeuristicLab.Algorithms.DataAnalysis-3.4.csproj.user
         5*.vs10x
  • branches/DataAnalysis Refactoring/HeuristicLab.Algorithms.DataAnalysis/3.4/Linear/LinearDiscriminantAnalysis.cs

    r5680 r5681  
    120120      IEnumerable<int> rows) {
    121121      string targetVariable = problemData.TargetVariable;
    122       List<double> originalClasses = problemData.ClassValues.ToList();
    123       int nClasses = problemData.Classes;
    124       List<double> estimatedValues = interpreter.GetSymbolicExpressionTreeValues(tree, problemData.Dataset, rows).ToList();
    125       double maxEstimatedValue = estimatedValues.Max();
    126       double minEstimatedValue = estimatedValues.Min();
    127       var estimatedTargetValues =
    128          (from row in problemData.TrainingIndizes
    129           select new { EstimatedValue = estimatedValues[row], TargetValue = problemData.Dataset[targetVariable, row] })
    130          .ToList();
     122      var estimatedValues = interpreter.GetSymbolicExpressionTreeValues(tree, problemData.Dataset, rows);
     123      var targetClassValues = problemData.Dataset.GetEnumeratedVariableValues(targetVariable, rows);
    131124
    132       Dictionary<double, double> classMean = new Dictionary<double, double>();
    133       Dictionary<double, double> classStdDev = new Dictionary<double, double>();
    134       // calculate moments per class
    135       foreach (var classValue in originalClasses) {
    136         var estimatedValuesForClass = from x in estimatedTargetValues
    137                                       where x.TargetValue == classValue
    138                                       select x.EstimatedValue;
    139         double mean, variance;
    140         OnlineMeanAndVarianceCalculator.Calculate(estimatedValuesForClass, out mean, out variance);
    141         classMean[classValue] = mean;
    142         classStdDev[classValue] = Math.Sqrt(variance);
    143       }
    144       List<double> thresholds = new List<double>();
    145       for (int i = 0; i < nClasses - 1; i++) {
    146         for (int j = i + 1; j < nClasses; j++) {
    147           double x1, x2;
    148           double class0 = originalClasses[i];
    149           double class1 = originalClasses[j];
    150           // calculate all thresholds
    151           CalculateCutPoints(classMean[class0], classStdDev[class0], classMean[class1], classStdDev[class1], out x1, out x2);
    152           if (!thresholds.Any(x => x.IsAlmost(x1))) thresholds.Add(x1);
    153           if (!thresholds.Any(x => x.IsAlmost(x2))) thresholds.Add(x2);
    154         }
    155       }
    156       thresholds.Sort();
    157       thresholds.Insert(0, double.NegativeInfinity);
    158       thresholds.Add(double.PositiveInfinity);
    159       List<double> classValues = new List<double>();
    160       for (int i = 0; i < thresholds.Count - 1; i++) {
    161         double m;
    162         if (double.IsNegativeInfinity(thresholds[i])) {
    163           m = thresholds[i + 1] - 1.0;
    164         } else if (double.IsPositiveInfinity(thresholds[i + 1])) {
    165           m = thresholds[i] + 1.0;
    166         } else {
    167           m = thresholds[i] + (thresholds[i + 1] - thresholds[i]) / 2.0;
    168         }
    169 
    170         double maxDensity = 0;
    171         double maxDensityClassValue = -1;
    172         foreach (var classValue in originalClasses) {
    173           double density = NormalDensity(m, classMean[classValue], classStdDev[classValue]);
    174           if (density > maxDensity) {
    175             maxDensity = density;
    176             maxDensityClassValue = classValue;
    177           }
    178         }
    179         classValues.Add(maxDensityClassValue);
    180       }
    181       List<double> filteredThresholds = new List<double>();
    182       List<double> filteredClassValues = new List<double>();
    183       filteredThresholds.Add(thresholds[0]);
    184       filteredClassValues.Add(classValues[0]);
    185       for (int i = 0; i < classValues.Count - 1; i++) {
    186         if (classValues[i] != classValues[i + 1]) {
    187           filteredThresholds.Add(thresholds[i + 1]);
    188           filteredClassValues.Add(classValues[i + 1]);
    189         }
    190       }
    191       filteredThresholds.Add(double.PositiveInfinity);
    192 
    193       return new SymbolicDiscriminantFunctionClassificationModel(tree, interpreter, filteredClassValues, filteredThresholds);
    194     }
    195 
    196     private static double NormalDensity(double x, double mu, double sigma) {
    197       return (1.0 / Math.Sqrt(2.0 * Math.PI * sigma * sigma)) * Math.Exp(-((x - mu) * (x - mu)) / (2.0 * sigma * sigma));
    198     }
    199 
    200     private static void CalculateCutPoints(double m1, double s1, double m2, double s2, out double x1, out double x2) {
    201       double a = (s1 * s1 - s2 * s2);
    202       x1 = -(-m2 * s1 * s1 + m1 * s2 * s2 + Math.Sqrt(s1 * s1 * s2 * s2 * ((m1 - m2) * (m1 - m2) + 2.0 * (-s1 * s1 + s2 * s2) * Math.Log(s2 / s1)))) / a;
    203       x2 = (m2 * s1 * s1 - m1 * s2 * s2 + Math.Sqrt(s1 * s1 * s2 * s2 * ((m1 - m2) * (m1 - m2) + 2.0 * (-s1 * s1 + s2 * s2) * Math.Log(s2 / s1)))) / a;
     125      double[] classValues;
     126      double[] thresholds;
     127      NormalDistributionCutPointsThresholdCalculator.CalculateThresholds(problemData, estimatedValues, targetClassValues, out classValues, out thresholds);
     128      return new SymbolicDiscriminantFunctionClassificationModel(tree, interpreter, classValues, thresholds);
    204129    }
    205130  }
Note: See TracChangeset for help on using the changeset viewer.