Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/28/12 16:13:02 (12 years ago)
Author:
sforsten
Message:

#1776:

  • merged r8508:8533 from trunk into branch
  • AverageThresholdCalculator and MedianThresholdCalculator can now handle multi class classification
  • changed combo boxes in ClassificationEnsembleSolutionView to drop down list
Location:
branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis

  • branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis/3.4/Implementation/Classification/WeightCalculators/AverageThresholdCalculator.cs

    r8297 r8534  
    5858
    5959    protected override double GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, int index, double estimatedClassValue) {
    60       // only works with binary classification
    61       if (!classValues.Count().Equals(2))
    62         return double.NaN;
    6360      Dataset dataset = solutions.First().ProblemData.Dataset;
    6461      IList<double> values = solutions.Select(s => s.Model.GetEstimatedValues(dataset, Enumerable.Repeat(index, 1)).First()).ToList();
     
    7067
    7168    public override IEnumerable<double> GetDiscriminantConfidence(IEnumerable<IDiscriminantFunctionClassificationSolution> solutions, IEnumerable<int> indices, IEnumerable<double> estimatedClassValue) {
    72       if (!classValues.Count().Equals(2))
    73         return Enumerable.Repeat(double.NaN, indices.Count());
    74 
    7569      Dataset dataset = solutions.First().ProblemData.Dataset;
    7670      double[][] values = solutions.Select(s => s.Model.GetEstimatedValues(dataset, indices).ToArray()).ToArray();
     
    8781
    8882    protected double GetAverageConfidence(double avg, double estimatedClassValue) {
    89       if (estimatedClassValue.Equals(classValues[0])) {
    90         if (avg < estimatedClassValue)
    91           return 1;
    92         else if (avg >= threshold[1])
    93           return 0;
    94         else {
    95           double distance = threshold[1] - classValues[0];
    96           return (1 / distance) * (threshold[1] - avg);
     83      for (int i = 0; i < classValues.Length; i++) {
     84        if (estimatedClassValue.Equals(classValues[i])) {
     85          //special case: avgerage is higher than value of highest class
     86          if (i == classValues.Length - 1 && avg > estimatedClassValue) {
     87            return 1;
     88          }
     89          //special case: average is lower than value of lowest class
     90          if (i == 0 && avg < estimatedClassValue) {
     91            return 1;
     92          }
     93          //special case: average is not between threshold of estimated class value
     94          if ((i < classValues.Length - 1 && avg >= threshold[i + 1]) || avg <= threshold[i]) {
     95            return 0;
     96          }
     97
     98          double thresholdToClassDistance, thresholdToAverageValueDistance;
     99          if (avg >= classValues[i]) {
     100            thresholdToClassDistance = threshold[i + 1] - classValues[i];
     101            thresholdToAverageValueDistance = threshold[i + 1] - avg;
     102          } else {
     103            thresholdToClassDistance = classValues[i] - threshold[i];
     104            thresholdToAverageValueDistance = avg - threshold[i];
     105          }
     106          return (1 / thresholdToClassDistance) * thresholdToAverageValueDistance;
    97107        }
    98       } else if (estimatedClassValue.Equals(classValues[1])) {
    99         if (avg > estimatedClassValue)
    100           return 1;
    101         else if (avg <= threshold[1])
    102           return 0;
    103         else {
    104           double distance = classValues[1] - threshold[1];
    105           return (1 / distance) * (avg - threshold[1]);
    106         }
    107       } else
    108         return double.NaN;
     108      }
     109      return double.NaN;
    109110    }
    110111  }
Note: See TracChangeset for help on using the changeset viewer.