/*
* SVM.NET Library
* Copyright (C) 2008 Matthew Johnson
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*/
using System;
using System.Collections.Generic;
using System.IO;
using System.Globalization;
namespace SVM
{
///
/// Class encoding a member of a ranked set of labels.
///
public class RankPair : IComparable
{
private double _score, _label;
///
/// Constructor.
///
/// Score for this pair
/// Label associated with the given score
public RankPair(double score, double label)
{
_score = score;
_label = label;
}
///
/// The score for this pair.
///
public double Score
{
get
{
return _score;
}
}
///
/// The Label for this pair.
///
public double Label
{
get
{
return _label;
}
}
#region IComparable Members
///
/// Compares this pair to another. It will end up in a sorted list in decending score order.
///
/// The pair to compare to
/// Whether this should come before or after the argument
public int CompareTo(RankPair other)
{
return other.Score.CompareTo(Score);
}
#endregion
///
/// Returns a string representation of this pair.
///
/// A string in the for Score:Label
public override string ToString()
{
return string.Format("{0}:{1}", Score, Label);
}
}
///
/// Class encoding the point on a 2D curve.
///
public class CurvePoint
{
private float _x, _y;
///
/// Constructor.
///
/// X coordinate
/// Y coordinate
public CurvePoint(float x, float y)
{
_x = x;
_y = y;
}
///
/// X coordinate
///
public float X
{
get
{
return _x;
}
}
///
/// Y coordinate
///
public float Y
{
get
{
return _y;
}
}
///
/// Creates a string representation of this point.
///
/// string in the form (x, y)
public override string ToString()
{
return string.Format("({0}, {1})", _x, _y);
}
}
///
/// Class which evaluates an SVM model using several standard techniques.
///
public class PerformanceEvaluator
{
private class ChangePoint
{
public ChangePoint(int tp, int fp, int tn, int fn)
{
TP = tp;
FP = fp;
TN = tn;
FN = fn;
}
public int TP, FP, TN, FN;
public override string ToString()
{
return string.Format("{0}:{1}:{2}:{3}", TP, FP, TN, FN);
}
}
private List _prCurve;
private double _ap;
private List _rocCurve;
private double _auc;
private List _data;
private List _changes;
///
/// Constructor.
///
/// A pre-computed ranked pair set
public PerformanceEvaluator(List set)
{
_data = set;
computeStatistics();
}
///
/// Constructor.
///
/// Model to evaluate
/// Problem to evaluate
/// Label to be evaluate for
public PerformanceEvaluator(Model model, Problem problem, double category) : this(model, problem, category, "tmp.results") { }
///
/// Constructor.
///
/// Model to evaluate
/// Problem to evaluate
/// Results file for output
/// Category to evaluate for
public PerformanceEvaluator(Model model, Problem problem, double category, string resultsFile)
{
Prediction.Predict(problem, resultsFile, model, true);
parseResultsFile(resultsFile, problem.Y, category);
computeStatistics();
}
///
/// Constructor.
///
/// Results file
/// The correct labels of each data item
/// The category to evaluate for
public PerformanceEvaluator(string resultsFile, double[] correctLabels, double category)
{
parseResultsFile(resultsFile, correctLabels, category);
computeStatistics();
}
private void parseResultsFile(string resultsFile, double[] labels, double category)
{
StreamReader input = new StreamReader(resultsFile);
string[] parts = input.ReadLine().Split(new char[] { ' ' }, StringSplitOptions.RemoveEmptyEntries);
int confidenceIndex = -1;
for (int i = 1; i < parts.Length; i++)
if (double.Parse(parts[i], CultureInfo.InvariantCulture) == category)
{
confidenceIndex = i;
break;
}
_data = new List();
for (int i = 0; i < labels.Length; i++)
{
parts = input.ReadLine().Split(new char[] { ' ' }, StringSplitOptions.RemoveEmptyEntries);
double confidence = double.Parse(parts[confidenceIndex], CultureInfo.InvariantCulture);
_data.Add(new RankPair(confidence, labels[i] == category ? 1 : 0));
}
input.Close();
}
private void computeStatistics()
{
_data.Sort();
findChanges();
computePR();
computeRoC();
}
private void findChanges()
{
int tp, fp, tn, fn;
tp = fp = tn = fn = 0;
for (int i = 0; i < _data.Count; i++)
{
if (_data[i].Label == 1)
fn++;
else tn++;
}
_changes = new List();
for (int i = 0; i < _data.Count; i++)
{
if (_data[i].Label == 1)
{
tp++;
fn--;
}
else
{
fp++;
tn--;
}
_changes.Add(new ChangePoint(tp, fp, tn, fn));
}
}
private float computePrecision(ChangePoint p)
{
return (float)p.TP / (p.TP + p.FP);
}
private float computeRecall(ChangePoint p)
{
return (float)p.TP / (p.TP + p.FN);
}
private void computePR()
{
_prCurve = new List();
_prCurve.Add(new CurvePoint(0, 1));
float precision = computePrecision(_changes[0]);
float recall = computeRecall(_changes[0]);
float precisionSum = 0;
if (_changes[0].TP > 0)
{
precisionSum += precision;
_prCurve.Add(new CurvePoint(recall, precision));
}
for (int i = 1; i < _changes.Count; i++)
{
precision = computePrecision(_changes[i]);
recall = computeRecall(_changes[i]);
if (_changes[i].TP > _changes[i - 1].TP)
{
precisionSum += precision;
_prCurve.Add(new CurvePoint(recall, precision));
}
}
_prCurve.Add(new CurvePoint(1, (float)(_changes[0].TP + _changes[0].FN) / (_changes[0].FP + _changes[0].TN)));
_ap = precisionSum / (_changes[0].FN + _changes[0].TP);
}
///
/// Writes the Precision-Recall curve to a tab-delimited file.
///
/// Filename for output
public void WritePRCurve(string filename)
{
StreamWriter output = new StreamWriter(filename);
output.WriteLine(_ap);
for (int i = 0; i < _prCurve.Count; i++)
output.WriteLine("{0}\t{1}", _prCurve[i].X, _prCurve[i].Y);
output.Close();
}
///
/// Writes the Receiver Operating Characteristic curve to a tab-delimited file.
///
/// Filename for output
public void WriteROCCurve(string filename)
{
StreamWriter output = new StreamWriter(filename);
output.WriteLine(_auc);
for (int i = 0; i < _rocCurve.Count; i++)
output.WriteLine("{0}\t{1}", _rocCurve[i].X, _rocCurve[i].Y);
output.Close();
}
///
/// Receiver Operating Characteristic curve
///
public List ROCCurve
{
get
{
return _rocCurve;
}
}
///
/// Returns the area under the ROC Curve
///
public double AuC
{
get
{
return _auc;
}
}
///
/// Precision-Recall curve
///
public List PRCurve
{
get
{
return _prCurve;
}
}
///
/// The average precision
///
public double AP
{
get
{
return _ap;
}
}
private float computeTPR(ChangePoint cp)
{
return computeRecall(cp);
}
private float computeFPR(ChangePoint cp)
{
return (float)cp.FP / (cp.FP + cp.TN);
}
private void computeRoC()
{
_rocCurve = new List();
_rocCurve.Add(new CurvePoint(0, 0));
float tpr = computeTPR(_changes[0]);
float fpr = computeFPR(_changes[0]);
_rocCurve.Add(new CurvePoint(fpr, tpr));
_auc = 0;
for (int i = 1; i < _changes.Count; i++)
{
float newTPR = computeTPR(_changes[i]);
float newFPR = computeFPR(_changes[i]);
if (_changes[i].TP > _changes[i - 1].TP)
{
_auc += tpr * (newFPR - fpr) + .5 * (newTPR - tpr) * (newFPR - fpr);
tpr = newTPR;
fpr = newFPR;
_rocCurve.Add(new CurvePoint(fpr, tpr));
}
}
_rocCurve.Add(new CurvePoint(1, 1));
_auc += tpr * (1 - fpr) + .5 * (1 - tpr) * (1 - fpr);
}
}
}