Changeset 678
- Timestamp:
- 10/17/08 10:56:51 (16 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.GP.StructureIdentification.Classification/ROCAnalyzer.cs
r672 r678 31 31 namespace HeuristicLab.GP.StructureIdentification.Classification { 32 32 public class ROCAnalyzer : OperatorBase { 33 private ItemList myRocValues; 34 private ItemList<DoubleData> myAucValues; 35 33 36 34 37 public override string Description { … … 39 42 : base() { 40 43 AddVariableInfo(new VariableInfo("Values", "Item list holding the estimated and orignial values for the ROCAnalyzer", typeof(ItemList), VariableKind.In)); 41 AddVariableInfo(new VariableInfo("ROCValues", "The values of the ROCAnalyzer, namely TPR & FPR", typeof(ItemList<ItemList<DoubleArrayData>>), VariableKind.New | VariableKind.Out)); 44 AddVariableInfo(new VariableInfo("ROCValues", "The values of the ROCAnalyzer, namely TPR & FPR", typeof(ItemList), VariableKind.New | VariableKind.Out)); 45 AddVariableInfo(new VariableInfo("AUCValues", "The AUC Values for each ROC", typeof(ItemList<DoubleData>), VariableKind.New | VariableKind.Out)); 42 46 } 43 47 44 48 public override IOperation Apply(IScope scope) { 49 #region initialize HL-variables 45 50 ItemList values = GetVariableValue<ItemList>("Values", scope, true); 46 ItemList<ItemList<DoubleArrayData>> rocValues = GetVariableValue<ItemList<ItemList<DoubleArrayData>>>("ROCValues", scope, false, false);47 if ( rocValues == null) {48 rocValues = new ItemList<ItemList<DoubleArrayData>>();51 myRocValues = GetVariableValue<ItemList>("ROCValues", scope, false, false); 52 if (myRocValues == null) { 53 myRocValues = new ItemList(); 49 54 IVariableInfo info = GetVariableInfo("ROCValues"); 50 55 if (info.Local) 51 AddVariable(new HeuristicLab.Core.Variable(info.ActualName, rocValues));56 AddVariable(new HeuristicLab.Core.Variable(info.ActualName, myRocValues)); 52 57 else 53 scope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(info.FormalName), rocValues));58 scope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(info.FormalName), myRocValues)); 54 59 } else { 55 rocValues.Clear(); 56 } 57 58 rocValues.Add(new ItemList<DoubleArrayData>()); 59 //ROC Curve starts at 0,0 60 DoubleArrayData point = new DoubleArrayData(); 61 point.Data = new double[2] { 0, 0 }; 62 rocValues[0].Add(point); 60 myRocValues.Clear(); 61 } 62 63 myAucValues = GetVariableValue<ItemList<DoubleData>>("AUCValues", scope, false, false); 64 if (myAucValues == null) { 65 myAucValues = new ItemList<DoubleData>(); 66 IVariableInfo info = GetVariableInfo("AUCValues"); 67 if (info.Local) 68 AddVariable(new HeuristicLab.Core.Variable(info.ActualName, myAucValues)); 69 else 70 scope.AddVariable(new HeuristicLab.Core.Variable(scope.TranslateName(info.FormalName), myAucValues)); 71 } else { 72 myAucValues.Clear(); 73 } 74 #endregion 63 75 64 76 //calculate new ROC Values 65 77 double estimated = 0.0; 66 78 double original = 0.0; 67 double positiveClassKey;68 double negativeClassKey;69 double truePositiveRate = 0.0;70 double falsePositiveRate = 0.0;71 79 72 80 //initialize classes dictionary 73 Dictionary<double, List<double>> classes = newDictionary<double, List<double>>();81 SortedDictionary<double, List<double>> classes = new SortedDictionary<double, List<double>>(); 74 82 foreach (ItemList value in values) { 75 83 estimated = ((DoubleData)value[0]).Data; … … 79 87 classes[original].Add(estimated); 80 88 } 89 foreach (double key in classes.Keys) 90 classes[key].Sort(); 81 91 82 92 //check for 2 classes classification problem 83 if (classes.Keys.Count != 2) 84 throw new Exception("ROCAnalyser only handles 2 class classification problems"); 85 86 //sort estimated values in classes dictionary 87 foreach (List<double> estimatedValues in classes.Values) 88 estimatedValues.Sort(); 89 90 //calculate truePosivite- & falsePositiveRate 91 positiveClassKey = classes.Keys.Min<double>(); 92 negativeClassKey = classes.Keys.Max<double>(); 93 foreach (double treshold in classes[negativeClassKey].Distinct<double>()) { 94 truePositiveRate = ((double)classes[positiveClassKey].Count<double>(value => value < treshold)) / classes[positiveClassKey].Count; 95 falsePositiveRate = ((double)classes[negativeClassKey].Count<double>(value => value < treshold)) / classes[negativeClassKey].Count; 96 point = new DoubleArrayData(new double[2] { falsePositiveRate, truePositiveRate }); 97 rocValues[0].Add(point); 98 99 //stop calculation if truePositiveRate = 1; save runtime 100 if (truePositiveRate == 1) 93 //if (classes.Keys.Count != 2) 94 // throw new Exception("ROCAnalyser only handles 2 class classification problems"); 95 96 //calculate ROC Curve 97 foreach (double key in classes.Keys) { 98 CalculateBestROC(key, classes); 99 } 100 101 return null; 102 } 103 104 protected void CalculateBestROC(double positiveClassKey, SortedDictionary<double, List<double>> classes) { 105 106 int rocIndex = myRocValues.Count - 1; 107 List<KeyValuePair<double, double>> rocCharacteristics; 108 List<KeyValuePair<double, double>> bestROC; 109 List<KeyValuePair<double, double>> actROC; 110 111 List<double> negatives = new List<double>(); 112 foreach (double key in classes.Keys) { 113 if (key != positiveClassKey) 114 negatives.AddRange(classes[key]); 115 } 116 List<double> actNegatives = negatives.Where<double>(value => value < classes[positiveClassKey].Max<double>()).ToList<double>(); 117 actNegatives.Add(classes[positiveClassKey].Max<double>()); 118 actNegatives.Sort(); 119 actNegatives = actNegatives.Reverse<double>().ToList<double>(); 120 121 double bestAUC = double.MinValue; 122 double actAUC = 0; 123 //first class 124 if (classes.Keys.ElementAt<double>(0) == positiveClassKey) { 125 rocCharacteristics = null; 126 CalculateROCValuesAndAUC(classes[positiveClassKey], actNegatives, negatives.Count, double.MinValue, ref rocCharacteristics, out actROC, out actAUC); 127 myAucValues.Add(new DoubleData(actAUC)); 128 myRocValues.Add(Convert(actROC)); 129 } 130 //middle classes 131 else if (classes.Keys.ElementAt<double>(classes.Keys.Count - 1) != positiveClassKey) { 132 rocCharacteristics = null; 133 bestROC = new List<KeyValuePair<double, double>>(); 134 foreach (double minTreshold in classes[positiveClassKey].Distinct<double>()) { 135 CalculateROCValuesAndAUC(classes[positiveClassKey], actNegatives, negatives.Count, minTreshold, ref rocCharacteristics, out actROC, out actAUC); 136 if (actAUC > bestAUC) { 137 bestAUC = actAUC; 138 bestROC = actROC; 139 } 140 } 141 myAucValues.Add(new DoubleData(bestAUC)); 142 myRocValues.Add(Convert(bestROC)); 143 144 } else { //last class 145 actNegatives = negatives.Where<double>(value => value > classes[positiveClassKey].Min<double>()).ToList<double>(); 146 actNegatives.Add(classes[positiveClassKey].Min<double>()); 147 actNegatives.Sort(); 148 CalculateROCValuesAndAUCForLastClass(classes[positiveClassKey], actNegatives, negatives.Count, out bestROC, out bestAUC); 149 myAucValues.Add(new DoubleData(bestAUC)); 150 myRocValues.Add(Convert(bestROC)); 151 152 } 153 154 } 155 156 protected void CalculateROCValuesAndAUC(List<double> positives, List<double> negatives, int negativesCount, double minTreshold, 157 ref List<KeyValuePair<double, double>> rocCharacteristics, out List<KeyValuePair<double, double>> roc, out double auc) { 158 double actTP = -1; 159 double actFP = -1; 160 double oldTP = -1; 161 double oldFP = -1; 162 auc = 0; 163 roc = new List<KeyValuePair<double, double>>(); 164 165 actTP = positives.Count<double>(value => minTreshold <= value && value <= negatives.Max<double>()); 166 actFP = negatives.Count<double>(value => minTreshold <= value && value <= negatives.Max<double>()); 167 //add point (1,TPR) for AUC 'correct' calculation 168 roc.Add(new KeyValuePair<double, double>(1, actTP / positives.Count)); 169 oldTP = actTP; 170 oldFP = negativesCount; 171 roc.Add(new KeyValuePair<double, double>(actFP / negativesCount, actTP / positives.Count)); 172 173 if (rocCharacteristics == null) { 174 rocCharacteristics = new List<KeyValuePair<double, double>>(); 175 foreach (double maxTreshold in negatives.Distinct<double>()) { 176 auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2; 177 oldTP = actTP; 178 oldFP = actFP; 179 actTP = positives.Count<double>(value => minTreshold <= value && value < maxTreshold); 180 actFP = negatives.Count<double>(value => minTreshold <= value && value < maxTreshold); 181 rocCharacteristics.Add(new KeyValuePair<double, double>(oldTP - actTP, oldFP - actFP)); 182 roc.Add(new KeyValuePair<double, double>(actFP / negativesCount, actTP / positives.Count)); 183 184 //stop calculation if truePositiveRate == 0 => straight line with y=0 & save runtime 185 if ((actTP / positives.Count == 0) || (actFP / negatives.Count == 0)) 186 break; 187 } 188 auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2; 189 } else { //characteristics of ROCs calculated 190 foreach (KeyValuePair<double, double> rocCharac in rocCharacteristics) { 191 auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2; 192 oldTP = actTP; 193 oldFP = actFP; 194 actTP = oldTP - rocCharac.Key; 195 actFP = oldFP - rocCharac.Value; 196 roc.Add(new KeyValuePair<double, double>(actFP / negativesCount, actTP / positives.Count)); 197 if (actTP / positives.Count == 0) 198 break; 199 } 200 auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2; 201 } 202 } 203 204 protected void CalculateROCValuesAndAUCForLastClass(List<double> positives, List<double> negatives, int negativesCount, 205 out List<KeyValuePair<double, double>> roc, out double auc) { 206 double actTP = -1; 207 double actFP = -1; 208 double oldTP = -1; 209 double oldFP = -1; 210 auc = 0; 211 roc = new List<KeyValuePair<double, double>>(); 212 213 actTP = positives.Count<double>(value => value >= negatives.Min<double>()); 214 actFP = negatives.Count<double>(value => value >= negatives.Min<double>()); 215 //add point (1,TPR) for AUC 'correct' calculation 216 roc.Add(new KeyValuePair<double, double>(1, actTP / positives.Count)); 217 oldTP = actTP; 218 oldFP = negativesCount; 219 roc.Add(new KeyValuePair<double, double>(actFP / negativesCount, actTP / positives.Count)); 220 221 foreach (double minTreshold in negatives.Distinct<double>()) { 222 auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2; 223 oldTP = actTP; 224 oldFP = actFP; 225 actTP = positives.Count<double>(value => minTreshold <= value); 226 actFP = negatives.Count<double>(value => minTreshold <= value); 227 roc.Add(new KeyValuePair<double, double>(actFP / negativesCount, actTP / positives.Count)); 228 229 //stop calculation if truePositiveRate == 0 => straight line with y=0 & save runtime 230 if ((actTP / positives.Count == 0) || (actFP / negatives.Count == 0)) 101 231 break; 102 232 } 103 104 //add case when treshold == max negative class value => falsePositiveRate ==1 105 if (truePositiveRate != 1.0) { 106 truePositiveRate = ((double)classes[positiveClassKey].Count<double>(value => value <= classes[negativeClassKey][classes[negativeClassKey].Count - 1])) / classes[positiveClassKey].Count; 107 falsePositiveRate = 1; 108 point = new DoubleArrayData(new double[2] { falsePositiveRate, truePositiveRate }); 109 rocValues[0].Add(point); 110 } else { 111 //ROC ends at 1,1 112 point = new DoubleArrayData(new double[2] { 1, 1 }); 113 rocValues[0].Add(point); 114 } 115 116 return null; 117 } 233 auc += ((oldTP + actTP) / positives.Count) * ((oldFP - actFP) / negativesCount) / 2; 234 235 } 236 237 private ItemList Convert(List<KeyValuePair<double, double>> data) { 238 ItemList list = new ItemList(); 239 ItemList row; 240 foreach (KeyValuePair<double, double> dataPoint in data) { 241 row = new ItemList(); 242 row.Add(new DoubleData(dataPoint.Key)); 243 row.Add(new DoubleData(dataPoint.Value)); 244 list.Add(row); 245 } 246 return list; 247 } 248 118 249 } 250 119 251 }
Note: See TracChangeset
for help on using the changeset viewer.