Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.GP.StructureIdentification.Classification/ROCAnalyzer.cs @ 681

Last change on this file since 681 was 678, checked in by mkommend, 16 years ago

support for multiclass classification added; AUC calculated (ticket #308)

File size: 10.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2008 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Text;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.DataAnalysis;
29
30
31namespace HeuristicLab.GP.StructureIdentification.Classification {
32  public class ROCAnalyzer : OperatorBase {
33    private ItemList myRocValues;
34    private ItemList<DoubleData> myAucValues;
35
36
37    public override string Description {
38      get { return @"Calculate TPR & FPR for various tresholds on dataset"; }
39    }
40
41    public ROCAnalyzer()
42      : base() {
43      AddVariableInfo(new VariableInfo("Values", "Item list holding the estimated and orignial values for the ROCAnalyzer", typeof(ItemList), VariableKind.In));
44      AddVariableInfo(new VariableInfo("ROCValues", "The values of the ROCAnalyzer, namely TPR & FPR", typeof(ItemList), VariableKind.New | VariableKind.Out));
45      AddVariableInfo(new VariableInfo("AUCValues", "The AUC Values for each ROC", typeof(ItemList<DoubleData>), VariableKind.New | VariableKind.Out));
46    }
47
48    public override IOperation Apply(IScope scope) {
49      #region initialize HL-variables
50      ItemList values = GetVariableValue<ItemList>("Values", scope, true);
51      myRocValues = GetVariableValue<ItemList>("ROCValues", scope, false, false);
52      if (myRocValues == null) {
53        myRocValues = new ItemList();
54        IVariableInfo info = GetVariableInfo("ROCValues");
55        if (info.Local)
56          AddVariable(new HeuristicLab.Core.Variable(info.ActualName, myRocValues));
57        else
58          scope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(info.FormalName), myRocValues));
59      } else {
60        myRocValues.Clear();
61      }
62
63      myAucValues = GetVariableValue<ItemList<DoubleData>>("AUCValues", scope, false, false);
64      if (myAucValues == null) {
65        myAucValues = new ItemList<DoubleData>();
66        IVariableInfo info = GetVariableInfo("AUCValues");
67        if (info.Local)
68          AddVariable(new HeuristicLab.Core.Variable(info.ActualName, myAucValues));
69        else
70          scope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(info.FormalName), myAucValues));
71      } else {
72        myAucValues.Clear();
73      }
74      #endregion
75
76      //calculate new ROC Values
77      double estimated = 0.0;
78      double original = 0.0;
79
80      //initialize classes dictionary
81      SortedDictionary<double, List<double>> classes = new SortedDictionary<double, List<double>>();
82      foreach (ItemList value in values) {
83        estimated = ((DoubleData)value[0]).Data;
84        original = ((DoubleData)value[1]).Data;
85        if (!classes.ContainsKey(original))
86          classes[original] = new List<double>();
87        classes[original].Add(estimated);
88      }
89      foreach (double key in classes.Keys)
90        classes[key].Sort();
91
92      //check for 2 classes classification problem
93      //if (classes.Keys.Count != 2)
94      //  throw new Exception("ROCAnalyser only handles  2 class classification problems");
95
96      //calculate ROC Curve
97      foreach (double key in classes.Keys) {
98        CalculateBestROC(key, classes);
99      }
100
101      return null;
102    }
103
104    protected void CalculateBestROC(double positiveClassKey, SortedDictionary<double, List<double>> classes) {
105
106      int rocIndex = myRocValues.Count - 1;
107      List<KeyValuePair<double, double>> rocCharacteristics;
108      List<KeyValuePair<double, double>> bestROC;
109      List<KeyValuePair<double, double>> actROC;
110
111      List<double> negatives = new List<double>();
112      foreach (double key in classes.Keys) {
113        if (key != positiveClassKey)
114          negatives.AddRange(classes[key]);
115      }
116      List<double> actNegatives = negatives.Where<double>(value => value < classes[positiveClassKey].Max<double>()).ToList<double>();
117      actNegatives.Add(classes[positiveClassKey].Max<double>());
118      actNegatives.Sort();
119      actNegatives = actNegatives.Reverse<double>().ToList<double>();
120
121      double bestAUC = double.MinValue;
122      double actAUC = 0;
123      //first class
124      if (classes.Keys.ElementAt<double>(0) == positiveClassKey) {
125        rocCharacteristics = null;
126        CalculateROCValuesAndAUC(classes[positiveClassKey], actNegatives, negatives.Count, double.MinValue, ref rocCharacteristics, out  actROC, out actAUC);
127        myAucValues.Add(new DoubleData(actAUC));
128        myRocValues.Add(Convert(actROC));
129      }
130        //middle classes 
131      else if (classes.Keys.ElementAt<double>(classes.Keys.Count - 1) != positiveClassKey) {
132        rocCharacteristics = null;
133        bestROC = new List<KeyValuePair<double, double>>();
134        foreach (double minTreshold in classes[positiveClassKey].Distinct<double>()) {
135          CalculateROCValuesAndAUC(classes[positiveClassKey], actNegatives, negatives.Count, minTreshold, ref rocCharacteristics, out  actROC, out actAUC);
136          if (actAUC > bestAUC) {
137            bestAUC = actAUC;
138            bestROC = actROC;
139          }
140        }
141          myAucValues.Add(new DoubleData(bestAUC));
142          myRocValues.Add(Convert(bestROC));
143       
144      } else { //last class
145        actNegatives = negatives.Where<double>(value => value > classes[positiveClassKey].Min<double>()).ToList<double>();
146        actNegatives.Add(classes[positiveClassKey].Min<double>());
147        actNegatives.Sort();
148        CalculateROCValuesAndAUCForLastClass(classes[positiveClassKey], actNegatives, negatives.Count, out bestROC, out bestAUC);
149        myAucValues.Add(new DoubleData(bestAUC));
150        myRocValues.Add(Convert(bestROC));
151
152      }
153
154    }
155
156    protected void CalculateROCValuesAndAUC(List<double> positives, List<double> negatives, int negativesCount, double minTreshold,
157      ref List<KeyValuePair<double, double>> rocCharacteristics, out List<KeyValuePair<double, double>> roc, out double auc) {
158      double actTP = -1;
159      double actFP = -1;
160      double oldTP = -1;
161      double oldFP = -1;
162      auc = 0;
163      roc = new List<KeyValuePair<double, double>>();
164
165      actTP = positives.Count<double>(value => minTreshold <= value && value <= negatives.Max<double>());
166      actFP = negatives.Count<double>(value => minTreshold <= value && value <= negatives.Max<double>());
167      //add point (1,TPR) for AUC 'correct' calculation
168      roc.Add(new KeyValuePair<double, double>(1, actTP / positives.Count));
169      oldTP = actTP;
170      oldFP = negativesCount;
171      roc.Add(new KeyValuePair<double, double>(actFP / negativesCount, actTP / positives.Count));
172
173      if (rocCharacteristics == null) {
174        rocCharacteristics = new List<KeyValuePair<double, double>>();
175        foreach (double maxTreshold in negatives.Distinct<double>()) {
176          auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2;
177          oldTP = actTP;
178          oldFP = actFP;
179          actTP = positives.Count<double>(value => minTreshold <= value && value < maxTreshold);
180          actFP = negatives.Count<double>(value => minTreshold <= value && value < maxTreshold);
181          rocCharacteristics.Add(new KeyValuePair<double, double>(oldTP - actTP, oldFP - actFP));
182          roc.Add(new KeyValuePair<double, double>(actFP / negativesCount, actTP / positives.Count));
183
184          //stop calculation if truePositiveRate == 0 => straight line with y=0 & save runtime
185          if ((actTP / positives.Count == 0) || (actFP / negatives.Count == 0))
186            break;
187        }
188        auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2;
189      } else { //characteristics of ROCs calculated
190        foreach (KeyValuePair<double, double> rocCharac in rocCharacteristics) {
191          auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2;
192          oldTP = actTP;
193          oldFP = actFP;
194          actTP = oldTP - rocCharac.Key;
195          actFP = oldFP - rocCharac.Value;
196          roc.Add(new KeyValuePair<double, double>(actFP / negativesCount, actTP / positives.Count));
197          if (actTP / positives.Count == 0)
198            break;
199        }
200        auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2;
201      }
202    }
203
204    protected void CalculateROCValuesAndAUCForLastClass(List<double> positives, List<double> negatives, int negativesCount,
205      out List<KeyValuePair<double, double>> roc, out double auc) {
206      double actTP = -1;
207      double actFP = -1;
208      double oldTP = -1;
209      double oldFP = -1;
210      auc = 0;
211      roc = new List<KeyValuePair<double, double>>();
212
213      actTP = positives.Count<double>(value => value >= negatives.Min<double>());
214      actFP = negatives.Count<double>(value => value >= negatives.Min<double>());
215      //add point (1,TPR) for AUC 'correct' calculation
216      roc.Add(new KeyValuePair<double, double>(1, actTP / positives.Count));
217      oldTP = actTP;
218      oldFP = negativesCount;
219      roc.Add(new KeyValuePair<double, double>(actFP / negativesCount, actTP / positives.Count));
220
221      foreach (double minTreshold in negatives.Distinct<double>()) {
222        auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2;
223        oldTP = actTP;
224        oldFP = actFP;
225        actTP = positives.Count<double>(value => minTreshold <= value);
226        actFP = negatives.Count<double>(value => minTreshold <= value);
227        roc.Add(new KeyValuePair<double, double>(actFP / negativesCount, actTP / positives.Count));
228
229        //stop calculation if truePositiveRate == 0 => straight line with y=0 & save runtime
230        if ((actTP / positives.Count == 0) || (actFP / negatives.Count == 0))
231          break;
232      }
233      auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2;
234
235    }
236
237    private ItemList Convert(List<KeyValuePair<double, double>> data) {
238      ItemList list = new ItemList();
239      ItemList row;
240      foreach (KeyValuePair<double, double> dataPoint in data) {
241        row = new ItemList();
242        row.Add(new DoubleData(dataPoint.Key));
243        row.Add(new DoubleData(dataPoint.Value));
244        list.Add(row);
245      }
246      return list;
247    }
248
249  }
250
251}
Note: See TracBrowser for help on using the repository browser.