Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/DiscriminantFunctionClassificationRocCurvesView.cs @ 6614

Last change on this file since 6614 was 5975, checked in by mkommend, 14 years ago

#1313: Updated view names to use spaces instead of !camelCasing.

File size: 11.0 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Drawing;
25using System.Linq;
26using System.Text;
27using System.Windows.Forms;
28using System.Windows.Forms.DataVisualization.Charting;
29using HeuristicLab.Common;
30using HeuristicLab.Core.Views;
31using HeuristicLab.MainForm;
32using HeuristicLab.MainForm.WindowsForms;
33namespace HeuristicLab.Problems.DataAnalysis.Views {
34  [View("ROC Curves")]
35  [Content(typeof(IDiscriminantFunctionClassificationSolution))]
36  public partial class DiscriminantFunctionClassificationRocCurvesView : ItemView, IDiscriminantFunctionClassificationSolutionEvaluationView {
37    private const string xAxisTitle = "False Positive Rate";
38    private const string yAxisTitle = "True Positive Rate";
39    private const string TrainingSamples = "Training";
40    private const string TestSamples = "Test";
41    private Dictionary<string, List<ROCPoint>> cachedRocPoints;
42
43    public DiscriminantFunctionClassificationRocCurvesView() {
44      InitializeComponent();
45
46      cachedRocPoints = new Dictionary<string, List<ROCPoint>>();
47
48      cmbSamples.Items.Add(TrainingSamples);
49      cmbSamples.Items.Add(TestSamples);
50      cmbSamples.SelectedIndex = 0;
51
52      chart.CustomizeAllChartAreas();
53      chart.ChartAreas[0].AxisX.Minimum = 0.0;
54      chart.ChartAreas[0].AxisX.Maximum = 1.0;
55      chart.ChartAreas[0].AxisX.MajorGrid.Interval = 0.2;
56      chart.ChartAreas[0].AxisY.Minimum = 0.0;
57      chart.ChartAreas[0].AxisY.Maximum = 1.0;
58      chart.ChartAreas[0].AxisY.MajorGrid.Interval = 0.2;
59
60      chart.ChartAreas[0].AxisX.Title = xAxisTitle;
61      chart.ChartAreas[0].AxisY.Title = yAxisTitle;
62    }
63
64    public new IDiscriminantFunctionClassificationSolution Content {
65      get { return (IDiscriminantFunctionClassificationSolution)base.Content; }
66      set { base.Content = value; }
67    }
68
69    protected override void RegisterContentEvents() {
70      base.RegisterContentEvents();
71      Content.ModelChanged += new EventHandler(Content_ModelChanged);
72      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
73    }
74    protected override void DeregisterContentEvents() {
75      base.DeregisterContentEvents();
76      Content.ModelChanged -= new EventHandler(Content_ModelChanged);
77      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
78    }
79
80    private void Content_ModelChanged(object sender, EventArgs e) {
81      UpdateChart();
82    }
83    private void Content_ProblemDataChanged(object sender, EventArgs e) {
84      UpdateChart();
85    }
86
87    protected override void OnContentChanged() {
88      base.OnContentChanged();
89      chart.Series.Clear();
90      if (Content != null) UpdateChart();
91    }
92
93    private void UpdateChart() {
94      if (InvokeRequired) Invoke((Action)UpdateChart);
95      else {
96        chart.Series.Clear();
97        chart.Annotations.Clear();
98        cachedRocPoints.Clear();
99
100        int slices = 100;
101        IEnumerable<int> rows;
102
103        if (cmbSamples.SelectedItem.ToString() == TrainingSamples) {
104          rows = Content.ProblemData.TrainingIndizes;
105        } else if (cmbSamples.SelectedItem.ToString() == TestSamples) {
106          rows = Content.ProblemData.TestIndizes;
107        } else throw new InvalidOperationException();
108
109        double[] estimatedValues = Content.GetEstimatedValues(rows).ToArray();
110        double[] targetClassValues = Content.ProblemData.Dataset.GetEnumeratedVariableValues(Content.ProblemData.TargetVariable, rows).ToArray();
111        double minThreshold = estimatedValues.Min();
112        double maxThreshold = estimatedValues.Max();
113        double thresholdIncrement = (maxThreshold - minThreshold) / slices;
114        minThreshold -= thresholdIncrement;
115        maxThreshold += thresholdIncrement;
116
117        List<double> classValues = Content.ProblemData.ClassValues.OrderBy(x => x).ToList();
118
119        foreach (double classValue in classValues) {
120          List<ROCPoint> rocPoints = new List<ROCPoint>();
121          int positives = targetClassValues.Where(c => c.IsAlmost(classValue)).Count();
122          int negatives = targetClassValues.Length - positives;
123
124          for (double lowerThreshold = minThreshold; lowerThreshold < maxThreshold; lowerThreshold += thresholdIncrement) {
125            for (double upperThreshold = lowerThreshold + thresholdIncrement; upperThreshold < maxThreshold; upperThreshold += thresholdIncrement) {
126              //only adapt lower threshold for binary classification problems and upper class prediction             
127              if (classValues.Count == 2 && classValue == classValues[1]) upperThreshold = double.PositiveInfinity;
128
129              int truePositives = 0;
130              int falsePositives = 0;
131
132              for (int row = 0; row < estimatedValues.Length; row++) {
133                if (lowerThreshold < estimatedValues[row] && estimatedValues[row] < upperThreshold) {
134                  if (targetClassValues[row].IsAlmost(classValue)) truePositives++;
135                  else falsePositives++;
136                }
137              }
138
139              double truePositiveRate = ((double)truePositives) / positives;
140              double falsePositiveRate = ((double)falsePositives) / negatives;
141
142              ROCPoint rocPoint = new ROCPoint(truePositiveRate, falsePositiveRate, lowerThreshold, upperThreshold);
143              if (!rocPoints.Any(x => x.truePositiveRate >= rocPoint.truePositiveRate && x.falsePositiveRate <= rocPoint.falsePositiveRate)) {
144                rocPoints.RemoveAll(x => x.falsePositiveRate >= rocPoint.falsePositiveRate && x.truePositiveRate <= rocPoint.truePositiveRate);
145                rocPoints.Add(rocPoint);
146              }
147            }
148            //only adapt upper threshold for binary classification problems and upper class prediction             
149            if (classValues.Count == 2 && classValue == classValues[0]) lowerThreshold = double.PositiveInfinity;
150          }
151
152          string className = Content.ProblemData.ClassNames.ElementAt(classValues.IndexOf(classValue));
153          cachedRocPoints[className] = rocPoints.OrderBy(x => x.falsePositiveRate).ToList(); ;
154
155          Series series = new Series(className);
156          series.ChartType = SeriesChartType.Line;
157          series.MarkerStyle = MarkerStyle.Diamond;
158          series.MarkerSize = 5;
159          chart.Series.Add(series);
160          FillSeriesWithDataPoints(series, cachedRocPoints[className]);
161
162          double auc = CalculateAreaUnderCurve(series);
163          series.LegendToolTip = "AUC: " + auc;
164        }
165      }
166    }
167
168    private void FillSeriesWithDataPoints(Series series, IEnumerable<ROCPoint> rocPoints) {
169      series.Points.Add(new DataPoint(0, 0));
170      foreach (ROCPoint rocPoint in rocPoints) {
171        DataPoint point = new DataPoint();
172        point.XValue = rocPoint.falsePositiveRate;
173        point.YValues[0] = rocPoint.truePositiveRate;
174        point.Tag = rocPoint;
175
176        StringBuilder sb = new StringBuilder();
177        sb.AppendLine("True Positive Rate: " + rocPoint.truePositiveRate);
178        sb.AppendLine("False Positive Rate: " + rocPoint.falsePositiveRate);
179        sb.AppendLine("Upper Threshold: " + rocPoint.upperThreshold);
180        sb.AppendLine("Lower Threshold: " + rocPoint.lowerThreshold);
181        point.ToolTip = sb.ToString();
182
183        series.Points.Add(point);
184      }
185      series.Points.Add(new DataPoint(1, 1));
186    }
187
188    private double CalculateAreaUnderCurve(Series series) {
189      if (series.Points.Count < 1) throw new ArgumentException("Could not calculate area under curve if less than 1 data points were given.");
190
191      double auc = 0.0;
192      for (int i = 1; i < series.Points.Count; i++) {
193        double width = series.Points[i].XValue - series.Points[i - 1].XValue;
194        double y1 = series.Points[i - 1].YValues[0];
195        double y2 = series.Points[i].YValues[0];
196
197        auc += (y1 + y2) * width / 2;
198      }
199
200      return auc;
201    }
202
203    private void cmbSamples_SelectedIndexChanged(object sender, System.EventArgs e) {
204      if (Content != null)
205        UpdateChart();
206    }
207
208
209    #region show / hide series
210    private void ToggleSeries(Series series) {
211      if (series.Points.Count == 0)
212        FillSeriesWithDataPoints(series, cachedRocPoints[series.Name]);
213      else
214        series.Points.Clear();
215    }
216    private void chart_MouseDown(object sender, MouseEventArgs e) {
217      HitTestResult result = chart.HitTest(e.X, e.Y);
218      if (result.ChartElementType == ChartElementType.LegendItem) {
219        if (result.Series != null) ToggleSeries(result.Series);
220      }
221    }
222    private void chart_CustomizeLegend(object sender, CustomizeLegendEventArgs e) {
223      foreach (LegendItem legendItem in e.LegendItems) {
224        var series = chart.Series[legendItem.SeriesName];
225        if (series != null) {
226          bool seriesIsInvisible = series.Points.Count == 0;
227          foreach (LegendCell cell in legendItem.Cells)
228            cell.ForeColor = seriesIsInvisible ? Color.Gray : Color.Black;
229        }
230      }
231    }
232    private void chart_MouseMove(object sender, MouseEventArgs e) {
233      HitTestResult result = chart.HitTest(e.X, e.Y);
234      if (result.ChartElementType == ChartElementType.LegendItem)
235        this.Cursor = Cursors.Hand;
236      else
237        this.Cursor = Cursors.Default;
238
239      string newTooltipText = string.Empty;
240      if (result.ChartElementType == ChartElementType.DataPoint)
241        newTooltipText = ((DataPoint)result.Object).ToolTip;
242
243      string oldTooltipText = this.toolTip.GetToolTip(chart);
244      if (newTooltipText != oldTooltipText)
245        this.toolTip.SetToolTip(chart, newTooltipText);
246    }
247    #endregion
248
249
250    private class ROCPoint {
251      public ROCPoint(double truePositiveRate, double falsePositiveRate, double lowerThreshold, double upperThreshold) {
252        this.truePositiveRate = truePositiveRate;
253        this.falsePositiveRate = falsePositiveRate;
254        this.lowerThreshold = lowerThreshold;
255        this.upperThreshold = upperThreshold;
256
257      }
258      public double truePositiveRate { get; private set; }
259      public double falsePositiveRate { get; private set; }
260      public double lowerThreshold { get; private set; }
261      public double upperThreshold { get; private set; }
262    }
263
264  }
265}
Note: See TracBrowser for help on using the repository browser.