Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Problems.DataAnalysis.Views/3.4/Regression/RegressionSolutionErrorCharacteristicsCurveView.cs @ 17990

Last change on this file since 17990 was 17976, checked in by mkommend, 4 years ago

#3125: Added error handling to ECC View and minor code improvements.

File size: 14.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Windows.Forms;
26using System.Windows.Forms.DataVisualization.Charting;
27using HeuristicLab.Algorithms.DataAnalysis;
28using HeuristicLab.Common;
29using HeuristicLab.MainForm;
30using HeuristicLab.Optimization;
31using HeuristicLab.PluginInfrastructure;
32
33namespace HeuristicLab.Problems.DataAnalysis.Views {
34  [View("Error Characteristics Curve")]
35  [Content(typeof(IRegressionSolution))]
36  public partial class RegressionSolutionErrorCharacteristicsCurveView : DataAnalysisSolutionEvaluationView {
37    protected const string TrainingSamples = "Training";
38    protected const string TestSamples = "Test";
39    protected const string AllSamples = "All Samples";
40
41    public RegressionSolutionErrorCharacteristicsCurveView()
42      : base() {
43      InitializeComponent();
44
45      cmbSamples.Items.Add(TrainingSamples);
46      cmbSamples.Items.Add(TestSamples);
47      cmbSamples.Items.Add(AllSamples);
48
49      cmbSamples.SelectedIndex = 0;
50
51      residualComboBox.SelectedIndex = 0;
52
53      chart.CustomizeAllChartAreas();
54      chart.ChartAreas[0].AxisX.Title = residualComboBox.SelectedItem.ToString();
55      chart.ChartAreas[0].AxisX.Minimum = 0.0;
56      chart.ChartAreas[0].AxisX.Maximum = 0.0;
57      chart.ChartAreas[0].AxisX.IntervalAutoMode = IntervalAutoMode.VariableCount;
58      chart.ChartAreas[0].CursorX.Interval = 0.01;
59
60      chart.ChartAreas[0].AxisY.Title = "Ratio of Residuals";
61      chart.ChartAreas[0].AxisY.Minimum = 0.0;
62      chart.ChartAreas[0].AxisY.Maximum = 1.0;
63      chart.ChartAreas[0].AxisY.MajorGrid.Interval = 0.2;
64      chart.ChartAreas[0].CursorY.Interval = 0.01;
65    }
66
67    // the view holds one regression solution as content but also contains several other regression solutions for comparison
68    // the following invariants must hold
69    // (Solutions.IsEmpty && Content == null) ||
70    // (Solutions[0] == Content && Solutions.All(s => s.ProblemData.TargetVariable == Content.TargetVariable))
71
72    public new IRegressionSolution Content {
73      get { return (IRegressionSolution)base.Content; }
74      set { base.Content = value; }
75    }
76
77    private readonly List<IRegressionSolution> solutions = new List<IRegressionSolution>();
78    public IEnumerable<IRegressionSolution> Solutions {
79      get { return solutions.AsEnumerable(); }
80    }
81
82    public IRegressionProblemData ProblemData {
83      get {
84        if (Content == null) return null;
85        return Content.ProblemData;
86      }
87    }
88
89    protected override void RegisterContentEvents() {
90      base.RegisterContentEvents();
91      Content.ModelChanged += new EventHandler(Content_ModelChanged);
92      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
93    }
94    protected override void DeregisterContentEvents() {
95      base.DeregisterContentEvents();
96      Content.ModelChanged -= new EventHandler(Content_ModelChanged);
97      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
98    }
99
100    protected virtual void Content_ModelChanged(object sender, EventArgs e) {
101      if (InvokeRequired) Invoke((Action<object, EventArgs>)Content_ModelChanged, sender, e);
102      else {
103        // recalculate baseline solutions (for symbolic regression models the features used in the model might have changed)
104        solutions.Clear();
105        solutions.Add(Content); // re-add the first solution
106        var baselineSolutions = CreateBaselineSolutions();
107        solutions.AddRange(baselineSolutions);
108        UpdateChart();
109      }
110    }
111    protected virtual void Content_ProblemDataChanged(object sender, EventArgs e) {
112      if (InvokeRequired) Invoke((Action<object, EventArgs>)Content_ProblemDataChanged, sender, e);
113      else {
114        // recalculate baseline solutions
115        solutions.Clear();
116        solutions.Add(Content); // re-add the first solution
117        var baselineSolutions = CreateBaselineSolutions();
118        solutions.AddRange(baselineSolutions);
119        UpdateChart();
120      }
121    }
122    protected override void OnContentChanged() {
123      base.OnContentChanged();
124      // the content object is always stored as the first element in solutions
125      solutions.Clear();
126      ReadOnly = Content == null;
127      if (Content != null) {
128        solutions.Add(Content);
129        var baselineSolutions = CreateBaselineSolutions();
130        solutions.AddRange(baselineSolutions);
131      }
132      UpdateChart();
133    }
134
135    protected virtual void UpdateChart() {
136      chart.Series.Clear();
137      chart.Annotations.Clear();
138      chart.ChartAreas[0].AxisX.Maximum = 0.0;
139      chart.ChartAreas[0].CursorX.Interval = 0.01;
140
141      if (Content == null) return;
142      if (cmbSamples.SelectedItem.ToString() == TrainingSamples && !ProblemData.TrainingIndices.Any()) return;
143      if (cmbSamples.SelectedItem.ToString() == TestSamples && !ProblemData.TestIndices.Any()) return;
144
145      foreach (var sol in Solutions) {
146        AddSeries(sol);
147      }
148
149      chart.ChartAreas[0].AxisX.Title = string.Format("{0} ({1})", residualComboBox.SelectedItem, Content.ProblemData.TargetVariable);
150    }
151
152    protected void AddSeries(IRegressionSolution solution) {
153      if (chart.Series.Any(s => s.Name == solution.Name)) return;
154
155      Series solutionSeries = new Series(solution.Name);
156      solutionSeries.Tag = solution;
157      solutionSeries.ChartType = SeriesChartType.FastLine;
158      var residuals = GetResiduals(GetOriginalValues(), GetEstimatedValues(solution));
159
160      var maxValue = residuals.Max();
161      if (maxValue >= chart.ChartAreas[0].AxisX.Maximum) {
162        double scale = Math.Pow(10, Math.Floor(Math.Log10(maxValue)));
163        var maximum = scale * (1 + (int)(maxValue / scale));
164        chart.ChartAreas[0].AxisX.Maximum = maximum;
165        chart.ChartAreas[0].CursorX.Interval = residuals.Min() / 100;
166      }
167
168      UpdateSeries(residuals, solutionSeries);
169
170      solutionSeries.ToolTip = "Area over Curve: " + CalculateAreaOverCurve(solutionSeries);
171      solutionSeries.LegendToolTip = "Double-click to open model";
172      chart.Series.Add(solutionSeries);
173    }
174
175    protected void UpdateSeries(List<double> residuals, Series series) {
176      series.Points.Clear();
177      residuals.Sort();
178      if (!residuals.Any() || residuals.All(double.IsNaN)) return;
179
180      series.Points.AddXY(0, 0);
181      for (int i = 0; i < residuals.Count; i++) {
182        var point = new DataPoint();
183        if (residuals[i] > chart.ChartAreas[0].AxisX.Maximum) {
184          point.XValue = chart.ChartAreas[0].AxisX.Maximum;
185          point.YValues[0] = ((double)i) / residuals.Count;
186          point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
187          series.Points.Add(point);
188          break;
189        }
190
191        point.XValue = residuals[i];
192        point.YValues[0] = ((double)i + 1) / residuals.Count;
193        point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
194        series.Points.Add(point);
195      }
196
197      if (series.Points.Last().XValue < chart.ChartAreas[0].AxisX.Maximum) {
198        var point = new DataPoint();
199        point.XValue = chart.ChartAreas[0].AxisX.Maximum;
200        point.YValues[0] = 1;
201        point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
202        series.Points.Add(point);
203      }
204    }
205
206    protected IEnumerable<double> GetOriginalValues() {
207      IEnumerable<double> originalValues;
208      switch (cmbSamples.SelectedItem.ToString()) {
209        case TrainingSamples:
210          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices);
211          break;
212        case TestSamples:
213          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndices);
214          break;
215        case AllSamples:
216          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable);
217          break;
218        default:
219          throw new NotSupportedException();
220      }
221      return originalValues;
222    }
223
224    protected IEnumerable<double> GetEstimatedValues(IRegressionSolution solution) {
225      IEnumerable<double> estimatedValues;
226      switch (cmbSamples.SelectedItem.ToString()) {
227        case TrainingSamples:
228          estimatedValues = solution.EstimatedTrainingValues;
229          break;
230        case TestSamples:
231          estimatedValues = solution.EstimatedTestValues;
232          break;
233        case AllSamples:
234          estimatedValues = solution.EstimatedValues;
235          break;
236        default:
237          throw new NotSupportedException();
238      }
239      return estimatedValues;
240    }
241
242    protected virtual List<double> GetResiduals(IEnumerable<double> originalValues, IEnumerable<double> estimatedValues) {
243      switch (residualComboBox.SelectedItem.ToString()) {
244        case "Absolute error":
245          return originalValues.Zip(estimatedValues, (x, y) => Math.Abs(x - y))
246            .Where(r => !double.IsNaN(r) && !double.IsInfinity(r)).ToList();
247        case "Squared error":
248          return originalValues.Zip(estimatedValues, (x, y) => (x - y) * (x - y))
249            .Where(r => !double.IsNaN(r) && !double.IsInfinity(r)).ToList();
250        case "Relative error":
251          return originalValues.Zip(estimatedValues, (x, y) => x.IsAlmost(0.0) ? -1 : Math.Abs((x - y) / x))
252            .Where(r => r > 0 && !double.IsNaN(r) && !double.IsInfinity(r)) // remove entries where the original value is 0
253            .ToList();
254        default: throw new NotSupportedException();
255      }
256    }
257
258    private double CalculateAreaOverCurve(Series series) {
259      if (series.Points.Count < 1) return 0;
260
261      double auc = 0.0;
262      for (int i = 1; i < series.Points.Count; i++) {
263        double width = series.Points[i].XValue - series.Points[i - 1].XValue;
264        double y1 = 1 - series.Points[i - 1].YValues[0];
265        double y2 = 1 - series.Points[i].YValues[0];
266
267        auc += (y1 + y2) * width / 2;
268      }
269
270      return auc;
271    }
272
273    protected void cmbSamples_SelectedIndexChanged(object sender, EventArgs e) {
274      if (InvokeRequired) Invoke((Action<object, EventArgs>)cmbSamples_SelectedIndexChanged, sender, e);
275      else UpdateChart();
276    }
277
278    private void Chart_MouseDoubleClick(object sender, MouseEventArgs e) {
279      HitTestResult result = chart.HitTest(e.X, e.Y);
280      if (result.ChartElementType != ChartElementType.LegendItem) return;
281
282      MainFormManager.MainForm.ShowContent((IRegressionSolution)result.Series.Tag);
283    }
284
285    protected virtual IEnumerable<IRegressionSolution> CreateBaselineSolutions() {
286      var constantSolution = CreateConstantSolution();
287      if (constantSolution != null) yield return CreateConstantSolution();
288
289      var linearRegressionSolution = CreateLinearSolution();
290      if (linearRegressionSolution != null) yield return CreateLinearSolution();
291    }
292
293    private IRegressionSolution CreateConstantSolution() {
294      if (!ProblemData.TrainingIndices.Any()) return null;
295
296      double averageTrainingTarget = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices).Average();
297      var model = new ConstantModel(averageTrainingTarget, ProblemData.TargetVariable);
298      var solution = model.CreateRegressionSolution(ProblemData);
299      solution.Name = "Baseline (constant)";
300      return solution;
301    }
302    private IRegressionSolution CreateLinearSolution() {
303      try {
304        var solution = LinearRegression.CreateSolution((IRegressionProblemData)ProblemData.Clone(), out _, out _);
305        solution.Name = "Baseline (linear)";
306        return solution;
307      } catch (NotSupportedException e) {
308        ErrorHandling.ShowErrorDialog("Could not create a linear regression solution.", e);
309      } catch (ArgumentException e) {
310        ErrorHandling.ShowErrorDialog("Could not create a linear regression solution.", e);
311      }
312      return null;
313    }
314
315    private void chart_MouseMove(object sender, MouseEventArgs e) {
316      HitTestResult result = chart.HitTest(e.X, e.Y);
317      if (result.ChartElementType == ChartElementType.LegendItem) {
318        Cursor = Cursors.Hand;
319      } else {
320        Cursor = Cursors.Default;
321      }
322    }
323
324    private void chart_DragDrop(object sender, DragEventArgs e) {
325      if (e.Data.GetDataPresent(HeuristicLab.Common.Constants.DragDropDataFormat)) {
326
327        var data = e.Data.GetData(HeuristicLab.Common.Constants.DragDropDataFormat);
328        var dataAsRegressionSolution = data as IRegressionSolution;
329        var dataAsResult = data as IResult;
330
331        if (dataAsRegressionSolution != null) {
332          solutions.Add((IRegressionSolution)dataAsRegressionSolution.Clone());
333        } else if (dataAsResult != null && dataAsResult.Value is IRegressionSolution) {
334          solutions.Add((IRegressionSolution)dataAsResult.Value.Clone());
335        }
336
337        UpdateChart();
338      }
339    }
340
341    private void chart_DragEnter(object sender, DragEventArgs e) {
342      e.Effect = DragDropEffects.None;
343      if (!e.Data.GetDataPresent(HeuristicLab.Common.Constants.DragDropDataFormat)) return;
344
345      var data = e.Data.GetData(HeuristicLab.Common.Constants.DragDropDataFormat);
346      var dataAsRegressionSolution = data as IRegressionSolution;
347      var dataAsResult = data as IResult;
348
349      if (!ReadOnly &&
350        (dataAsRegressionSolution != null || (dataAsResult != null && dataAsResult.Value is IRegressionSolution))) {
351        e.Effect = DragDropEffects.Copy;
352      }
353    }
354
355    private void residualComboBox_SelectedIndexChanged(object sender, EventArgs e) {
356      UpdateChart();
357    }
358  }
359}
Note: See TracBrowser for help on using the repository browser.