Free cookie consent management tool by TermsFeed Policy Generator

Changeset 5657


Ignore:
Timestamp:
03/10/11 12:37:11 (14 years ago)
Author:
gkronber
Message:

#1418 Implemented calculation of thresholds.

Location:
branches/DataAnalysis Refactoring
Files:
5 edited

Legend:

Unmodified
Added
Removed
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicDiscriminantFunctionClassificationModel.cs

    r5649 r5657  
    6262      : base(tree, interpreter) {
    6363      this.classValues = classValues.ToArray();
     64      this.thresholds = new double[0];
    6465    }
    6566
     
    8081          else break;
    8182        }
    82         yield return classValues.ElementAt(classIndex);
     83        yield return classValues.ElementAt(classIndex - 1);
    8384      }
    8485    }
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicDiscriminantFunctionClassificationSolution.cs

    r5649 r5657  
    6161    public override IDeepCloneable Clone(Cloner cloner) {
    6262      return new SymbolicDiscriminantFunctionClassificationSolution(this, cloner);
    63     }
     63    } 
    6464  }
    6565}
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/DiscriminantFunctionClassificationSolution.cs

    r5649 r5657  
    7474
    7575    public IEnumerable<double> Thresholds {
    76       get { return Model.Thresholds; }
     76      get {
     77        return Model.Thresholds;
     78      }
     79      protected set { Model.Thresholds = value; }
    7780    }
    7881
     
    8891    }
    8992    #endregion
     93
     94    public override IEnumerable<double> GetEstimatedClassValues(IEnumerable<int> rows) {
     95      if (Model.Thresholds == null || Model.Thresholds.Count() == 0) RecalculateClassIntermediates();
     96      return base.GetEstimatedClassValues(rows);
     97    }
     98
     99    private void RecalculateClassIntermediates() {
     100      int slices = 100;
     101      List<double> estimatedValues = EstimatedValues.ToList();
     102      List<int> classInstances = (from classValue in ProblemData.Dataset.GetVariableValues(ProblemData.TargetVariable)
     103                                  group classValue by classValue into grouping
     104                                  select grouping.Count()).ToList();
     105      double maxEstimatedValue = estimatedValues.Max();
     106      double minEstimatedValue = estimatedValues.Min();
     107      List<KeyValuePair<double, double>> estimatedTargetValues =
     108         (from row in ProblemData.TrainingIndizes
     109          select new KeyValuePair<double, double>(
     110            estimatedValues[row],
     111            ProblemData.Dataset[ProblemData.TargetVariable, row])).ToList();
     112
     113      List<double> originalClasses = ProblemData.ClassValues.OrderBy(x => x).ToList();
     114      int nClasses = originalClasses.Distinct().Count();
     115      double[] thresholds = new double[nClasses + 1];
     116      thresholds[0] = double.NegativeInfinity;
     117      thresholds[thresholds.Length - 1] = double.PositiveInfinity;
     118
     119      for (int i = 1; i < thresholds.Length - 1; i++) {
     120        double lowerThreshold = thresholds[i - 1];
     121        double actualThreshold = minEstimatedValue;
     122        double thresholdIncrement = (maxEstimatedValue - minEstimatedValue) / slices;
     123
     124        double lowestBestThreshold = double.NaN;
     125        double highestBestThreshold = double.NaN;
     126        double bestClassificationScore = double.PositiveInfinity;
     127        bool seriesOfEqualClassificationScores = false;
     128
     129        while (actualThreshold < maxEstimatedValue) {
     130          double classificationScore = 0.0;
     131
     132          foreach (KeyValuePair<double, double> estimatedTarget in estimatedTargetValues) {
     133            //all positives
     134            if (estimatedTarget.Value.IsAlmost(originalClasses[i - 1])) {
     135              if (estimatedTarget.Key > lowerThreshold && estimatedTarget.Key < actualThreshold)
     136                //true positive
     137                classificationScore += ProblemData.GetClassificationPenalty(originalClasses[i - 1], originalClasses[i - 1]);
     138              else
     139                //false negative
     140                classificationScore += ProblemData.GetClassificationPenalty(originalClasses[i], originalClasses[i - 1]);
     141            }
     142              //all negatives
     143            else {
     144              if (estimatedTarget.Key > lowerThreshold && estimatedTarget.Key < actualThreshold)
     145                //false positive
     146                classificationScore += ProblemData.GetClassificationPenalty(originalClasses[i - 1], originalClasses[i]);
     147              else
     148                //true negative, consider only upper class
     149                classificationScore += ProblemData.GetClassificationPenalty(originalClasses[i], originalClasses[i]);
     150            }
     151          }
     152
     153          //new best classification score found
     154          if (classificationScore < bestClassificationScore) {
     155            bestClassificationScore = classificationScore;
     156            lowestBestThreshold = actualThreshold;
     157            highestBestThreshold = actualThreshold;
     158            seriesOfEqualClassificationScores = true;
     159          }
     160            //equal classification scores => if seriesOfEqualClassifcationScores == true update highest threshold
     161          else if (Math.Abs(classificationScore - bestClassificationScore) < double.Epsilon && seriesOfEqualClassificationScores)
     162            highestBestThreshold = actualThreshold;
     163          //worse classificatoin score found reset seriesOfEqualClassifcationScores
     164          else seriesOfEqualClassificationScores = false;
     165
     166          actualThreshold += thresholdIncrement;
     167        }
     168        //scale lowest thresholds and highest found optimal threshold according to the misclassification matrix
     169        double falseNegativePenalty = ProblemData.GetClassificationPenalty(originalClasses[i], originalClasses[i - 1]);
     170        double falsePositivePenalty = ProblemData.GetClassificationPenalty(originalClasses[i - 1], originalClasses[i]);
     171        thresholds[i] = (lowestBestThreshold * falsePositivePenalty + highestBestThreshold * falseNegativePenalty) / (falseNegativePenalty + falsePositivePenalty);
     172      }
     173      Thresholds = new List<double>(thresholds);
     174    }
    90175  }
    91176}
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/Interfaces/Classification/IDiscriminantFunctionClassificationModel.cs

    r5649 r5657  
    2424namespace HeuristicLab.Problems.DataAnalysis {
    2525  public interface IDiscriminantFunctionClassificationModel : IClassificationModel {
    26     IEnumerable<double> Thresholds { get; }
     26    IEnumerable<double> Thresholds { get; set; }
    2727    event EventHandler ThresholdsChanged;
    2828    IEnumerable<double> GetEstimatedValues(Dataset dataset, IEnumerable<int> rows);
  • branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis/3.4/OnlineEvaluators/OnlineAccuracyEvaluator.cs

    r5649 r5657  
    3434          throw new InvalidOperationException("No elements");
    3535        else
    36           return correctlyClassified / n;
     36          return correctlyClassified / (double)n;
    3737      }
    3838    }
Note: See TracChangeset for help on using the changeset viewer.