Free cookie consent management tool by TermsFeed Policy Generator

source: branches/1776_ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/ClassificationEnsembleSolutionAccuracyToCoveredSamples.cs @ 17912

Last change on this file since 17912 was 8863, checked in by sforsten, 12 years ago

#1776:

  • merged r8810:8862 from trunk into branch
  • fixed exception in ClassificationEnsembleSolutionAccuracyToCoveredSamples, if no estimated values are in a partition, so nothing can be shown
File size: 9.5 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Drawing;
25using System.Linq;
26using System.Windows.Forms;
27using System.Windows.Forms.DataVisualization.Charting;
28using HeuristicLab.MainForm;
29
30namespace HeuristicLab.Problems.DataAnalysis.Views {
31  [View("Accuracy Covered Dependence")]
32  [Content(typeof(IClassificationEnsembleSolution))]
33  public partial class ClassificationEnsembleSolutionAccuracyToCoveredSamples : DataAnalysisSolutionEvaluationView {
34    private const string ACCURACYCOVERED = "Accuracy to Covered percentage";
35    private const string AREA = "Area";
36
37    private const string SamplesComboBoxAllSamples = "All Samples";
38    private const string SamplesComboBoxTrainingSamples = "Training Samples";
39    private const string SamplesComboBoxTestSamples = "Test Samples";
40
41    private const int maxPoints = 101;
42
43    public new ClassificationEnsembleSolution Content {
44      get { return (ClassificationEnsembleSolution)base.Content; }
45      set { base.Content = value; }
46    }
47
48    public ClassificationEnsembleSolutionAccuracyToCoveredSamples()
49      : base() {
50      InitializeComponent();
51
52      SamplesComboBox.Items.AddRange(new string[] { SamplesComboBoxAllSamples, SamplesComboBoxTrainingSamples, SamplesComboBoxTestSamples });
53      SamplesComboBox.SelectedIndex = 0;
54      //configure axis
55      this.chart.CustomizeAllChartAreas();
56      this.chart.ChartAreas[0].CursorX.IsUserSelectionEnabled = true;
57      this.chart.ChartAreas[0].AxisX.ScaleView.Zoomable = true;
58      this.chart.ChartAreas[0].AxisX.IsStartedFromZero = true;
59      this.chart.ChartAreas[0].AxisX.Minimum = 0;
60      this.chart.ChartAreas[0].AxisX.Maximum = 1;
61      this.chart.ChartAreas[0].AxisX.Title = "Covered Samples in %";
62
63      this.chart.ChartAreas[0].CursorY.IsUserSelectionEnabled = true;
64      this.chart.ChartAreas[0].AxisY.ScaleView.Zoomable = true;
65      this.chart.ChartAreas[0].AxisY.IsStartedFromZero = true;
66      this.chart.ChartAreas[0].AxisY.Minimum = 0;
67      this.chart.ChartAreas[0].AxisY.Maximum = 1;
68      this.chart.ChartAreas[0].AxisY.Title = "Accuracy";
69
70      AUCLabel.Parent = chart;
71      AUCLabel.BackColor = Color.Transparent;
72    }
73
74    private void RedrawChart() {
75      this.chart.Series.Clear();
76      if (Content != null) {
77
78        double[] accuracy = new double[maxPoints + 1];
79        double[] covered = new double[maxPoints + 1];
80
81        IClassificationEnsembleSolutionWeightCalculator weightCalc = Content.WeightCalculator;
82        var solutions = Content.ClassificationSolutions;
83        double[] estimatedClassValues = null;
84        double[] target;
85        OnlineAccuracyCalculator accuracyCalc = new OnlineAccuracyCalculator();
86
87        int[] indizes;
88        double[] confidences;
89
90        target = Content.ProblemData.Dataset.GetDoubleValues(Content.ProblemData.TargetVariable).ToArray();
91        switch (SamplesComboBox.SelectedItem.ToString()) {
92          case SamplesComboBoxAllSamples:
93            indizes = Enumerable.Range(0, Content.ProblemData.Dataset.Rows).ToArray();
94            estimatedClassValues = Content.EstimatedClassValues.ToArray();
95            confidences = weightCalc.GetConfidence(solutions,
96                                                   Enumerable.Range(0, Content.ProblemData.Dataset.Rows),
97                                                   estimatedClassValues,
98                                                   weightCalc.GetAllClassDelegate()).ToArray();
99            break;
100          case SamplesComboBoxTrainingSamples:
101            indizes = Content.ProblemData.TrainingIndices.ToArray();
102            estimatedClassValues = Content.EstimatedTrainingClassValues.ToArray();
103            confidences = weightCalc.GetConfidence(solutions,
104                                                   Content.ProblemData.TrainingIndices,
105                                                   estimatedClassValues,
106                                                   weightCalc.GetTrainingClassDelegate()).ToArray();
107            break;
108          case SamplesComboBoxTestSamples:
109            indizes = Content.ProblemData.TestIndices.ToArray();
110            estimatedClassValues = Content.EstimatedTestClassValues.ToArray();
111            confidences = weightCalc.GetConfidence(solutions,
112                                                   Content.ProblemData.TestIndices,
113                                                   estimatedClassValues,
114                                                   weightCalc.GetTestClassDelegate()).ToArray();
115            break;
116          default:
117            throw new ArgumentException();
118        }
119
120        if (!estimatedClassValues.All(x => Double.IsNaN(x))) {
121          int row;
122          for (int i = 0; i < maxPoints; i++) {
123            double confidenceValue = (1.0 / (maxPoints - 1)) * i;
124            int notCovered = 0;
125
126            for (int j = 0; j < indizes.Length; j++) {
127              row = indizes[j];
128              if (confidences[j] >= confidenceValue) {
129                accuracyCalc.Add(target[row], estimatedClassValues[j]);
130              } else {
131                notCovered++;
132              }
133            }
134
135            accuracy[i + 1] = accuracyCalc.Accuracy;
136            if (indizes.Length > 0) {
137              covered[i] = 1.0 - (double)notCovered / (double)indizes.Length;
138            }
139            accuracyCalc.Reset();
140          }
141
142          accuracy[0] = accuracy[1];
143          covered[maxPoints] = 0.0;
144
145          accuracy = accuracy.Reverse().ToArray();
146          covered = covered.Reverse().ToArray();
147
148          Series area = this.chart.Series.Add(AREA);
149          area.ChartType = SeriesChartType.Area;
150          area.Color = Color.LightBlue;
151          IEnumerable<IEnumerable<double>> areaPoints = CalculateAreaPoints(covered, accuracy);
152          area.Points.DataBindXY(areaPoints.ElementAt(0), areaPoints.ElementAt(1));
153
154          Series series = this.chart.Series.Add(ACCURACYCOVERED);
155          series.Color = Color.Red;
156          series.ChartType = SeriesChartType.FastPoint;
157          series.MarkerStyle = MarkerStyle.Diamond;
158          series.MarkerSize = 5;
159          series.Points.DataBindXY(covered, accuracy);
160
161          double auc = CalculateAreaUnderCurve(series);
162          area.LegendToolTip = "AUC: " + auc;
163
164          AUCLabel.Text = "AUC: " + auc;
165        } else {
166          AUCLabel.Text = "No values in this partition!";
167        }
168      }
169    }
170
171    private IEnumerable<IEnumerable<double>> CalculateAreaPoints(double[] covered, double[] accuracy) {
172      List<double> newCovered = new List<double>();
173      List<double> worseAccuracy = new List<double>();
174      newCovered.Add(covered[0]);
175      worseAccuracy.Add(accuracy[0]);
176      for (int i = 1; i < covered.Length; i++) {
177        if (accuracy[i] > accuracy[i - 1]) {
178          worseAccuracy.Add(accuracy[i - 1]);
179          newCovered.Add(covered[i] - Double.Epsilon);
180        } else {
181          worseAccuracy.Add(accuracy[i]);
182          newCovered.Add(covered[i - 1] + Double.Epsilon);
183        }
184        worseAccuracy.Add(accuracy[i]);
185        newCovered.Add(covered[i]);
186      }
187      return new List<IEnumerable<double>>() { newCovered, worseAccuracy };
188    }
189
190    private double CalculateAreaUnderCurve(Series series) {
191      if (series.Points.Count < 1) throw new ArgumentException("Could not calculate area under curve if less than 1 data points were given.");
192
193      double auc = 0.0;
194      for (int i = 1; i < series.Points.Count; i++) {
195        double width = series.Points[i].XValue - series.Points[i - 1].XValue;
196        double y1 = series.Points[i - 1].YValues[0];
197        double y2 = series.Points[i].YValues[0];
198
199        auc += (y1 + y2) * width / 2;
200      }
201
202      return auc;
203    }
204
205    #region events
206    protected override void RegisterContentEvents() {
207      base.RegisterContentEvents();
208      Content.ModelChanged += new EventHandler(Content_ModelChanged);
209      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
210    }
211    protected override void DeregisterContentEvents() {
212      base.DeregisterContentEvents();
213      Content.ModelChanged -= new EventHandler(Content_ModelChanged);
214      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
215    }
216
217    protected override void OnContentChanged() {
218      base.OnContentChanged();
219      RedrawChart();
220    }
221    private void Content_ProblemDataChanged(object sender, EventArgs e) {
222      RedrawChart();
223    }
224    private void Content_ModelChanged(object sender, EventArgs e) {
225      RedrawChart();
226    }
227    private void SamplesComboBox_SelectedIndexChanged(object sender, EventArgs e) {
228      RedrawChart();
229    }
230    #endregion
231  }
232}
Note: See TracBrowser for help on using the repository browser.