Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Views/3.4/Regression/RegressionSolutionErrorCharacteristicsCurveView.cs @ 11301

Last change on this file since 11301 was 11171, checked in by ascheibe, 10 years ago

#2115 merged r11170 (copyright update) into trunk

File size: 10.1 KB
RevLine 
[4417]1#region License Information
2/* HeuristicLab
[11171]3 * Copyright (C) 2002-2014 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[4417]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Windows.Forms;
26using System.Windows.Forms.DataVisualization.Charting;
27using HeuristicLab.MainForm;
[7701]28
[5829]29namespace HeuristicLab.Problems.DataAnalysis.Views {
[6642]30  [View("Error Characteristics Curve")]
31  [Content(typeof(IRegressionSolution))]
32  public partial class RegressionSolutionErrorCharacteristicsCurveView : DataAnalysisSolutionEvaluationView {
33    protected const string TrainingSamples = "Training";
34    protected const string TestSamples = "Test";
35    protected const string AllSamples = "All Samples";
[4417]36
[6642]37    public RegressionSolutionErrorCharacteristicsCurveView()
38      : base() {
[4417]39      InitializeComponent();
40
41      cmbSamples.Items.Add(TrainingSamples);
42      cmbSamples.Items.Add(TestSamples);
[6642]43      cmbSamples.Items.Add(AllSamples);
44
[4417]45      cmbSamples.SelectedIndex = 0;
46
[4651]47      chart.CustomizeAllChartAreas();
[6642]48      chart.ChartAreas[0].AxisX.Title = "Absolute Error";
[4417]49      chart.ChartAreas[0].AxisX.Minimum = 0.0;
50      chart.ChartAreas[0].AxisX.Maximum = 1.0;
[6642]51      chart.ChartAreas[0].AxisX.IntervalAutoMode = IntervalAutoMode.VariableCount;
52      chart.ChartAreas[0].CursorX.Interval = 0.01;
53
[10500]54      chart.ChartAreas[0].AxisY.Title = "Ratio of Residuals";
[4417]55      chart.ChartAreas[0].AxisY.Minimum = 0.0;
56      chart.ChartAreas[0].AxisY.Maximum = 1.0;
57      chart.ChartAreas[0].AxisY.MajorGrid.Interval = 0.2;
[6642]58      chart.ChartAreas[0].CursorY.Interval = 0.01;
[4417]59    }
60
[6642]61    public new IRegressionSolution Content {
62      get { return (IRegressionSolution)base.Content; }
[4417]63      set { base.Content = value; }
64    }
[6642]65    public IRegressionProblemData ProblemData {
66      get {
67        if (Content == null) return null;
68        return Content.ProblemData;
69      }
70    }
[4417]71
72    protected override void RegisterContentEvents() {
73      base.RegisterContentEvents();
[5664]74      Content.ModelChanged += new EventHandler(Content_ModelChanged);
[4417]75      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
76    }
77    protected override void DeregisterContentEvents() {
78      base.DeregisterContentEvents();
[5664]79      Content.ModelChanged -= new EventHandler(Content_ModelChanged);
[4417]80      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
81    }
82
[6642]83    protected virtual void Content_ModelChanged(object sender, EventArgs e) {
84      if (InvokeRequired) Invoke((Action<object, EventArgs>)Content_ModelChanged, sender, e);
85      else UpdateChart();
[4417]86    }
[6642]87    protected virtual void Content_ProblemDataChanged(object sender, EventArgs e) {
88      if (InvokeRequired) Invoke((Action<object, EventArgs>)Content_ProblemDataChanged, sender, e);
89      else {
90        UpdateChart();
91      }
[4417]92    }
93    protected override void OnContentChanged() {
94      base.OnContentChanged();
[6642]95      UpdateChart();
[4417]96    }
97
[6642]98    protected virtual void UpdateChart() {
99      chart.Series.Clear();
100      chart.Annotations.Clear();
[11093]101
[6642]102      if (Content == null) return;
[11093]103      if (cmbSamples.SelectedItem.ToString() == TrainingSamples && !ProblemData.TrainingIndices.Any()) return;
104      if (cmbSamples.SelectedItem.ToString() == TestSamples && !ProblemData.TestIndices.Any()) return;
[4417]105
[11093]106      if (Content.ProblemData.TrainingIndices.Any()) {
107        var constantModel = CreateConstantModel();
108        var originalValues = GetOriginalValues().ToList();
109        var baselineEstimatedValues = GetEstimatedValues(constantModel);
110        var baselineResiduals = GetResiduals(originalValues, baselineEstimatedValues);
[4417]111
[11093]112        Series baselineSeries = new Series("Baseline");
113        baselineSeries.ChartType = SeriesChartType.FastLine;
114        UpdateSeries(baselineResiduals, baselineSeries);
115        baselineSeries.ToolTip = "Area over Curve: " + CalculateAreaOverCurve(baselineSeries);
116        baselineSeries.Tag = constantModel;
117        baselineSeries.LegendToolTip = "Double-click to open model";
118        chart.Series.Add(baselineSeries);
119      }
[4417]120
[6642]121      AddRegressionSolution(Content);
122    }
[4417]123
[6642]124    protected void AddRegressionSolution(IRegressionSolution solution) {
125      if (chart.Series.Any(s => s.Name == solution.Name)) return;
[4417]126
[6642]127      Series solutionSeries = new Series(solution.Name);
128      solutionSeries.Tag = solution;
129      solutionSeries.ChartType = SeriesChartType.FastLine;
[11093]130      var residuals = GetResiduals(GetOriginalValues(), GetEstimatedValues(solution));
131     
132      chart.ChartAreas[0].AxisX.Maximum = Math.Ceiling(residuals.Max());
133      chart.ChartAreas[0].CursorX.Interval = residuals.Min() / 100;
134
135      UpdateSeries(residuals, solutionSeries);
136
[6642]137      solutionSeries.ToolTip = "Area over Curve: " + CalculateAreaOverCurve(solutionSeries);
[8105]138      solutionSeries.LegendToolTip = "Double-click to open model";
[6642]139      chart.Series.Add(solutionSeries);
140    }
[5417]141
[6642]142    protected void UpdateSeries(List<double> residuals, Series series) {
143      series.Points.Clear();
144      residuals.Sort();
[6982]145      if (!residuals.Any() || residuals.All(double.IsNaN)) return;
[4417]146
[6642]147      series.Points.AddXY(0, 0);
148      for (int i = 0; i < residuals.Count; i++) {
149        var point = new DataPoint();
150        if (residuals[i] > chart.ChartAreas[0].AxisX.Maximum) {
151          point.XValue = chart.ChartAreas[0].AxisX.Maximum;
[6750]152          point.YValues[0] = ((double)i) / residuals.Count;
[6642]153          point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
154          series.Points.Add(point);
155          break;
156        }
[4417]157
[6642]158        point.XValue = residuals[i];
[6982]159        point.YValues[0] = ((double)i + 1) / residuals.Count;
[6642]160        point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
161        series.Points.Add(point);
162      }
[4417]163
[6642]164      if (series.Points.Last().XValue < chart.ChartAreas[0].AxisX.Maximum) {
165        var point = new DataPoint();
166        point.XValue = chart.ChartAreas[0].AxisX.Maximum;
167        point.YValues[0] = 1;
168        point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
169        series.Points.Add(point);
170      }
171    }
[4417]172
[6642]173    protected IEnumerable<double> GetOriginalValues() {
174      IEnumerable<double> originalValues;
175      switch (cmbSamples.SelectedItem.ToString()) {
176        case TrainingSamples:
[8139]177          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices);
[6642]178          break;
179        case TestSamples:
[8139]180          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndices);
[6642]181          break;
182        case AllSamples:
[6740]183          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable);
[6642]184          break;
185        default:
186          throw new NotSupportedException();
187      }
188      return originalValues;
189    }
[4417]190
[6642]191    protected IEnumerable<double> GetEstimatedValues(IRegressionSolution solution) {
192      IEnumerable<double> estimatedValues;
193      switch (cmbSamples.SelectedItem.ToString()) {
194        case TrainingSamples:
195          estimatedValues = solution.EstimatedTrainingValues;
196          break;
197        case TestSamples:
198          estimatedValues = solution.EstimatedTestValues;
199          break;
200        case AllSamples:
201          estimatedValues = solution.EstimatedValues;
202          break;
203        default:
204          throw new NotSupportedException();
[4417]205      }
[6642]206      return estimatedValues;
[4417]207    }
208
[6642]209    protected virtual List<double> GetResiduals(IEnumerable<double> originalValues, IEnumerable<double> estimatedValues) {
210      return originalValues.Zip(estimatedValues, (x, y) => Math.Abs(x - y)).ToList();
[4417]211    }
212
[6642]213    private double CalculateAreaOverCurve(Series series) {
[6982]214      if (series.Points.Count < 1) return 0;
[4417]215
216      double auc = 0.0;
217      for (int i = 1; i < series.Points.Count; i++) {
218        double width = series.Points[i].XValue - series.Points[i - 1].XValue;
[6642]219        double y1 = 1 - series.Points[i - 1].YValues[0];
220        double y2 = 1 - series.Points[i].YValues[0];
[4417]221
222        auc += (y1 + y2) * width / 2;
223      }
224
225      return auc;
226    }
227
[6642]228    protected void cmbSamples_SelectedIndexChanged(object sender, EventArgs e) {
229      if (InvokeRequired) Invoke((Action<object, EventArgs>)cmbSamples_SelectedIndexChanged, sender, e);
230      else UpdateChart();
[4417]231    }
[7043]232
[7701]233    #region Baseline
[7700]234    private void Chart_MouseDoubleClick(object sender, MouseEventArgs e) {
[7043]235      HitTestResult result = chart.HitTest(e.X, e.Y);
236      if (result.ChartElementType != ChartElementType.LegendItem) return;
237
238      MainFormManager.MainForm.ShowContent((IRegressionSolution)result.Series.Tag);
239    }
240
241    private IRegressionSolution CreateConstantModel() {
[8139]242      double averageTrainingTarget = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices).Average();
[8963]243      var model = new ConstantRegressionModel(averageTrainingTarget);
[11093]244      var solution = new ConstantRegressionSolution(model, (IRegressionProblemData)ProblemData.Clone());
[7700]245      solution.Name = "Baseline";
[7043]246      return solution;
247    }
[7701]248    #endregion
[7700]249
[7701]250    private void chart_MouseMove(object sender, MouseEventArgs e) {
251      HitTestResult result = chart.HitTest(e.X, e.Y);
[8102]252      if (result.ChartElementType == ChartElementType.LegendItem) {
[7701]253        Cursor = Cursors.Hand;
[8102]254      } else {
[7701]255        Cursor = Cursors.Default;
[8102]256      }
[7700]257    }
[4417]258  }
259}
Note: See TracBrowser for help on using the repository browser.