Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.ExtLibs/HeuristicLab.LibSVM/1.6.3/LibSVM-1.6.3/PerformanceEvaluator.cs @ 6349

Last change on this file since 6349 was 4068, checked in by swagner, 14 years ago

Sorted usings and removed unused usings in entire solution (#1094)

File size: 10.6 KB
Line 
1/*
2 * SVM.NET Library
3 * Copyright (C) 2008 Matthew Johnson
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation, either version 3 of the License, or
8 * (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU General Public License for more details.
14 *
15 * You should have received a copy of the GNU General Public License
16 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19
20using System;
21using System.Collections.Generic;
22using System.Globalization;
23using System.IO;
24
25namespace SVM {
26  /// <summary>
27  /// Class encoding a member of a ranked set of labels.
28  /// </summary>
29  public class RankPair : IComparable<RankPair> {
30    private double _score, _label;
31
32    /// <summary>
33    /// Constructor.
34    /// </summary>
35    /// <param name="score">Score for this pair</param>
36    /// <param name="label">Label associated with the given score</param>
37    public RankPair(double score, double label) {
38      _score = score;
39      _label = label;
40    }
41
42    /// <summary>
43    /// The score for this pair.
44    /// </summary>
45    public double Score {
46      get {
47        return _score;
48      }
49    }
50
51    /// <summary>
52    /// The Label for this pair.
53    /// </summary>
54    public double Label {
55      get {
56        return _label;
57      }
58    }
59
60    #region IComparable<RankPair> Members
61
62    /// <summary>
63    /// Compares this pair to another.  It will end up in a sorted list in decending score order.
64    /// </summary>
65    /// <param name="other">The pair to compare to</param>
66    /// <returns>Whether this should come before or after the argument</returns>
67    public int CompareTo(RankPair other) {
68      return other.Score.CompareTo(Score);
69    }
70
71    #endregion
72
73    /// <summary>
74    /// Returns a string representation of this pair.
75    /// </summary>
76    /// <returns>A string in the for Score:Label</returns>
77    public override string ToString() {
78      return string.Format("{0}:{1}", Score, Label);
79    }
80  }
81
82  /// <summary>
83  /// Class encoding the point on a 2D curve.
84  /// </summary>
85  public class CurvePoint {
86    private float _x, _y;
87
88    /// <summary>
89    /// Constructor.
90    /// </summary>
91    /// <param name="x">X coordinate</param>
92    /// <param name="y">Y coordinate</param>
93    public CurvePoint(float x, float y) {
94      _x = x;
95      _y = y;
96    }
97
98    /// <summary>
99    /// X coordinate
100    /// </summary>
101    public float X {
102      get {
103        return _x;
104      }
105    }
106
107    /// <summary>
108    /// Y coordinate
109    /// </summary>
110    public float Y {
111      get {
112        return _y;
113      }
114    }
115
116    /// <summary>
117    /// Creates a string representation of this point.
118    /// </summary>
119    /// <returns>string in the form (x, y)</returns>
120    public override string ToString() {
121      return string.Format("({0}, {1})", _x, _y);
122    }
123  }
124
125  /// <summary>
126  /// Class which evaluates an SVM model using several standard techniques.
127  /// </summary>
128  public class PerformanceEvaluator {
129    private class ChangePoint {
130      public ChangePoint(int tp, int fp, int tn, int fn) {
131        TP = tp;
132        FP = fp;
133        TN = tn;
134        FN = fn;
135      }
136
137      public int TP, FP, TN, FN;
138
139      public override string ToString() {
140        return string.Format("{0}:{1}:{2}:{3}", TP, FP, TN, FN);
141      }
142    }
143
144    private List<CurvePoint> _prCurve;
145    private double _ap;
146
147    private List<CurvePoint> _rocCurve;
148    private double _auc;
149
150    private List<RankPair> _data;
151    private List<ChangePoint> _changes;
152
153    /// <summary>
154    /// Constructor.
155    /// </summary>
156    /// <param name="set">A pre-computed ranked pair set</param>
157    public PerformanceEvaluator(List<RankPair> set) {
158      _data = set;
159      computeStatistics();
160    }
161
162    /// <summary>
163    /// Constructor.
164    /// </summary>
165    /// <param name="model">Model to evaluate</param>
166    /// <param name="problem">Problem to evaluate</param>
167    /// <param name="category">Label to be evaluate for</param>
168    public PerformanceEvaluator(Model model, Problem problem, double category) : this(model, problem, category, "tmp.results") { }
169    /// <summary>
170    /// Constructor.
171    /// </summary>
172    /// <param name="model">Model to evaluate</param>
173    /// <param name="problem">Problem to evaluate</param>
174    /// <param name="resultsFile">Results file for output</param>
175    /// <param name="category">Category to evaluate for</param>
176    public PerformanceEvaluator(Model model, Problem problem, double category, string resultsFile) {
177      Prediction.Predict(problem, resultsFile, model, true);
178      parseResultsFile(resultsFile, problem.Y, category);
179
180      computeStatistics();
181    }
182
183    /// <summary>
184    /// Constructor.
185    /// </summary>
186    /// <param name="resultsFile">Results file</param>
187    /// <param name="correctLabels">The correct labels of each data item</param>
188    /// <param name="category">The category to evaluate for</param>
189    public PerformanceEvaluator(string resultsFile, double[] correctLabels, double category) {
190      parseResultsFile(resultsFile, correctLabels, category);
191      computeStatistics();
192    }
193
194    private void parseResultsFile(string resultsFile, double[] labels, double category) {
195      StreamReader input = new StreamReader(resultsFile);
196      string[] parts = input.ReadLine().Split(new char[] { ' ' }, StringSplitOptions.RemoveEmptyEntries);
197      int confidenceIndex = -1;
198      for (int i = 1; i < parts.Length; i++)
199        if (double.Parse(parts[i], CultureInfo.InvariantCulture) == category) {
200          confidenceIndex = i;
201          break;
202        }
203      _data = new List<RankPair>();
204      for (int i = 0; i < labels.Length; i++) {
205        parts = input.ReadLine().Split(new char[] { ' ' }, StringSplitOptions.RemoveEmptyEntries);
206        double confidence = double.Parse(parts[confidenceIndex], CultureInfo.InvariantCulture);
207        _data.Add(new RankPair(confidence, labels[i] == category ? 1 : 0));
208      }
209      input.Close();
210    }
211
212    private void computeStatistics() {
213      _data.Sort();
214
215      findChanges();
216      computePR();
217      computeRoC();
218    }
219
220    private void findChanges() {
221      int tp, fp, tn, fn;
222      tp = fp = tn = fn = 0;
223      for (int i = 0; i < _data.Count; i++) {
224        if (_data[i].Label == 1)
225          fn++;
226        else tn++;
227      }
228      _changes = new List<ChangePoint>();
229      for (int i = 0; i < _data.Count; i++) {
230        if (_data[i].Label == 1) {
231          tp++;
232          fn--;
233        } else {
234          fp++;
235          tn--;
236        }
237        _changes.Add(new ChangePoint(tp, fp, tn, fn));
238      }
239    }
240
241    private float computePrecision(ChangePoint p) {
242      return (float)p.TP / (p.TP + p.FP);
243    }
244
245    private float computeRecall(ChangePoint p) {
246      return (float)p.TP / (p.TP + p.FN);
247    }
248
249    private void computePR() {
250      _prCurve = new List<CurvePoint>();
251      _prCurve.Add(new CurvePoint(0, 1));
252      float precision = computePrecision(_changes[0]);
253      float recall = computeRecall(_changes[0]);
254      float precisionSum = 0;
255      if (_changes[0].TP > 0) {
256        precisionSum += precision;
257        _prCurve.Add(new CurvePoint(recall, precision));
258      }
259      for (int i = 1; i < _changes.Count; i++) {
260        precision = computePrecision(_changes[i]);
261        recall = computeRecall(_changes[i]);
262        if (_changes[i].TP > _changes[i - 1].TP) {
263          precisionSum += precision;
264          _prCurve.Add(new CurvePoint(recall, precision));
265        }
266      }
267      _prCurve.Add(new CurvePoint(1, (float)(_changes[0].TP + _changes[0].FN) / (_changes[0].FP + _changes[0].TN)));
268      _ap = precisionSum / (_changes[0].FN + _changes[0].TP);
269    }
270
271    /// <summary>
272    /// Writes the Precision-Recall curve to a tab-delimited file.
273    /// </summary>
274    /// <param name="filename">Filename for output</param>
275    public void WritePRCurve(string filename) {
276      StreamWriter output = new StreamWriter(filename);
277      output.WriteLine(_ap);
278      for (int i = 0; i < _prCurve.Count; i++)
279        output.WriteLine("{0}\t{1}", _prCurve[i].X, _prCurve[i].Y);
280      output.Close();
281    }
282
283    /// <summary>
284    /// Writes the Receiver Operating Characteristic curve to a tab-delimited file.
285    /// </summary>
286    /// <param name="filename">Filename for output</param>
287    public void WriteROCCurve(string filename) {
288      StreamWriter output = new StreamWriter(filename);
289      output.WriteLine(_auc);
290      for (int i = 0; i < _rocCurve.Count; i++)
291        output.WriteLine("{0}\t{1}", _rocCurve[i].X, _rocCurve[i].Y);
292      output.Close();
293    }
294
295    /// <summary>
296    /// Receiver Operating Characteristic curve
297    /// </summary>
298    public List<CurvePoint> ROCCurve {
299      get {
300        return _rocCurve;
301      }
302    }
303
304    /// <summary>
305    /// Returns the area under the ROC Curve
306    /// </summary>
307    public double AuC {
308      get {
309        return _auc;
310      }
311    }
312
313    /// <summary>
314    /// Precision-Recall curve
315    /// </summary>
316    public List<CurvePoint> PRCurve {
317      get {
318        return _prCurve;
319      }
320    }
321
322    /// <summary>
323    /// The average precision
324    /// </summary>
325    public double AP {
326      get {
327        return _ap;
328      }
329    }
330
331    private float computeTPR(ChangePoint cp) {
332      return computeRecall(cp);
333    }
334
335    private float computeFPR(ChangePoint cp) {
336      return (float)cp.FP / (cp.FP + cp.TN);
337    }
338
339    private void computeRoC() {
340      _rocCurve = new List<CurvePoint>();
341      _rocCurve.Add(new CurvePoint(0, 0));
342      float tpr = computeTPR(_changes[0]);
343      float fpr = computeFPR(_changes[0]);
344      _rocCurve.Add(new CurvePoint(fpr, tpr));
345      _auc = 0;
346      for (int i = 1; i < _changes.Count; i++) {
347        float newTPR = computeTPR(_changes[i]);
348        float newFPR = computeFPR(_changes[i]);
349        if (_changes[i].TP > _changes[i - 1].TP) {
350          _auc += tpr * (newFPR - fpr) + .5 * (newTPR - tpr) * (newFPR - fpr);
351          tpr = newTPR;
352          fpr = newFPR;
353          _rocCurve.Add(new CurvePoint(fpr, tpr));
354        }
355      }
356      _rocCurve.Add(new CurvePoint(1, 1));
357      _auc += tpr * (1 - fpr) + .5 * (1 - tpr) * (1 - fpr);
358    }
359
360  }
361}
Note: See TracBrowser for help on using the repository browser.