Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2701_MemPRAlgorithm/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/DiscriminantFunctionClassificationRocCurvesView.cs @ 16752

Last change on this file since 16752 was 14185, checked in by swagner, 8 years ago

#2526: Updated year of copyrights in license headers

File size: 10.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Drawing;
25using System.Linq;
26using System.Text;
27using System.Windows.Forms;
28using System.Windows.Forms.DataVisualization.Charting;
29using HeuristicLab.Common;
30using HeuristicLab.MainForm;
31using HeuristicLab.MainForm.WindowsForms;
32namespace HeuristicLab.Problems.DataAnalysis.Views {
33  [View("ROC Curves")]
34  [Content(typeof(IDiscriminantFunctionClassificationSolution))]
35  public partial class DiscriminantFunctionClassificationRocCurvesView : DataAnalysisSolutionEvaluationView {
36    private const string xAxisTitle = "False Positive Rate";
37    private const string yAxisTitle = "True Positive Rate";
38    private const string TrainingSamples = "Training";
39    private const string TestSamples = "Test";
40    private Dictionary<string, List<ROCPoint>> cachedRocPoints;
41
42    public DiscriminantFunctionClassificationRocCurvesView() {
43      InitializeComponent();
44
45      cachedRocPoints = new Dictionary<string, List<ROCPoint>>();
46
47      cmbSamples.Items.Add(TrainingSamples);
48      cmbSamples.Items.Add(TestSamples);
49      cmbSamples.SelectedIndex = 0;
50
51      chart.CustomizeAllChartAreas();
52      chart.ChartAreas[0].AxisX.Minimum = 0.0;
53      chart.ChartAreas[0].AxisX.Maximum = 1.0;
54      chart.ChartAreas[0].AxisX.MajorGrid.Interval = 0.2;
55      chart.ChartAreas[0].AxisY.Minimum = 0.0;
56      chart.ChartAreas[0].AxisY.Maximum = 1.0;
57      chart.ChartAreas[0].AxisY.MajorGrid.Interval = 0.2;
58
59      chart.ChartAreas[0].AxisX.Title = xAxisTitle;
60      chart.ChartAreas[0].AxisY.Title = yAxisTitle;
61    }
62
63    public new IDiscriminantFunctionClassificationSolution Content {
64      get { return (IDiscriminantFunctionClassificationSolution)base.Content; }
65      set { base.Content = value; }
66    }
67
68    protected override void RegisterContentEvents() {
69      base.RegisterContentEvents();
70      Content.ModelChanged += new EventHandler(Content_ModelChanged);
71      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
72    }
73    protected override void DeregisterContentEvents() {
74      base.DeregisterContentEvents();
75      Content.ModelChanged -= new EventHandler(Content_ModelChanged);
76      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
77    }
78
79    private void Content_ModelChanged(object sender, EventArgs e) {
80      UpdateChart();
81    }
82    private void Content_ProblemDataChanged(object sender, EventArgs e) {
83      UpdateChart();
84    }
85
86    protected override void OnContentChanged() {
87      base.OnContentChanged();
88      chart.Series.Clear();
89      if (Content != null) UpdateChart();
90    }
91
92    private void UpdateChart() {
93      if (InvokeRequired) Invoke((Action)UpdateChart);
94      else {
95        chart.Series.Clear();
96        chart.Annotations.Clear();
97        cachedRocPoints.Clear();
98
99        int slices = 100;
100        IEnumerable<int> rows;
101
102        if (cmbSamples.SelectedItem.ToString() == TrainingSamples) {
103          rows = Content.ProblemData.TrainingIndices;
104        } else if (cmbSamples.SelectedItem.ToString() == TestSamples) {
105          rows = Content.ProblemData.TestIndices;
106        } else throw new InvalidOperationException();
107
108        double[] estimatedValues = Content.GetEstimatedValues(rows).ToArray();
109        double[] targetClassValues = Content.ProblemData.Dataset.GetDoubleValues(Content.ProblemData.TargetVariable, rows).ToArray();
110        double minThreshold = estimatedValues.Min();
111        double maxThreshold = estimatedValues.Max();
112        double thresholdIncrement = (maxThreshold - minThreshold) / slices;
113        minThreshold -= thresholdIncrement;
114        maxThreshold += thresholdIncrement;
115
116        List<double> classValues = Content.ProblemData.ClassValues.ToList();
117
118        foreach (double classValue in classValues) {
119          List<ROCPoint> rocPoints = new List<ROCPoint>();
120          int positives = targetClassValues.Where(c => c.IsAlmost(classValue)).Count();
121          int negatives = targetClassValues.Length - positives;
122
123          for (double lowerThreshold = minThreshold; lowerThreshold < maxThreshold; lowerThreshold += thresholdIncrement) {
124            for (double upperThreshold = lowerThreshold + thresholdIncrement; upperThreshold < maxThreshold; upperThreshold += thresholdIncrement) {
125              //only adapt lower threshold for binary classification problems and upper class prediction             
126              if (classValues.Count == 2 && classValue == classValues[1]) upperThreshold = double.PositiveInfinity;
127
128              int truePositives = 0;
129              int falsePositives = 0;
130
131              for (int row = 0; row < estimatedValues.Length; row++) {
132                if (lowerThreshold < estimatedValues[row] && estimatedValues[row] < upperThreshold) {
133                  if (targetClassValues[row].IsAlmost(classValue)) truePositives++;
134                  else falsePositives++;
135                }
136              }
137
138              double truePositiveRate = ((double)truePositives) / positives;
139              double falsePositiveRate = ((double)falsePositives) / negatives;
140
141              ROCPoint rocPoint = new ROCPoint(truePositiveRate, falsePositiveRate, lowerThreshold, upperThreshold);
142              if (!rocPoints.Any(x => x.TruePositiveRate >= rocPoint.TruePositiveRate && x.FalsePositiveRate <= rocPoint.FalsePositiveRate)) {
143                rocPoints.RemoveAll(x => x.FalsePositiveRate >= rocPoint.FalsePositiveRate && x.TruePositiveRate <= rocPoint.TruePositiveRate);
144                rocPoints.Add(rocPoint);
145              }
146            }
147            //only adapt upper threshold for binary classification problems and upper class prediction             
148            if (classValues.Count == 2 && classValue == classValues[0]) lowerThreshold = double.PositiveInfinity;
149          }
150
151          string className = Content.ProblemData.ClassNames.ElementAt(classValues.IndexOf(classValue));
152          cachedRocPoints[className] = rocPoints.OrderBy(x => x.FalsePositiveRate).ToList(); ;
153
154          Series series = new Series(className);
155          series.ChartType = SeriesChartType.Line;
156          series.MarkerStyle = MarkerStyle.Diamond;
157          series.MarkerSize = 5;
158          chart.Series.Add(series);
159          FillSeriesWithDataPoints(series, cachedRocPoints[className]);
160
161          double auc = CalculateAreaUnderCurve(series);
162          series.LegendToolTip = "AUC: " + auc;
163        }
164      }
165    }
166
167    private void FillSeriesWithDataPoints(Series series, IEnumerable<ROCPoint> rocPoints) {
168      series.Points.Add(new DataPoint(0, 0));
169      foreach (ROCPoint rocPoint in rocPoints) {
170        DataPoint point = new DataPoint();
171        point.XValue = rocPoint.FalsePositiveRate;
172        point.YValues[0] = rocPoint.TruePositiveRate;
173        point.Tag = rocPoint;
174
175        StringBuilder sb = new StringBuilder();
176        sb.AppendLine("True Positive Rate: " + rocPoint.TruePositiveRate);
177        sb.AppendLine("False Positive Rate: " + rocPoint.FalsePositiveRate);
178        sb.AppendLine("Upper Threshold: " + rocPoint.UpperThreshold);
179        sb.AppendLine("Lower Threshold: " + rocPoint.LowerThreshold);
180        point.ToolTip = sb.ToString();
181
182        series.Points.Add(point);
183      }
184      series.Points.Add(new DataPoint(1, 1));
185    }
186
187    private double CalculateAreaUnderCurve(Series series) {
188      if (series.Points.Count < 1) throw new ArgumentException("Could not calculate area under curve if less than 1 data points were given.");
189
190      double auc = 0.0;
191      for (int i = 1; i < series.Points.Count; i++) {
192        double width = series.Points[i].XValue - series.Points[i - 1].XValue;
193        double y1 = series.Points[i - 1].YValues[0];
194        double y2 = series.Points[i].YValues[0];
195
196        auc += (y1 + y2) * width / 2;
197      }
198
199      return auc;
200    }
201
202    private void cmbSamples_SelectedIndexChanged(object sender, System.EventArgs e) {
203      if (Content != null)
204        UpdateChart();
205    }
206
207
208    #region show / hide series
209    private void ToggleSeries(Series series) {
210      if (series.Points.Count == 0)
211        FillSeriesWithDataPoints(series, cachedRocPoints[series.Name]);
212      else
213        series.Points.Clear();
214    }
215    private void chart_MouseDown(object sender, MouseEventArgs e) {
216      HitTestResult result = chart.HitTest(e.X, e.Y);
217      if (result.ChartElementType == ChartElementType.LegendItem) {
218        if (result.Series != null) ToggleSeries(result.Series);
219      }
220    }
221    private void chart_CustomizeLegend(object sender, CustomizeLegendEventArgs e) {
222      foreach (LegendItem legendItem in e.LegendItems) {
223        var series = chart.Series[legendItem.SeriesName];
224        if (series != null) {
225          bool seriesIsInvisible = series.Points.Count == 0;
226          foreach (LegendCell cell in legendItem.Cells)
227            cell.ForeColor = seriesIsInvisible ? Color.Gray : Color.Black;
228        }
229      }
230    }
231    private void chart_MouseMove(object sender, MouseEventArgs e) {
232      HitTestResult result = chart.HitTest(e.X, e.Y);
233      if (result.ChartElementType == ChartElementType.LegendItem)
234        this.Cursor = Cursors.Hand;
235      else
236        this.Cursor = Cursors.Default;
237
238      string newTooltipText = string.Empty;
239      if (result.ChartElementType == ChartElementType.DataPoint)
240        newTooltipText = ((DataPoint)result.Object).ToolTip;
241
242      string oldTooltipText = this.toolTip.GetToolTip(chart);
243      if (newTooltipText != oldTooltipText)
244        this.toolTip.SetToolTip(chart, newTooltipText);
245    }
246    #endregion
247
248
249    private class ROCPoint {
250      public ROCPoint(double truePositiveRate, double falsePositiveRate, double lowerThreshold, double upperThreshold) {
251        this.TruePositiveRate = truePositiveRate;
252        this.FalsePositiveRate = falsePositiveRate;
253        this.LowerThreshold = lowerThreshold;
254        this.UpperThreshold = upperThreshold;
255
256      }
257      public double TruePositiveRate { get; private set; }
258      public double FalsePositiveRate { get; private set; }
259      public double LowerThreshold { get; private set; }
260      public double UpperThreshold { get; private set; }
261    }
262
263  }
264}
Note: See TracBrowser for help on using the repository browser.