Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Problems.DataAnalysis.Views/3.4/Regression/RegressionSolutionErrorCharacteristicsCurveView.cs @ 18060

Last change on this file since 18060 was 18020, checked in by gkronber, 3 years ago

#3125: made a few small code simplifications

File size: 14.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Windows.Forms;
26using System.Windows.Forms.DataVisualization.Charting;
27using HeuristicLab.Algorithms.DataAnalysis;
28using HeuristicLab.Common;
29using HeuristicLab.MainForm;
30using HeuristicLab.Optimization;
31using HeuristicLab.PluginInfrastructure;
32
33namespace HeuristicLab.Problems.DataAnalysis.Views {
34  [View("Error Characteristics Curve")]
35  [Content(typeof(IRegressionSolution))]
36  public partial class RegressionSolutionErrorCharacteristicsCurveView : DataAnalysisSolutionEvaluationView {
37    protected const string TrainingSamples = "Training";
38    protected const string TestSamples = "Test";
39    protected const string AllSamples = "All Samples";
40
41    public RegressionSolutionErrorCharacteristicsCurveView()
42      : base() {
43      InitializeComponent();
44
45      cmbSamples.Items.Add(TrainingSamples);
46      cmbSamples.Items.Add(TestSamples);
47      cmbSamples.Items.Add(AllSamples);
48
49      cmbSamples.SelectedIndex = 0;
50
51      residualComboBox.SelectedIndex = 0;
52
53      chart.CustomizeAllChartAreas();
54      chart.ChartAreas[0].AxisX.Title = residualComboBox.SelectedItem.ToString();
55      chart.ChartAreas[0].AxisX.Minimum = 0.0;
56      chart.ChartAreas[0].AxisX.Maximum = 0.0;
57      chart.ChartAreas[0].AxisX.IntervalAutoMode = IntervalAutoMode.VariableCount;
58      chart.ChartAreas[0].CursorX.Interval = 0.01;
59
60      chart.ChartAreas[0].AxisY.Title = "Ratio of Residuals";
61      chart.ChartAreas[0].AxisY.Minimum = 0.0;
62      chart.ChartAreas[0].AxisY.Maximum = 1.0;
63      chart.ChartAreas[0].AxisY.MajorGrid.Interval = 0.2;
64      chart.ChartAreas[0].CursorY.Interval = 0.01;
65    }
66
67    // the view holds one regression solution as content but also contains several other regression solutions for comparison
68    // the following invariants must hold
69    // (Solutions.IsEmpty && Content == null) ||
70    // (Solutions[0] == Content && Solutions.All(s => s.ProblemData.TargetVariable == Content.TargetVariable))
71
72    public new IRegressionSolution Content {
73      get { return (IRegressionSolution)base.Content; }
74      set { base.Content = value; }
75    }
76
77    private readonly List<IRegressionSolution> solutions = new List<IRegressionSolution>();
78    public IEnumerable<IRegressionSolution> Solutions {
79      get { return solutions.AsEnumerable(); }
80    }
81
82    public IRegressionProblemData ProblemData {
83      get {
84        if (Content == null) return null;
85        return Content.ProblemData;
86      }
87    }
88
89    protected override void RegisterContentEvents() {
90      base.RegisterContentEvents();
91      Content.ModelChanged += new EventHandler(Content_ModelChanged);
92      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
93    }
94    protected override void DeregisterContentEvents() {
95      base.DeregisterContentEvents();
96      Content.ModelChanged -= new EventHandler(Content_ModelChanged);
97      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
98    }
99
100    protected virtual void Content_ModelChanged(object sender, EventArgs e) {
101      if (InvokeRequired) Invoke((Action<object, EventArgs>)Content_ModelChanged, sender, e);
102      else {
103        // recalculate baseline solutions (for symbolic regression models the features used in the model might have changed)
104        solutions.Clear();
105        solutions.Add(Content); // re-add the first solution
106        solutions.AddRange(CreateBaselineSolutions());
107        UpdateChart();
108      }
109    }
110    protected virtual void Content_ProblemDataChanged(object sender, EventArgs e) {
111      if (InvokeRequired) Invoke((Action<object, EventArgs>)Content_ProblemDataChanged, sender, e);
112      else {
113        // recalculate baseline solutions
114        solutions.Clear();
115        solutions.Add(Content); // re-add the first solution
116        solutions.AddRange(CreateBaselineSolutions());
117        UpdateChart();
118      }
119    }
120    protected override void OnContentChanged() {
121      base.OnContentChanged();
122      // the content object is always stored as the first element in solutions
123      solutions.Clear();
124      ReadOnly = Content == null;
125      if (Content != null) {
126        solutions.Add(Content);
127        solutions.AddRange(CreateBaselineSolutions());
128      }
129      UpdateChart();
130    }
131
132    protected virtual void UpdateChart() {
133      chart.Series.Clear();
134      chart.Annotations.Clear();
135      chart.ChartAreas[0].AxisX.Maximum = 0.0;
136      chart.ChartAreas[0].CursorX.Interval = 0.01;
137
138      if (Content == null) return;
139      if (cmbSamples.SelectedItem.ToString() == TrainingSamples && !ProblemData.TrainingIndices.Any()) return;
140      if (cmbSamples.SelectedItem.ToString() == TestSamples && !ProblemData.TestIndices.Any()) return;
141
142      foreach (var sol in Solutions) {
143        AddSeries(sol);
144      }
145
146      chart.ChartAreas[0].AxisX.Title = string.Format("{0} ({1})", residualComboBox.SelectedItem, Content.ProblemData.TargetVariable);
147    }
148
149    protected void AddSeries(IRegressionSolution solution) {
150      if (chart.Series.Any(s => s.Name == solution.Name)) return;
151
152      Series solutionSeries = new Series(solution.Name);
153      solutionSeries.Tag = solution;
154      solutionSeries.ChartType = SeriesChartType.FastLine;
155      var residuals = GetResiduals(GetOriginalValues(), GetEstimatedValues(solution));
156
157      var maxValue = residuals.Max();
158      if (maxValue >= chart.ChartAreas[0].AxisX.Maximum) {
159        double scale = Math.Pow(10, Math.Floor(Math.Log10(maxValue)));
160        var maximum = scale * (1 + (int)(maxValue / scale));
161        chart.ChartAreas[0].AxisX.Maximum = maximum;
162        chart.ChartAreas[0].CursorX.Interval = residuals.Min() / 100;
163      }
164
165      UpdateSeries(residuals, solutionSeries);
166
167      solutionSeries.ToolTip = "Area over curve: " + CalculateAreaOverCurve(solutionSeries);
168      solutionSeries.LegendToolTip = "Double-click to open model";
169      chart.Series.Add(solutionSeries);
170    }
171
172    protected void UpdateSeries(List<double> residuals, Series series) {
173      series.Points.Clear();
174      residuals.Sort();
175      if (!residuals.Any() || residuals.All(double.IsNaN)) return;
176
177      series.Points.AddXY(0, 0);
178      for (int i = 0; i < residuals.Count; i++) {
179        var point = new DataPoint();
180        if (residuals[i] > chart.ChartAreas[0].AxisX.Maximum) {
181          point.XValue = chart.ChartAreas[0].AxisX.Maximum;
182          point.YValues[0] = ((double)i) / residuals.Count;
183          point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
184          series.Points.Add(point);
185          break;
186        }
187
188        point.XValue = residuals[i];
189        point.YValues[0] = ((double)i + 1) / residuals.Count;
190        point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
191        series.Points.Add(point);
192      }
193
194      if (series.Points.Last().XValue < chart.ChartAreas[0].AxisX.Maximum) {
195        var point = new DataPoint();
196        point.XValue = chart.ChartAreas[0].AxisX.Maximum;
197        point.YValues[0] = 1;
198        point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
199        series.Points.Add(point);
200      }
201    }
202
203    protected IEnumerable<double> GetOriginalValues() {
204      IEnumerable<double> originalValues;
205      switch (cmbSamples.SelectedItem.ToString()) {
206        case TrainingSamples:
207          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices);
208          break;
209        case TestSamples:
210          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndices);
211          break;
212        case AllSamples:
213          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable);
214          break;
215        default:
216          throw new NotSupportedException();
217      }
218      return originalValues;
219    }
220
221    protected IEnumerable<double> GetEstimatedValues(IRegressionSolution solution) {
222      IEnumerable<double> estimatedValues;
223      switch (cmbSamples.SelectedItem.ToString()) {
224        case TrainingSamples:
225          estimatedValues = solution.EstimatedTrainingValues;
226          break;
227        case TestSamples:
228          estimatedValues = solution.EstimatedTestValues;
229          break;
230        case AllSamples:
231          estimatedValues = solution.EstimatedValues;
232          break;
233        default:
234          throw new NotSupportedException();
235      }
236      return estimatedValues;
237    }
238
239    protected virtual List<double> GetResiduals(IEnumerable<double> originalValues, IEnumerable<double> estimatedValues) {
240      switch (residualComboBox.SelectedItem.ToString()) {
241        case "Absolute error":
242          return originalValues.Zip(estimatedValues, (x, y) => Math.Abs(x - y))
243            .Where(r => !double.IsNaN(r) && !double.IsInfinity(r)).ToList();
244        case "Squared error":
245          return originalValues.Zip(estimatedValues, (x, y) => (x - y) * (x - y))
246            .Where(r => !double.IsNaN(r) && !double.IsInfinity(r)).ToList();
247        case "Relative error":
248          return originalValues.Zip(estimatedValues, (x, y) => x.IsAlmost(0.0) ? -1 : Math.Abs((x - y) / x))
249            .Where(r => r > 0 && !double.IsNaN(r) && !double.IsInfinity(r)) // remove entries where the original value is 0
250            .ToList();
251        default: throw new NotSupportedException();
252      }
253    }
254
255    private double CalculateAreaOverCurve(Series series) {
256      if (series.Points.Count < 1) return 0;
257
258      double auc = 0.0;
259      for (int i = 1; i < series.Points.Count; i++) {
260        double width = series.Points[i].XValue - series.Points[i - 1].XValue;
261        double y1 = 1 - series.Points[i - 1].YValues[0];
262        double y2 = 1 - series.Points[i].YValues[0];
263
264        auc += (y1 + y2) * width / 2;
265      }
266
267      return auc;
268    }
269
270    protected void cmbSamples_SelectedIndexChanged(object sender, EventArgs e) {
271      if (InvokeRequired) Invoke((Action<object, EventArgs>)cmbSamples_SelectedIndexChanged, sender, e);
272      else UpdateChart();
273    }
274
275    private void Chart_MouseDoubleClick(object sender, MouseEventArgs e) {
276      HitTestResult result = chart.HitTest(e.X, e.Y);
277      if (result.ChartElementType != ChartElementType.LegendItem) return;
278
279      MainFormManager.MainForm.ShowContent((IRegressionSolution)result.Series.Tag);
280    }
281
282    protected virtual IEnumerable<IRegressionSolution> CreateBaselineSolutions() {
283      var constantSolution = CreateConstantSolution();
284      if (constantSolution != null) yield return constantSolution;
285
286      var linearRegressionSolution = CreateLinearSolution();
287      if (linearRegressionSolution != null) yield return linearRegressionSolution;
288    }
289
290    private IRegressionSolution CreateConstantSolution() {
291      if (!ProblemData.TrainingIndices.Any()) return null;
292
293      double averageTrainingTarget = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices).Average();
294      var model = new ConstantModel(averageTrainingTarget, ProblemData.TargetVariable);
295      var solution = model.CreateRegressionSolution(ProblemData);
296      solution.Name = "Baseline (constant)";
297      return solution;
298    }
299    private IRegressionSolution CreateLinearSolution() {
300      try {
301        var solution = LinearRegression.CreateSolution((IRegressionProblemData)ProblemData.Clone(), out _, out _);
302        solution.Name = "Baseline (linear)";
303        return solution;
304      } catch (NotSupportedException e) {
305        ErrorHandling.ShowErrorDialog("Could not create a linear regression solution.", e);
306      } catch (ArgumentException e) {
307        ErrorHandling.ShowErrorDialog("Could not create a linear regression solution.", e);
308      }
309      return null;
310    }
311
312    private void chart_MouseMove(object sender, MouseEventArgs e) {
313      HitTestResult result = chart.HitTest(e.X, e.Y);
314      if (result.ChartElementType == ChartElementType.LegendItem) {
315        Cursor = Cursors.Hand;
316      } else {
317        Cursor = Cursors.Default;
318      }
319    }
320
321    private void chart_DragDrop(object sender, DragEventArgs e) {
322      if (e.Data.GetDataPresent(HeuristicLab.Common.Constants.DragDropDataFormat)) {
323
324        var data = e.Data.GetData(HeuristicLab.Common.Constants.DragDropDataFormat);
325        var dataAsRegressionSolution = data as IRegressionSolution;
326        var dataAsResult = data as IResult;
327
328        if (dataAsRegressionSolution != null) {
329          solutions.Add((IRegressionSolution)dataAsRegressionSolution.Clone());
330        } else if (dataAsResult != null && dataAsResult.Value is IRegressionSolution) {
331          solutions.Add((IRegressionSolution)dataAsResult.Value.Clone());
332        }
333
334        UpdateChart();
335      }
336    }
337
338    private void chart_DragEnter(object sender, DragEventArgs e) {
339      e.Effect = DragDropEffects.None;
340      if (!e.Data.GetDataPresent(HeuristicLab.Common.Constants.DragDropDataFormat)) return;
341
342      var data = e.Data.GetData(HeuristicLab.Common.Constants.DragDropDataFormat);
343      var dataAsRegressionSolution = data as IRegressionSolution;
344      var dataAsResult = data as IResult;
345
346      if (!ReadOnly &&
347        (dataAsRegressionSolution != null || (dataAsResult != null && dataAsResult.Value is IRegressionSolution))) {
348        e.Effect = DragDropEffects.Copy;
349      }
350    }
351
352    private void residualComboBox_SelectedIndexChanged(object sender, EventArgs e) {
353      UpdateChart();
354    }
355  }
356}
Note: See TracBrowser for help on using the repository browser.