#region License Information
/* HeuristicLab
* Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections.Generic;
using System.Drawing;
using System.Linq;
using System.Text;
using System.Windows.Forms;
using System.Windows.Forms.DataVisualization.Charting;
using HeuristicLab.Common;
using HeuristicLab.MainForm;
using HeuristicLab.MainForm.WindowsForms;
namespace HeuristicLab.Problems.DataAnalysis.Classification.Views {
[View("Discriminant function classification solution ROC curves view")]
[Content(typeof(IDiscriminantFunctionClassificationSolution))]
public partial class DiscriminantFunctionClassificationRocCurvesView : AsynchronousContentView {
private const string xAxisTitle = "False Positive Rate";
private const string yAxisTitle = "True Positive Rate";
private const string TrainingSamples = "Training";
private const string TestSamples = "Test";
private Dictionary> cachedRocPoints;
public DiscriminantFunctionClassificationRocCurvesView() {
InitializeComponent();
cachedRocPoints = new Dictionary>();
cmbSamples.Items.Add(TrainingSamples);
cmbSamples.Items.Add(TestSamples);
cmbSamples.SelectedIndex = 0;
chart.CustomizeAllChartAreas();
chart.ChartAreas[0].AxisX.Minimum = 0.0;
chart.ChartAreas[0].AxisX.Maximum = 1.0;
chart.ChartAreas[0].AxisX.MajorGrid.Interval = 0.2;
chart.ChartAreas[0].AxisY.Minimum = 0.0;
chart.ChartAreas[0].AxisY.Maximum = 1.0;
chart.ChartAreas[0].AxisY.MajorGrid.Interval = 0.2;
chart.ChartAreas[0].AxisX.Title = xAxisTitle;
chart.ChartAreas[0].AxisY.Title = yAxisTitle;
}
public new IDiscriminantFunctionClassificationSolution Content {
get { return (IDiscriminantFunctionClassificationSolution)base.Content; }
set { base.Content = value; }
}
protected override void RegisterContentEvents() {
base.RegisterContentEvents();
Content.ModelChanged += new EventHandler(Content_ModelChanged);
Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
}
protected override void DeregisterContentEvents() {
base.DeregisterContentEvents();
Content.ModelChanged -= new EventHandler(Content_ModelChanged);
Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
}
private void Content_ModelChanged(object sender, EventArgs e) {
UpdateChart();
}
private void Content_ProblemDataChanged(object sender, EventArgs e) {
UpdateChart();
}
protected override void OnContentChanged() {
base.OnContentChanged();
chart.Series.Clear();
if (Content != null) UpdateChart();
}
private void UpdateChart() {
if (InvokeRequired) Invoke((Action)UpdateChart);
else {
chart.Series.Clear();
chart.Annotations.Clear();
cachedRocPoints.Clear();
int slices = 100;
IEnumerable rows;
if (cmbSamples.SelectedItem.ToString() == TrainingSamples) {
rows = Content.ProblemData.TrainingIndizes;
} else if (cmbSamples.SelectedItem.ToString() == TestSamples) {
rows = Content.ProblemData.TestIndizes;
} else throw new InvalidOperationException();
double[] estimatedValues = Content.GetEstimatedValues(rows).ToArray();
double[] targetClassValues = Content.ProblemData.Dataset.GetEnumeratedVariableValues(Content.ProblemData.TargetVariable, rows).ToArray();
double minThreshold = estimatedValues.Min();
double maxThreshold = estimatedValues.Max();
double thresholdIncrement = (maxThreshold - minThreshold) / slices;
minThreshold -= thresholdIncrement;
maxThreshold += thresholdIncrement;
List classValues = Content.ProblemData.ClassValues.OrderBy(x => x).ToList();
foreach (double classValue in classValues) {
List rocPoints = new List();
int positives = targetClassValues.Where(c => c.IsAlmost(classValue)).Count();
int negatives = targetClassValues.Length - positives;
for (double lowerThreshold = minThreshold; lowerThreshold < maxThreshold; lowerThreshold += thresholdIncrement) {
for (double upperThreshold = lowerThreshold + thresholdIncrement; upperThreshold < maxThreshold; upperThreshold += thresholdIncrement) {
//only adapt lower threshold for binary classification problems and upper class prediction
if (classValues.Count == 2 && classValue == classValues[1]) upperThreshold = double.PositiveInfinity;
int truePositives = 0;
int falsePositives = 0;
for (int row = 0; row < estimatedValues.Length; row++) {
if (lowerThreshold < estimatedValues[row] && estimatedValues[row] < upperThreshold) {
if (targetClassValues[row].IsAlmost(classValue)) truePositives++;
else falsePositives++;
}
}
double truePositiveRate = ((double)truePositives) / positives;
double falsePositiveRate = ((double)falsePositives) / negatives;
ROCPoint rocPoint = new ROCPoint(truePositiveRate, falsePositiveRate, lowerThreshold, upperThreshold);
if (!rocPoints.Any(x => x.truePositiveRate >= rocPoint.truePositiveRate && x.falsePositiveRate <= rocPoint.falsePositiveRate)) {
rocPoints.RemoveAll(x => x.falsePositiveRate >= rocPoint.falsePositiveRate && x.truePositiveRate <= rocPoint.truePositiveRate);
rocPoints.Add(rocPoint);
}
}
//only adapt upper threshold for binary classification problems and upper class prediction
if (classValues.Count == 2 && classValue == classValues[0]) lowerThreshold = double.PositiveInfinity;
}
string className = Content.ProblemData.ClassNames.ElementAt(classValues.IndexOf(classValue));
cachedRocPoints[className] = rocPoints.OrderBy(x => x.falsePositiveRate).ToList(); ;
Series series = new Series(className);
series.ChartType = SeriesChartType.Line;
series.MarkerStyle = MarkerStyle.Diamond;
series.MarkerSize = 5;
chart.Series.Add(series);
FillSeriesWithDataPoints(series, cachedRocPoints[className]);
double auc = CalculateAreaUnderCurve(series);
series.LegendToolTip = "AUC: " + auc;
}
}
}
private void FillSeriesWithDataPoints(Series series, IEnumerable rocPoints) {
series.Points.Add(new DataPoint(0, 0));
foreach (ROCPoint rocPoint in rocPoints) {
DataPoint point = new DataPoint();
point.XValue = rocPoint.falsePositiveRate;
point.YValues[0] = rocPoint.truePositiveRate;
point.Tag = rocPoint;
StringBuilder sb = new StringBuilder();
sb.AppendLine("True Positive Rate: " + rocPoint.truePositiveRate);
sb.AppendLine("False Positive Rate: " + rocPoint.falsePositiveRate);
sb.AppendLine("Upper Threshold: " + rocPoint.upperThreshold);
sb.AppendLine("Lower Threshold: " + rocPoint.lowerThreshold);
point.ToolTip = sb.ToString();
series.Points.Add(point);
}
series.Points.Add(new DataPoint(1, 1));
}
private double CalculateAreaUnderCurve(Series series) {
if (series.Points.Count < 1) throw new ArgumentException("Could not calculate area under curve if less than 1 data points were given.");
double auc = 0.0;
for (int i = 1; i < series.Points.Count; i++) {
double width = series.Points[i].XValue - series.Points[i - 1].XValue;
double y1 = series.Points[i - 1].YValues[0];
double y2 = series.Points[i].YValues[0];
auc += (y1 + y2) * width / 2;
}
return auc;
}
private void cmbSamples_SelectedIndexChanged(object sender, System.EventArgs e) {
if (Content != null)
UpdateChart();
}
#region show / hide series
private void ToggleSeries(Series series) {
if (series.Points.Count == 0)
FillSeriesWithDataPoints(series, cachedRocPoints[series.Name]);
else
series.Points.Clear();
}
private void chart_MouseDown(object sender, MouseEventArgs e) {
HitTestResult result = chart.HitTest(e.X, e.Y);
if (result.ChartElementType == ChartElementType.LegendItem) {
if (result.Series != null) ToggleSeries(result.Series);
}
}
private void chart_CustomizeLegend(object sender, CustomizeLegendEventArgs e) {
foreach (LegendItem legendItem in e.LegendItems) {
var series = chart.Series[legendItem.SeriesName];
if (series != null) {
bool seriesIsInvisible = series.Points.Count == 0;
foreach (LegendCell cell in legendItem.Cells)
cell.ForeColor = seriesIsInvisible ? Color.Gray : Color.Black;
}
}
}
private void chart_MouseMove(object sender, MouseEventArgs e) {
HitTestResult result = chart.HitTest(e.X, e.Y);
if (result.ChartElementType == ChartElementType.LegendItem)
this.Cursor = Cursors.Hand;
else
this.Cursor = Cursors.Default;
string newTooltipText = string.Empty;
if (result.ChartElementType == ChartElementType.DataPoint)
newTooltipText = ((DataPoint)result.Object).ToolTip;
string oldTooltipText = this.toolTip.GetToolTip(chart);
if (newTooltipText != oldTooltipText)
this.toolTip.SetToolTip(chart, newTooltipText);
}
#endregion
private class ROCPoint {
public ROCPoint(double truePositiveRate, double falsePositiveRate, double lowerThreshold, double upperThreshold) {
this.truePositiveRate = truePositiveRate;
this.falsePositiveRate = falsePositiveRate;
this.lowerThreshold = lowerThreshold;
this.upperThreshold = upperThreshold;
}
public double truePositiveRate { get; private set; }
public double falsePositiveRate { get; private set; }
public double lowerThreshold { get; private set; }
public double upperThreshold { get; private set; }
}
}
}