Free cookie consent management tool by TermsFeed Policy Generator

Changeset 5678 for branches


Ignore:
Timestamp:
03/14/11 19:00:05 (14 years ago)
Author:
gkronber
Message:

#1418 Worked on calculation of thresholds for classification solutions based on discriminant functions.

Location:
branches/DataAnalysis Refactoring
Files:
2 added
9 edited

Legend:

Unmodified
Added
Removed
  • branches/DataAnalysis Refactoring/HeuristicLab.Algorithms.DataAnalysis/3.4/Linear/LinearDiscriminantAnalysis.cs

    r5664 r5678  
    107107      addition.AddSubTree(cNode);
    108108
    109       var model = new SymbolicDiscriminantFunctionClassificationModel(tree, new SymbolicDataAnalysisExpressionTreeInterpreter(), classValues);
     109
     110      var model = LinearDiscriminantAnalysis.CreateDiscriminantFunctionModel(tree, new SymbolicDataAnalysisExpressionTreeInterpreter(), problemData, rows);
    110111      SymbolicDiscriminantFunctionClassificationSolution solution = new SymbolicDiscriminantFunctionClassificationSolution(model, problemData);
     112
    111113      return solution;
    112114    }
    113115    #endregion
     116
     117    private static SymbolicDiscriminantFunctionClassificationModel CreateDiscriminantFunctionModel(ISymbolicExpressionTree tree,
     118      ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
     119      IClassificationProblemData problemData,
     120      IEnumerable<int> rows) {
     121      string targetVariable = problemData.TargetVariable;
     122      List<double> originalClasses = problemData.ClassValues.ToList();
     123      int nClasses = problemData.Classes;
     124      List<double> estimatedValues = interpreter.GetSymbolicExpressionTreeValues(tree, problemData.Dataset, rows).ToList();
     125      double maxEstimatedValue = estimatedValues.Max();
     126      double minEstimatedValue = estimatedValues.Min();
     127      var estimatedTargetValues =
     128         (from row in problemData.TrainingIndizes
     129          select new { EstimatedValue = estimatedValues[row], TargetValue = problemData.Dataset[targetVariable, row] })
     130         .ToList();
     131
     132      Dictionary<double, double> classMean = new Dictionary<double, double>();
     133      Dictionary<double, double> classStdDev = new Dictionary<double, double>();
     134      // calculate moments per class
     135      foreach (var classValue in originalClasses) {
     136        var estimatedValuesForClass = from x in estimatedTargetValues
     137                                      where x.TargetValue == classValue
     138                                      select x.EstimatedValue;
     139        double mean, variance;
     140        OnlineMeanAndVarianceCalculator.Calculate(estimatedValuesForClass, out mean, out variance);
     141        classMean[classValue] = mean;
     142        classStdDev[classValue] = Math.Sqrt(variance);
     143      }
     144      List<double> thresholds = new List<double>();
     145      for (int i = 0; i < nClasses - 1; i++) {
     146        for (int j = i + 1; j < nClasses; j++) {
     147          double x1, x2;
     148          double class0 = originalClasses[i];
     149          double class1 = originalClasses[j];
     150          // calculate all thresholds
     151          CalculateCutPoints(classMean[class0], classStdDev[class0], classMean[class1], classStdDev[class1], out x1, out x2);
     152          if (!thresholds.Any(x => x.IsAlmost(x1))) thresholds.Add(x1);
     153          if (!thresholds.Any(x => x.IsAlmost(x2))) thresholds.Add(x2);
     154        }
     155      }
     156      thresholds.Sort();
     157      thresholds.Insert(0, double.NegativeInfinity);
     158      thresholds.Add(double.PositiveInfinity);
     159      List<double> classValues = new List<double>();
     160      for (int i = 0; i < thresholds.Count - 1; i++) {
     161        double m;
     162        if (double.IsNegativeInfinity(thresholds[i])) {
     163          m = thresholds[i + 1] - 1.0;
     164        } else if (double.IsPositiveInfinity(thresholds[i + 1])) {
     165          m = thresholds[i] + 1.0;
     166        } else {
     167          m = thresholds[i] + (thresholds[i + 1] - thresholds[i]) / 2.0;
     168        }
     169
     170        double maxDensity = 0;
     171        double maxDensityClassValue = -1;
     172        foreach (var classValue in originalClasses) {
     173          double density = NormalDensity(m, classMean[classValue], classStdDev[classValue]);
     174          if (density > maxDensity) {
     175            maxDensity = density;
     176            maxDensityClassValue = classValue;
     177          }
     178        }
     179        classValues.Add(maxDensityClassValue);
     180      }
     181      List<double> filteredThresholds = new List<double>();
     182      List<double> filteredClassValues = new List<double>();
     183      filteredThresholds.Add(thresholds[0]);
     184      filteredClassValues.Add(classValues[0]);
     185      for (int i = 0; i < classValues.Count - 1; i++) {
     186        if (classValues[i] != classValues[i + 1]) {
     187          filteredThresholds.Add(thresholds[i + 1]);
     188          filteredClassValues.Add(classValues[i + 1]);
     189        }
     190      }
     191      filteredThresholds.Add(double.PositiveInfinity);
     192
     193      return new SymbolicDiscriminantFunctionClassificationModel(tree, interpreter, filteredClassValues, filteredThresholds);
     194    }
     195
     196    private static double NormalDensity(double x, double mu, double sigma) {
     197      return (1.0 / Math.Sqrt(2.0 * Math.PI * sigma * sigma)) * Math.Exp(-((x - mu) * (x - mu)) / (2.0 * sigma * sigma));
     198    }
     199
     200    private static void CalculateCutPoints(double m1, double s1, double m2, double s2, out double x1, out double x2) {
     201      double a = (s1 * s1 - s2 * s2);
     202      double b = (m1 * s2 * s2 - m2 * s1 * s1);
     203      double c = 2 * s1 * s1 * s2 * s2 * Math.Log(s2) - 2 * s1 * s1 * s2 * s2 * Math.Log(s1) - s1 * s1 * m2 * m2 + s2 * s2 * m1 * m1;
     204      x1 = -(-m2 * s1 * s1 + m1 * s2 * s2 + Math.Sqrt(s1 * s1 * s2 * s2 * ((m1 - m2) * (m1 - m2) + 2.0 * (-s1 * s1 + s2 * s2) * Math.Log(s2 / s1)))) / a;
     205      x2 = (m2 * s1 * s1 - m1 * s2 * s2 + Math.Sqrt(s1 * s1 * s2 * s2 * ((m1 - m2) * (m1 - m2) + 2.0 * (-s1 * s1 + s2 * s2) * Math.Log(s2 / s1)))) / a;
     206    }
    114207  }
    115208}
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/MultiObjective/SymbolicClassificationMultiObjectiveTrainingBestSolutionAnalyzer.cs

    r5649 r5678  
    7070
    7171    protected override ISymbolicClassificationSolution CreateSolution(ISymbolicExpressionTree bestTree, double[] bestQuality) {
    72       var model = new SymbolicDiscriminantFunctionClassificationModel(bestTree, SymbolicDataAnalysisTreeInterpreter, ProblemData.ClassValues);
     72      double[] classValues;
     73      double[] thresholds;
     74      var estimatedValues = SymbolicDataAnalysisTreeInterpreter.GetSymbolicExpressionTreeValues(bestTree, ProblemData.Dataset, ProblemData.TrainingIndizes);
     75      var targetValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
     76      DiscriminantFunctionClassificationSolution.CalculateClassThresholds(ProblemData, estimatedValues, targetValues, out classValues, out thresholds);
     77      var model = new SymbolicDiscriminantFunctionClassificationModel(bestTree, SymbolicDataAnalysisTreeInterpreter, classValues, thresholds);
    7378      return new SymbolicDiscriminantFunctionClassificationSolution(model, ProblemData);
    74     }
     79    } 
    7580  }
    7681}
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SingleObjective/SymbolicClassificationSingleObjectiveTrainingBestSolutionAnalyzer.cs

    r5649 r5678  
    6868
    6969    protected override ISymbolicClassificationSolution CreateSolution(ISymbolicExpressionTree bestTree, double bestQuality) {
    70       var model = new SymbolicDiscriminantFunctionClassificationModel(bestTree, SymbolicDataAnalysisTreeInterpreter, ProblemData.ClassValues);
     70      double[] classValues;
     71      double[] thresholds;
     72      var estimatedValues = SymbolicDataAnalysisTreeInterpreter.GetSymbolicExpressionTreeValues(bestTree, ProblemData.Dataset, ProblemData.TrainingIndizes);
     73      var targetValues = ProblemData.Dataset.GetEnumeratedVariableValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
     74      DiscriminantFunctionClassificationSolution.CalculateClassThresholds(ProblemData, estimatedValues, targetValues, out classValues, out thresholds);
     75      var model = new SymbolicDiscriminantFunctionClassificationModel(bestTree, SymbolicDataAnalysisTreeInterpreter, classValues, thresholds);
    7176      return new SymbolicDiscriminantFunctionClassificationSolution(model, ProblemData);
    7277    }
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicDiscriminantFunctionClassificationModel.cs

    r5657 r5678  
    3939  [Item(Name = "SymbolicDiscriminantFunctionClassificationModel", Description = "Represents a symbolic classification model unsing a discriminant function.")]
    4040  public class SymbolicDiscriminantFunctionClassificationModel : SymbolicDataAnalysisModel, ISymbolicDiscriminantFunctionClassificationModel {
    41     [Storable]
    42     private double[] classValues;
    4341
    4442    [Storable]
     
    5149      }
    5250    }
    53 
     51    [Storable]
     52    private double[] classValues;
     53    public IEnumerable<double> ClassValues {
     54      get { return (IEnumerable<double>)classValues.Clone(); }
     55      set { classValues = value.ToArray(); }
     56    }
    5457    [StorableConstructor]
    5558    protected SymbolicDiscriminantFunctionClassificationModel(bool deserializing) : base(deserializing) { }
     
    5962      thresholds = (double[])original.thresholds.Clone();
    6063    }
    61     public SymbolicDiscriminantFunctionClassificationModel(ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IEnumerable<double> classValues)
     64    public SymbolicDiscriminantFunctionClassificationModel(ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IEnumerable<double> classValues, IEnumerable<double> thresholds)
    6265      : base(tree, interpreter) {
    6366      this.classValues = classValues.ToArray();
    64       this.thresholds = new double[0];
     67      this.thresholds = thresholds.ToArray();
    6568    }
    6669
     
    7679      foreach (var x in GetEstimatedValues(dataset, rows)) {
    7780        int classIndex = 0;
    78         // find first threshold value which is smaller than x => class index = threshold index + 1
     81        // find first threshold value which is larger than x => class index = threshold index + 1
    7982        for (int i = 0; i < thresholds.Length; i++) {
    8083          if (x > thresholds[i]) classIndex++;
     
    9194      if (listener != null) listener(this, e);
    9295    }
    93     #endregion
     96    #endregion   
    9497  }
    9598}
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicDiscriminantFunctionClassificationSolution.cs

    r5657 r5678  
    4141    #region ISymbolicClassificationSolution Members
    4242
    43     public new ISymbolicClassificationModel Model {
    44       get { return (ISymbolicClassificationModel)base.Model; }
     43    public new IDiscriminantFunctionClassificationModel Model {
     44      get { return (IDiscriminantFunctionClassificationModel)base.Model; }
     45    }
     46
     47    ISymbolicClassificationModel ISymbolicClassificationSolution.Model {
     48      get { return (ISymbolicClassificationModel)Model; }
    4549    }
    4650
    4751    ISymbolicDataAnalysisModel ISymbolicDataAnalysisSolution.Model {
    48       get { return (ISymbolicDataAnalysisModel)base.Model; }
     52      get { return (ISymbolicDataAnalysisModel)Model; }
    4953    }
    5054
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis.Views/3.4/HeuristicLab.Problems.DataAnalysis.Views-3.4.csproj

    r5664 r5678  
    122122      <DependentUpon>ClassificationSolutionEstimatedClassValuesView.cs</DependentUpon>
    123123    </Compile>
     124    <Compile Include="Classification\DiscriminantFunctionClassificationSolutionEstimatedClassValuesView.cs">
     125      <SubType>UserControl</SubType>
     126    </Compile>
     127    <Compile Include="Classification\DiscriminantFunctionClassificationSolutionEstimatedClassValuesView.Designer.cs">
     128      <DependentUpon>DiscriminantFunctionClassificationSolutionEstimatedClassValuesView.cs</DependentUpon>
     129    </Compile>
    124130    <Compile Include="Classification\DiscriminantFunctionClassificationRocCurvesView.cs">
    125131      <SubType>UserControl</SubType>
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/DiscriminantFunctionClassificationModel.cs

    r5649 r5678  
    4242    [Storable]
    4343    private double[] classValues;
    44 
    45     [StorableConstructor]
    46     protected DiscriminantFunctionClassificationModel() : base() { }
    47     protected DiscriminantFunctionClassificationModel(DiscriminantFunctionClassificationModel original, Cloner cloner)
    48       : base(original, cloner) {
    49       model = cloner.Clone(original.model);
    50       classValues = (double[])original.classValues.Clone();
     44    // class values are not necessarily sorted in ascending order
     45    public IEnumerable<double> ClassValues {
     46      get { return (double[])classValues.Clone(); }
     47      set {
     48        if (value == null) throw new ArgumentException();
     49        double[] newValue = value.ToArray();
     50        if (newValue.Length != classValues.Length) throw new ArgumentException();
     51        classValues = newValue;
     52      }
    5153    }
    52     public DiscriminantFunctionClassificationModel(IRegressionModel model, IEnumerable<double> classValues)
    53       : base() {
    54       this.name = ItemName;
    55       this.description = ItemDescription;
    56       this.model = model;
    57       this.classValues = classValues.ToArray();
    58     }
    59 
    60     public override IDeepCloneable Clone(Cloner cloner) {
    61       return new DiscriminantFunctionClassificationModel(this, cloner);
    62     }
    63 
    64     #region IDiscriminantFunctionClassificationModel Members
    65 
     54    [Storable]
    6655    private double[] thresholds;
    6756    public IEnumerable<double> Thresholds {
     
    7362    }
    7463
    75     public event EventHandler ThresholdsChanged;
    76     protected virtual void OnThresholdsChanged(EventArgs e) {
    77       var listener = ThresholdsChanged;
    78       if (listener != null) listener(this, e);
     64
     65    [StorableConstructor]
     66    protected DiscriminantFunctionClassificationModel() : base() { }
     67    protected DiscriminantFunctionClassificationModel(DiscriminantFunctionClassificationModel original, Cloner cloner)
     68      : base(original, cloner) {
     69      model = cloner.Clone(original.model);
     70      classValues = (double[])original.classValues.Clone();
     71      thresholds = (double[])original.thresholds.Clone();
     72    }
     73    public DiscriminantFunctionClassificationModel(IRegressionModel model, IEnumerable<double> classValues, IEnumerable<double> thresholds)
     74      : base() {
     75      this.name = ItemName;
     76      this.description = ItemDescription;
     77      this.model = model;
     78      this.classValues = classValues.ToArray();
     79      this.thresholds = thresholds.ToArray();
     80    }
     81
     82    public override IDeepCloneable Clone(Cloner cloner) {
     83      return new DiscriminantFunctionClassificationModel(this, cloner);
    7984    }
    8085
     
    8691      foreach (var x in GetEstimatedValues(dataset, rows)) {
    8792        int classIndex = 0;
    88         // find first threshold value which is smaller than x => class index = threshold index + 1
     93        // find first threshold value which is larger than x => class index = threshold index + 1
    8994        for (int i = 0; i < thresholds.Length; i++) {
    9095          if (x > thresholds[i]) classIndex++;
     
    9499      }
    95100    }
    96 
     101    #region events
     102    public event EventHandler ThresholdsChanged;
     103    protected virtual void OnThresholdsChanged(EventArgs e) {
     104      var listener = ThresholdsChanged;
     105      if (listener != null) listener(this, e);
     106    }
    97107    #endregion
    98108  }
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/DiscriminantFunctionClassificationSolution.cs

    r5664 r5678  
    4444    }
    4545    public DiscriminantFunctionClassificationSolution(IRegressionModel model, IClassificationProblemData problemData)
    46       : this(new DiscriminantFunctionClassificationModel(model, problemData.ClassValues), problemData) {
     46      : this(new DiscriminantFunctionClassificationModel(model, problemData.ClassValues, CalculateClassThresholds(model, problemData, problemData.TrainingIndizes)), problemData) {
    4747    }
    4848    public DiscriminantFunctionClassificationSolution(IDiscriminantFunctionClassificationModel model, IClassificationProblemData problemData)
     
    9292    #endregion
    9393
    94     public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
    95       if (Model.Thresholds == null || Model.Thresholds.Count() == 0) RecalculateClassIntermediates();
    96       return base.GetEstimatedClassValues(rows);
     94    private static double[] CalculateClassThresholds(IRegressionModel model, IClassificationProblemData problemData, IEnumerable<int> rows) {
     95      double[] thresholds;
     96      double[] classValues;
     97      CalculateClassThresholds(problemData, model.GetEstimatedValues(problemData.Dataset, rows), problemData.Dataset.GetEnumeratedVariableValues(problemData.TargetVariable, rows), out classValues, out thresholds);
     98      return thresholds;
    9799    }
    98100
    99     private void RecalculateClassIntermediates() {
     101    public static void CalculateClassThresholds(IClassificationProblemData problemData, IEnumerable<double> estimatedValues, IEnumerable<double> targetClassValues, out double[] classValues, out double[] thresholds) {
    100102      int slices = 100;
    101       List<double> estimatedValues = EstimatedValues.ToList();
    102       List<int> classInstances = (from classValue in ProblemData.Dataset.GetVariableValues(ProblemData.TargetVariable)
    103                                   group classValue by classValue into grouping
    104                                   select grouping.Count()).ToList();
    105       double maxEstimatedValue = estimatedValues.Max();
    106       double minEstimatedValue = estimatedValues.Min();
    107       List<KeyValuePair<double, double>> estimatedTargetValues =
    108          (from row in ProblemData.TrainingIndizes
    109           select new KeyValuePair<double, double>(
    110             estimatedValues[row],
    111             ProblemData.Dataset[ProblemData.TargetVariable, row])).ToList();
     103      List<double> estimatedValuesList = estimatedValues.ToList();
     104      double maxEstimatedValue = estimatedValuesList.Max();
     105      double minEstimatedValue = estimatedValuesList.Min();
     106      double thresholdIncrement = (maxEstimatedValue - minEstimatedValue) / slices;
     107      var estimatedAndTargetValuePairs =
     108        estimatedValuesList.Zip(targetClassValues, (x, y) => new { EstimatedValue = x, TargetClassValue = y })
     109        .OrderBy(x => x.EstimatedValue)
     110        .ToList();
    112111
    113       List<double> originalClasses = ProblemData.ClassValues.OrderBy(x => x).ToList();
    114       int nClasses = originalClasses.Distinct().Count();
    115       double[] thresholds = new double[nClasses + 1];
     112      classValues = problemData.ClassValues.OrderBy(x => x).ToArray();
     113      int nClasses = classValues.Length;
     114      thresholds = new double[nClasses + 1];
    116115      thresholds[0] = double.NegativeInfinity;
    117116      thresholds[thresholds.Length - 1] = double.PositiveInfinity;
    118117
    119       double thresholdIncrement = (maxEstimatedValue - minEstimatedValue) / slices;
     118      // incrementally calculate accuracy of all possible thresholds
     119      int[,] confusionMatrix = new int[nClasses, nClasses];
    120120
     121      // one threshold is always treated as binary separation of the remaining classes
    121122      for (int i = 1; i < thresholds.Length - 1; i++) {
    122123        double lowerThreshold = thresholds[i - 1];
     
    130131          double classificationScore = 0.0;
    131132
    132           foreach (KeyValuePair<double, double> estimatedTarget in estimatedTargetValues) {
     133          foreach (var pair in estimatedAndTargetValuePairs) {
    133134            //all positives
    134             if (estimatedTarget.Value.IsAlmost(originalClasses[i - 1])) {
    135               if (estimatedTarget.Key > lowerThreshold && estimatedTarget.Key < actualThreshold)
     135            if (pair.TargetClassValue.IsAlmost(classValues[i - 1])) {
     136              if (pair.EstimatedValue > lowerThreshold && pair.EstimatedValue < actualThreshold)
    136137                //true positive
    137                 classificationScore += ProblemData.GetClassificationPenalty(originalClasses[i - 1], originalClasses[i - 1]);
     138                classificationScore += problemData.GetClassificationPenalty(classValues[i - 1], classValues[i - 1]);
    138139              else
    139140                //false negative
    140                 classificationScore += ProblemData.GetClassificationPenalty(originalClasses[i], originalClasses[i - 1]);
     141                classificationScore += problemData.GetClassificationPenalty(classValues[i], classValues[i - 1]);
    141142            }
    142143              //all negatives
    143144            else {
    144               if (estimatedTarget.Key > lowerThreshold && estimatedTarget.Key < actualThreshold)
     145              if (pair.EstimatedValue > lowerThreshold && pair.EstimatedValue < actualThreshold)
    145146                //false positive
    146                 classificationScore += ProblemData.GetClassificationPenalty(originalClasses[i - 1], originalClasses[i]);
     147                classificationScore += problemData.GetClassificationPenalty(classValues[i - 1], classValues[i]);
    147148              else
    148149                //true negative, consider only upper class
    149                 classificationScore += ProblemData.GetClassificationPenalty(originalClasses[i], originalClasses[i]);
     150                classificationScore += problemData.GetClassificationPenalty(classValues[i], classValues[i]);
    150151            }
    151152          }
     
    167168        }
    168169        //scale lowest thresholds and highest found optimal threshold according to the misclassification matrix
    169         double falseNegativePenalty = ProblemData.GetClassificationPenalty(originalClasses[i], originalClasses[i - 1]);
    170         double falsePositivePenalty = ProblemData.GetClassificationPenalty(originalClasses[i - 1], originalClasses[i]);
     170        double falseNegativePenalty = problemData.GetClassificationPenalty(classValues[i], classValues[i - 1]);
     171        double falsePositivePenalty = problemData.GetClassificationPenalty(classValues[i - 1], classValues[i]);
    171172        thresholds[i] = (lowestBestThreshold * falsePositivePenalty + highestBestThreshold * falseNegativePenalty) / (falseNegativePenalty + falsePositivePenalty);
    172173      }
    173       Thresholds = new List<double>(thresholds);
    174174    }
    175175  }
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Classification/IDiscriminantFunctionClassificationModel.cs

    r5657 r5678  
    2525  public interface IDiscriminantFunctionClassificationModel : IClassificationModel {
    2626    IEnumerable<double> Thresholds { get; set; }
     27    IEnumerable<double> ClassValues { get; set; }
    2728    event EventHandler ThresholdsChanged;
    2829    IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows);
Note: See TracChangeset for help on using the changeset viewer.