Free cookie consent management tool by TermsFeed Policy Generator

source: branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/ClassificationEnsembleSolutionAccuracyToCoveredSamples.cs @ 8510

Last change on this file since 8510 was 8510, checked in by sforsten, 12 years ago

#1776: added AUC to ClassificationEnsembleSolutionAccuracyToCoveredSamples

File size: 9.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Drawing;
25using System.Linq;
26using System.Windows.Forms;
27using System.Windows.Forms.DataVisualization.Charting;
28using HeuristicLab.Data;
29using HeuristicLab.MainForm;
30using HeuristicLab.Problems.DataAnalysis.Interfaces;
31
32namespace HeuristicLab.Problems.DataAnalysis.Views {
33  [View("Accuracy Covered Dependence")]
34  [Content(typeof(IClassificationEnsembleSolution))]
35  public partial class ClassificationEnsembleSolutionAccuracyToCoveredSamples : DataAnalysisSolutionEvaluationView {
36    private const string ACCURACYCOVERED = "Accuracy to Covered percentage";
37    private const string AREA = "Area";
38
39    private const string SamplesComboBoxAllSamples = "All Samples";
40    private const string SamplesComboBoxTrainingSamples = "Training Samples";
41    private const string SamplesComboBoxTestSamples = "Test Samples";
42
43    private const int maxPoints = 101;
44
45    public new ClassificationEnsembleSolution Content {
46      get { return (ClassificationEnsembleSolution)base.Content; }
47      set { base.Content = value; }
48    }
49
50    public ClassificationEnsembleSolutionAccuracyToCoveredSamples()
51      : base() {
52      InitializeComponent();
53
54      SamplesComboBox.Items.AddRange(new string[] { SamplesComboBoxAllSamples, SamplesComboBoxTrainingSamples, SamplesComboBoxTestSamples });
55      SamplesComboBox.SelectedIndex = 0;
56      //configure axis
57      this.chart.CustomizeAllChartAreas();
58      this.chart.ChartAreas[0].CursorX.IsUserSelectionEnabled = true;
59      this.chart.ChartAreas[0].AxisX.ScaleView.Zoomable = true;
60      this.chart.ChartAreas[0].AxisX.IsStartedFromZero = true;
61      this.chart.ChartAreas[0].AxisX.Minimum = 0;
62      this.chart.ChartAreas[0].AxisX.Maximum = 1;
63      this.chart.ChartAreas[0].AxisX.Title = "Covered Samples in %";
64
65      this.chart.ChartAreas[0].CursorY.IsUserSelectionEnabled = true;
66      this.chart.ChartAreas[0].AxisY.ScaleView.Zoomable = true;
67      this.chart.ChartAreas[0].AxisY.IsStartedFromZero = true;
68      this.chart.ChartAreas[0].AxisY.Minimum = 0;
69      this.chart.ChartAreas[0].AxisY.Maximum = 1;
70      this.chart.ChartAreas[0].AxisY.Title = "Accuracy";
71
72      AUCLabel.Parent = chart;
73      AUCLabel.BackColor = Color.Transparent;
74    }
75
76    private void RedrawChart() {
77      this.chart.Series.Clear();
78      if (Content != null) {
79
80        double[] accuracy = new double[maxPoints + 1];
81        double[] covered = new double[maxPoints + 1];
82
83        IClassificationEnsembleSolutionWeightCalculator weightCalc = Content.WeightCalculator;
84        var solutions = Content.ClassificationSolutions;
85        double[] estimatedClassValues;
86        double[] classValues;
87        OnlineAccuracyCalculator accuracyCalc = new OnlineAccuracyCalculator();
88
89        int rows;
90        double[] confidences;
91
92        if (SamplesComboBox.SelectedItem.ToString().Equals(SamplesComboBoxAllSamples)) {
93          rows = Content.ProblemData.Dataset.Rows;
94          estimatedClassValues = Content.EstimatedClassValues.ToArray();
95          classValues = Content.ProblemData.Dataset.GetDoubleValues(Content.ProblemData.TargetVariable).ToArray();
96          confidences = weightCalc.GetConfidence(solutions, Enumerable.Range(0, rows), estimatedClassValues).ToArray();
97        } else {
98          IntRange range;
99          if (SamplesComboBox.SelectedItem.ToString().Equals(SamplesComboBoxTrainingSamples)) {
100            range = Content.ProblemData.TrainingPartition;
101            estimatedClassValues = Content.EstimatedTrainingClassValues.ToArray();
102          } else if (SamplesComboBox.SelectedItem.ToString().Equals(SamplesComboBoxTestSamples)) {
103            range = Content.ProblemData.TestPartition;
104            estimatedClassValues = Content.EstimatedTestClassValues.ToArray();
105          } else {
106            return;
107          }
108          rows = range.End - range.Start;
109          classValues = Content.ProblemData.Dataset.GetDoubleValues(Content.ProblemData.TargetVariable)
110                                      .Skip(range.Start).Take(range.End - range.Start).ToArray();
111          confidences = new double[rows];
112          int index;
113          for (int i = 0; i < rows; i++) {
114            index = range.Start + i;
115            confidences[i] = weightCalc.GetConfidence(GetRelevantSolutions(SamplesComboBox.SelectedItem.ToString(), solutions, index),
116                                                      index, estimatedClassValues[i]);
117          }
118        }
119
120        for (int i = 0; i < maxPoints; i++) {
121          double confidenceValue = (1.0 / (maxPoints - 1)) * i;
122          int notCovered = 0;
123
124          for (int j = 0; j < rows; j++) {
125            if (confidences[j] >= confidenceValue) {
126              accuracyCalc.Add(classValues[j], estimatedClassValues[j]);
127            } else {
128              notCovered++;
129            }
130          }
131
132          accuracy[i + 1] = accuracyCalc.Accuracy;
133          covered[i] = 1.0 - (double)notCovered / (double)rows;
134          accuracyCalc.Reset();
135        }
136
137        accuracy[0] = accuracy[1];
138        covered[maxPoints] = 0.0;
139
140        accuracy = accuracy.Reverse().ToArray();
141        covered = covered.Reverse().ToArray();
142
143        Series area = this.chart.Series.Add(AREA);
144        area.ChartType = SeriesChartType.Area;
145        area.Color = Color.LightBlue;
146        IEnumerable<IEnumerable<double>> areaPoints = CalculateAreaPoints(covered, accuracy);
147        area.Points.DataBindXY(areaPoints.ElementAt(0), areaPoints.ElementAt(1));
148
149        Series series = this.chart.Series.Add(ACCURACYCOVERED);
150        series.Color = Color.Red;
151        series.ChartType = SeriesChartType.FastPoint;
152        series.MarkerStyle = MarkerStyle.Diamond;
153        series.MarkerSize = 5;
154        series.Points.DataBindXY(covered, accuracy);
155
156        double auc = CalculateAreaUnderCurve(series);
157        area.LegendToolTip = "AUC: " + auc;
158
159        AUCLabel.Text = "AUC: " + auc;
160      }
161    }
162
163    private IEnumerable<IEnumerable<double>> CalculateAreaPoints(double[] covered, double[] accuracy) {
164      List<double> newCovered = new List<double>();
165      List<double> worseAccuracy = new List<double>();
166      newCovered.Add(covered[0]);
167      worseAccuracy.Add(accuracy[0]);
168      for (int i = 1; i < covered.Length; i++) {
169        if (accuracy[i] > accuracy[i - 1]) {
170          worseAccuracy.Add(accuracy[i - 1]);
171          newCovered.Add(covered[i] - Double.Epsilon);
172        } else {
173          worseAccuracy.Add(accuracy[i]);
174          newCovered.Add(covered[i - 1] + Double.Epsilon);
175        }
176        worseAccuracy.Add(accuracy[i]);
177        newCovered.Add(covered[i]);
178      }
179      return new List<IEnumerable<double>>() { newCovered, worseAccuracy };
180    }
181
182    protected IEnumerable<IClassificationSolution> GetRelevantSolutions(string samplesSelection, IEnumerable<IClassificationSolution> solutions, int curRow) {
183      if (samplesSelection == SamplesComboBoxAllSamples)
184        return solutions;
185      else if (samplesSelection == SamplesComboBoxTrainingSamples)
186        return solutions.Where(s => s.ProblemData.IsTrainingSample(curRow));
187      else if (samplesSelection == SamplesComboBoxTestSamples)
188        return solutions.Where(s => s.ProblemData.IsTestSample(curRow));
189      else
190        return new List<IClassificationSolution>();
191    }
192
193    private double CalculateAreaUnderCurve(Series series) {
194      if (series.Points.Count < 1) throw new ArgumentException("Could not calculate area under curve if less than 1 data points were given.");
195
196      double auc = 0.0;
197      for (int i = 1; i < series.Points.Count; i++) {
198        double width = series.Points[i].XValue - series.Points[i - 1].XValue;
199        double y1 = series.Points[i - 1].YValues[0];
200        double y2 = series.Points[i].YValues[0];
201
202        auc += (y1 + y2) * width / 2;
203      }
204
205      return auc;
206    }
207
208    #region events
209    protected override void RegisterContentEvents() {
210      base.RegisterContentEvents();
211      Content.ModelChanged += new EventHandler(Content_ModelChanged);
212      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
213    }
214    protected override void DeregisterContentEvents() {
215      base.DeregisterContentEvents();
216      Content.ModelChanged -= new EventHandler(Content_ModelChanged);
217      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
218    }
219
220    protected override void OnContentChanged() {
221      base.OnContentChanged();
222      RedrawChart();
223    }
224    private void Content_ProblemDataChanged(object sender, EventArgs e) {
225      RedrawChart();
226    }
227    private void Content_ModelChanged(object sender, EventArgs e) {
228      RedrawChart();
229    }
230    private void SamplesComboBox_SelectedIndexChanged(object sender, EventArgs e) {
231      RedrawChart();
232    }
233    #endregion
234  }
235}
Note: See TracBrowser for help on using the repository browser.