Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/LibSVM/PerformanceEvaluator.cs @ 2273

Last change on this file since 2273 was 1819, checked in by mkommend, 16 years ago

created new project for LibSVM source files (ticket #619)

File size: 12.5 KB
Line 
1/*
2 * SVM.NET Library
3 * Copyright (C) 2008 Matthew Johnson
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation, either version 3 of the License, or
8 * (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13 * GNU General Public License for more details.
14 *
15 * You should have received a copy of the GNU General Public License
16 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
17 */
18
19
20using System;
21using System.Collections.Generic;
22using System.IO;
23
24namespace SVM
25{
26    /// <remarks>
27    /// Class encoding a member of a ranked set of labels.
28    /// </remarks>
29    public class RankPair : IComparable<RankPair>
30    {
31        private double _score, _label;
32
33        /// <summary>
34        /// Constructor.
35        /// </summary>
36        /// <param name="score">Score for this pair</param>
37        /// <param name="label">Label associated with the given score</param>
38        public RankPair(double score, double label)
39        {
40            _score = score;
41            _label = label;
42        }
43
44        /// <summary>
45        /// The score for this pair.
46        /// </summary>
47        public double Score
48        {
49            get
50            {
51                return _score;
52            }
53        }
54
55        /// <summary>
56        /// The Label for this pair.
57        /// </summary>
58        public double Label
59        {
60            get
61            {
62                return _label;
63            }
64        }
65
66        #region IComparable<RankPair> Members
67
68        /// <summary>
69        /// Compares this pair to another.  It will end up in a sorted list in decending score order.
70        /// </summary>
71        /// <param name="other">The pair to compare to</param>
72        /// <returns>Whether this should come before or after the argument</returns>
73        public int CompareTo(RankPair other)
74        {
75            return other.Score.CompareTo(Score);
76        }
77
78        #endregion
79
80        /// <summary>
81        /// Returns a string representation of this pair.
82        /// </summary>
83        /// <returns>A string in the for Score:Label</returns>
84        public override string ToString()
85        {
86            return string.Format("{0}:{1}", Score, Label);
87        }
88    }
89
90    /// <summary>
91    /// Class encoding the point on a 2D curve.
92    /// </summary>
93    public class CurvePoint
94    {
95        private float _x, _y;
96
97        /// <summary>
98        /// Constructor.
99        /// </summary>
100        /// <param name="x">X coordinate</param>
101        /// <param name="y">Y coordinate</param>
102        public CurvePoint(float x, float y)
103        {
104            _x = x;
105            _y = y;
106        }
107
108        /// <summary>
109        /// X coordinate
110        /// </summary>
111        public float X
112        {
113            get
114            {
115                return _x;
116            }
117        }
118
119        /// <summary>
120        /// Y coordinate
121        /// </summary>
122        public float Y
123        {
124            get
125            {
126                return _y;
127            }
128        }
129
130        /// <summary>
131        /// Creates a string representation of this point.
132        /// </summary>
133        /// <returns>string in the form (x, y)</returns>
134        public override string ToString()
135        {
136            return string.Format("({0}, {1})", _x, _y);
137        }
138    }
139
140    /// <remarks>
141    /// Class which evaluates an SVM model using several standard techniques.
142    /// </remarks>
143    public class PerformanceEvaluator
144    {
145        private class ChangePoint
146        {
147            public ChangePoint(int tp, int fp, int tn, int fn)
148            {
149                TP = tp;
150                FP = fp;
151                TN = tn;
152                FN = fn;
153            }
154
155            public int TP, FP, TN, FN;
156
157            public override string ToString()
158            {
159                return string.Format("{0}:{1}:{2}:{3}", TP, FP, TN, FN);
160            }
161        }
162
163        private List<CurvePoint> _prCurve;
164        private double _ap;
165
166        private List<CurvePoint> _rocCurve;
167        private double _auc;
168
169        private List<RankPair> _data;
170        private List<ChangePoint> _changes;
171
172        /// <summary>
173        /// Constructor.
174        /// </summary>
175        /// <param name="set">A pre-computed ranked pair set</param>
176        public PerformanceEvaluator(List<RankPair> set)
177        {
178            _data = set;
179            computeStatistics();
180        }
181
182        /// <summary>
183        /// Constructor.
184        /// </summary>
185        /// <param name="model">Model to evaluate</param>
186        /// <param name="problem">Problem to evaluate</param>
187        /// <param name="label">Label to be evaluate for</param>
188        public PerformanceEvaluator(Model model, Problem problem, double label) : this(model, problem, label, "tmp.results") { }
189        /// <summary>
190        /// Constructor.
191        /// </summary>
192        /// <param name="model">Model to evaluate</param>
193        /// <param name="problem">Problem to evaluate</param>
194        /// <param name="resultsFile">Results file for output</param>
195        /// <param name="category">Category to evaluate for</param>
196        public PerformanceEvaluator(Model model, Problem problem, double category, string resultsFile)
197        {
198            Prediction.Predict(problem, resultsFile, model, true);
199            parseResultsFile(resultsFile, problem.Y, category);
200
201            computeStatistics();
202        }
203
204        /// <summary>
205        /// Constructor.
206        /// </summary>
207        /// <param name="resultsFile">Results file</param>
208        /// <param name="correctLabels">The correct labels of each data item</param>
209        /// <param name="category">The category to evaluate for</param>
210        public PerformanceEvaluator(string resultsFile, double[] correctLabels, double category)
211        {
212            parseResultsFile(resultsFile, correctLabels, category);
213            computeStatistics();
214        }
215
216        private void parseResultsFile(string resultsFile, double[] labels, double category)
217        {
218            StreamReader input = new StreamReader(resultsFile);
219            string[] parts = input.ReadLine().Split(new char[] { ' ' }, StringSplitOptions.RemoveEmptyEntries);
220            int confidenceIndex = -1;
221            for (int i = 1; i < parts.Length; i++)
222                if (double.Parse(parts[i]) == category)
223                {
224                    confidenceIndex = i;
225                    break;
226                }
227            _data = new List<RankPair>();
228            for (int i = 0; i < labels.Length; i++)
229            {
230                parts = input.ReadLine().Split(new char[] { ' ' }, StringSplitOptions.RemoveEmptyEntries);
231                double confidence = double.Parse(parts[confidenceIndex]);
232                _data.Add(new RankPair(confidence, labels[i]));
233            }
234            input.Close();
235        }
236
237        private void computeStatistics()
238        {
239            _data.Sort();
240
241            findChanges();
242            computePR();
243            computeRoC();
244        }
245
246        private void findChanges()
247        {
248            int tp, fp, tn, fn;
249            tp = fp = tn = fn = 0;
250            for (int i = 0; i < _data.Count; i++)
251            {
252                if (_data[i].Label == 1)
253                    fn++;
254                else tn++;
255            }
256            _changes = new List<ChangePoint>();
257            for (int i = 0; i < _data.Count; i++)
258            {
259                if (_data[i].Label == 1)
260                {
261                    tp++;
262                    fn--;
263                }
264                else
265                {
266                    fp++;
267                    tn--;
268                }
269                _changes.Add(new ChangePoint(tp, fp, tn, fn));
270            }
271        }
272
273        private float computePrecision(ChangePoint p)
274        {
275            return (float)p.TP / (p.TP + p.FP);
276        }
277
278        private float computeRecall(ChangePoint p)
279        {
280            return (float)p.TP / (p.TP + p.FN);
281        }
282
283        private void computePR()
284        {
285            _prCurve = new List<CurvePoint>();
286            _prCurve.Add(new CurvePoint(0, 1));
287            float precision = computePrecision(_changes[0]);
288            float recall = computeRecall(_changes[0]);
289            float precisionSum = 0;
290            if (_changes[0].TP > 0)
291            {
292                precisionSum += precision;
293                _prCurve.Add(new CurvePoint(recall, precision));
294            }
295            for (int i = 1; i < _changes.Count; i++)
296            {
297                precision = computePrecision(_changes[i]);
298                recall = computeRecall(_changes[i]);
299                if (_changes[i].TP > _changes[i - 1].TP)
300                {
301                    precisionSum += precision;
302                    _prCurve.Add(new CurvePoint(recall, precision));
303                }
304            }
305            _prCurve.Add(new CurvePoint(1, (float)(_changes[0].TP + _changes[0].FN) / (_changes[0].FP + _changes[0].TN)));
306            _ap = precisionSum / (_changes[0].FN + _changes[0].TP);
307        }
308
309        /// <summary>
310        /// Writes the Precision-Recall curve to a tab-delimited file.
311        /// </summary>
312        /// <param name="filename">Filename for output</param>
313        public void WritePRCurve(string filename)
314        {
315            StreamWriter output = new StreamWriter(filename);
316            output.WriteLine(_ap);
317            for (int i = 0; i < _prCurve.Count; i++)
318                output.WriteLine("{0}\t{1}", _prCurve[i].X, _prCurve[i].Y);
319            output.Close();
320        }
321
322        /// <summary>
323        /// Writes the Receiver Operating Characteristic curve to a tab-delimited file.
324        /// </summary>
325        /// <param name="filename">Filename for output</param>
326        public void WriteROCCurve(string filename)
327        {
328            StreamWriter output = new StreamWriter(filename);
329            output.WriteLine(_auc);
330            for (int i = 0; i < _rocCurve.Count; i++)
331                output.WriteLine("{0}\t{1}", _rocCurve[i].X, _rocCurve[i].Y);
332            output.Close();
333        }
334
335        /// <summary>
336        /// Receiver Operating Characteristic curve
337        /// </summary>
338        public List<CurvePoint> ROCCurve
339        {
340            get
341            {
342                return _rocCurve;
343            }
344        }
345
346        /// <summary>
347        /// Returns the area under the ROC Curve
348        /// </summary>
349        public double AuC
350        {
351            get
352            {
353                return _auc;
354            }
355        }
356
357        /// <summary>
358        /// Precision-Recall curve
359        /// </summary>
360        public List<CurvePoint> PRCurve
361        {
362            get
363            {
364                return _prCurve;
365            }
366        }
367
368        /// <summary>
369        /// The average precision
370        /// </summary>
371        public double AP
372        {
373            get
374            {
375                return _ap;
376            }
377        }
378
379        private float computeTPR(ChangePoint cp)
380        {
381            return computeRecall(cp);
382        }
383
384        private float computeFPR(ChangePoint cp)
385        {
386            return (float)cp.FP / (cp.FP + cp.TN);
387        }
388
389        private void computeRoC()
390        {
391            _rocCurve = new List<CurvePoint>();
392            _rocCurve.Add(new CurvePoint(0, 0));
393            float tpr = computeTPR(_changes[0]);
394            float fpr = computeFPR(_changes[0]);
395            _rocCurve.Add(new CurvePoint(fpr, tpr));
396            _auc = 0;
397            for (int i = 1; i < _changes.Count; i++)
398            {
399                float newTPR = computeTPR(_changes[i]);
400                float newFPR = computeFPR(_changes[i]);
401                if (_changes[i].TP > _changes[i - 1].TP)
402                {
403                    _auc += tpr * (newFPR - fpr) + .5 * (newTPR - tpr) * (newFPR - fpr);
404                    tpr = newTPR;
405                    fpr = newFPR;
406                    _rocCurve.Add(new CurvePoint(fpr, tpr));
407                }
408            }
409            _rocCurve.Add(new CurvePoint(1, 1));
410            _auc += tpr * (1 - fpr) + .5 * (1 - tpr) * (1 - fpr);
411        }
412
413    }
414}
Note: See TracBrowser for help on using the repository browser.