Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Classification/HeuristicLab.Problems.DataAnalysis.Classification.Views/3.3/RocCurvesView.cs @ 4542

Last change on this file since 4542 was 4469, checked in by mkommend, 14 years ago

Added logic to remove the test samples from the training samples (ticket #939).

File size: 10.4 KB
RevLine 
[4417]1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Drawing;
25using System.Linq;
26using System.Text;
27using System.Windows.Forms;
28using System.Windows.Forms.DataVisualization.Charting;
29using HeuristicLab.Common;
30using HeuristicLab.MainForm;
31using HeuristicLab.MainForm.WindowsForms;
32namespace HeuristicLab.Problems.DataAnalysis.Classification.Views {
33  [View("ROC Curves View")]
34  [Content(typeof(SymbolicClassificationSolution))]
35  public partial class RocCurvesView : AsynchronousContentView {
36    private const string xAxisTitle = "False Positive Rate";
37    private const string yAxisTitle = "True Positive Rate";
38    private const string TrainingSamples = "Training";
39    private const string TestSamples = "Test";
40    private Dictionary<string, List<ROCPoint>> cachedRocPoints;
41
42    public RocCurvesView() {
43      InitializeComponent();
44
45      cachedRocPoints = new Dictionary<string, List<ROCPoint>>();
46
47      cmbSamples.Items.Add(TrainingSamples);
48      cmbSamples.Items.Add(TestSamples);
49      cmbSamples.SelectedIndex = 0;
50
51      chart.ChartAreas[0].AxisX.Minimum = 0.0;
52      chart.ChartAreas[0].AxisX.Maximum = 1.0;
53      chart.ChartAreas[0].AxisX.MajorGrid.Interval = 0.2;
54      chart.ChartAreas[0].AxisY.Minimum = 0.0;
55      chart.ChartAreas[0].AxisY.Maximum = 1.0;
56      chart.ChartAreas[0].AxisY.MajorGrid.Interval = 0.2;
57
58      chart.ChartAreas[0].AxisX.Title = xAxisTitle;
59      chart.ChartAreas[0].AxisY.Title = yAxisTitle;
60    }
61
62    public new SymbolicClassificationSolution Content {
63      get { return (SymbolicClassificationSolution)base.Content; }
64      set { base.Content = value; }
65    }
66
67    protected override void RegisterContentEvents() {
68      base.RegisterContentEvents();
69      Content.EstimatedValuesChanged += new EventHandler(Content_EstimatedValuesChanged);
70      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
71    }
72    protected override void DeregisterContentEvents() {
73      base.DeregisterContentEvents();
74      Content.EstimatedValuesChanged -= new EventHandler(Content_EstimatedValuesChanged);
75      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
76    }
77
78    private void Content_EstimatedValuesChanged(object sender, EventArgs e) {
79      UpdateChart();
80    }
81    private void Content_ProblemDataChanged(object sender, EventArgs e) {
82      UpdateChart();
83    }
84
85    protected override void OnContentChanged() {
86      base.OnContentChanged();
87      chart.Series.Clear();
88      if (Content != null) UpdateChart();
89    }
90
91    private void UpdateChart() {
92      if (InvokeRequired) Invoke((Action)UpdateChart);
93      else {
94        chart.Series.Clear();
95        chart.Annotations.Clear();
96        cachedRocPoints.Clear();
97
98        int slices = 100;
[4469]99        IEnumerable<int> rows;
[4417]100
101        if (cmbSamples.SelectedItem.ToString() == TrainingSamples) {
[4469]102          rows = Content.ProblemData.TrainingIndizes;
[4417]103        } else if (cmbSamples.SelectedItem.ToString() == TestSamples) {
[4469]104          rows = Content.ProblemData.TestIndizes;
[4417]105        } else throw new InvalidOperationException();
106
[4469]107        double[] estimatedValues = Content.GetEstimatedValues(rows).ToArray();
108        double[] targetClassValues = Content.ProblemData.Dataset.GetEnumeratedVariableValues(Content.ProblemData.TargetVariable.Value, rows).ToArray();
[4417]109        double minThreshold = estimatedValues.Min();
110        double maxThreshold = estimatedValues.Max();
111        double thresholdIncrement = (maxThreshold - minThreshold) / slices;
112        minThreshold -= thresholdIncrement;
113        maxThreshold += thresholdIncrement;
114
115        List<double> classValues = Content.ProblemData.SortedClassValues.ToList();
116
117        foreach (double classValue in classValues) {
118          List<ROCPoint> rocPoints = new List<ROCPoint>();
119          int positives = targetClassValues.Where(c => c.IsAlmost(classValue)).Count();
[4469]120          int negatives = targetClassValues.Length - positives;
[4417]121
122          for (double lowerThreshold = minThreshold; lowerThreshold < maxThreshold; lowerThreshold += thresholdIncrement) {
123            for (double upperThreshold = lowerThreshold + thresholdIncrement; upperThreshold < maxThreshold; upperThreshold += thresholdIncrement) {
124              int truePositives = 0;
125              int falsePositives = 0;
126
127              for (int row = 0; row < estimatedValues.Length; row++) {
128                if (lowerThreshold < estimatedValues[row] && estimatedValues[row] < upperThreshold) {
129                  if (targetClassValues[row].IsAlmost(classValue)) truePositives++;
130                  else falsePositives++;
131                }
132              }
133
134              double truePositiveRate = ((double)truePositives) / positives;
135              double falsePositiveRate = ((double)falsePositives) / negatives;
136
137              ROCPoint rocPoint = new ROCPoint(truePositiveRate, falsePositiveRate, lowerThreshold, upperThreshold);
138              if (!rocPoints.Any(x => x.truePositiveRate >= rocPoint.truePositiveRate && x.falsePositiveRate <= rocPoint.falsePositiveRate)) {
139                rocPoints.RemoveAll(x => x.falsePositiveRate >= rocPoint.falsePositiveRate && x.truePositiveRate <= rocPoint.truePositiveRate);
140                rocPoints.Add(rocPoint);
141              }
142            }
143          }
144
145          string className = Content.ProblemData.ClassNames.ElementAt(classValues.IndexOf(classValue));
146          cachedRocPoints[className] = rocPoints.OrderBy(x => x.falsePositiveRate).ToList(); ;
147
148          Series series = new Series(className);
149          series.ChartType = SeriesChartType.Line;
150          series.MarkerStyle = MarkerStyle.Diamond;
151          series.MarkerSize = 5;
152          chart.Series.Add(series);
153          FillSeriesWithDataPoints(series, cachedRocPoints[className]);
154
155          double auc = CalculateAreaUnderCurve(series);
156          series.LegendToolTip = "AUC: " + auc;
157        }
158      }
159    }
160
161    private void FillSeriesWithDataPoints(Series series, IEnumerable<ROCPoint> rocPoints) {
162      series.Points.Add(new DataPoint(0, 0));
163      foreach (ROCPoint rocPoint in rocPoints) {
164        DataPoint point = new DataPoint();
165        point.XValue = rocPoint.falsePositiveRate;
166        point.YValues[0] = rocPoint.truePositiveRate;
167        point.Tag = rocPoint;
168
169        StringBuilder sb = new StringBuilder();
170        sb.AppendLine("True Positive Rate: " + rocPoint.truePositiveRate);
171        sb.AppendLine("False Positive Rate: " + rocPoint.falsePositiveRate);
172        sb.AppendLine("Upper Threshold: " + rocPoint.upperThreshold);
173        sb.AppendLine("Lower Threshold: " + rocPoint.lowerThreshold);
174        point.ToolTip = sb.ToString();
175
176        series.Points.Add(point);
177      }
178      series.Points.Add(new DataPoint(1, 1));
179    }
180
181    private double CalculateAreaUnderCurve(Series series) {
182      if (series.Points.Count < 1) throw new ArgumentException("Could not calculate area under curve if less than 1 data points were given.");
183
184      double auc = 0.0;
185      for (int i = 1; i < series.Points.Count; i++) {
186        double width = series.Points[i].XValue - series.Points[i - 1].XValue;
187        double y1 = series.Points[i - 1].YValues[0];
188        double y2 = series.Points[i].YValues[0];
189
190        auc += (y1 + y2) * width / 2;
191      }
192
193      return auc;
194    }
195
196    private void cmbSamples_SelectedIndexChanged(object sender, System.EventArgs e) {
197      if (Content != null)
198        UpdateChart();
199    }
200
201
202    #region show / hide series
203    private void ToggleSeries(Series series) {
204      if (series.Points.Count == 0)
205        FillSeriesWithDataPoints(series, cachedRocPoints[series.Name]);
206      else
207        series.Points.Clear();
208    }
209    private void chart_MouseDown(object sender, MouseEventArgs e) {
210      HitTestResult result = chart.HitTest(e.X, e.Y);
211      if (result.ChartElementType == ChartElementType.LegendItem) {
212        if (result.Series != null) ToggleSeries(result.Series);
213      }
214    }
215    private void chart_CustomizeLegend(object sender, CustomizeLegendEventArgs e) {
216      foreach (LegendItem legendItem in e.LegendItems) {
217        var series = chart.Series[legendItem.SeriesName];
218        if (series != null) {
219          bool seriesIsInvisible = series.Points.Count == 0;
220          foreach (LegendCell cell in legendItem.Cells)
221            cell.ForeColor = seriesIsInvisible ? Color.Gray : Color.Black;
222        }
223      }
224    }
225    private void chart_MouseMove(object sender, MouseEventArgs e) {
226      HitTestResult result = chart.HitTest(e.X, e.Y);
227      if (result.ChartElementType == ChartElementType.LegendItem)
228        this.Cursor = Cursors.Hand;
229      else
230        this.Cursor = Cursors.Default;
231
232      string newTooltipText = string.Empty;
233      if (result.ChartElementType == ChartElementType.DataPoint)
234        newTooltipText = ((DataPoint)result.Object).ToolTip;
235
236      string oldTooltipText = this.toolTip.GetToolTip(chart);
237      if (newTooltipText != oldTooltipText)
238        this.toolTip.SetToolTip(chart, newTooltipText);
239    }
240    #endregion
241
242
243    private class ROCPoint {
244      public ROCPoint(double truePositiveRate, double falsePositiveRate, double lowerThreshold, double upperThreshold) {
245        this.truePositiveRate = truePositiveRate;
246        this.falsePositiveRate = falsePositiveRate;
247        this.lowerThreshold = lowerThreshold;
248        this.upperThreshold = upperThreshold;
249
250      }
251      public double truePositiveRate { get; private set; }
252      public double falsePositiveRate { get; private set; }
253      public double lowerThreshold { get; private set; }
254      public double upperThreshold { get; private set; }
255    }
256
257  }
258}
Note: See TracBrowser for help on using the repository browser.