Free cookie consent management tool by TermsFeed Policy Generator

Changeset 678


Ignore:
Timestamp:
10/17/08 10:56:51 (16 years ago)
Author:
mkommend
Message:

support for multiclass classification added; AUC calculated (ticket #308)

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.GP.StructureIdentification.Classification/ROCAnalyzer.cs

    r672 r678  
    3131namespace HeuristicLab.GP.StructureIdentification.Classification {
    3232  public class ROCAnalyzer : OperatorBase {
     33    private ItemList myRocValues;
     34    private ItemList<DoubleData> myAucValues;
     35
    3336
    3437    public override string Description {
     
    3942      : base() {
    4043      AddVariableInfo(new VariableInfo("Values", "Item list holding the estimated and orignial values for the ROCAnalyzer", typeof(ItemList), VariableKind.In));
    41       AddVariableInfo(new VariableInfo("ROCValues", "The values of the ROCAnalyzer, namely TPR & FPR", typeof(ItemList<ItemList<DoubleArrayData>>), VariableKind.New | VariableKind.Out));
     44      AddVariableInfo(new VariableInfo("ROCValues", "The values of the ROCAnalyzer, namely TPR & FPR", typeof(ItemList), VariableKind.New | VariableKind.Out));
     45      AddVariableInfo(new VariableInfo("AUCValues", "The AUC Values for each ROC", typeof(ItemList<DoubleData>), VariableKind.New | VariableKind.Out));
    4246    }
    4347
    4448    public override IOperation Apply(IScope scope) {
     49      #region initialize HL-variables
    4550      ItemList values = GetVariableValue<ItemList>("Values", scope, true);
    46       ItemList<ItemList<DoubleArrayData>> rocValues = GetVariableValue<ItemList<ItemList<DoubleArrayData>>>("ROCValues", scope, false, false);
    47       if (rocValues == null) {
    48         rocValues = new ItemList<ItemList<DoubleArrayData>>();
     51      myRocValues = GetVariableValue<ItemList>("ROCValues", scope, false, false);
     52      if (myRocValues == null) {
     53        myRocValues = new ItemList();
    4954        IVariableInfo info = GetVariableInfo("ROCValues");
    5055        if (info.Local)
    51           AddVariable(new HeuristicLab.Core.Variable(info.ActualName, rocValues));
     56          AddVariable(new HeuristicLab.Core.Variable(info.ActualName, myRocValues));
    5257        else
    53           scope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(info.FormalName), rocValues));
     58          scope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(info.FormalName), myRocValues));
    5459      } else {
    55         rocValues.Clear();
    56       }
    57 
    58       rocValues.Add(new ItemList<DoubleArrayData>());
    59       //ROC Curve starts at 0,0
    60       DoubleArrayData point = new DoubleArrayData();
    61       point.Data = new double[2] { 0, 0 };
    62       rocValues[0].Add(point);
     60        myRocValues.Clear();
     61      }
     62
     63      myAucValues = GetVariableValue<ItemList<DoubleData>>("AUCValues", scope, false, false);
     64      if (myAucValues == null) {
     65        myAucValues = new ItemList<DoubleData>();
     66        IVariableInfo info = GetVariableInfo("AUCValues");
     67        if (info.Local)
     68          AddVariable(new HeuristicLab.Core.Variable(info.ActualName, myAucValues));
     69        else
     70          scope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(info.FormalName), myAucValues));
     71      } else {
     72        myAucValues.Clear();
     73      }
     74      #endregion
    6375
    6476      //calculate new ROC Values
    6577      double estimated = 0.0;
    6678      double original = 0.0;
    67       double positiveClassKey;
    68       double negativeClassKey;
    69       double truePositiveRate = 0.0;
    70       double falsePositiveRate = 0.0;
    7179
    7280      //initialize classes dictionary
    73       Dictionary<double, List<double>> classes = new Dictionary<double, List<double>>();
     81      SortedDictionary<double, List<double>> classes = new SortedDictionary<double, List<double>>();
    7482      foreach (ItemList value in values) {
    7583        estimated = ((DoubleData)value[0]).Data;
     
    7987        classes[original].Add(estimated);
    8088      }
     89      foreach (double key in classes.Keys)
     90        classes[key].Sort();
    8191
    8292      //check for 2 classes classification problem
    83       if (classes.Keys.Count != 2)
    84         throw new Exception("ROCAnalyser only handles  2 class classification problems");
    85 
    86       //sort estimated values in classes dictionary
    87       foreach (List<double> estimatedValues in classes.Values)
    88         estimatedValues.Sort();
    89 
    90       //calculate truePosivite- & falsePositiveRate
    91       positiveClassKey = classes.Keys.Min<double>();
    92       negativeClassKey = classes.Keys.Max<double>();
    93       foreach (double treshold in classes[negativeClassKey].Distinct<double>()) {
    94         truePositiveRate = ((double)classes[positiveClassKey].Count<double>(value => value < treshold)) / classes[positiveClassKey].Count;
    95         falsePositiveRate = ((double)classes[negativeClassKey].Count<double>(value => value < treshold)) / classes[negativeClassKey].Count;
    96         point = new DoubleArrayData(new double[2] { falsePositiveRate, truePositiveRate });
    97         rocValues[0].Add(point);
    98 
    99         //stop calculation if truePositiveRate = 1; save runtime
    100         if (truePositiveRate == 1)
     93      //if (classes.Keys.Count != 2)
     94      //  throw new Exception("ROCAnalyser only handles  2 class classification problems");
     95
     96      //calculate ROC Curve
     97      foreach (double key in classes.Keys) {
     98        CalculateBestROC(key, classes);
     99      }
     100
     101      return null;
     102    }
     103
     104    protected void CalculateBestROC(double positiveClassKey, SortedDictionary<double, List<double>> classes) {
     105
     106      int rocIndex = myRocValues.Count - 1;
     107      List<KeyValuePair<double, double>> rocCharacteristics;
     108      List<KeyValuePair<double, double>> bestROC;
     109      List<KeyValuePair<double, double>> actROC;
     110
     111      List<double> negatives = new List<double>();
     112      foreach (double key in classes.Keys) {
     113        if (key != positiveClassKey)
     114          negatives.AddRange(classes[key]);
     115      }
     116      List<double> actNegatives = negatives.Where<double>(value => value < classes[positiveClassKey].Max<double>()).ToList<double>();
     117      actNegatives.Add(classes[positiveClassKey].Max<double>());
     118      actNegatives.Sort();
     119      actNegatives = actNegatives.Reverse<double>().ToList<double>();
     120
     121      double bestAUC = double.MinValue;
     122      double actAUC = 0;
     123      //first class
     124      if (classes.Keys.ElementAt<double>(0) == positiveClassKey) {
     125        rocCharacteristics = null;
     126        CalculateROCValuesAndAUC(classes[positiveClassKey], actNegatives, negatives.Count, double.MinValue, ref rocCharacteristics, out  actROC, out actAUC);
     127        myAucValues.Add(new DoubleData(actAUC));
     128        myRocValues.Add(Convert(actROC));
     129      }
     130        //middle classes 
     131      else if (classes.Keys.ElementAt<double>(classes.Keys.Count - 1) != positiveClassKey) {
     132        rocCharacteristics = null;
     133        bestROC = new List<KeyValuePair<double, double>>();
     134        foreach (double minTreshold in classes[positiveClassKey].Distinct<double>()) {
     135          CalculateROCValuesAndAUC(classes[positiveClassKey], actNegatives, negatives.Count, minTreshold, ref rocCharacteristics, out  actROC, out actAUC);
     136          if (actAUC > bestAUC) {
     137            bestAUC = actAUC;
     138            bestROC = actROC;
     139          }
     140        }
     141          myAucValues.Add(new DoubleData(bestAUC));
     142          myRocValues.Add(Convert(bestROC));
     143       
     144      } else { //last class
     145        actNegatives = negatives.Where<double>(value => value > classes[positiveClassKey].Min<double>()).ToList<double>();
     146        actNegatives.Add(classes[positiveClassKey].Min<double>());
     147        actNegatives.Sort();
     148        CalculateROCValuesAndAUCForLastClass(classes[positiveClassKey], actNegatives, negatives.Count, out bestROC, out bestAUC);
     149        myAucValues.Add(new DoubleData(bestAUC));
     150        myRocValues.Add(Convert(bestROC));
     151
     152      }
     153
     154    }
     155
     156    protected void CalculateROCValuesAndAUC(List<double> positives, List<double> negatives, int negativesCount, double minTreshold,
     157      ref List<KeyValuePair<double, double>> rocCharacteristics, out List<KeyValuePair<double, double>> roc, out double auc) {
     158      double actTP = -1;
     159      double actFP = -1;
     160      double oldTP = -1;
     161      double oldFP = -1;
     162      auc = 0;
     163      roc = new List<KeyValuePair<double, double>>();
     164
     165      actTP = positives.Count<double>(value => minTreshold <= value && value <= negatives.Max<double>());
     166      actFP = negatives.Count<double>(value => minTreshold <= value && value <= negatives.Max<double>());
     167      //add point (1,TPR) for AUC 'correct' calculation
     168      roc.Add(new KeyValuePair<double, double>(1, actTP / positives.Count));
     169      oldTP = actTP;
     170      oldFP = negativesCount;
     171      roc.Add(new KeyValuePair<double, double>(actFP / negativesCount, actTP / positives.Count));
     172
     173      if (rocCharacteristics == null) {
     174        rocCharacteristics = new List<KeyValuePair<double, double>>();
     175        foreach (double maxTreshold in negatives.Distinct<double>()) {
     176          auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2;
     177          oldTP = actTP;
     178          oldFP = actFP;
     179          actTP = positives.Count<double>(value => minTreshold <= value && value < maxTreshold);
     180          actFP = negatives.Count<double>(value => minTreshold <= value && value < maxTreshold);
     181          rocCharacteristics.Add(new KeyValuePair<double, double>(oldTP - actTP, oldFP - actFP));
     182          roc.Add(new KeyValuePair<double, double>(actFP / negativesCount, actTP / positives.Count));
     183
     184          //stop calculation if truePositiveRate == 0 => straight line with y=0 & save runtime
     185          if ((actTP / positives.Count == 0) || (actFP / negatives.Count == 0))
     186            break;
     187        }
     188        auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2;
     189      } else { //characteristics of ROCs calculated
     190        foreach (KeyValuePair<double, double> rocCharac in rocCharacteristics) {
     191          auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2;
     192          oldTP = actTP;
     193          oldFP = actFP;
     194          actTP = oldTP - rocCharac.Key;
     195          actFP = oldFP - rocCharac.Value;
     196          roc.Add(new KeyValuePair<double, double>(actFP / negativesCount, actTP / positives.Count));
     197          if (actTP / positives.Count == 0)
     198            break;
     199        }
     200        auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2;
     201      }
     202    }
     203
     204    protected void CalculateROCValuesAndAUCForLastClass(List<double> positives, List<double> negatives, int negativesCount,
     205      out List<KeyValuePair<double, double>> roc, out double auc) {
     206      double actTP = -1;
     207      double actFP = -1;
     208      double oldTP = -1;
     209      double oldFP = -1;
     210      auc = 0;
     211      roc = new List<KeyValuePair<double, double>>();
     212
     213      actTP = positives.Count<double>(value => value >= negatives.Min<double>());
     214      actFP = negatives.Count<double>(value => value >= negatives.Min<double>());
     215      //add point (1,TPR) for AUC 'correct' calculation
     216      roc.Add(new KeyValuePair<double, double>(1, actTP / positives.Count));
     217      oldTP = actTP;
     218      oldFP = negativesCount;
     219      roc.Add(new KeyValuePair<double, double>(actFP / negativesCount, actTP / positives.Count));
     220
     221      foreach (double minTreshold in negatives.Distinct<double>()) {
     222        auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2;
     223        oldTP = actTP;
     224        oldFP = actFP;
     225        actTP = positives.Count<double>(value => minTreshold <= value);
     226        actFP = negatives.Count<double>(value => minTreshold <= value);
     227        roc.Add(new KeyValuePair<double, double>(actFP / negativesCount, actTP / positives.Count));
     228
     229        //stop calculation if truePositiveRate == 0 => straight line with y=0 & save runtime
     230        if ((actTP / positives.Count == 0) || (actFP / negatives.Count == 0))
    101231          break;
    102232      }
    103 
    104       //add case when treshold == max negative class value => falsePositiveRate ==1
    105       if (truePositiveRate != 1.0) {
    106         truePositiveRate = ((double)classes[positiveClassKey].Count<double>(value => value <= classes[negativeClassKey][classes[negativeClassKey].Count - 1])) / classes[positiveClassKey].Count;
    107         falsePositiveRate = 1;
    108         point = new DoubleArrayData(new double[2] { falsePositiveRate, truePositiveRate });
    109         rocValues[0].Add(point);
    110       } else {
    111         //ROC ends at 1,1
    112         point = new DoubleArrayData(new double[2] { 1, 1 });
    113         rocValues[0].Add(point);
    114       }
    115 
    116       return null;
    117     }
     233      auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2;
     234
     235    }
     236
     237    private ItemList Convert(List<KeyValuePair<double, double>> data) {
     238      ItemList list = new ItemList();
     239      ItemList row;
     240      foreach (KeyValuePair<double, double> dataPoint in data) {
     241        row = new ItemList();
     242        row.Add(new DoubleData(dataPoint.Key));
     243        row.Add(new DoubleData(dataPoint.Value));
     244        list.Add(row);
     245      }
     246      return list;
     247    }
     248
    118249  }
     250
    119251}
Note: See TracChangeset for help on using the changeset viewer.