#region License Information /* HeuristicLab * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Drawing; using System.Linq; using System.Windows.Forms; using System.Windows.Forms.DataVisualization.Charting; using HeuristicLab.MainForm; namespace HeuristicLab.Problems.DataAnalysis.Views { [View("Accuracy Covered Dependence")] [Content(typeof(IClassificationEnsembleSolution))] public partial class ClassificationEnsembleSolutionAccuracyToCoveredSamples : DataAnalysisSolutionEvaluationView { private const string ACCURACYCOVERED = "Accuracy to Covered percentage"; private const string AREA = "Area"; private const string SamplesComboBoxAllSamples = "All Samples"; private const string SamplesComboBoxTrainingSamples = "Training Samples"; private const string SamplesComboBoxTestSamples = "Test Samples"; private const int maxPoints = 101; public new ClassificationEnsembleSolution Content { get { return (ClassificationEnsembleSolution)base.Content; } set { base.Content = value; } } public ClassificationEnsembleSolutionAccuracyToCoveredSamples() : base() { InitializeComponent(); SamplesComboBox.Items.AddRange(new string[] { SamplesComboBoxAllSamples, SamplesComboBoxTrainingSamples, SamplesComboBoxTestSamples }); SamplesComboBox.SelectedIndex = 0; //configure axis this.chart.CustomizeAllChartAreas(); this.chart.ChartAreas[0].CursorX.IsUserSelectionEnabled = true; this.chart.ChartAreas[0].AxisX.ScaleView.Zoomable = true; this.chart.ChartAreas[0].AxisX.IsStartedFromZero = true; this.chart.ChartAreas[0].AxisX.Minimum = 0; this.chart.ChartAreas[0].AxisX.Maximum = 1; this.chart.ChartAreas[0].AxisX.Title = "Covered Samples in %"; this.chart.ChartAreas[0].CursorY.IsUserSelectionEnabled = true; this.chart.ChartAreas[0].AxisY.ScaleView.Zoomable = true; this.chart.ChartAreas[0].AxisY.IsStartedFromZero = true; this.chart.ChartAreas[0].AxisY.Minimum = 0; this.chart.ChartAreas[0].AxisY.Maximum = 1; this.chart.ChartAreas[0].AxisY.Title = "Accuracy"; AUCLabel.Parent = chart; AUCLabel.BackColor = Color.Transparent; } private void RedrawChart() { this.chart.Series.Clear(); if (Content != null) { double[] accuracy = new double[maxPoints + 1]; double[] covered = new double[maxPoints + 1]; IClassificationEnsembleSolutionWeightCalculator weightCalc = Content.WeightCalculator; var solutions = Content.ClassificationSolutions; double[] estimatedClassValues = null; double[] target; OnlineAccuracyCalculator accuracyCalc = new OnlineAccuracyCalculator(); int[] indizes; double[] confidences; target = Content.ProblemData.Dataset.GetDoubleValues(Content.ProblemData.TargetVariable).ToArray(); switch (SamplesComboBox.SelectedItem.ToString()) { case SamplesComboBoxAllSamples: indizes = Enumerable.Range(0, Content.ProblemData.Dataset.Rows).ToArray(); estimatedClassValues = Content.EstimatedClassValues.ToArray(); confidences = weightCalc.GetConfidence(solutions, Enumerable.Range(0, Content.ProblemData.Dataset.Rows), estimatedClassValues, weightCalc.GetAllClassDelegate()).ToArray(); break; case SamplesComboBoxTrainingSamples: indizes = Content.ProblemData.TrainingIndices.ToArray(); estimatedClassValues = Content.EstimatedTrainingClassValues.ToArray(); confidences = weightCalc.GetConfidence(solutions, Content.ProblemData.TrainingIndices, estimatedClassValues, weightCalc.GetTrainingClassDelegate()).ToArray(); break; case SamplesComboBoxTestSamples: indizes = Content.ProblemData.TestIndices.ToArray(); estimatedClassValues = Content.EstimatedTestClassValues.ToArray(); confidences = weightCalc.GetConfidence(solutions, Content.ProblemData.TestIndices, estimatedClassValues, weightCalc.GetTestClassDelegate()).ToArray(); break; default: throw new ArgumentException(); } if (!estimatedClassValues.All(x => Double.IsNaN(x))) { int row; for (int i = 0; i < maxPoints; i++) { double confidenceValue = (1.0 / (maxPoints - 1)) * i; int notCovered = 0; for (int j = 0; j < indizes.Length; j++) { row = indizes[j]; if (confidences[j] >= confidenceValue) { accuracyCalc.Add(target[row], estimatedClassValues[j]); } else { notCovered++; } } accuracy[i + 1] = accuracyCalc.Accuracy; if (indizes.Length > 0) { covered[i] = 1.0 - (double)notCovered / (double)indizes.Length; } accuracyCalc.Reset(); } accuracy[0] = accuracy[1]; covered[maxPoints] = 0.0; accuracy = accuracy.Reverse().ToArray(); covered = covered.Reverse().ToArray(); Series area = this.chart.Series.Add(AREA); area.ChartType = SeriesChartType.Area; area.Color = Color.LightBlue; IEnumerable> areaPoints = CalculateAreaPoints(covered, accuracy); area.Points.DataBindXY(areaPoints.ElementAt(0), areaPoints.ElementAt(1)); Series series = this.chart.Series.Add(ACCURACYCOVERED); series.Color = Color.Red; series.ChartType = SeriesChartType.FastPoint; series.MarkerStyle = MarkerStyle.Diamond; series.MarkerSize = 5; series.Points.DataBindXY(covered, accuracy); double auc = CalculateAreaUnderCurve(series); area.LegendToolTip = "AUC: " + auc; AUCLabel.Text = "AUC: " + auc; } else { AUCLabel.Text = "No values in this partition!"; } } } private IEnumerable> CalculateAreaPoints(double[] covered, double[] accuracy) { List newCovered = new List(); List worseAccuracy = new List(); newCovered.Add(covered[0]); worseAccuracy.Add(accuracy[0]); for (int i = 1; i < covered.Length; i++) { if (accuracy[i] > accuracy[i - 1]) { worseAccuracy.Add(accuracy[i - 1]); newCovered.Add(covered[i] - Double.Epsilon); } else { worseAccuracy.Add(accuracy[i]); newCovered.Add(covered[i - 1] + Double.Epsilon); } worseAccuracy.Add(accuracy[i]); newCovered.Add(covered[i]); } return new List>() { newCovered, worseAccuracy }; } private double CalculateAreaUnderCurve(Series series) { if (series.Points.Count < 1) throw new ArgumentException("Could not calculate area under curve if less than 1 data points were given."); double auc = 0.0; for (int i = 1; i < series.Points.Count; i++) { double width = series.Points[i].XValue - series.Points[i - 1].XValue; double y1 = series.Points[i - 1].YValues[0]; double y2 = series.Points[i].YValues[0]; auc += (y1 + y2) * width / 2; } return auc; } #region events protected override void RegisterContentEvents() { base.RegisterContentEvents(); Content.ModelChanged += new EventHandler(Content_ModelChanged); Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged); } protected override void DeregisterContentEvents() { base.DeregisterContentEvents(); Content.ModelChanged -= new EventHandler(Content_ModelChanged); Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged); } protected override void OnContentChanged() { base.OnContentChanged(); RedrawChart(); } private void Content_ProblemDataChanged(object sender, EventArgs e) { RedrawChart(); } private void Content_ModelChanged(object sender, EventArgs e) { RedrawChart(); } private void SamplesComboBox_SelectedIndexChanged(object sender, EventArgs e) { RedrawChart(); } #endregion } }