Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.GP.StructureIdentification.Classification/ROCAnalyzer.cs @ 1325

Last change on this file since 1325 was 696, checked in by mkommend, 16 years ago

source improved through codereview with GK (ticket #308)

File size: 10.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2008 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Text;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.DataAnalysis;
29
30
31namespace HeuristicLab.GP.StructureIdentification.Classification {
32  public class ROCAnalyzer : OperatorBase {
33    private ItemList myRocValues;
34    private ItemList<DoubleData> myAucValues;
35
36
37    public override string Description {
38      get { return @"Calculate TPR & FPR for various thresholds on dataset"; }
39    }
40
41    public ROCAnalyzer()
42      : base() {
43      AddVariableInfo(new VariableInfo("Values", "Item list holding the estimated and original values for the ROCAnalyzer", typeof(ItemList), VariableKind.In));
44      AddVariableInfo(new VariableInfo("ROCValues", "The values of the ROCAnalyzer, namely TPR & FPR", typeof(ItemList), VariableKind.New | VariableKind.Out));
45      AddVariableInfo(new VariableInfo("AUCValues", "The AUC Values for each ROC", typeof(ItemList<DoubleData>), VariableKind.New | VariableKind.Out));
46    }
47
48    public override IOperation Apply(IScope scope) {
49      #region initialize HL-variables
50      ItemList values = GetVariableValue<ItemList>("Values", scope, true);
51      myRocValues = GetVariableValue<ItemList>("ROCValues", scope, false, false);
52      if (myRocValues == null) {
53        myRocValues = new ItemList();
54        IVariableInfo info = GetVariableInfo("ROCValues");
55        if (info.Local)
56          AddVariable(new HeuristicLab.Core.Variable(info.ActualName, myRocValues));
57        else
58          scope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(info.FormalName), myRocValues));
59      } else {
60        myRocValues.Clear();
61      }
62
63      myAucValues = GetVariableValue<ItemList<DoubleData>>("AUCValues", scope, false, false);
64      if (myAucValues == null) {
65        myAucValues = new ItemList<DoubleData>();
66        IVariableInfo info = GetVariableInfo("AUCValues");
67        if (info.Local)
68          AddVariable(new HeuristicLab.Core.Variable(info.ActualName, myAucValues));
69        else
70          scope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(info.FormalName), myAucValues));
71      } else {
72        myAucValues.Clear();
73      }
74      #endregion
75
76      //calculate new ROC Values
77      double estimated = 0.0;
78      double original = 0.0;
79
80      //initialize classes dictionary
81      SortedDictionary<double, List<double>> classes = new SortedDictionary<double, List<double>>();
82      foreach (ItemList value in values) {
83        estimated = ((DoubleData)value[0]).Data;
84        original = ((DoubleData)value[1]).Data;
85        if (!classes.ContainsKey(original))
86          classes[original] = new List<double>();
87        classes[original].Add(estimated);
88      }
89      foreach (double key in classes.Keys)
90        classes[key].Sort();
91 
92      //calculate ROC Curve
93      foreach (double key in classes.Keys) {
94        CalculateBestROC(key, classes);
95      }
96
97      return null;
98    }
99
100    protected void CalculateBestROC(double positiveClassKey, SortedDictionary<double, List<double>> classes) {
101      List<KeyValuePair<double, double>> rocCharacteristics;
102      List<KeyValuePair<double, double>> bestROC;
103      List<KeyValuePair<double, double>> actROC;
104
105      List<double> negatives = new List<double>();
106      foreach (double key in classes.Keys) {
107        if (key != positiveClassKey)
108          negatives.AddRange(classes[key]);
109      }
110      List<double> actNegatives = negatives.Where<double>(value => value < classes[positiveClassKey].Max<double>()).ToList<double>();
111      actNegatives.Add(classes[positiveClassKey].Max<double>());
112      actNegatives.Sort();
113      actNegatives = actNegatives.Reverse<double>().ToList<double>();
114
115      double bestAUC = double.MinValue;
116      double actAUC = 0;
117      //first class
118      if (classes.Keys.ElementAt<double>(0) == positiveClassKey) {
119        rocCharacteristics = null;
120        CalculateROCValuesAndAUC(classes[positiveClassKey], actNegatives, negatives.Count, double.MinValue, ref rocCharacteristics, out  actROC, out actAUC);
121        myAucValues.Add(new DoubleData(actAUC));
122        myRocValues.Add(Convert(actROC));
123      }
124        //middle classes 
125      else if (classes.Keys.ElementAt<double>(classes.Keys.Count - 1) != positiveClassKey) {
126        rocCharacteristics = null;
127        bestROC = new List<KeyValuePair<double, double>>();
128        foreach (double minThreshold in classes[positiveClassKey].Distinct<double>()) {
129          CalculateROCValuesAndAUC(classes[positiveClassKey], actNegatives, negatives.Count, minThreshold, ref rocCharacteristics, out  actROC, out actAUC);
130          if (actAUC > bestAUC) {
131            bestAUC = actAUC;
132            bestROC = actROC;
133          }
134        }
135          myAucValues.Add(new DoubleData(bestAUC));
136          myRocValues.Add(Convert(bestROC));
137       
138      } else { //last class
139        actNegatives = negatives.Where<double>(value => value > classes[positiveClassKey].Min<double>()).ToList<double>();
140        actNegatives.Add(classes[positiveClassKey].Min<double>());
141        actNegatives.Sort();
142        CalculateROCValuesAndAUCForLastClass(classes[positiveClassKey], actNegatives, negatives.Count, out bestROC, out bestAUC);
143        myAucValues.Add(new DoubleData(bestAUC));
144        myRocValues.Add(Convert(bestROC));
145
146      }
147
148    }
149
150    protected void CalculateROCValuesAndAUC(List<double> positives, List<double> negatives, int negativesCount, double minThreshold,
151      ref List<KeyValuePair<double, double>> rocCharacteristics, out List<KeyValuePair<double, double>> roc, out double auc) {
152      double actTP = -1;
153      double actFP = -1;
154      double oldTP = -1;
155      double oldFP = -1;
156      auc = 0;
157      roc = new List<KeyValuePair<double, double>>();
158
159      actTP = positives.Count<double>(value => minThreshold <= value && value <= negatives.Max<double>());
160      actFP = negatives.Count<double>(value => minThreshold <= value );
161      //add point (1,TPR) for AUC 'correct' calculation
162      roc.Add(new KeyValuePair<double, double>(1, actTP / positives.Count));
163      oldTP = actTP;
164      oldFP = negativesCount;
165      roc.Add(new KeyValuePair<double, double>(actFP / negativesCount, actTP / positives.Count));
166
167      if (rocCharacteristics == null) {
168        rocCharacteristics = new List<KeyValuePair<double, double>>();
169        foreach (double maxThreshold in negatives.Distinct<double>()) {
170          auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2;
171          oldTP = actTP;
172          oldFP = actFP;
173          actTP = positives.Count<double>(value => minThreshold <= value && value < maxThreshold);
174          actFP = negatives.Count<double>(value => minThreshold <= value && value < maxThreshold);
175          rocCharacteristics.Add(new KeyValuePair<double, double>(oldTP - actTP, oldFP - actFP));
176          roc.Add(new KeyValuePair<double, double>(actFP / negativesCount, actTP / positives.Count));
177
178          //stop calculation if truePositiveRate == 0 => straight line with y=0 & save runtime
179          if ((actTP == 0) || (actFP == 0))
180            break;
181        }
182        auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2;
183      } else { //characteristics of ROCs calculated
184        foreach (KeyValuePair<double, double> rocCharac in rocCharacteristics) {
185          auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2;
186          oldTP = actTP;
187          oldFP = actFP;
188          actTP = oldTP - rocCharac.Key;
189          actFP = oldFP - rocCharac.Value;
190          roc.Add(new KeyValuePair<double, double>(actFP / negativesCount, actTP / positives.Count));
191          if ((actTP == 0) || (actFP == 0))
192            break;
193        }
194        auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2;
195      }
196    }
197
198    protected void CalculateROCValuesAndAUCForLastClass(List<double> positives, List<double> negatives, int negativesCount,
199      out List<KeyValuePair<double, double>> roc, out double auc) {
200      double actTP = -1;
201      double actFP = -1;
202      double oldTP = -1;
203      double oldFP = -1;
204      auc = 0;
205      roc = new List<KeyValuePair<double, double>>();
206
207      actTP = positives.Count<double>(value => value >= negatives.Min<double>());
208      actFP = negatives.Count<double>(value => value >= negatives.Min<double>());
209      //add point (1,TPR) for AUC 'correct' calculation
210      roc.Add(new KeyValuePair<double, double>(1, actTP / positives.Count));
211      oldTP = actTP;
212      oldFP = negativesCount;
213      roc.Add(new KeyValuePair<double, double>(actFP / negativesCount, actTP / positives.Count));
214
215      foreach (double minThreshold in negatives.Distinct<double>()) {
216        auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2;
217        oldTP = actTP;
218        oldFP = actFP;
219        actTP = positives.Count<double>(value => minThreshold < value);
220        actFP = negatives.Count<double>(value => minThreshold < value);
221        roc.Add(new KeyValuePair<double, double>(actFP / negativesCount, actTP / positives.Count));
222
223        //stop calculation if truePositiveRate == 0 => straight line with y=0 & save runtime
224        if (actTP == 0 || actFP==0)
225          break;
226      }
227      auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2;
228
229    }
230
231    private ItemList Convert(List<KeyValuePair<double, double>> data) {
232      ItemList list = new ItemList();
233      ItemList row;
234      foreach (KeyValuePair<double, double> dataPoint in data) {
235        row = new ItemList();
236        row.Add(new DoubleData(dataPoint.Key));
237        row.Add(new DoubleData(dataPoint.Value));
238        list.Add(row);
239      }
240      return list;
241    }
242
243  }
244
245}
Note: See TracBrowser for help on using the repository browser.