Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Problems.DataAnalysis.Views/3.4/Regression/RegressionSolutionErrorCharacteristicsCurveView.cs @ 18190

Last change on this file since 18190 was 18020, checked in by gkronber, 3 years ago

#3125: made a few small code simplifications

File size: 14.2 KB
RevLine 
[4417]1#region License Information
2/* HeuristicLab
[17180]3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[4417]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Windows.Forms;
26using System.Windows.Forms.DataVisualization.Charting;
[13003]27using HeuristicLab.Algorithms.DataAnalysis;
[12493]28using HeuristicLab.Common;
[4417]29using HeuristicLab.MainForm;
[13003]30using HeuristicLab.Optimization;
[17976]31using HeuristicLab.PluginInfrastructure;
[7701]32
[5829]33namespace HeuristicLab.Problems.DataAnalysis.Views {
[6642]34  [View("Error Characteristics Curve")]
35  [Content(typeof(IRegressionSolution))]
36  public partial class RegressionSolutionErrorCharacteristicsCurveView : DataAnalysisSolutionEvaluationView {
37    protected const string TrainingSamples = "Training";
38    protected const string TestSamples = "Test";
39    protected const string AllSamples = "All Samples";
[4417]40
[6642]41    public RegressionSolutionErrorCharacteristicsCurveView()
42      : base() {
[4417]43      InitializeComponent();
44
45      cmbSamples.Items.Add(TrainingSamples);
46      cmbSamples.Items.Add(TestSamples);
[6642]47      cmbSamples.Items.Add(AllSamples);
48
[4417]49      cmbSamples.SelectedIndex = 0;
50
[12493]51      residualComboBox.SelectedIndex = 0;
52
[4651]53      chart.CustomizeAllChartAreas();
[12493]54      chart.ChartAreas[0].AxisX.Title = residualComboBox.SelectedItem.ToString();
[4417]55      chart.ChartAreas[0].AxisX.Minimum = 0.0;
[12365]56      chart.ChartAreas[0].AxisX.Maximum = 0.0;
[6642]57      chart.ChartAreas[0].AxisX.IntervalAutoMode = IntervalAutoMode.VariableCount;
58      chart.ChartAreas[0].CursorX.Interval = 0.01;
59
[10500]60      chart.ChartAreas[0].AxisY.Title = "Ratio of Residuals";
[4417]61      chart.ChartAreas[0].AxisY.Minimum = 0.0;
62      chart.ChartAreas[0].AxisY.Maximum = 1.0;
63      chart.ChartAreas[0].AxisY.MajorGrid.Interval = 0.2;
[6642]64      chart.ChartAreas[0].CursorY.Interval = 0.01;
[4417]65    }
66
[13003]67    // the view holds one regression solution as content but also contains several other regression solutions for comparison
68    // the following invariants must hold
69    // (Solutions.IsEmpty && Content == null) ||
70    // (Solutions[0] == Content && Solutions.All(s => s.ProblemData.TargetVariable == Content.TargetVariable))
71
[6642]72    public new IRegressionSolution Content {
73      get { return (IRegressionSolution)base.Content; }
[4417]74      set { base.Content = value; }
75    }
[13003]76
[17976]77    private readonly List<IRegressionSolution> solutions = new List<IRegressionSolution>();
[13003]78    public IEnumerable<IRegressionSolution> Solutions {
79      get { return solutions.AsEnumerable(); }
80    }
81
[6642]82    public IRegressionProblemData ProblemData {
83      get {
84        if (Content == null) return null;
85        return Content.ProblemData;
86      }
87    }
[4417]88
89    protected override void RegisterContentEvents() {
90      base.RegisterContentEvents();
[5664]91      Content.ModelChanged += new EventHandler(Content_ModelChanged);
[4417]92      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
93    }
94    protected override void DeregisterContentEvents() {
95      base.DeregisterContentEvents();
[5664]96      Content.ModelChanged -= new EventHandler(Content_ModelChanged);
[4417]97      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
98    }
99
[6642]100    protected virtual void Content_ModelChanged(object sender, EventArgs e) {
101      if (InvokeRequired) Invoke((Action<object, EventArgs>)Content_ModelChanged, sender, e);
[13003]102      else {
103        // recalculate baseline solutions (for symbolic regression models the features used in the model might have changed)
[17976]104        solutions.Clear();
[13003]105        solutions.Add(Content); // re-add the first solution
[18020]106        solutions.AddRange(CreateBaselineSolutions());
[13003]107        UpdateChart();
108      }
[4417]109    }
[6642]110    protected virtual void Content_ProblemDataChanged(object sender, EventArgs e) {
111      if (InvokeRequired) Invoke((Action<object, EventArgs>)Content_ProblemDataChanged, sender, e);
112      else {
[13003]113        // recalculate baseline solutions
[17976]114        solutions.Clear();
[13003]115        solutions.Add(Content); // re-add the first solution
[18020]116        solutions.AddRange(CreateBaselineSolutions());
[6642]117        UpdateChart();
118      }
[4417]119    }
120    protected override void OnContentChanged() {
121      base.OnContentChanged();
[13003]122      // the content object is always stored as the first element in solutions
123      solutions.Clear();
124      ReadOnly = Content == null;
125      if (Content != null) {
126        solutions.Add(Content);
[18020]127        solutions.AddRange(CreateBaselineSolutions());
[13003]128      }
[6642]129      UpdateChart();
[4417]130    }
131
[6642]132    protected virtual void UpdateChart() {
133      chart.Series.Clear();
134      chart.Annotations.Clear();
[12642]135      chart.ChartAreas[0].AxisX.Maximum = 0.0;
136      chart.ChartAreas[0].CursorX.Interval = 0.01;
[11093]137
[6642]138      if (Content == null) return;
[11093]139      if (cmbSamples.SelectedItem.ToString() == TrainingSamples && !ProblemData.TrainingIndices.Any()) return;
140      if (cmbSamples.SelectedItem.ToString() == TestSamples && !ProblemData.TestIndices.Any()) return;
[4417]141
[13003]142      foreach (var sol in Solutions) {
143        AddSeries(sol);
[11093]144      }
[4417]145
[14255]146      chart.ChartAreas[0].AxisX.Title = string.Format("{0} ({1})", residualComboBox.SelectedItem, Content.ProblemData.TargetVariable);
[6642]147    }
[4417]148
[13003]149    protected void AddSeries(IRegressionSolution solution) {
[6642]150      if (chart.Series.Any(s => s.Name == solution.Name)) return;
[4417]151
[6642]152      Series solutionSeries = new Series(solution.Name);
153      solutionSeries.Tag = solution;
154      solutionSeries.ChartType = SeriesChartType.FastLine;
[11093]155      var residuals = GetResiduals(GetOriginalValues(), GetEstimatedValues(solution));
[12365]156
157      var maxValue = residuals.Max();
[12577]158      if (maxValue >= chart.ChartAreas[0].AxisX.Maximum) {
159        double scale = Math.Pow(10, Math.Floor(Math.Log10(maxValue)));
160        var maximum = scale * (1 + (int)(maxValue / scale));
161        chart.ChartAreas[0].AxisX.Maximum = maximum;
162        chart.ChartAreas[0].CursorX.Interval = residuals.Min() / 100;
163      }
[11093]164
165      UpdateSeries(residuals, solutionSeries);
166
[18020]167      solutionSeries.ToolTip = "Area over curve: " + CalculateAreaOverCurve(solutionSeries);
[8105]168      solutionSeries.LegendToolTip = "Double-click to open model";
[6642]169      chart.Series.Add(solutionSeries);
170    }
[5417]171
[6642]172    protected void UpdateSeries(List<double> residuals, Series series) {
173      series.Points.Clear();
174      residuals.Sort();
[6982]175      if (!residuals.Any() || residuals.All(double.IsNaN)) return;
[4417]176
[6642]177      series.Points.AddXY(0, 0);
178      for (int i = 0; i < residuals.Count; i++) {
179        var point = new DataPoint();
180        if (residuals[i] > chart.ChartAreas[0].AxisX.Maximum) {
181          point.XValue = chart.ChartAreas[0].AxisX.Maximum;
[6750]182          point.YValues[0] = ((double)i) / residuals.Count;
[6642]183          point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
184          series.Points.Add(point);
185          break;
186        }
[4417]187
[6642]188        point.XValue = residuals[i];
[6982]189        point.YValues[0] = ((double)i + 1) / residuals.Count;
[6642]190        point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
191        series.Points.Add(point);
192      }
[4417]193
[6642]194      if (series.Points.Last().XValue < chart.ChartAreas[0].AxisX.Maximum) {
195        var point = new DataPoint();
196        point.XValue = chart.ChartAreas[0].AxisX.Maximum;
197        point.YValues[0] = 1;
198        point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
199        series.Points.Add(point);
200      }
201    }
[4417]202
[6642]203    protected IEnumerable<double> GetOriginalValues() {
204      IEnumerable<double> originalValues;
205      switch (cmbSamples.SelectedItem.ToString()) {
206        case TrainingSamples:
[8139]207          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices);
[6642]208          break;
209        case TestSamples:
[8139]210          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndices);
[6642]211          break;
212        case AllSamples:
[6740]213          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable);
[6642]214          break;
215        default:
216          throw new NotSupportedException();
217      }
218      return originalValues;
219    }
[4417]220
[6642]221    protected IEnumerable<double> GetEstimatedValues(IRegressionSolution solution) {
222      IEnumerable<double> estimatedValues;
223      switch (cmbSamples.SelectedItem.ToString()) {
224        case TrainingSamples:
225          estimatedValues = solution.EstimatedTrainingValues;
226          break;
227        case TestSamples:
228          estimatedValues = solution.EstimatedTestValues;
229          break;
230        case AllSamples:
231          estimatedValues = solution.EstimatedValues;
232          break;
233        default:
234          throw new NotSupportedException();
[4417]235      }
[6642]236      return estimatedValues;
[4417]237    }
238
[6642]239    protected virtual List<double> GetResiduals(IEnumerable<double> originalValues, IEnumerable<double> estimatedValues) {
[12493]240      switch (residualComboBox.SelectedItem.ToString()) {
[17976]241        case "Absolute error":
242          return originalValues.Zip(estimatedValues, (x, y) => Math.Abs(x - y))
[15810]243            .Where(r => !double.IsNaN(r) && !double.IsInfinity(r)).ToList();
[17976]244        case "Squared error":
245          return originalValues.Zip(estimatedValues, (x, y) => (x - y) * (x - y))
[15810]246            .Where(r => !double.IsNaN(r) && !double.IsInfinity(r)).ToList();
[15789]247        case "Relative error":
248          return originalValues.Zip(estimatedValues, (x, y) => x.IsAlmost(0.0) ? -1 : Math.Abs((x - y) / x))
[15810]249            .Where(r => r > 0 && !double.IsNaN(r) && !double.IsInfinity(r)) // remove entries where the original value is 0
250            .ToList();
[12577]251        default: throw new NotSupportedException();
[12493]252      }
[4417]253    }
254
[6642]255    private double CalculateAreaOverCurve(Series series) {
[6982]256      if (series.Points.Count < 1) return 0;
[4417]257
258      double auc = 0.0;
259      for (int i = 1; i < series.Points.Count; i++) {
260        double width = series.Points[i].XValue - series.Points[i - 1].XValue;
[6642]261        double y1 = 1 - series.Points[i - 1].YValues[0];
262        double y2 = 1 - series.Points[i].YValues[0];
[4417]263
264        auc += (y1 + y2) * width / 2;
265      }
266
267      return auc;
268    }
269
[6642]270    protected void cmbSamples_SelectedIndexChanged(object sender, EventArgs e) {
271      if (InvokeRequired) Invoke((Action<object, EventArgs>)cmbSamples_SelectedIndexChanged, sender, e);
272      else UpdateChart();
[4417]273    }
[7043]274
[7700]275    private void Chart_MouseDoubleClick(object sender, MouseEventArgs e) {
[7043]276      HitTestResult result = chart.HitTest(e.X, e.Y);
277      if (result.ChartElementType != ChartElementType.LegendItem) return;
278
279      MainFormManager.MainForm.ShowContent((IRegressionSolution)result.Series.Tag);
280    }
281
[13003]282    protected virtual IEnumerable<IRegressionSolution> CreateBaselineSolutions() {
[17976]283      var constantSolution = CreateConstantSolution();
[18020]284      if (constantSolution != null) yield return constantSolution;
[17976]285
286      var linearRegressionSolution = CreateLinearSolution();
[18020]287      if (linearRegressionSolution != null) yield return linearRegressionSolution;
[13003]288    }
289
290    private IRegressionSolution CreateConstantSolution() {
[17976]291      if (!ProblemData.TrainingIndices.Any()) return null;
292
[8139]293      double averageTrainingTarget = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices).Average();
[13992]294      var model = new ConstantModel(averageTrainingTarget, ProblemData.TargetVariable);
[13100]295      var solution = model.CreateRegressionSolution(ProblemData);
[13003]296      solution.Name = "Baseline (constant)";
[7043]297      return solution;
298    }
[13003]299    private IRegressionSolution CreateLinearSolution() {
[17976]300      try {
301        var solution = LinearRegression.CreateSolution((IRegressionProblemData)ProblemData.Clone(), out _, out _);
302        solution.Name = "Baseline (linear)";
303        return solution;
304      } catch (NotSupportedException e) {
305        ErrorHandling.ShowErrorDialog("Could not create a linear regression solution.", e);
306      } catch (ArgumentException e) {
307        ErrorHandling.ShowErrorDialog("Could not create a linear regression solution.", e);
308      }
309      return null;
[13003]310    }
[7700]311
[7701]312    private void chart_MouseMove(object sender, MouseEventArgs e) {
313      HitTestResult result = chart.HitTest(e.X, e.Y);
[8102]314      if (result.ChartElementType == ChartElementType.LegendItem) {
[7701]315        Cursor = Cursors.Hand;
[8102]316      } else {
[7701]317        Cursor = Cursors.Default;
[8102]318      }
[7700]319    }
[12493]320
[13002]321    private void chart_DragDrop(object sender, DragEventArgs e) {
[13003]322      if (e.Data.GetDataPresent(HeuristicLab.Common.Constants.DragDropDataFormat)) {
323
324        var data = e.Data.GetData(HeuristicLab.Common.Constants.DragDropDataFormat);
325        var dataAsRegressionSolution = data as IRegressionSolution;
326        var dataAsResult = data as IResult;
327
328        if (dataAsRegressionSolution != null) {
329          solutions.Add((IRegressionSolution)dataAsRegressionSolution.Clone());
330        } else if (dataAsResult != null && dataAsResult.Value is IRegressionSolution) {
331          solutions.Add((IRegressionSolution)dataAsResult.Value.Clone());
332        }
333
334        UpdateChart();
[13002]335      }
336    }
337
338    private void chart_DragEnter(object sender, DragEventArgs e) {
[13003]339      e.Effect = DragDropEffects.None;
340      if (!e.Data.GetDataPresent(HeuristicLab.Common.Constants.DragDropDataFormat)) return;
341
342      var data = e.Data.GetData(HeuristicLab.Common.Constants.DragDropDataFormat);
343      var dataAsRegressionSolution = data as IRegressionSolution;
344      var dataAsResult = data as IResult;
345
346      if (!ReadOnly &&
347        (dataAsRegressionSolution != null || (dataAsResult != null && dataAsResult.Value is IRegressionSolution))) {
348        e.Effect = DragDropEffects.Copy;
349      }
[13002]350    }
351
[12493]352    private void residualComboBox_SelectedIndexChanged(object sender, EventArgs e) {
353      UpdateChart();
354    }
[4417]355  }
356}
Note: See TracBrowser for help on using the repository browser.