Free cookie consent management tool by TermsFeed Policy Generator

source: branches/ClassificationEnsembleVoting/HeuristicLab.Problems.DataAnalysis.Views/3.4/Classification/ClassificationEnsembleSolutionAccuracyToCoveredSamples.cs @ 8811

Last change on this file since 8811 was 8811, checked in by sforsten, 11 years ago

#1776:

File size: 9.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Drawing;
25using System.Linq;
26using System.Windows.Forms;
27using System.Windows.Forms.DataVisualization.Charting;
28using HeuristicLab.Data;
29using HeuristicLab.MainForm;
30
31namespace HeuristicLab.Problems.DataAnalysis.Views {
32  [View("Accuracy Covered Dependence")]
33  [Content(typeof(IClassificationEnsembleSolution))]
34  public partial class ClassificationEnsembleSolutionAccuracyToCoveredSamples : DataAnalysisSolutionEvaluationView {
35    private const string ACCURACYCOVERED = "Accuracy to Covered percentage";
36    private const string AREA = "Area";
37
38    private const string SamplesComboBoxAllSamples = "All Samples";
39    private const string SamplesComboBoxTrainingSamples = "Training Samples";
40    private const string SamplesComboBoxTestSamples = "Test Samples";
41
42    private const int maxPoints = 101;
43
44    public new ClassificationEnsembleSolution Content {
45      get { return (ClassificationEnsembleSolution)base.Content; }
46      set { base.Content = value; }
47    }
48
49    public ClassificationEnsembleSolutionAccuracyToCoveredSamples()
50      : base() {
51      InitializeComponent();
52
53      SamplesComboBox.Items.AddRange(new string[] { SamplesComboBoxAllSamples, SamplesComboBoxTrainingSamples, SamplesComboBoxTestSamples });
54      SamplesComboBox.SelectedIndex = 0;
55      //configure axis
56      this.chart.CustomizeAllChartAreas();
57      this.chart.ChartAreas[0].CursorX.IsUserSelectionEnabled = true;
58      this.chart.ChartAreas[0].AxisX.ScaleView.Zoomable = true;
59      this.chart.ChartAreas[0].AxisX.IsStartedFromZero = true;
60      this.chart.ChartAreas[0].AxisX.Minimum = 0;
61      this.chart.ChartAreas[0].AxisX.Maximum = 1;
62      this.chart.ChartAreas[0].AxisX.Title = "Covered Samples in %";
63
64      this.chart.ChartAreas[0].CursorY.IsUserSelectionEnabled = true;
65      this.chart.ChartAreas[0].AxisY.ScaleView.Zoomable = true;
66      this.chart.ChartAreas[0].AxisY.IsStartedFromZero = true;
67      this.chart.ChartAreas[0].AxisY.Minimum = 0;
68      this.chart.ChartAreas[0].AxisY.Maximum = 1;
69      this.chart.ChartAreas[0].AxisY.Title = "Accuracy";
70
71      AUCLabel.Parent = chart;
72      AUCLabel.BackColor = Color.Transparent;
73    }
74
75    private void RedrawChart() {
76      this.chart.Series.Clear();
77      if (Content != null) {
78
79        double[] accuracy = new double[maxPoints + 1];
80        double[] covered = new double[maxPoints + 1];
81
82        IClassificationEnsembleSolutionWeightCalculator weightCalc = Content.WeightCalculator;
83        var solutions = Content.ClassificationSolutions;
84        double[] estimatedClassValues;
85        double[] classValues;
86        OnlineAccuracyCalculator accuracyCalc = new OnlineAccuracyCalculator();
87
88        int rows;
89        double[] confidences;
90
91        if (SamplesComboBox.SelectedItem.ToString().Equals(SamplesComboBoxAllSamples)) {
92          rows = Content.ProblemData.Dataset.Rows;
93          estimatedClassValues = Content.EstimatedClassValues.ToArray();
94          classValues = Content.ProblemData.Dataset.GetDoubleValues(Content.ProblemData.TargetVariable).ToArray();
95          confidences = weightCalc.GetConfidence(solutions, Enumerable.Range(0, rows), estimatedClassValues).ToArray();
96        } else {
97          IntRange range;
98          if (SamplesComboBox.SelectedItem.ToString().Equals(SamplesComboBoxTrainingSamples)) {
99            range = Content.ProblemData.TrainingPartition;
100            estimatedClassValues = Content.EstimatedTrainingClassValues.ToArray();
101          } else if (SamplesComboBox.SelectedItem.ToString().Equals(SamplesComboBoxTestSamples)) {
102            range = Content.ProblemData.TestPartition;
103            estimatedClassValues = Content.EstimatedTestClassValues.ToArray();
104          } else {
105            return;
106          }
107          rows = range.End - range.Start;
108          classValues = Content.ProblemData.Dataset.GetDoubleValues(Content.ProblemData.TargetVariable)
109                                      .Skip(range.Start).Take(range.End - range.Start).ToArray();
110          confidences = new double[rows];
111          int index;
112          for (int i = 0; i < rows; i++) {
113            index = range.Start + i;
114            confidences[i] = weightCalc.GetConfidence(GetRelevantSolutions(SamplesComboBox.SelectedItem.ToString(), solutions, index),
115                                                      index, estimatedClassValues[i]);
116          }
117        }
118
119        for (int i = 0; i < maxPoints; i++) {
120          double confidenceValue = (1.0 / (maxPoints - 1)) * i;
121          int notCovered = 0;
122
123          for (int j = 0; j < rows; j++) {
124            if (confidences[j] >= confidenceValue) {
125              accuracyCalc.Add(classValues[j], estimatedClassValues[j]);
126            } else {
127              notCovered++;
128            }
129          }
130
131          accuracy[i + 1] = accuracyCalc.Accuracy;
132          covered[i] = 1.0 - (double)notCovered / (double)rows;
133          accuracyCalc.Reset();
134        }
135
136        accuracy[0] = accuracy[1];
137        covered[maxPoints] = 0.0;
138
139        accuracy = accuracy.Reverse().ToArray();
140        covered = covered.Reverse().ToArray();
141
142        Series area = this.chart.Series.Add(AREA);
143        area.ChartType = SeriesChartType.Area;
144        area.Color = Color.LightBlue;
145        IEnumerable<IEnumerable<double>> areaPoints = CalculateAreaPoints(covered, accuracy);
146        area.Points.DataBindXY(areaPoints.ElementAt(0), areaPoints.ElementAt(1));
147
148        Series series = this.chart.Series.Add(ACCURACYCOVERED);
149        series.Color = Color.Red;
150        series.ChartType = SeriesChartType.FastPoint;
151        series.MarkerStyle = MarkerStyle.Diamond;
152        series.MarkerSize = 5;
153        series.Points.DataBindXY(covered, accuracy);
154
155        double auc = CalculateAreaUnderCurve(series);
156        area.LegendToolTip = "AUC: " + auc;
157
158        AUCLabel.Text = "AUC: " + auc;
159      }
160    }
161
162    private IEnumerable<IEnumerable<double>> CalculateAreaPoints(double[] covered, double[] accuracy) {
163      List<double> newCovered = new List<double>();
164      List<double> worseAccuracy = new List<double>();
165      newCovered.Add(covered[0]);
166      worseAccuracy.Add(accuracy[0]);
167      for (int i = 1; i < covered.Length; i++) {
168        if (accuracy[i] > accuracy[i - 1]) {
169          worseAccuracy.Add(accuracy[i - 1]);
170          newCovered.Add(covered[i] - Double.Epsilon);
171        } else {
172          worseAccuracy.Add(accuracy[i]);
173          newCovered.Add(covered[i - 1] + Double.Epsilon);
174        }
175        worseAccuracy.Add(accuracy[i]);
176        newCovered.Add(covered[i]);
177      }
178      return new List<IEnumerable<double>>() { newCovered, worseAccuracy };
179    }
180
181    protected IEnumerable<IClassificationSolution> GetRelevantSolutions(string samplesSelection, IEnumerable<IClassificationSolution> solutions, int curRow) {
182      if (samplesSelection == SamplesComboBoxAllSamples)
183        return solutions;
184      else if (samplesSelection == SamplesComboBoxTrainingSamples)
185        return solutions.Where(s => s.ProblemData.IsTrainingSample(curRow));
186      else if (samplesSelection == SamplesComboBoxTestSamples)
187        return solutions.Where(s => s.ProblemData.IsTestSample(curRow));
188      else
189        return new List<IClassificationSolution>();
190    }
191
192    private double CalculateAreaUnderCurve(Series series) {
193      if (series.Points.Count < 1) throw new ArgumentException("Could not calculate area under curve if less than 1 data points were given.");
194
195      double auc = 0.0;
196      for (int i = 1; i < series.Points.Count; i++) {
197        double width = series.Points[i].XValue - series.Points[i - 1].XValue;
198        double y1 = series.Points[i - 1].YValues[0];
199        double y2 = series.Points[i].YValues[0];
200
201        auc += (y1 + y2) * width / 2;
202      }
203
204      return auc;
205    }
206
207    #region events
208    protected override void RegisterContentEvents() {
209      base.RegisterContentEvents();
210      Content.ModelChanged += new EventHandler(Content_ModelChanged);
211      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
212    }
213    protected override void DeregisterContentEvents() {
214      base.DeregisterContentEvents();
215      Content.ModelChanged -= new EventHandler(Content_ModelChanged);
216      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
217    }
218
219    protected override void OnContentChanged() {
220      base.OnContentChanged();
221      RedrawChart();
222    }
223    private void Content_ProblemDataChanged(object sender, EventArgs e) {
224      RedrawChart();
225    }
226    private void Content_ModelChanged(object sender, EventArgs e) {
227      RedrawChart();
228    }
229    private void SamplesComboBox_SelectedIndexChanged(object sender, EventArgs e) {
230      RedrawChart();
231    }
232    #endregion
233  }
234}
Note: See TracBrowser for help on using the repository browser.