Free cookie consent management tool by TermsFeed Policy Generator

source: branches/Persistence Test/HeuristicLab.GP.StructureIdentification.Classification/3.3/ROCAnalyzer.cs @ 2491

Last change on this file since 2491 was 2222, checked in by gkronber, 15 years ago

Merged changes from GP-refactoring branch back into the trunk #713.

File size: 10.5 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2008 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Core;
25using HeuristicLab.Data;
26
27
28namespace HeuristicLab.GP.StructureIdentification.Classification {
29  public class ROCAnalyzer : OperatorBase {
30    private ItemList myRocValues;
31    private ItemList<DoubleData> myAucValues;
32
33
34    public override string Description {
35      get { return @"Calculate TPR & FPR for various thresholds on dataset"; }
36    }
37
38    public ROCAnalyzer()
39      : base() {
40      AddVariableInfo(new VariableInfo("Values", "Item list holding the estimated and original values for the ROCAnalyzer", typeof(ItemList), VariableKind.In));
41      AddVariableInfo(new VariableInfo("ROCValues", "The values of the ROCAnalyzer, namely TPR & FPR", typeof(ItemList), VariableKind.New | VariableKind.Out));
42      AddVariableInfo(new VariableInfo("AUCValues", "The AUC Values for each ROC", typeof(ItemList<DoubleData>), VariableKind.New | VariableKind.Out));
43    }
44
45    public override IOperation Apply(IScope scope) {
46      #region initialize HL-variables
47      ItemList values = GetVariableValue<ItemList>("Values", scope, true);
48      myRocValues = GetVariableValue<ItemList>("ROCValues", scope, false, false);
49      if (myRocValues == null) {
50        myRocValues = new ItemList();
51        IVariableInfo info = GetVariableInfo("ROCValues");
52        if (info.Local)
53          AddVariable(new HeuristicLab.Core.Variable(info.ActualName, myRocValues));
54        else
55          scope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(info.FormalName), myRocValues));
56      } else {
57        myRocValues.Clear();
58      }
59
60      myAucValues = GetVariableValue<ItemList<DoubleData>>("AUCValues", scope, false, false);
61      if (myAucValues == null) {
62        myAucValues = new ItemList<DoubleData>();
63        IVariableInfo info = GetVariableInfo("AUCValues");
64        if (info.Local)
65          AddVariable(new HeuristicLab.Core.Variable(info.ActualName, myAucValues));
66        else
67          scope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(info.FormalName), myAucValues));
68      } else {
69        myAucValues.Clear();
70      }
71      #endregion
72
73      //calculate new ROC Values
74      double estimated = 0.0;
75      double original = 0.0;
76
77      //initialize classes dictionary
78      SortedDictionary<double, List<double>> classes = new SortedDictionary<double, List<double>>();
79      foreach (ItemList value in values) {
80        estimated = ((DoubleData)value[0]).Data;
81        original = ((DoubleData)value[1]).Data;
82        if (!classes.ContainsKey(original))
83          classes[original] = new List<double>();
84        classes[original].Add(estimated);
85      }
86      foreach (double key in classes.Keys)
87        classes[key].Sort();
88
89      //calculate ROC Curve
90      foreach (double key in classes.Keys) {
91        CalculateBestROC(key, classes);
92      }
93
94      return null;
95    }
96
97    protected void CalculateBestROC(double positiveClassKey, SortedDictionary<double, List<double>> classes) {
98      List<KeyValuePair<double, double>> rocCharacteristics;
99      List<KeyValuePair<double, double>> bestROC;
100      List<KeyValuePair<double, double>> actROC;
101
102      List<double> negatives = new List<double>();
103      foreach (double key in classes.Keys) {
104        if (key != positiveClassKey)
105          negatives.AddRange(classes[key]);
106      }
107      List<double> actNegatives = negatives.Where<double>(value => value < classes[positiveClassKey].Max<double>()).ToList<double>();
108      actNegatives.Add(classes[positiveClassKey].Max<double>());
109      actNegatives.Sort();
110      actNegatives = actNegatives.Reverse<double>().ToList<double>();
111
112      double bestAUC = double.MinValue;
113      double actAUC = 0;
114      //first class
115      if (classes.Keys.ElementAt<double>(0) == positiveClassKey) {
116        rocCharacteristics = null;
117        CalculateROCValuesAndAUC(classes[positiveClassKey], actNegatives, negatives.Count, double.MinValue, ref rocCharacteristics, out  actROC, out actAUC);
118        myAucValues.Add(new DoubleData(actAUC));
119        myRocValues.Add(Convert(actROC));
120      }
121        //middle classes 
122      else if (classes.Keys.ElementAt<double>(classes.Keys.Count - 1) != positiveClassKey) {
123        rocCharacteristics = null;
124        bestROC = new List<KeyValuePair<double, double>>();
125        foreach (double minThreshold in classes[positiveClassKey].Distinct<double>()) {
126          CalculateROCValuesAndAUC(classes[positiveClassKey], actNegatives, negatives.Count, minThreshold, ref rocCharacteristics, out  actROC, out actAUC);
127          if (actAUC > bestAUC) {
128            bestAUC = actAUC;
129            bestROC = actROC;
130          }
131        }
132        myAucValues.Add(new DoubleData(bestAUC));
133        myRocValues.Add(Convert(bestROC));
134
135      } else { //last class
136        actNegatives = negatives.Where<double>(value => value > classes[positiveClassKey].Min<double>()).ToList<double>();
137        actNegatives.Add(classes[positiveClassKey].Min<double>());
138        actNegatives.Sort();
139        CalculateROCValuesAndAUCForLastClass(classes[positiveClassKey], actNegatives, negatives.Count, out bestROC, out bestAUC);
140        myAucValues.Add(new DoubleData(bestAUC));
141        myRocValues.Add(Convert(bestROC));
142
143      }
144
145    }
146
147    protected void CalculateROCValuesAndAUC(List<double> positives, List<double> negatives, int negativesCount, double minThreshold,
148      ref List<KeyValuePair<double, double>> rocCharacteristics, out List<KeyValuePair<double, double>> roc, out double auc) {
149      double actTP = -1;
150      double actFP = -1;
151      double oldTP = -1;
152      double oldFP = -1;
153      auc = 0;
154      roc = new List<KeyValuePair<double, double>>();
155
156      actTP = positives.Count<double>(value => minThreshold <= value && value <= negatives.Max<double>());
157      actFP = negatives.Count<double>(value => minThreshold <= value);
158      //add point (1,TPR) for AUC 'correct' calculation
159      roc.Add(new KeyValuePair<double, double>(1, actTP / positives.Count));
160      oldTP = actTP;
161      oldFP = negativesCount;
162      roc.Add(new KeyValuePair<double, double>(actFP / negativesCount, actTP / positives.Count));
163
164      if (rocCharacteristics == null) {
165        rocCharacteristics = new List<KeyValuePair<double, double>>();
166        foreach (double maxThreshold in negatives.Distinct<double>()) {
167          auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2;
168          oldTP = actTP;
169          oldFP = actFP;
170          actTP = positives.Count<double>(value => minThreshold <= value && value < maxThreshold);
171          actFP = negatives.Count<double>(value => minThreshold <= value && value < maxThreshold);
172          rocCharacteristics.Add(new KeyValuePair<double, double>(oldTP - actTP, oldFP - actFP));
173          roc.Add(new KeyValuePair<double, double>(actFP / negativesCount, actTP / positives.Count));
174
175          //stop calculation if truePositiveRate == 0 => straight line with y=0 & save runtime
176          if ((actTP == 0) || (actFP == 0))
177            break;
178        }
179        auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2;
180      } else { //characteristics of ROCs calculated
181        foreach (KeyValuePair<double, double> rocCharac in rocCharacteristics) {
182          auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2;
183          oldTP = actTP;
184          oldFP = actFP;
185          actTP = oldTP - rocCharac.Key;
186          actFP = oldFP - rocCharac.Value;
187          roc.Add(new KeyValuePair<double, double>(actFP / negativesCount, actTP / positives.Count));
188          if ((actTP == 0) || (actFP == 0))
189            break;
190        }
191        auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2;
192      }
193    }
194
195    protected void CalculateROCValuesAndAUCForLastClass(List<double> positives, List<double> negatives, int negativesCount,
196      out List<KeyValuePair<double, double>> roc, out double auc) {
197      double actTP = -1;
198      double actFP = -1;
199      double oldTP = -1;
200      double oldFP = -1;
201      auc = 0;
202      roc = new List<KeyValuePair<double, double>>();
203
204      actTP = positives.Count<double>(value => value >= negatives.Min<double>());
205      actFP = negatives.Count<double>(value => value >= negatives.Min<double>());
206      //add point (1,TPR) for AUC 'correct' calculation
207      roc.Add(new KeyValuePair<double, double>(1, actTP / positives.Count));
208      oldTP = actTP;
209      oldFP = negativesCount;
210      roc.Add(new KeyValuePair<double, double>(actFP / negativesCount, actTP / positives.Count));
211
212      foreach (double minThreshold in negatives.Distinct<double>()) {
213        auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2;
214        oldTP = actTP;
215        oldFP = actFP;
216        actTP = positives.Count<double>(value => minThreshold < value);
217        actFP = negatives.Count<double>(value => minThreshold < value);
218        roc.Add(new KeyValuePair<double, double>(actFP / negativesCount, actTP / positives.Count));
219
220        //stop calculation if truePositiveRate == 0 => straight line with y=0 & save runtime
221        if (actTP == 0 || actFP == 0)
222          break;
223      }
224      auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2;
225
226    }
227
228    private ItemList Convert(List<KeyValuePair<double, double>> data) {
229      ItemList list = new ItemList();
230      ItemList row;
231      foreach (KeyValuePair<double, double> dataPoint in data) {
232        row = new ItemList();
233        row.Add(new DoubleData(dataPoint.Key));
234        row.Add(new DoubleData(dataPoint.Value));
235        list.Add(row);
236      }
237      return list;
238    }
239
240  }
241
242}
Note: See TracBrowser for help on using the repository browser.