Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HiveStatistics/sources/HeuristicLab.Problems.DataAnalysis.Views/3.4/Regression/RegressionSolutionErrorCharacteristicsCurveView.cs @ 12803

Last change on this file since 12803 was 12689, checked in by dglaser, 9 years ago

#2388: Merged trunk into HiveStatistics branch

File size: 10.5 KB
RevLine 
[4417]1#region License Information
2/* HeuristicLab
[12012]3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[4417]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Windows.Forms;
26using System.Windows.Forms.DataVisualization.Charting;
[12515]27using HeuristicLab.Common;
[4417]28using HeuristicLab.MainForm;
[7701]29
[5829]30namespace HeuristicLab.Problems.DataAnalysis.Views {
[6642]31  [View("Error Characteristics Curve")]
32  [Content(typeof(IRegressionSolution))]
33  public partial class RegressionSolutionErrorCharacteristicsCurveView : DataAnalysisSolutionEvaluationView {
34    protected const string TrainingSamples = "Training";
35    protected const string TestSamples = "Test";
36    protected const string AllSamples = "All Samples";
[4417]37
[6642]38    public RegressionSolutionErrorCharacteristicsCurveView()
39      : base() {
[4417]40      InitializeComponent();
41
42      cmbSamples.Items.Add(TrainingSamples);
43      cmbSamples.Items.Add(TestSamples);
[6642]44      cmbSamples.Items.Add(AllSamples);
45
[4417]46      cmbSamples.SelectedIndex = 0;
47
[12515]48      residualComboBox.SelectedIndex = 0;
49
[4651]50      chart.CustomizeAllChartAreas();
[12515]51      chart.ChartAreas[0].AxisX.Title = residualComboBox.SelectedItem.ToString();
[4417]52      chart.ChartAreas[0].AxisX.Minimum = 0.0;
[12365]53      chart.ChartAreas[0].AxisX.Maximum = 0.0;
[6642]54      chart.ChartAreas[0].AxisX.IntervalAutoMode = IntervalAutoMode.VariableCount;
55      chart.ChartAreas[0].CursorX.Interval = 0.01;
56
[10500]57      chart.ChartAreas[0].AxisY.Title = "Ratio of Residuals";
[4417]58      chart.ChartAreas[0].AxisY.Minimum = 0.0;
59      chart.ChartAreas[0].AxisY.Maximum = 1.0;
60      chart.ChartAreas[0].AxisY.MajorGrid.Interval = 0.2;
[6642]61      chart.ChartAreas[0].CursorY.Interval = 0.01;
[4417]62    }
63
[6642]64    public new IRegressionSolution Content {
65      get { return (IRegressionSolution)base.Content; }
[4417]66      set { base.Content = value; }
67    }
[6642]68    public IRegressionProblemData ProblemData {
69      get {
70        if (Content == null) return null;
71        return Content.ProblemData;
72      }
73    }
[4417]74
75    protected override void RegisterContentEvents() {
76      base.RegisterContentEvents();
[5664]77      Content.ModelChanged += new EventHandler(Content_ModelChanged);
[4417]78      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
79    }
80    protected override void DeregisterContentEvents() {
81      base.DeregisterContentEvents();
[5664]82      Content.ModelChanged -= new EventHandler(Content_ModelChanged);
[4417]83      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
84    }
85
[6642]86    protected virtual void Content_ModelChanged(object sender, EventArgs e) {
87      if (InvokeRequired) Invoke((Action<object, EventArgs>)Content_ModelChanged, sender, e);
88      else UpdateChart();
[4417]89    }
[6642]90    protected virtual void Content_ProblemDataChanged(object sender, EventArgs e) {
91      if (InvokeRequired) Invoke((Action<object, EventArgs>)Content_ProblemDataChanged, sender, e);
92      else {
93        UpdateChart();
94      }
[4417]95    }
96    protected override void OnContentChanged() {
97      base.OnContentChanged();
[6642]98      UpdateChart();
[4417]99    }
100
[6642]101    protected virtual void UpdateChart() {
102      chart.Series.Clear();
103      chart.Annotations.Clear();
[12689]104      chart.ChartAreas[0].AxisX.Maximum = 0.0;
105      chart.ChartAreas[0].CursorX.Interval = 0.01;
[11093]106
[6642]107      if (Content == null) return;
[11093]108      if (cmbSamples.SelectedItem.ToString() == TrainingSamples && !ProblemData.TrainingIndices.Any()) return;
109      if (cmbSamples.SelectedItem.ToString() == TestSamples && !ProblemData.TestIndices.Any()) return;
[4417]110
[11093]111      if (Content.ProblemData.TrainingIndices.Any()) {
[11367]112        AddRegressionSolution(CreateConstantSolution());
[11093]113      }
[4417]114
[6642]115      AddRegressionSolution(Content);
[12515]116
117      chart.ChartAreas[0].AxisX.Title = residualComboBox.SelectedItem.ToString();
[6642]118    }
[4417]119
[6642]120    protected void AddRegressionSolution(IRegressionSolution solution) {
121      if (chart.Series.Any(s => s.Name == solution.Name)) return;
[4417]122
[6642]123      Series solutionSeries = new Series(solution.Name);
124      solutionSeries.Tag = solution;
125      solutionSeries.ChartType = SeriesChartType.FastLine;
[11093]126      var residuals = GetResiduals(GetOriginalValues(), GetEstimatedValues(solution));
[12365]127
128      var maxValue = residuals.Max();
[12586]129      if (maxValue >= chart.ChartAreas[0].AxisX.Maximum) {
130        double scale = Math.Pow(10, Math.Floor(Math.Log10(maxValue)));
131        var maximum = scale * (1 + (int)(maxValue / scale));
132        chart.ChartAreas[0].AxisX.Maximum = maximum;
133        chart.ChartAreas[0].CursorX.Interval = residuals.Min() / 100;
134      }
[11093]135
136      UpdateSeries(residuals, solutionSeries);
137
[6642]138      solutionSeries.ToolTip = "Area over Curve: " + CalculateAreaOverCurve(solutionSeries);
[8105]139      solutionSeries.LegendToolTip = "Double-click to open model";
[6642]140      chart.Series.Add(solutionSeries);
141    }
[5417]142
[6642]143    protected void UpdateSeries(List<double> residuals, Series series) {
144      series.Points.Clear();
145      residuals.Sort();
[6982]146      if (!residuals.Any() || residuals.All(double.IsNaN)) return;
[4417]147
[6642]148      series.Points.AddXY(0, 0);
149      for (int i = 0; i < residuals.Count; i++) {
150        var point = new DataPoint();
151        if (residuals[i] > chart.ChartAreas[0].AxisX.Maximum) {
152          point.XValue = chart.ChartAreas[0].AxisX.Maximum;
[6750]153          point.YValues[0] = ((double)i) / residuals.Count;
[6642]154          point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
155          series.Points.Add(point);
156          break;
157        }
[4417]158
[6642]159        point.XValue = residuals[i];
[6982]160        point.YValues[0] = ((double)i + 1) / residuals.Count;
[6642]161        point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
162        series.Points.Add(point);
163      }
[4417]164
[6642]165      if (series.Points.Last().XValue < chart.ChartAreas[0].AxisX.Maximum) {
166        var point = new DataPoint();
167        point.XValue = chart.ChartAreas[0].AxisX.Maximum;
168        point.YValues[0] = 1;
169        point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
170        series.Points.Add(point);
171      }
172    }
[4417]173
[6642]174    protected IEnumerable<double> GetOriginalValues() {
175      IEnumerable<double> originalValues;
176      switch (cmbSamples.SelectedItem.ToString()) {
177        case TrainingSamples:
[8139]178          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices);
[6642]179          break;
180        case TestSamples:
[8139]181          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndices);
[6642]182          break;
183        case AllSamples:
[6740]184          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable);
[6642]185          break;
186        default:
187          throw new NotSupportedException();
188      }
189      return originalValues;
190    }
[4417]191
[6642]192    protected IEnumerable<double> GetEstimatedValues(IRegressionSolution solution) {
193      IEnumerable<double> estimatedValues;
194      switch (cmbSamples.SelectedItem.ToString()) {
195        case TrainingSamples:
196          estimatedValues = solution.EstimatedTrainingValues;
197          break;
198        case TestSamples:
199          estimatedValues = solution.EstimatedTestValues;
200          break;
201        case AllSamples:
202          estimatedValues = solution.EstimatedValues;
203          break;
204        default:
205          throw new NotSupportedException();
[4417]206      }
[6642]207      return estimatedValues;
[4417]208    }
209
[6642]210    protected virtual List<double> GetResiduals(IEnumerable<double> originalValues, IEnumerable<double> estimatedValues) {
[12515]211      switch (residualComboBox.SelectedItem.ToString()) {
212        case "Absolute error": return originalValues.Zip(estimatedValues, (x, y) => Math.Abs(x - y)).ToList();
213        case "Squared error": return originalValues.Zip(estimatedValues, (x, y) => (x - y) * (x - y)).ToList();
214        case "Relative error": return originalValues.Zip(estimatedValues, (x, y) => x.IsAlmost(0.0) ? -1 : Math.Abs((x - y) / x))
215          .Where(x => x > 0) // remove entries where the original value is 0
216          .ToList();
[12586]217        default: throw new NotSupportedException();
[12515]218      }
[4417]219    }
220
[6642]221    private double CalculateAreaOverCurve(Series series) {
[6982]222      if (series.Points.Count < 1) return 0;
[4417]223
224      double auc = 0.0;
225      for (int i = 1; i < series.Points.Count; i++) {
226        double width = series.Points[i].XValue - series.Points[i - 1].XValue;
[6642]227        double y1 = 1 - series.Points[i - 1].YValues[0];
228        double y2 = 1 - series.Points[i].YValues[0];
[4417]229
230        auc += (y1 + y2) * width / 2;
231      }
232
233      return auc;
234    }
235
[6642]236    protected void cmbSamples_SelectedIndexChanged(object sender, EventArgs e) {
237      if (InvokeRequired) Invoke((Action<object, EventArgs>)cmbSamples_SelectedIndexChanged, sender, e);
238      else UpdateChart();
[4417]239    }
[7043]240
[7701]241    #region Baseline
[7700]242    private void Chart_MouseDoubleClick(object sender, MouseEventArgs e) {
[7043]243      HitTestResult result = chart.HitTest(e.X, e.Y);
244      if (result.ChartElementType != ChartElementType.LegendItem) return;
245
246      MainFormManager.MainForm.ShowContent((IRegressionSolution)result.Series.Tag);
247    }
248
[11367]249    private ConstantRegressionSolution CreateConstantSolution() {
[8139]250      double averageTrainingTarget = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices).Average();
[8963]251      var model = new ConstantRegressionModel(averageTrainingTarget);
[11093]252      var solution = new ConstantRegressionSolution(model, (IRegressionProblemData)ProblemData.Clone());
[7700]253      solution.Name = "Baseline";
[7043]254      return solution;
255    }
[7701]256    #endregion
[7700]257
[7701]258    private void chart_MouseMove(object sender, MouseEventArgs e) {
259      HitTestResult result = chart.HitTest(e.X, e.Y);
[8102]260      if (result.ChartElementType == ChartElementType.LegendItem) {
[7701]261        Cursor = Cursors.Hand;
[8102]262      } else {
[7701]263        Cursor = Cursors.Default;
[8102]264      }
[7700]265    }
[12515]266
267    private void residualComboBox_SelectedIndexChanged(object sender, EventArgs e) {
268      UpdateChart();
269    }
[4417]270  }
271}
Note: See TracBrowser for help on using the repository browser.