Free cookie consent management tool by TermsFeed Policy Generator

source: branches/Persistence Test/LibSVM/PerformanceEvaluator.cs @ 3380

Last change on this file since 3380 was 2415, checked in by gkronber, 15 years ago

Updated LibSVM project to latest version. #774

File size: 12.6 KB
Line 
1/*
2 * SVM.NET Library
3 * Copyright (C) 2008 Matthew Johnson
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation, either version 3 of the License, or
8 * (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU General Public License for more details.
14 *
15 * You should have received a copy of the GNU General Public License
16 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19
20using System;
21using System.Collections.Generic;
22using System.IO;
23using System.Globalization;
24
25namespace SVM
26{
27    /// <summary>
28    /// Class encoding a member of a ranked set of labels.
29    /// </summary>
30    public class RankPair : IComparable<RankPair>
31    {
32        private double _score, _label;
33
34        /// <summary>
35        /// Constructor.
36        /// </summary>
37        /// <param name="score">Score for this pair</param>
38        /// <param name="label">Label associated with the given score</param>
39        public RankPair(double score, double label)
40        {
41            _score = score;
42            _label = label;
43        }
44
45        /// <summary>
46        /// The score for this pair.
47        /// </summary>
48        public double Score
49        {
50            get
51            {
52                return _score;
53            }
54        }
55
56        /// <summary>
57        /// The Label for this pair.
58        /// </summary>
59        public double Label
60        {
61            get
62            {
63                return _label;
64            }
65        }
66
67        #region IComparable<RankPair> Members
68
69        /// <summary>
70        /// Compares this pair to another.  It will end up in a sorted list in decending score order.
71        /// </summary>
72        /// <param name="other">The pair to compare to</param>
73        /// <returns>Whether this should come before or after the argument</returns>
74        public int CompareTo(RankPair other)
75        {
76            return other.Score.CompareTo(Score);
77        }
78
79        #endregion
80
81        /// <summary>
82        /// Returns a string representation of this pair.
83        /// </summary>
84        /// <returns>A string in the for Score:Label</returns>
85        public override string ToString()
86        {
87            return string.Format("{0}:{1}", Score, Label);
88        }
89    }
90
91    /// <summary>
92    /// Class encoding the point on a 2D curve.
93    /// </summary>
94    public class CurvePoint
95    {
96        private float _x, _y;
97
98        /// <summary>
99        /// Constructor.
100        /// </summary>
101        /// <param name="x">X coordinate</param>
102        /// <param name="y">Y coordinate</param>
103        public CurvePoint(float x, float y)
104        {
105            _x = x;
106            _y = y;
107        }
108
109        /// <summary>
110        /// X coordinate
111        /// </summary>
112        public float X
113        {
114            get
115            {
116                return _x;
117            }
118        }
119
120        /// <summary>
121        /// Y coordinate
122        /// </summary>
123        public float Y
124        {
125            get
126            {
127                return _y;
128            }
129        }
130
131        /// <summary>
132        /// Creates a string representation of this point.
133        /// </summary>
134        /// <returns>string in the form (x, y)</returns>
135        public override string ToString()
136        {
137            return string.Format("({0}, {1})", _x, _y);
138        }
139    }
140
141    /// <summary>
142    /// Class which evaluates an SVM model using several standard techniques.
143    /// </summary>
144    public class PerformanceEvaluator
145    {
146        private class ChangePoint
147        {
148            public ChangePoint(int tp, int fp, int tn, int fn)
149            {
150                TP = tp;
151                FP = fp;
152                TN = tn;
153                FN = fn;
154            }
155
156            public int TP, FP, TN, FN;
157
158            public override string ToString()
159            {
160                return string.Format("{0}:{1}:{2}:{3}", TP, FP, TN, FN);
161            }
162        }
163
164        private List<CurvePoint> _prCurve;
165        private double _ap;
166
167        private List<CurvePoint> _rocCurve;
168        private double _auc;
169
170        private List<RankPair> _data;
171        private List<ChangePoint> _changes;
172
173        /// <summary>
174        /// Constructor.
175        /// </summary>
176        /// <param name="set">A pre-computed ranked pair set</param>
177        public PerformanceEvaluator(List<RankPair> set)
178        {
179            _data = set;
180            computeStatistics();
181        }
182
183        /// <summary>
184        /// Constructor.
185        /// </summary>
186        /// <param name="model">Model to evaluate</param>
187        /// <param name="problem">Problem to evaluate</param>
188        /// <param name="category">Label to be evaluate for</param>
189        public PerformanceEvaluator(Model model, Problem problem, double category) : this(model, problem, category, "tmp.results") { }
190        /// <summary>
191        /// Constructor.
192        /// </summary>
193        /// <param name="model">Model to evaluate</param>
194        /// <param name="problem">Problem to evaluate</param>
195        /// <param name="resultsFile">Results file for output</param>
196        /// <param name="category">Category to evaluate for</param>
197        public PerformanceEvaluator(Model model, Problem problem, double category, string resultsFile)
198        {
199            Prediction.Predict(problem, resultsFile, model, true);
200            parseResultsFile(resultsFile, problem.Y, category);
201
202            computeStatistics();
203        }
204
205        /// <summary>
206        /// Constructor.
207        /// </summary>
208        /// <param name="resultsFile">Results file</param>
209        /// <param name="correctLabels">The correct labels of each data item</param>
210        /// <param name="category">The category to evaluate for</param>
211        public PerformanceEvaluator(string resultsFile, double[] correctLabels, double category)
212        {
213            parseResultsFile(resultsFile, correctLabels, category);
214            computeStatistics();
215        }
216
217        private void parseResultsFile(string resultsFile, double[] labels, double category)
218        {
219            StreamReader input = new StreamReader(resultsFile);
220            string[] parts = input.ReadLine().Split(new char[] { ' ' }, StringSplitOptions.RemoveEmptyEntries);
221            int confidenceIndex = -1;
222            for (int i = 1; i < parts.Length; i++)
223                if (double.Parse(parts[i], CultureInfo.InvariantCulture) == category)
224                {
225                    confidenceIndex = i;
226                    break;
227                }
228            _data = new List<RankPair>();
229            for (int i = 0; i < labels.Length; i++)
230            {
231                parts = input.ReadLine().Split(new char[] { ' ' }, StringSplitOptions.RemoveEmptyEntries);
232                double confidence = double.Parse(parts[confidenceIndex], CultureInfo.InvariantCulture);
233                _data.Add(new RankPair(confidence, labels[i] == category ? 1 : 0));
234            }
235            input.Close();
236        }
237
238        private void computeStatistics()
239        {
240            _data.Sort();
241
242            findChanges();
243            computePR();
244            computeRoC();
245        }
246
247        private void findChanges()
248        {
249            int tp, fp, tn, fn;
250            tp = fp = tn = fn = 0;
251            for (int i = 0; i < _data.Count; i++)
252            {
253                if (_data[i].Label == 1)
254                    fn++;
255                else tn++;
256            }
257            _changes = new List<ChangePoint>();
258            for (int i = 0; i < _data.Count; i++)
259            {
260                if (_data[i].Label == 1)
261                {
262                    tp++;
263                    fn--;
264                }
265                else
266                {
267                    fp++;
268                    tn--;
269                }
270                _changes.Add(new ChangePoint(tp, fp, tn, fn));
271            }
272        }
273
274        private float computePrecision(ChangePoint p)
275        {
276            return (float)p.TP / (p.TP + p.FP);
277        }
278
279        private float computeRecall(ChangePoint p)
280        {
281            return (float)p.TP / (p.TP + p.FN);
282        }
283
284        private void computePR()
285        {
286            _prCurve = new List<CurvePoint>();
287            _prCurve.Add(new CurvePoint(0, 1));
288            float precision = computePrecision(_changes[0]);
289            float recall = computeRecall(_changes[0]);
290            float precisionSum = 0;
291            if (_changes[0].TP > 0)
292            {
293                precisionSum += precision;
294                _prCurve.Add(new CurvePoint(recall, precision));
295            }
296            for (int i = 1; i < _changes.Count; i++)
297            {
298                precision = computePrecision(_changes[i]);
299                recall = computeRecall(_changes[i]);
300                if (_changes[i].TP > _changes[i - 1].TP)
301                {
302                    precisionSum += precision;
303                    _prCurve.Add(new CurvePoint(recall, precision));
304                }
305            }
306            _prCurve.Add(new CurvePoint(1, (float)(_changes[0].TP + _changes[0].FN) / (_changes[0].FP + _changes[0].TN)));
307            _ap = precisionSum / (_changes[0].FN + _changes[0].TP);
308        }
309
310        /// <summary>
311        /// Writes the Precision-Recall curve to a tab-delimited file.
312        /// </summary>
313        /// <param name="filename">Filename for output</param>
314        public void WritePRCurve(string filename)
315        {
316            StreamWriter output = new StreamWriter(filename);
317            output.WriteLine(_ap);
318            for (int i = 0; i < _prCurve.Count; i++)
319                output.WriteLine("{0}\t{1}", _prCurve[i].X, _prCurve[i].Y);
320            output.Close();
321        }
322
323        /// <summary>
324        /// Writes the Receiver Operating Characteristic curve to a tab-delimited file.
325        /// </summary>
326        /// <param name="filename">Filename for output</param>
327        public void WriteROCCurve(string filename)
328        {
329            StreamWriter output = new StreamWriter(filename);
330            output.WriteLine(_auc);
331            for (int i = 0; i < _rocCurve.Count; i++)
332                output.WriteLine("{0}\t{1}", _rocCurve[i].X, _rocCurve[i].Y);
333            output.Close();
334        }
335
336        /// <summary>
337        /// Receiver Operating Characteristic curve
338        /// </summary>
339        public List<CurvePoint> ROCCurve
340        {
341            get
342            {
343                return _rocCurve;
344            }
345        }
346
347        /// <summary>
348        /// Returns the area under the ROC Curve
349        /// </summary>
350        public double AuC
351        {
352            get
353            {
354                return _auc;
355            }
356        }
357
358        /// <summary>
359        /// Precision-Recall curve
360        /// </summary>
361        public List<CurvePoint> PRCurve
362        {
363            get
364            {
365                return _prCurve;
366            }
367        }
368
369        /// <summary>
370        /// The average precision
371        /// </summary>
372        public double AP
373        {
374            get
375            {
376                return _ap;
377            }
378        }
379
380        private float computeTPR(ChangePoint cp)
381        {
382            return computeRecall(cp);
383        }
384
385        private float computeFPR(ChangePoint cp)
386        {
387            return (float)cp.FP / (cp.FP + cp.TN);
388        }
389
390        private void computeRoC()
391        {
392            _rocCurve = new List<CurvePoint>();
393            _rocCurve.Add(new CurvePoint(0, 0));
394            float tpr = computeTPR(_changes[0]);
395            float fpr = computeFPR(_changes[0]);
396            _rocCurve.Add(new CurvePoint(fpr, tpr));
397            _auc = 0;
398            for (int i = 1; i < _changes.Count; i++)
399            {
400                float newTPR = computeTPR(_changes[i]);
401                float newFPR = computeFPR(_changes[i]);
402                if (_changes[i].TP > _changes[i - 1].TP)
403                {
404                    _auc += tpr * (newFPR - fpr) + .5 * (newTPR - tpr) * (newFPR - fpr);
405                    tpr = newTPR;
406                    fpr = newFPR;
407                    _rocCurve.Add(new CurvePoint(fpr, tpr));
408                }
409            }
410            _rocCurve.Add(new CurvePoint(1, 1));
411            _auc += tpr * (1 - fpr) + .5 * (1 - tpr) * (1 - fpr);
412        }
413
414    }
415}
Note: See TracBrowser for help on using the repository browser.