Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
03/14/11 19:00:05 (13 years ago)
Author:
gkronber
Message:

#1418 Worked on calculation of thresholds for classification solutions based on discriminant functions.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/DataAnalysis Refactoring/HeuristicLab.Algorithms.DataAnalysis/3.4/Linear/LinearDiscriminantAnalysis.cs

    r5664 r5678  
    107107      addition.AddSubTree(cNode);
    108108
    109       var model = new SymbolicDiscriminantFunctionClassificationModel(tree, new SymbolicDataAnalysisExpressionTreeInterpreter(), classValues);
     109
     110      var model = LinearDiscriminantAnalysis.CreateDiscriminantFunctionModel(tree, new SymbolicDataAnalysisExpressionTreeInterpreter(), problemData, rows);
    110111      SymbolicDiscriminantFunctionClassificationSolution solution = new SymbolicDiscriminantFunctionClassificationSolution(model, problemData);
     112
    111113      return solution;
    112114    }
    113115    #endregion
     116
     117    private static SymbolicDiscriminantFunctionClassificationModel CreateDiscriminantFunctionModel(ISymbolicExpressionTree tree,
     118      ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
     119      IClassificationProblemData problemData,
     120      IEnumerable<int> rows) {
     121      string targetVariable = problemData.TargetVariable;
     122      List<double> originalClasses = problemData.ClassValues.ToList();
     123      int nClasses = problemData.Classes;
     124      List<double> estimatedValues = interpreter.GetSymbolicExpressionTreeValues(tree, problemData.Dataset, rows).ToList();
     125      double maxEstimatedValue = estimatedValues.Max();
     126      double minEstimatedValue = estimatedValues.Min();
     127      var estimatedTargetValues =
     128         (from row in problemData.TrainingIndizes
     129          select new { EstimatedValue = estimatedValues[row], TargetValue = problemData.Dataset[targetVariable, row] })
     130         .ToList();
     131
     132      Dictionary<double, double> classMean = new Dictionary<double, double>();
     133      Dictionary<double, double> classStdDev = new Dictionary<double, double>();
     134      // calculate moments per class
     135      foreach (var classValue in originalClasses) {
     136        var estimatedValuesForClass = from x in estimatedTargetValues
     137                                      where x.TargetValue == classValue
     138                                      select x.EstimatedValue;
     139        double mean, variance;
     140        OnlineMeanAndVarianceCalculator.Calculate(estimatedValuesForClass, out mean, out variance);
     141        classMean[classValue] = mean;
     142        classStdDev[classValue] = Math.Sqrt(variance);
     143      }
     144      List<double> thresholds = new List<double>();
     145      for (int i = 0; i < nClasses - 1; i++) {
     146        for (int j = i + 1; j < nClasses; j++) {
     147          double x1, x2;
     148          double class0 = originalClasses[i];
     149          double class1 = originalClasses[j];
     150          // calculate all thresholds
     151          CalculateCutPoints(classMean[class0], classStdDev[class0], classMean[class1], classStdDev[class1], out x1, out x2);
     152          if (!thresholds.Any(x => x.IsAlmost(x1))) thresholds.Add(x1);
     153          if (!thresholds.Any(x => x.IsAlmost(x2))) thresholds.Add(x2);
     154        }
     155      }
     156      thresholds.Sort();
     157      thresholds.Insert(0, double.NegativeInfinity);
     158      thresholds.Add(double.PositiveInfinity);
     159      List<double> classValues = new List<double>();
     160      for (int i = 0; i < thresholds.Count - 1; i++) {
     161        double m;
     162        if (double.IsNegativeInfinity(thresholds[i])) {
     163          m = thresholds[i + 1] - 1.0;
     164        } else if (double.IsPositiveInfinity(thresholds[i + 1])) {
     165          m = thresholds[i] + 1.0;
     166        } else {
     167          m = thresholds[i] + (thresholds[i + 1] - thresholds[i]) / 2.0;
     168        }
     169
     170        double maxDensity = 0;
     171        double maxDensityClassValue = -1;
     172        foreach (var classValue in originalClasses) {
     173          double density = NormalDensity(m, classMean[classValue], classStdDev[classValue]);
     174          if (density > maxDensity) {
     175            maxDensity = density;
     176            maxDensityClassValue = classValue;
     177          }
     178        }
     179        classValues.Add(maxDensityClassValue);
     180      }
     181      List<double> filteredThresholds = new List<double>();
     182      List<double> filteredClassValues = new List<double>();
     183      filteredThresholds.Add(thresholds[0]);
     184      filteredClassValues.Add(classValues[0]);
     185      for (int i = 0; i < classValues.Count - 1; i++) {
     186        if (classValues[i] != classValues[i + 1]) {
     187          filteredThresholds.Add(thresholds[i + 1]);
     188          filteredClassValues.Add(classValues[i + 1]);
     189        }
     190      }
     191      filteredThresholds.Add(double.PositiveInfinity);
     192
     193      return new SymbolicDiscriminantFunctionClassificationModel(tree, interpreter, filteredClassValues, filteredThresholds);
     194    }
     195
     196    private static double NormalDensity(double x, double mu, double sigma) {
     197      return (1.0 / Math.Sqrt(2.0 * Math.PI * sigma * sigma)) * Math.Exp(-((x - mu) * (x - mu)) / (2.0 * sigma * sigma));
     198    }
     199
     200    private static void CalculateCutPoints(double m1, double s1, double m2, double s2, out double x1, out double x2) {
     201      double a = (s1 * s1 - s2 * s2);
     202      double b = (m1 * s2 * s2 - m2 * s1 * s1);
     203      double c = 2 * s1 * s1 * s2 * s2 * Math.Log(s2) - 2 * s1 * s1 * s2 * s2 * Math.Log(s1) - s1 * s1 * m2 * m2 + s2 * s2 * m1 * m1;
     204      x1 = -(-m2 * s1 * s1 + m1 * s2 * s2 + Math.Sqrt(s1 * s1 * s2 * s2 * ((m1 - m2) * (m1 - m2) + 2.0 * (-s1 * s1 + s2 * s2) * Math.Log(s2 / s1)))) / a;
     205      x2 = (m2 * s1 * s1 - m1 * s2 * s2 + Math.Sqrt(s1 * s1 * s2 * s2 * ((m1 - m2) * (m1 - m2) + 2.0 * (-s1 * s1 + s2 * s2) * Math.Log(s2 / s1)))) / a;
     206    }
    114207  }
    115208}
Note: See TracChangeset for help on using the changeset viewer.