Changeset 4068 for trunk/sources/HeuristicLab.ExtLibs/HeuristicLab.LibSVM/1.6.3/LibSVM-1.6.3/PerformanceEvaluator.cs
- Timestamp:
- 07/22/10 00:44:01 (14 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.ExtLibs/HeuristicLab.LibSVM/1.6.3/LibSVM-1.6.3/PerformanceEvaluator.cs
r2645 r4068 20 20 using System; 21 21 using System.Collections.Generic; 22 using System.Globalization; 22 23 using System.IO; 23 using System.Globalization; 24 25 namespace SVM 26 { 27 /// <summary> 28 /// Class encoding a member of a ranked set of labels. 29 /// </summary> 30 public class RankPair : IComparable<RankPair> 31 { 32 private double _score, _label; 33 34 /// <summary> 35 /// Constructor. 36 /// </summary> 37 /// <param name="score">Score for this pair</param> 38 /// <param name="label">Label associated with the given score</param> 39 public RankPair(double score, double label) 40 { 41 _score = score; 42 _label = label; 24 25 namespace SVM { 26 /// <summary> 27 /// Class encoding a member of a ranked set of labels. 28 /// </summary> 29 public class RankPair : IComparable<RankPair> { 30 private double _score, _label; 31 32 /// <summary> 33 /// Constructor. 34 /// </summary> 35 /// <param name="score">Score for this pair</param> 36 /// <param name="label">Label associated with the given score</param> 37 public RankPair(double score, double label) { 38 _score = score; 39 _label = label; 40 } 41 42 /// <summary> 43 /// The score for this pair. 44 /// </summary> 45 public double Score { 46 get { 47 return _score; 48 } 49 } 50 51 /// <summary> 52 /// The Label for this pair. 53 /// </summary> 54 public double Label { 55 get { 56 return _label; 57 } 58 } 59 60 #region IComparable<RankPair> Members 61 62 /// <summary> 63 /// Compares this pair to another. It will end up in a sorted list in decending score order. 64 /// </summary> 65 /// <param name="other">The pair to compare to</param> 66 /// <returns>Whether this should come before or after the argument</returns> 67 public int CompareTo(RankPair other) { 68 return other.Score.CompareTo(Score); 69 } 70 71 #endregion 72 73 /// <summary> 74 /// Returns a string representation of this pair. 75 /// </summary> 76 /// <returns>A string in the for Score:Label</returns> 77 public override string ToString() { 78 return string.Format("{0}:{1}", Score, Label); 79 } 80 } 81 82 /// <summary> 83 /// Class encoding the point on a 2D curve. 84 /// </summary> 85 public class CurvePoint { 86 private float _x, _y; 87 88 /// <summary> 89 /// Constructor. 90 /// </summary> 91 /// <param name="x">X coordinate</param> 92 /// <param name="y">Y coordinate</param> 93 public CurvePoint(float x, float y) { 94 _x = x; 95 _y = y; 96 } 97 98 /// <summary> 99 /// X coordinate 100 /// </summary> 101 public float X { 102 get { 103 return _x; 104 } 105 } 106 107 /// <summary> 108 /// Y coordinate 109 /// </summary> 110 public float Y { 111 get { 112 return _y; 113 } 114 } 115 116 /// <summary> 117 /// Creates a string representation of this point. 118 /// </summary> 119 /// <returns>string in the form (x, y)</returns> 120 public override string ToString() { 121 return string.Format("({0}, {1})", _x, _y); 122 } 123 } 124 125 /// <summary> 126 /// Class which evaluates an SVM model using several standard techniques. 127 /// </summary> 128 public class PerformanceEvaluator { 129 private class ChangePoint { 130 public ChangePoint(int tp, int fp, int tn, int fn) { 131 TP = tp; 132 FP = fp; 133 TN = tn; 134 FN = fn; 135 } 136 137 public int TP, FP, TN, FN; 138 139 public override string ToString() { 140 return string.Format("{0}:{1}:{2}:{3}", TP, FP, TN, FN); 141 } 142 } 143 144 private List<CurvePoint> _prCurve; 145 private double _ap; 146 147 private List<CurvePoint> _rocCurve; 148 private double _auc; 149 150 private List<RankPair> _data; 151 private List<ChangePoint> _changes; 152 153 /// <summary> 154 /// Constructor. 155 /// </summary> 156 /// <param name="set">A pre-computed ranked pair set</param> 157 public PerformanceEvaluator(List<RankPair> set) { 158 _data = set; 159 computeStatistics(); 160 } 161 162 /// <summary> 163 /// Constructor. 164 /// </summary> 165 /// <param name="model">Model to evaluate</param> 166 /// <param name="problem">Problem to evaluate</param> 167 /// <param name="category">Label to be evaluate for</param> 168 public PerformanceEvaluator(Model model, Problem problem, double category) : this(model, problem, category, "tmp.results") { } 169 /// <summary> 170 /// Constructor. 171 /// </summary> 172 /// <param name="model">Model to evaluate</param> 173 /// <param name="problem">Problem to evaluate</param> 174 /// <param name="resultsFile">Results file for output</param> 175 /// <param name="category">Category to evaluate for</param> 176 public PerformanceEvaluator(Model model, Problem problem, double category, string resultsFile) { 177 Prediction.Predict(problem, resultsFile, model, true); 178 parseResultsFile(resultsFile, problem.Y, category); 179 180 computeStatistics(); 181 } 182 183 /// <summary> 184 /// Constructor. 185 /// </summary> 186 /// <param name="resultsFile">Results file</param> 187 /// <param name="correctLabels">The correct labels of each data item</param> 188 /// <param name="category">The category to evaluate for</param> 189 public PerformanceEvaluator(string resultsFile, double[] correctLabels, double category) { 190 parseResultsFile(resultsFile, correctLabels, category); 191 computeStatistics(); 192 } 193 194 private void parseResultsFile(string resultsFile, double[] labels, double category) { 195 StreamReader input = new StreamReader(resultsFile); 196 string[] parts = input.ReadLine().Split(new char[] { ' ' }, StringSplitOptions.RemoveEmptyEntries); 197 int confidenceIndex = -1; 198 for (int i = 1; i < parts.Length; i++) 199 if (double.Parse(parts[i], CultureInfo.InvariantCulture) == category) { 200 confidenceIndex = i; 201 break; 43 202 } 44 45 /// <summary> 46 /// The score for this pair. 47 /// </summary> 48 public double Score 49 { 50 get 51 { 52 return _score; 53 } 203 _data = new List<RankPair>(); 204 for (int i = 0; i < labels.Length; i++) { 205 parts = input.ReadLine().Split(new char[] { ' ' }, StringSplitOptions.RemoveEmptyEntries); 206 double confidence = double.Parse(parts[confidenceIndex], CultureInfo.InvariantCulture); 207 _data.Add(new RankPair(confidence, labels[i] == category ? 1 : 0)); 208 } 209 input.Close(); 210 } 211 212 private void computeStatistics() { 213 _data.Sort(); 214 215 findChanges(); 216 computePR(); 217 computeRoC(); 218 } 219 220 private void findChanges() { 221 int tp, fp, tn, fn; 222 tp = fp = tn = fn = 0; 223 for (int i = 0; i < _data.Count; i++) { 224 if (_data[i].Label == 1) 225 fn++; 226 else tn++; 227 } 228 _changes = new List<ChangePoint>(); 229 for (int i = 0; i < _data.Count; i++) { 230 if (_data[i].Label == 1) { 231 tp++; 232 fn--; 233 } else { 234 fp++; 235 tn--; 54 236 } 55 56 /// <summary> 57 /// The Label for this pair. 58 /// </summary> 59 public double Label 60 { 61 get 62 { 63 return _label; 64 } 237 _changes.Add(new ChangePoint(tp, fp, tn, fn)); 238 } 239 } 240 241 private float computePrecision(ChangePoint p) { 242 return (float)p.TP / (p.TP + p.FP); 243 } 244 245 private float computeRecall(ChangePoint p) { 246 return (float)p.TP / (p.TP + p.FN); 247 } 248 249 private void computePR() { 250 _prCurve = new List<CurvePoint>(); 251 _prCurve.Add(new CurvePoint(0, 1)); 252 float precision = computePrecision(_changes[0]); 253 float recall = computeRecall(_changes[0]); 254 float precisionSum = 0; 255 if (_changes[0].TP > 0) { 256 precisionSum += precision; 257 _prCurve.Add(new CurvePoint(recall, precision)); 258 } 259 for (int i = 1; i < _changes.Count; i++) { 260 precision = computePrecision(_changes[i]); 261 recall = computeRecall(_changes[i]); 262 if (_changes[i].TP > _changes[i - 1].TP) { 263 precisionSum += precision; 264 _prCurve.Add(new CurvePoint(recall, precision)); 65 265 } 66 67 #region IComparable<RankPair> Members 68 69 /// <summary> 70 /// Compares this pair to another. It will end up in a sorted list in decending score order. 71 /// </summary> 72 /// <param name="other">The pair to compare to</param> 73 /// <returns>Whether this should come before or after the argument</returns> 74 public int CompareTo(RankPair other) 75 { 76 return other.Score.CompareTo(Score); 266 } 267 _prCurve.Add(new CurvePoint(1, (float)(_changes[0].TP + _changes[0].FN) / (_changes[0].FP + _changes[0].TN))); 268 _ap = precisionSum / (_changes[0].FN + _changes[0].TP); 269 } 270 271 /// <summary> 272 /// Writes the Precision-Recall curve to a tab-delimited file. 273 /// </summary> 274 /// <param name="filename">Filename for output</param> 275 public void WritePRCurve(string filename) { 276 StreamWriter output = new StreamWriter(filename); 277 output.WriteLine(_ap); 278 for (int i = 0; i < _prCurve.Count; i++) 279 output.WriteLine("{0}\t{1}", _prCurve[i].X, _prCurve[i].Y); 280 output.Close(); 281 } 282 283 /// <summary> 284 /// Writes the Receiver Operating Characteristic curve to a tab-delimited file. 285 /// </summary> 286 /// <param name="filename">Filename for output</param> 287 public void WriteROCCurve(string filename) { 288 StreamWriter output = new StreamWriter(filename); 289 output.WriteLine(_auc); 290 for (int i = 0; i < _rocCurve.Count; i++) 291 output.WriteLine("{0}\t{1}", _rocCurve[i].X, _rocCurve[i].Y); 292 output.Close(); 293 } 294 295 /// <summary> 296 /// Receiver Operating Characteristic curve 297 /// </summary> 298 public List<CurvePoint> ROCCurve { 299 get { 300 return _rocCurve; 301 } 302 } 303 304 /// <summary> 305 /// Returns the area under the ROC Curve 306 /// </summary> 307 public double AuC { 308 get { 309 return _auc; 310 } 311 } 312 313 /// <summary> 314 /// Precision-Recall curve 315 /// </summary> 316 public List<CurvePoint> PRCurve { 317 get { 318 return _prCurve; 319 } 320 } 321 322 /// <summary> 323 /// The average precision 324 /// </summary> 325 public double AP { 326 get { 327 return _ap; 328 } 329 } 330 331 private float computeTPR(ChangePoint cp) { 332 return computeRecall(cp); 333 } 334 335 private float computeFPR(ChangePoint cp) { 336 return (float)cp.FP / (cp.FP + cp.TN); 337 } 338 339 private void computeRoC() { 340 _rocCurve = new List<CurvePoint>(); 341 _rocCurve.Add(new CurvePoint(0, 0)); 342 float tpr = computeTPR(_changes[0]); 343 float fpr = computeFPR(_changes[0]); 344 _rocCurve.Add(new CurvePoint(fpr, tpr)); 345 _auc = 0; 346 for (int i = 1; i < _changes.Count; i++) { 347 float newTPR = computeTPR(_changes[i]); 348 float newFPR = computeFPR(_changes[i]); 349 if (_changes[i].TP > _changes[i - 1].TP) { 350 _auc += tpr * (newFPR - fpr) + .5 * (newTPR - tpr) * (newFPR - fpr); 351 tpr = newTPR; 352 fpr = newFPR; 353 _rocCurve.Add(new CurvePoint(fpr, tpr)); 77 354 } 78 79 #endregion 80 81 /// <summary> 82 /// Returns a string representation of this pair. 83 /// </summary> 84 /// <returns>A string in the for Score:Label</returns> 85 public override string ToString() 86 { 87 return string.Format("{0}:{1}", Score, Label); 88 } 89 } 90 91 /// <summary> 92 /// Class encoding the point on a 2D curve. 93 /// </summary> 94 public class CurvePoint 95 { 96 private float _x, _y; 97 98 /// <summary> 99 /// Constructor. 100 /// </summary> 101 /// <param name="x">X coordinate</param> 102 /// <param name="y">Y coordinate</param> 103 public CurvePoint(float x, float y) 104 { 105 _x = x; 106 _y = y; 107 } 108 109 /// <summary> 110 /// X coordinate 111 /// </summary> 112 public float X 113 { 114 get 115 { 116 return _x; 117 } 118 } 119 120 /// <summary> 121 /// Y coordinate 122 /// </summary> 123 public float Y 124 { 125 get 126 { 127 return _y; 128 } 129 } 130 131 /// <summary> 132 /// Creates a string representation of this point. 133 /// </summary> 134 /// <returns>string in the form (x, y)</returns> 135 public override string ToString() 136 { 137 return string.Format("({0}, {1})", _x, _y); 138 } 139 } 140 141 /// <summary> 142 /// Class which evaluates an SVM model using several standard techniques. 143 /// </summary> 144 public class PerformanceEvaluator 145 { 146 private class ChangePoint 147 { 148 public ChangePoint(int tp, int fp, int tn, int fn) 149 { 150 TP = tp; 151 FP = fp; 152 TN = tn; 153 FN = fn; 154 } 155 156 public int TP, FP, TN, FN; 157 158 public override string ToString() 159 { 160 return string.Format("{0}:{1}:{2}:{3}", TP, FP, TN, FN); 161 } 162 } 163 164 private List<CurvePoint> _prCurve; 165 private double _ap; 166 167 private List<CurvePoint> _rocCurve; 168 private double _auc; 169 170 private List<RankPair> _data; 171 private List<ChangePoint> _changes; 172 173 /// <summary> 174 /// Constructor. 175 /// </summary> 176 /// <param name="set">A pre-computed ranked pair set</param> 177 public PerformanceEvaluator(List<RankPair> set) 178 { 179 _data = set; 180 computeStatistics(); 181 } 182 183 /// <summary> 184 /// Constructor. 185 /// </summary> 186 /// <param name="model">Model to evaluate</param> 187 /// <param name="problem">Problem to evaluate</param> 188 /// <param name="category">Label to be evaluate for</param> 189 public PerformanceEvaluator(Model model, Problem problem, double category) : this(model, problem, category, "tmp.results") { } 190 /// <summary> 191 /// Constructor. 192 /// </summary> 193 /// <param name="model">Model to evaluate</param> 194 /// <param name="problem">Problem to evaluate</param> 195 /// <param name="resultsFile">Results file for output</param> 196 /// <param name="category">Category to evaluate for</param> 197 public PerformanceEvaluator(Model model, Problem problem, double category, string resultsFile) 198 { 199 Prediction.Predict(problem, resultsFile, model, true); 200 parseResultsFile(resultsFile, problem.Y, category); 201 202 computeStatistics(); 203 } 204 205 /// <summary> 206 /// Constructor. 207 /// </summary> 208 /// <param name="resultsFile">Results file</param> 209 /// <param name="correctLabels">The correct labels of each data item</param> 210 /// <param name="category">The category to evaluate for</param> 211 public PerformanceEvaluator(string resultsFile, double[] correctLabels, double category) 212 { 213 parseResultsFile(resultsFile, correctLabels, category); 214 computeStatistics(); 215 } 216 217 private void parseResultsFile(string resultsFile, double[] labels, double category) 218 { 219 StreamReader input = new StreamReader(resultsFile); 220 string[] parts = input.ReadLine().Split(new char[] { ' ' }, StringSplitOptions.RemoveEmptyEntries); 221 int confidenceIndex = -1; 222 for (int i = 1; i < parts.Length; i++) 223 if (double.Parse(parts[i], CultureInfo.InvariantCulture) == category) 224 { 225 confidenceIndex = i; 226 break; 227 } 228 _data = new List<RankPair>(); 229 for (int i = 0; i < labels.Length; i++) 230 { 231 parts = input.ReadLine().Split(new char[] { ' ' }, StringSplitOptions.RemoveEmptyEntries); 232 double confidence = double.Parse(parts[confidenceIndex], CultureInfo.InvariantCulture); 233 _data.Add(new RankPair(confidence, labels[i] == category ? 1 : 0)); 234 } 235 input.Close(); 236 } 237 238 private void computeStatistics() 239 { 240 _data.Sort(); 241 242 findChanges(); 243 computePR(); 244 computeRoC(); 245 } 246 247 private void findChanges() 248 { 249 int tp, fp, tn, fn; 250 tp = fp = tn = fn = 0; 251 for (int i = 0; i < _data.Count; i++) 252 { 253 if (_data[i].Label == 1) 254 fn++; 255 else tn++; 256 } 257 _changes = new List<ChangePoint>(); 258 for (int i = 0; i < _data.Count; i++) 259 { 260 if (_data[i].Label == 1) 261 { 262 tp++; 263 fn--; 264 } 265 else 266 { 267 fp++; 268 tn--; 269 } 270 _changes.Add(new ChangePoint(tp, fp, tn, fn)); 271 } 272 } 273 274 private float computePrecision(ChangePoint p) 275 { 276 return (float)p.TP / (p.TP + p.FP); 277 } 278 279 private float computeRecall(ChangePoint p) 280 { 281 return (float)p.TP / (p.TP + p.FN); 282 } 283 284 private void computePR() 285 { 286 _prCurve = new List<CurvePoint>(); 287 _prCurve.Add(new CurvePoint(0, 1)); 288 float precision = computePrecision(_changes[0]); 289 float recall = computeRecall(_changes[0]); 290 float precisionSum = 0; 291 if (_changes[0].TP > 0) 292 { 293 precisionSum += precision; 294 _prCurve.Add(new CurvePoint(recall, precision)); 295 } 296 for (int i = 1; i < _changes.Count; i++) 297 { 298 precision = computePrecision(_changes[i]); 299 recall = computeRecall(_changes[i]); 300 if (_changes[i].TP > _changes[i - 1].TP) 301 { 302 precisionSum += precision; 303 _prCurve.Add(new CurvePoint(recall, precision)); 304 } 305 } 306 _prCurve.Add(new CurvePoint(1, (float)(_changes[0].TP + _changes[0].FN) / (_changes[0].FP + _changes[0].TN))); 307 _ap = precisionSum / (_changes[0].FN + _changes[0].TP); 308 } 309 310 /// <summary> 311 /// Writes the Precision-Recall curve to a tab-delimited file. 312 /// </summary> 313 /// <param name="filename">Filename for output</param> 314 public void WritePRCurve(string filename) 315 { 316 StreamWriter output = new StreamWriter(filename); 317 output.WriteLine(_ap); 318 for (int i = 0; i < _prCurve.Count; i++) 319 output.WriteLine("{0}\t{1}", _prCurve[i].X, _prCurve[i].Y); 320 output.Close(); 321 } 322 323 /// <summary> 324 /// Writes the Receiver Operating Characteristic curve to a tab-delimited file. 325 /// </summary> 326 /// <param name="filename">Filename for output</param> 327 public void WriteROCCurve(string filename) 328 { 329 StreamWriter output = new StreamWriter(filename); 330 output.WriteLine(_auc); 331 for (int i = 0; i < _rocCurve.Count; i++) 332 output.WriteLine("{0}\t{1}", _rocCurve[i].X, _rocCurve[i].Y); 333 output.Close(); 334 } 335 336 /// <summary> 337 /// Receiver Operating Characteristic curve 338 /// </summary> 339 public List<CurvePoint> ROCCurve 340 { 341 get 342 { 343 return _rocCurve; 344 } 345 } 346 347 /// <summary> 348 /// Returns the area under the ROC Curve 349 /// </summary> 350 public double AuC 351 { 352 get 353 { 354 return _auc; 355 } 356 } 357 358 /// <summary> 359 /// Precision-Recall curve 360 /// </summary> 361 public List<CurvePoint> PRCurve 362 { 363 get 364 { 365 return _prCurve; 366 } 367 } 368 369 /// <summary> 370 /// The average precision 371 /// </summary> 372 public double AP 373 { 374 get 375 { 376 return _ap; 377 } 378 } 379 380 private float computeTPR(ChangePoint cp) 381 { 382 return computeRecall(cp); 383 } 384 385 private float computeFPR(ChangePoint cp) 386 { 387 return (float)cp.FP / (cp.FP + cp.TN); 388 } 389 390 private void computeRoC() 391 { 392 _rocCurve = new List<CurvePoint>(); 393 _rocCurve.Add(new CurvePoint(0, 0)); 394 float tpr = computeTPR(_changes[0]); 395 float fpr = computeFPR(_changes[0]); 396 _rocCurve.Add(new CurvePoint(fpr, tpr)); 397 _auc = 0; 398 for (int i = 1; i < _changes.Count; i++) 399 { 400 float newTPR = computeTPR(_changes[i]); 401 float newFPR = computeFPR(_changes[i]); 402 if (_changes[i].TP > _changes[i - 1].TP) 403 { 404 _auc += tpr * (newFPR - fpr) + .5 * (newTPR - tpr) * (newFPR - fpr); 405 tpr = newTPR; 406 fpr = newFPR; 407 _rocCurve.Add(new CurvePoint(fpr, tpr)); 408 } 409 } 410 _rocCurve.Add(new CurvePoint(1, 1)); 411 _auc += tpr * (1 - fpr) + .5 * (1 - tpr) * (1 - fpr); 412 } 413 414 } 355 } 356 _rocCurve.Add(new CurvePoint(1, 1)); 357 _auc += tpr * (1 - fpr) + .5 * (1 - tpr) * (1 - fpr); 358 } 359 360 } 415 361 }
Note: See TracChangeset
for help on using the changeset viewer.