#region License Information /* HeuristicLab * Copyright (C) 2002-2008 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Linq; using System.Text; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.DataAnalysis; namespace HeuristicLab.GP.StructureIdentification.Classification { public class ROCAnalyzer : OperatorBase { private ItemList myRocValues; private ItemList myAucValues; public override string Description { get { return @"Calculate TPR & FPR for various thresholds on dataset"; } } public ROCAnalyzer() : base() { AddVariableInfo(new VariableInfo("Values", "Item list holding the estimated and original values for the ROCAnalyzer", typeof(ItemList), VariableKind.In)); AddVariableInfo(new VariableInfo("ROCValues", "The values of the ROCAnalyzer, namely TPR & FPR", typeof(ItemList), VariableKind.New | VariableKind.Out)); AddVariableInfo(new VariableInfo("AUCValues", "The AUC Values for each ROC", typeof(ItemList), VariableKind.New | VariableKind.Out)); } public override IOperation Apply(IScope scope) { #region initialize HL-variables ItemList values = GetVariableValue("Values", scope, true); myRocValues = GetVariableValue("ROCValues", scope, false, false); if (myRocValues == null) { myRocValues = new ItemList(); IVariableInfo info = GetVariableInfo("ROCValues"); if (info.Local) AddVariable(new HeuristicLab.Core.Variable(info.ActualName, myRocValues)); else scope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(info.FormalName), myRocValues)); } else { myRocValues.Clear(); } myAucValues = GetVariableValue>("AUCValues", scope, false, false); if (myAucValues == null) { myAucValues = new ItemList(); IVariableInfo info = GetVariableInfo("AUCValues"); if (info.Local) AddVariable(new HeuristicLab.Core.Variable(info.ActualName, myAucValues)); else scope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(info.FormalName), myAucValues)); } else { myAucValues.Clear(); } #endregion //calculate new ROC Values double estimated = 0.0; double original = 0.0; //initialize classes dictionary SortedDictionary> classes = new SortedDictionary>(); foreach (ItemList value in values) { estimated = ((DoubleData)value[0]).Data; original = ((DoubleData)value[1]).Data; if (!classes.ContainsKey(original)) classes[original] = new List(); classes[original].Add(estimated); } foreach (double key in classes.Keys) classes[key].Sort(); //calculate ROC Curve foreach (double key in classes.Keys) { CalculateBestROC(key, classes); } return null; } protected void CalculateBestROC(double positiveClassKey, SortedDictionary> classes) { List> rocCharacteristics; List> bestROC; List> actROC; List negatives = new List(); foreach (double key in classes.Keys) { if (key != positiveClassKey) negatives.AddRange(classes[key]); } List actNegatives = negatives.Where(value => value < classes[positiveClassKey].Max()).ToList(); actNegatives.Add(classes[positiveClassKey].Max()); actNegatives.Sort(); actNegatives = actNegatives.Reverse().ToList(); double bestAUC = double.MinValue; double actAUC = 0; //first class if (classes.Keys.ElementAt(0) == positiveClassKey) { rocCharacteristics = null; CalculateROCValuesAndAUC(classes[positiveClassKey], actNegatives, negatives.Count, double.MinValue, ref rocCharacteristics, out actROC, out actAUC); myAucValues.Add(new DoubleData(actAUC)); myRocValues.Add(Convert(actROC)); } //middle classes else if (classes.Keys.ElementAt(classes.Keys.Count - 1) != positiveClassKey) { rocCharacteristics = null; bestROC = new List>(); foreach (double minThreshold in classes[positiveClassKey].Distinct()) { CalculateROCValuesAndAUC(classes[positiveClassKey], actNegatives, negatives.Count, minThreshold, ref rocCharacteristics, out actROC, out actAUC); if (actAUC > bestAUC) { bestAUC = actAUC; bestROC = actROC; } } myAucValues.Add(new DoubleData(bestAUC)); myRocValues.Add(Convert(bestROC)); } else { //last class actNegatives = negatives.Where(value => value > classes[positiveClassKey].Min()).ToList(); actNegatives.Add(classes[positiveClassKey].Min()); actNegatives.Sort(); CalculateROCValuesAndAUCForLastClass(classes[positiveClassKey], actNegatives, negatives.Count, out bestROC, out bestAUC); myAucValues.Add(new DoubleData(bestAUC)); myRocValues.Add(Convert(bestROC)); } } protected void CalculateROCValuesAndAUC(List positives, List negatives, int negativesCount, double minThreshold, ref List> rocCharacteristics, out List> roc, out double auc) { double actTP = -1; double actFP = -1; double oldTP = -1; double oldFP = -1; auc = 0; roc = new List>(); actTP = positives.Count(value => minThreshold <= value && value <= negatives.Max()); actFP = negatives.Count(value => minThreshold <= value ); //add point (1,TPR) for AUC 'correct' calculation roc.Add(new KeyValuePair(1, actTP / positives.Count)); oldTP = actTP; oldFP = negativesCount; roc.Add(new KeyValuePair(actFP / negativesCount, actTP / positives.Count)); if (rocCharacteristics == null) { rocCharacteristics = new List>(); foreach (double maxThreshold in negatives.Distinct()) { auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2; oldTP = actTP; oldFP = actFP; actTP = positives.Count(value => minThreshold <= value && value < maxThreshold); actFP = negatives.Count(value => minThreshold <= value && value < maxThreshold); rocCharacteristics.Add(new KeyValuePair(oldTP - actTP, oldFP - actFP)); roc.Add(new KeyValuePair(actFP / negativesCount, actTP / positives.Count)); //stop calculation if truePositiveRate == 0 => straight line with y=0 & save runtime if ((actTP == 0) || (actFP == 0)) break; } auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2; } else { //characteristics of ROCs calculated foreach (KeyValuePair rocCharac in rocCharacteristics) { auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2; oldTP = actTP; oldFP = actFP; actTP = oldTP - rocCharac.Key; actFP = oldFP - rocCharac.Value; roc.Add(new KeyValuePair(actFP / negativesCount, actTP / positives.Count)); if ((actTP == 0) || (actFP == 0)) break; } auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2; } } protected void CalculateROCValuesAndAUCForLastClass(List positives, List negatives, int negativesCount, out List> roc, out double auc) { double actTP = -1; double actFP = -1; double oldTP = -1; double oldFP = -1; auc = 0; roc = new List>(); actTP = positives.Count(value => value >= negatives.Min()); actFP = negatives.Count(value => value >= negatives.Min()); //add point (1,TPR) for AUC 'correct' calculation roc.Add(new KeyValuePair(1, actTP / positives.Count)); oldTP = actTP; oldFP = negativesCount; roc.Add(new KeyValuePair(actFP / negativesCount, actTP / positives.Count)); foreach (double minThreshold in negatives.Distinct()) { auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2; oldTP = actTP; oldFP = actFP; actTP = positives.Count(value => minThreshold < value); actFP = negatives.Count(value => minThreshold < value); roc.Add(new KeyValuePair(actFP / negativesCount, actTP / positives.Count)); //stop calculation if truePositiveRate == 0 => straight line with y=0 & save runtime if (actTP == 0 || actFP==0) break; } auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2; } private ItemList Convert(List> data) { ItemList list = new ItemList(); ItemList row; foreach (KeyValuePair dataPoint in data) { row = new ItemList(); row.Add(new DoubleData(dataPoint.Key)); row.Add(new DoubleData(dataPoint.Value)); list.Add(row); } return list; } } }