source: branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/ClassificationEnsembleSolutionAccuracyToCoveredSamples.cs @ 8814

Last change on this file since 8814 was 8814, checked in by sforsten, 7 years ago

#1776:

  • improved performance of confidence calculation
  • fixed bug in median confidence calculation
  • fixed bug in average confidence calculation
  • confidence calculation is now easier for training and test
  • removed obsolete view ClassificationEnsembleSolutionConfidenceAccuracyDependence
File size: 9.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Drawing;
25using System.Linq;
26using System.Windows.Forms;
27using System.Windows.Forms.DataVisualization.Charting;
28using HeuristicLab.MainForm;
29
30namespace HeuristicLab.Problems.DataAnalysis.Views {
31  [View("Accuracy Covered Dependence")]
32  [Content(typeof(IClassificationEnsembleSolution))]
33  public partial class ClassificationEnsembleSolutionAccuracyToCoveredSamples : DataAnalysisSolutionEvaluationView {
34    private const string ACCURACYCOVERED = "Accuracy to Covered percentage";
35    private const string AREA = "Area";
36
37    private const string SamplesComboBoxAllSamples = "All Samples";
38    private const string SamplesComboBoxTrainingSamples = "Training Samples";
39    private const string SamplesComboBoxTestSamples = "Test Samples";
40
41    private const int maxPoints = 101;
42
43    public new ClassificationEnsembleSolution Content {
44      get { return (ClassificationEnsembleSolution)base.Content; }
45      set { base.Content = value; }
46    }
47
48    public ClassificationEnsembleSolutionAccuracyToCoveredSamples()
49      : base() {
50      InitializeComponent();
51
52      SamplesComboBox.Items.AddRange(new string[] { SamplesComboBoxAllSamples, SamplesComboBoxTrainingSamples, SamplesComboBoxTestSamples });
53      SamplesComboBox.SelectedIndex = 0;
54      //configure axis
55      this.chart.CustomizeAllChartAreas();
56      this.chart.ChartAreas[0].CursorX.IsUserSelectionEnabled = true;
57      this.chart.ChartAreas[0].AxisX.ScaleView.Zoomable = true;
58      this.chart.ChartAreas[0].AxisX.IsStartedFromZero = true;
59      this.chart.ChartAreas[0].AxisX.Minimum = 0;
60      this.chart.ChartAreas[0].AxisX.Maximum = 1;
61      this.chart.ChartAreas[0].AxisX.Title = "Covered Samples in %";
62
63      this.chart.ChartAreas[0].CursorY.IsUserSelectionEnabled = true;
64      this.chart.ChartAreas[0].AxisY.ScaleView.Zoomable = true;
65      this.chart.ChartAreas[0].AxisY.IsStartedFromZero = true;
66      this.chart.ChartAreas[0].AxisY.Minimum = 0;
67      this.chart.ChartAreas[0].AxisY.Maximum = 1;
68      this.chart.ChartAreas[0].AxisY.Title = "Accuracy";
69
70      AUCLabel.Parent = chart;
71      AUCLabel.BackColor = Color.Transparent;
72    }
73
74    private void RedrawChart() {
75      this.chart.Series.Clear();
76      if (Content != null) {
77
78        double[] accuracy = new double[maxPoints + 1];
79        double[] covered = new double[maxPoints + 1];
80
81        IClassificationEnsembleSolutionWeightCalculator weightCalc = Content.WeightCalculator;
82        var solutions = Content.ClassificationSolutions;
83        double[] estimatedClassValues = null;
84        double[] classValues;
85        OnlineAccuracyCalculator accuracyCalc = new OnlineAccuracyCalculator();
86
87        int rows = 0;
88        double[] confidences = null;
89
90        classValues = Content.ProblemData.Dataset.GetDoubleValues(Content.ProblemData.TargetVariable).ToArray();
91
92        if (SamplesComboBox.SelectedItem.ToString().Equals(SamplesComboBoxAllSamples)) {
93          rows = Content.ProblemData.Dataset.Rows;
94          estimatedClassValues = Content.EstimatedClassValues.ToArray();
95          confidences = weightCalc.GetConfidence(solutions,
96                                                 Enumerable.Range(0, Content.ProblemData.Dataset.Rows),
97                                                 estimatedClassValues,
98                                                 weightCalc.GetAllClassDelegate()).ToArray();
99        } else if (SamplesComboBox.SelectedItem.ToString().Equals(SamplesComboBoxTrainingSamples)) {
100          rows = Content.ProblemData.TrainingIndices.Count();
101          estimatedClassValues = Content.EstimatedTrainingClassValues.ToArray();
102          confidences = weightCalc.GetConfidence(solutions,
103                                                 Content.ProblemData.TrainingIndices,
104                                                 estimatedClassValues,
105                                                 weightCalc.GetTrainingClassDelegate()).ToArray();
106        } else if (SamplesComboBox.SelectedItem.ToString().Equals(SamplesComboBoxTestSamples)) {
107          rows = Content.ProblemData.TestIndices.Count();
108          estimatedClassValues = Content.EstimatedTestClassValues.ToArray();
109          confidences = weightCalc.GetConfidence(solutions,
110                                                 Content.ProblemData.TestIndices,
111                                                 estimatedClassValues,
112                                                 weightCalc.GetTestClassDelegate()).ToArray();
113        }
114
115        for (int i = 0; i < maxPoints; i++) {
116          double confidenceValue = (1.0 / (maxPoints - 1)) * i;
117          int notCovered = 0;
118
119          for (int j = 0; j < rows; j++) {
120            if (confidences[j] >= confidenceValue) {
121              accuracyCalc.Add(classValues[j], estimatedClassValues[j]);
122            } else {
123              notCovered++;
124            }
125          }
126
127          accuracy[i + 1] = accuracyCalc.Accuracy;
128          if (rows > 0) {
129            covered[i] = 1.0 - (double)notCovered / (double)rows;
130          }
131          accuracyCalc.Reset();
132        }
133
134        accuracy[0] = accuracy[1];
135        covered[maxPoints] = 0.0;
136
137        accuracy = accuracy.Reverse().ToArray();
138        covered = covered.Reverse().ToArray();
139
140        Series area = this.chart.Series.Add(AREA);
141        area.ChartType = SeriesChartType.Area;
142        area.Color = Color.LightBlue;
143        IEnumerable<IEnumerable<double>> areaPoints = CalculateAreaPoints(covered, accuracy);
144        area.Points.DataBindXY(areaPoints.ElementAt(0), areaPoints.ElementAt(1));
145
146        Series series = this.chart.Series.Add(ACCURACYCOVERED);
147        series.Color = Color.Red;
148        series.ChartType = SeriesChartType.FastPoint;
149        series.MarkerStyle = MarkerStyle.Diamond;
150        series.MarkerSize = 5;
151        series.Points.DataBindXY(covered, accuracy);
152
153        double auc = CalculateAreaUnderCurve(series);
154        area.LegendToolTip = "AUC: " + auc;
155
156        AUCLabel.Text = "AUC: " + auc;
157      }
158    }
159
160    private IEnumerable<IEnumerable<double>> CalculateAreaPoints(double[] covered, double[] accuracy) {
161      List<double> newCovered = new List<double>();
162      List<double> worseAccuracy = new List<double>();
163      newCovered.Add(covered[0]);
164      worseAccuracy.Add(accuracy[0]);
165      for (int i = 1; i < covered.Length; i++) {
166        if (accuracy[i] > accuracy[i - 1]) {
167          worseAccuracy.Add(accuracy[i - 1]);
168          newCovered.Add(covered[i] - Double.Epsilon);
169        } else {
170          worseAccuracy.Add(accuracy[i]);
171          newCovered.Add(covered[i - 1] + Double.Epsilon);
172        }
173        worseAccuracy.Add(accuracy[i]);
174        newCovered.Add(covered[i]);
175      }
176      return new List<IEnumerable<double>>() { newCovered, worseAccuracy };
177    }
178
179    private double CalculateAreaUnderCurve(Series series) {
180      if (series.Points.Count < 1) throw new ArgumentException("Could not calculate area under curve if less than 1 data points were given.");
181
182      double auc = 0.0;
183      for (int i = 1; i < series.Points.Count; i++) {
184        double width = series.Points[i].XValue - series.Points[i - 1].XValue;
185        double y1 = series.Points[i - 1].YValues[0];
186        double y2 = series.Points[i].YValues[0];
187
188        auc += (y1 + y2) * width / 2;
189      }
190
191      return auc;
192    }
193
194    #region events
195    protected override void RegisterContentEvents() {
196      base.RegisterContentEvents();
197      Content.ModelChanged += new EventHandler(Content_ModelChanged);
198      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
199    }
200    protected override void DeregisterContentEvents() {
201      base.DeregisterContentEvents();
202      Content.ModelChanged -= new EventHandler(Content_ModelChanged);
203      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
204    }
205
206    protected override void OnContentChanged() {
207      base.OnContentChanged();
208      RedrawChart();
209    }
210    private void Content_ProblemDataChanged(object sender, EventArgs e) {
211      RedrawChart();
212    }
213    private void Content_ModelChanged(object sender, EventArgs e) {
214      RedrawChart();
215    }
216    private void SamplesComboBox_SelectedIndexChanged(object sender, EventArgs e) {
217      RedrawChart();
218    }
219    #endregion
220  }
221}
Note: See TracBrowser for help on using the repository browser.