source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Views/3.4/Regression/RegressionSolutionErrorCharacteristicsCurveView.cs @ 11093

Last change on this file since 11093 was 11093, checked in by gkronber, 5 years ago

#2197: fixed bugs in views for data analysis solutions that might occur if the problem does not have training samples (e.g. when the data is loaded into an existing solution)

File size: 10.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Windows.Forms;
26using System.Windows.Forms.DataVisualization.Charting;
27using HeuristicLab.MainForm;
28
29namespace HeuristicLab.Problems.DataAnalysis.Views {
30  [View("Error Characteristics Curve")]
31  [Content(typeof(IRegressionSolution))]
32  public partial class RegressionSolutionErrorCharacteristicsCurveView : DataAnalysisSolutionEvaluationView {
33    protected const string TrainingSamples = "Training";
34    protected const string TestSamples = "Test";
35    protected const string AllSamples = "All Samples";
36
37    public RegressionSolutionErrorCharacteristicsCurveView()
38      : base() {
39      InitializeComponent();
40
41      cmbSamples.Items.Add(TrainingSamples);
42      cmbSamples.Items.Add(TestSamples);
43      cmbSamples.Items.Add(AllSamples);
44
45      cmbSamples.SelectedIndex = 0;
46
47      chart.CustomizeAllChartAreas();
48      chart.ChartAreas[0].AxisX.Title = "Absolute Error";
49      chart.ChartAreas[0].AxisX.Minimum = 0.0;
50      chart.ChartAreas[0].AxisX.Maximum = 1.0;
51      chart.ChartAreas[0].AxisX.IntervalAutoMode = IntervalAutoMode.VariableCount;
52      chart.ChartAreas[0].CursorX.Interval = 0.01;
53
54      chart.ChartAreas[0].AxisY.Title = "Ratio of Residuals";
55      chart.ChartAreas[0].AxisY.Minimum = 0.0;
56      chart.ChartAreas[0].AxisY.Maximum = 1.0;
57      chart.ChartAreas[0].AxisY.MajorGrid.Interval = 0.2;
58      chart.ChartAreas[0].CursorY.Interval = 0.01;
59    }
60
61    public new IRegressionSolution Content {
62      get { return (IRegressionSolution)base.Content; }
63      set { base.Content = value; }
64    }
65    public IRegressionProblemData ProblemData {
66      get {
67        if (Content == null) return null;
68        return Content.ProblemData;
69      }
70    }
71
72    protected override void RegisterContentEvents() {
73      base.RegisterContentEvents();
74      Content.ModelChanged += new EventHandler(Content_ModelChanged);
75      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
76    }
77    protected override void DeregisterContentEvents() {
78      base.DeregisterContentEvents();
79      Content.ModelChanged -= new EventHandler(Content_ModelChanged);
80      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
81    }
82
83    protected virtual void Content_ModelChanged(object sender, EventArgs e) {
84      if (InvokeRequired) Invoke((Action<object, EventArgs>)Content_ModelChanged, sender, e);
85      else UpdateChart();
86    }
87    protected virtual void Content_ProblemDataChanged(object sender, EventArgs e) {
88      if (InvokeRequired) Invoke((Action<object, EventArgs>)Content_ProblemDataChanged, sender, e);
89      else {
90        UpdateChart();
91      }
92    }
93    protected override void OnContentChanged() {
94      base.OnContentChanged();
95      UpdateChart();
96    }
97
98    protected virtual void UpdateChart() {
99      chart.Series.Clear();
100      chart.Annotations.Clear();
101
102      if (Content == null) return;
103      if (cmbSamples.SelectedItem.ToString() == TrainingSamples && !ProblemData.TrainingIndices.Any()) return;
104      if (cmbSamples.SelectedItem.ToString() == TestSamples && !ProblemData.TestIndices.Any()) return;
105
106      if (Content.ProblemData.TrainingIndices.Any()) {
107        var constantModel = CreateConstantModel();
108        var originalValues = GetOriginalValues().ToList();
109        var baselineEstimatedValues = GetEstimatedValues(constantModel);
110        var baselineResiduals = GetResiduals(originalValues, baselineEstimatedValues);
111
112        Series baselineSeries = new Series("Baseline");
113        baselineSeries.ChartType = SeriesChartType.FastLine;
114        UpdateSeries(baselineResiduals, baselineSeries);
115        baselineSeries.ToolTip = "Area over Curve: " + CalculateAreaOverCurve(baselineSeries);
116        baselineSeries.Tag = constantModel;
117        baselineSeries.LegendToolTip = "Double-click to open model";
118        chart.Series.Add(baselineSeries);
119      }
120
121      AddRegressionSolution(Content);
122    }
123
124    protected void AddRegressionSolution(IRegressionSolution solution) {
125      if (chart.Series.Any(s => s.Name == solution.Name)) return;
126
127      Series solutionSeries = new Series(solution.Name);
128      solutionSeries.Tag = solution;
129      solutionSeries.ChartType = SeriesChartType.FastLine;
130      var residuals = GetResiduals(GetOriginalValues(), GetEstimatedValues(solution));
131     
132      chart.ChartAreas[0].AxisX.Maximum = Math.Ceiling(residuals.Max());
133      chart.ChartAreas[0].CursorX.Interval = residuals.Min() / 100;
134
135      UpdateSeries(residuals, solutionSeries);
136
137      solutionSeries.ToolTip = "Area over Curve: " + CalculateAreaOverCurve(solutionSeries);
138      solutionSeries.LegendToolTip = "Double-click to open model";
139      chart.Series.Add(solutionSeries);
140    }
141
142    protected void UpdateSeries(List<double> residuals, Series series) {
143      series.Points.Clear();
144      residuals.Sort();
145      if (!residuals.Any() || residuals.All(double.IsNaN)) return;
146
147      series.Points.AddXY(0, 0);
148      for (int i = 0; i < residuals.Count; i++) {
149        var point = new DataPoint();
150        if (residuals[i] > chart.ChartAreas[0].AxisX.Maximum) {
151          point.XValue = chart.ChartAreas[0].AxisX.Maximum;
152          point.YValues[0] = ((double)i) / residuals.Count;
153          point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
154          series.Points.Add(point);
155          break;
156        }
157
158        point.XValue = residuals[i];
159        point.YValues[0] = ((double)i + 1) / residuals.Count;
160        point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
161        series.Points.Add(point);
162      }
163
164      if (series.Points.Last().XValue < chart.ChartAreas[0].AxisX.Maximum) {
165        var point = new DataPoint();
166        point.XValue = chart.ChartAreas[0].AxisX.Maximum;
167        point.YValues[0] = 1;
168        point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
169        series.Points.Add(point);
170      }
171    }
172
173    protected IEnumerable<double> GetOriginalValues() {
174      IEnumerable<double> originalValues;
175      switch (cmbSamples.SelectedItem.ToString()) {
176        case TrainingSamples:
177          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices);
178          break;
179        case TestSamples:
180          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndices);
181          break;
182        case AllSamples:
183          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable);
184          break;
185        default:
186          throw new NotSupportedException();
187      }
188      return originalValues;
189    }
190
191    protected IEnumerable<double> GetEstimatedValues(IRegressionSolution solution) {
192      IEnumerable<double> estimatedValues;
193      switch (cmbSamples.SelectedItem.ToString()) {
194        case TrainingSamples:
195          estimatedValues = solution.EstimatedTrainingValues;
196          break;
197        case TestSamples:
198          estimatedValues = solution.EstimatedTestValues;
199          break;
200        case AllSamples:
201          estimatedValues = solution.EstimatedValues;
202          break;
203        default:
204          throw new NotSupportedException();
205      }
206      return estimatedValues;
207    }
208
209    protected virtual List<double> GetResiduals(IEnumerable<double> originalValues, IEnumerable<double> estimatedValues) {
210      return originalValues.Zip(estimatedValues, (x, y) => Math.Abs(x - y)).ToList();
211    }
212
213    private double CalculateAreaOverCurve(Series series) {
214      if (series.Points.Count < 1) return 0;
215
216      double auc = 0.0;
217      for (int i = 1; i < series.Points.Count; i++) {
218        double width = series.Points[i].XValue - series.Points[i - 1].XValue;
219        double y1 = 1 - series.Points[i - 1].YValues[0];
220        double y2 = 1 - series.Points[i].YValues[0];
221
222        auc += (y1 + y2) * width / 2;
223      }
224
225      return auc;
226    }
227
228    protected void cmbSamples_SelectedIndexChanged(object sender, EventArgs e) {
229      if (InvokeRequired) Invoke((Action<object, EventArgs>)cmbSamples_SelectedIndexChanged, sender, e);
230      else UpdateChart();
231    }
232
233    #region Baseline
234    private void Chart_MouseDoubleClick(object sender, MouseEventArgs e) {
235      HitTestResult result = chart.HitTest(e.X, e.Y);
236      if (result.ChartElementType != ChartElementType.LegendItem) return;
237
238      MainFormManager.MainForm.ShowContent((IRegressionSolution)result.Series.Tag);
239    }
240
241    private IRegressionSolution CreateConstantModel() {
242      double averageTrainingTarget = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices).Average();
243      var model = new ConstantRegressionModel(averageTrainingTarget);
244      var solution = new ConstantRegressionSolution(model, (IRegressionProblemData)ProblemData.Clone());
245      solution.Name = "Baseline";
246      return solution;
247    }
248    #endregion
249
250    private void chart_MouseMove(object sender, MouseEventArgs e) {
251      HitTestResult result = chart.HitTest(e.X, e.Y);
252      if (result.ChartElementType == ChartElementType.LegendItem) {
253        Cursor = Cursors.Hand;
254      } else {
255        Cursor = Cursors.Default;
256      }
257    }
258  }
259}
Note: See TracBrowser for help on using the repository browser.