#region License Information /* HeuristicLab * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Drawing; using System.Linq; using System.Windows.Forms; using HeuristicLab.Common; using HeuristicLab.Data; using HeuristicLab.MainForm; using HeuristicLab.MainForm.WindowsForms; using HeuristicLab.Problems.DataAnalysis.Interfaces; namespace HeuristicLab.Problems.DataAnalysis.Views { [View("Estimated Class Values")] [Content(typeof(ClassificationEnsembleSolution))] public partial class ClassificationEnsembleSolutionEstimatedClassValuesView : ClassificationSolutionEstimatedClassValuesView { private const string RowColumnName = "Row"; private const string TargetClassValuesColumnName = "Target Variable"; private const string EstimatedClassValuesColumnName = "Estimated Class Values"; private const string CorrectClassificationColumnName = "Correct Classification"; private const string ConfidenceColumnName = "Confidence"; private const string SamplesComboBoxAllSamples = "All Samples"; private const string SamplesComboBoxTrainingSamples = "Training Samples"; private const string SamplesComboBoxTestSamples = "Test Samples"; public new ClassificationEnsembleSolution Content { get { return (ClassificationEnsembleSolution)base.Content; } set { base.Content = value; } } public ClassificationEnsembleSolutionEstimatedClassValuesView() : base() { InitializeComponent(); SamplesComboBox.Items.AddRange(new string[] { SamplesComboBoxAllSamples, SamplesComboBoxTrainingSamples, SamplesComboBoxTestSamples }); SamplesComboBox.SelectedIndex = 0; matrixView.DataGridView.RowPrePaint += new DataGridViewRowPrePaintEventHandler(DataGridView_RowPrePaint); } private void SamplesComboBox_SelectedIndexChanged(object sender, EventArgs e) { UpdateEstimatedValues(); } protected override void UpdateEstimatedValues() { if (InvokeRequired) { Invoke((Action)UpdateEstimatedValues); return; } if (Content == null) { matrixView.Content = null; return; } int[] indizes; double[] estimatedClassValues; switch (SamplesComboBox.SelectedItem.ToString()) { case SamplesComboBoxAllSamples: { indizes = Enumerable.Range(0, Content.ProblemData.Dataset.Rows).ToArray(); estimatedClassValues = Content.EstimatedClassValues.ToArray(); break; } case SamplesComboBoxTrainingSamples: { indizes = Content.ProblemData.TrainingIndizes.ToArray(); estimatedClassValues = Content.EstimatedTrainingClassValues.ToArray(); break; } case SamplesComboBoxTestSamples: { indizes = Content.ProblemData.TestIndizes.ToArray(); estimatedClassValues = Content.EstimatedTestClassValues.ToArray(); break; } default: throw new ArgumentException(); } IEnumerable solutions = Content.ClassificationSolutions.CheckedItems; int classValuesCount = Content.ProblemData.ClassValues.Count; int solutionsCount = solutions.Count(); string[,] values = new string[indizes.Length, 5 + classValuesCount + solutionsCount]; double[] target = Content.ProblemData.Dataset.GetDoubleValues(Content.ProblemData.TargetVariable).ToArray(); List> estimatedValuesVector = GetEstimatedValues(SamplesComboBox.SelectedItem.ToString(), indizes, solutions); IClassificationEnsembleSolutionWeightCalculator weightCalc = Content.WeightCalculator; // needed to calculate average confidences of correct and wrong estimated classes bool correctClassified; double[] confidence = new double[2]; int[] classified = new int[2]; double curConfidence; double[] confidences = null; if (SamplesComboBox.SelectedItem.ToString() == SamplesComboBoxAllSamples) { confidences = weightCalc.GetConfidence(solutions, indizes, estimatedClassValues).ToArray(); } for (int i = 0; i < indizes.Length; i++) { int row = indizes[i]; values[i, 0] = row.ToString(); values[i, 1] = target[i].ToString(); //display only indices and target values if no models are present if (solutionsCount > 0) { values[i, 2] = estimatedClassValues[i].ToString(); correctClassified = target[i].IsAlmost(estimatedClassValues[i]); values[i, 3] = correctClassified.ToString(); if (SamplesComboBox.SelectedItem.ToString() == SamplesComboBoxAllSamples) { curConfidence = confidences[i]; } else { curConfidence = weightCalc.GetConfidence(GetRelevantSolutions(SamplesComboBox.SelectedItem.ToString(), solutions, row), indizes[i], estimatedClassValues[i]); } if (correctClassified) { confidence[0] += curConfidence; classified[0]++; } else { confidence[1] += curConfidence; classified[1]++; } values[i, 4] = curConfidence.ToString(); var groups = estimatedValuesVector[i].GroupBy(x => x).Select(g => new { Key = g.Key, Count = g.Count() }).ToList(); for (int classIndex = 0; classIndex < Content.ProblemData.ClassValues.Count; classIndex++) { var group = groups.Where(g => g.Key == Content.ProblemData.ClassValues[classIndex]).SingleOrDefault(); if (group == null) values[i, 5 + classIndex] = 0.ToString(); else values[i, 5 + classIndex] = group.Count.ToString(); } for (int modelIndex = 0; modelIndex < estimatedValuesVector[i].Count; modelIndex++) { values[i, 5 + classValuesCount + modelIndex] = estimatedValuesVector[i][modelIndex] == null ? string.Empty : estimatedValuesVector[i][modelIndex].ToString(); } } } CorrectClassifiedConfidence.Text = (confidence[0] / (double)classified[0]).ToString(); WrongClassifiedConfidence.Text = (confidence[1] / (double)classified[1]).ToString(); StringMatrix matrix = new StringMatrix(values); List columnNames = new List() { "Id", TargetClassValuesColumnName, EstimatedClassValuesColumnName, CorrectClassificationColumnName, ConfidenceColumnName }; columnNames.AddRange(Content.ProblemData.ClassNames); columnNames.AddRange(Content.ClassificationSolutions.CheckedItems.Select(s => s.Model.Name));//.Model.Models.Select(m => m.Name)); matrix.ColumnNames = columnNames; matrix.SortableView = true; matrixView.Content = matrix; } protected IEnumerable GetRelevantSolutions(string samplesSelection, IEnumerable solutions, int curRow) { if (samplesSelection == SamplesComboBoxAllSamples) return solutions; else if (samplesSelection == SamplesComboBoxTrainingSamples) return solutions.Where(s => s.ProblemData.IsTrainingSample(curRow)); else if (samplesSelection == SamplesComboBoxTestSamples) return solutions.Where(s => s.ProblemData.IsTestSample(curRow)); else return new List(); } private IEnumerable FindAllIndices(List list, double value) { List indices = new List(); for (int i = 0; i < list.Count; i++) { if (list[i].Equals(value)) indices.Add(i); } return indices; } private List> GetEstimatedValues(string samplesSelection, int[] rows, IEnumerable solutions) { List> values = new List>(); int solutionIndex = 0; foreach (var solution in solutions) { double[] estimation = solution.GetEstimatedClassValues(rows).ToArray(); for (int i = 0; i < rows.Length; i++) { var row = rows[i]; if (solutionIndex == 0) values.Add(new List()); if (samplesSelection == SamplesComboBoxAllSamples) values[i].Add(estimation[i]); else if (samplesSelection == SamplesComboBoxTrainingSamples && solution.ProblemData.IsTrainingSample(row)) values[i].Add(estimation[i]); else if (samplesSelection == SamplesComboBoxTestSamples && solution.ProblemData.IsTestSample(row)) values[i].Add(estimation[i]); else values[i].Add(null); } solutionIndex++; } return values; } private void DataGridView_RowPrePaint(object sender, DataGridViewRowPrePaintEventArgs e) { if (InvokeRequired) { Invoke(new EventHandler(DataGridView_RowPrePaint), sender, e); return; } var cellValue = matrixView.DataGridView[3, e.RowIndex].Value.ToString(); if (string.IsNullOrEmpty(cellValue)) return; bool correctClassified = bool.Parse(cellValue); matrixView.DataGridView.Rows[e.RowIndex].DefaultCellStyle.ForeColor = correctClassified ? Color.MediumSeaGreen : Color.Red; } } }