Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Problems.DataAnalysis.Views/3.4/Regression/RegressionSolutionErrorCharacteristicsCurveView.cs @ 17990

Last change on this file since 17990 was 17976, checked in by mkommend, 4 years ago

#3125: Added error handling to ECC View and minor code improvements.

File size: 14.4 KB
RevLine 
[4417]1#region License Information
2/* HeuristicLab
[17180]3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[4417]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Windows.Forms;
26using System.Windows.Forms.DataVisualization.Charting;
[13003]27using HeuristicLab.Algorithms.DataAnalysis;
[12493]28using HeuristicLab.Common;
[4417]29using HeuristicLab.MainForm;
[13003]30using HeuristicLab.Optimization;
[17976]31using HeuristicLab.PluginInfrastructure;
[7701]32
[5829]33namespace HeuristicLab.Problems.DataAnalysis.Views {
[6642]34  [View("Error Characteristics Curve")]
35  [Content(typeof(IRegressionSolution))]
36  public partial class RegressionSolutionErrorCharacteristicsCurveView : DataAnalysisSolutionEvaluationView {
37    protected const string TrainingSamples = "Training";
38    protected const string TestSamples = "Test";
39    protected const string AllSamples = "All Samples";
[4417]40
[6642]41    public RegressionSolutionErrorCharacteristicsCurveView()
42      : base() {
[4417]43      InitializeComponent();
44
45      cmbSamples.Items.Add(TrainingSamples);
46      cmbSamples.Items.Add(TestSamples);
[6642]47      cmbSamples.Items.Add(AllSamples);
48
[4417]49      cmbSamples.SelectedIndex = 0;
50
[12493]51      residualComboBox.SelectedIndex = 0;
52
[4651]53      chart.CustomizeAllChartAreas();
[12493]54      chart.ChartAreas[0].AxisX.Title = residualComboBox.SelectedItem.ToString();
[4417]55      chart.ChartAreas[0].AxisX.Minimum = 0.0;
[12365]56      chart.ChartAreas[0].AxisX.Maximum = 0.0;
[6642]57      chart.ChartAreas[0].AxisX.IntervalAutoMode = IntervalAutoMode.VariableCount;
58      chart.ChartAreas[0].CursorX.Interval = 0.01;
59
[10500]60      chart.ChartAreas[0].AxisY.Title = "Ratio of Residuals";
[4417]61      chart.ChartAreas[0].AxisY.Minimum = 0.0;
62      chart.ChartAreas[0].AxisY.Maximum = 1.0;
63      chart.ChartAreas[0].AxisY.MajorGrid.Interval = 0.2;
[6642]64      chart.ChartAreas[0].CursorY.Interval = 0.01;
[4417]65    }
66
[13003]67    // the view holds one regression solution as content but also contains several other regression solutions for comparison
68    // the following invariants must hold
69    // (Solutions.IsEmpty && Content == null) ||
70    // (Solutions[0] == Content && Solutions.All(s => s.ProblemData.TargetVariable == Content.TargetVariable))
71
[6642]72    public new IRegressionSolution Content {
73      get { return (IRegressionSolution)base.Content; }
[4417]74      set { base.Content = value; }
75    }
[13003]76
[17976]77    private readonly List<IRegressionSolution> solutions = new List<IRegressionSolution>();
[13003]78    public IEnumerable<IRegressionSolution> Solutions {
79      get { return solutions.AsEnumerable(); }
80    }
81
[6642]82    public IRegressionProblemData ProblemData {
83      get {
84        if (Content == null) return null;
85        return Content.ProblemData;
86      }
87    }
[4417]88
89    protected override void RegisterContentEvents() {
90      base.RegisterContentEvents();
[5664]91      Content.ModelChanged += new EventHandler(Content_ModelChanged);
[4417]92      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
93    }
94    protected override void DeregisterContentEvents() {
95      base.DeregisterContentEvents();
[5664]96      Content.ModelChanged -= new EventHandler(Content_ModelChanged);
[4417]97      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
98    }
99
[6642]100    protected virtual void Content_ModelChanged(object sender, EventArgs e) {
101      if (InvokeRequired) Invoke((Action<object, EventArgs>)Content_ModelChanged, sender, e);
[13003]102      else {
103        // recalculate baseline solutions (for symbolic regression models the features used in the model might have changed)
[17976]104        solutions.Clear();
[13003]105        solutions.Add(Content); // re-add the first solution
[17976]106        var baselineSolutions = CreateBaselineSolutions();
107        solutions.AddRange(baselineSolutions);
[13003]108        UpdateChart();
109      }
[4417]110    }
[6642]111    protected virtual void Content_ProblemDataChanged(object sender, EventArgs e) {
112      if (InvokeRequired) Invoke((Action<object, EventArgs>)Content_ProblemDataChanged, sender, e);
113      else {
[13003]114        // recalculate baseline solutions
[17976]115        solutions.Clear();
[13003]116        solutions.Add(Content); // re-add the first solution
[17976]117        var baselineSolutions = CreateBaselineSolutions();
118        solutions.AddRange(baselineSolutions);
[6642]119        UpdateChart();
120      }
[4417]121    }
122    protected override void OnContentChanged() {
123      base.OnContentChanged();
[13003]124      // the content object is always stored as the first element in solutions
125      solutions.Clear();
126      ReadOnly = Content == null;
127      if (Content != null) {
128        solutions.Add(Content);
[17976]129        var baselineSolutions = CreateBaselineSolutions();
130        solutions.AddRange(baselineSolutions);
[13003]131      }
[6642]132      UpdateChart();
[4417]133    }
134
[6642]135    protected virtual void UpdateChart() {
136      chart.Series.Clear();
137      chart.Annotations.Clear();
[12642]138      chart.ChartAreas[0].AxisX.Maximum = 0.0;
139      chart.ChartAreas[0].CursorX.Interval = 0.01;
[11093]140
[6642]141      if (Content == null) return;
[11093]142      if (cmbSamples.SelectedItem.ToString() == TrainingSamples && !ProblemData.TrainingIndices.Any()) return;
143      if (cmbSamples.SelectedItem.ToString() == TestSamples && !ProblemData.TestIndices.Any()) return;
[4417]144
[13003]145      foreach (var sol in Solutions) {
146        AddSeries(sol);
[11093]147      }
[4417]148
[14255]149      chart.ChartAreas[0].AxisX.Title = string.Format("{0} ({1})", residualComboBox.SelectedItem, Content.ProblemData.TargetVariable);
[6642]150    }
[4417]151
[13003]152    protected void AddSeries(IRegressionSolution solution) {
[6642]153      if (chart.Series.Any(s => s.Name == solution.Name)) return;
[4417]154
[6642]155      Series solutionSeries = new Series(solution.Name);
156      solutionSeries.Tag = solution;
157      solutionSeries.ChartType = SeriesChartType.FastLine;
[11093]158      var residuals = GetResiduals(GetOriginalValues(), GetEstimatedValues(solution));
[12365]159
160      var maxValue = residuals.Max();
[12577]161      if (maxValue >= chart.ChartAreas[0].AxisX.Maximum) {
162        double scale = Math.Pow(10, Math.Floor(Math.Log10(maxValue)));
163        var maximum = scale * (1 + (int)(maxValue / scale));
164        chart.ChartAreas[0].AxisX.Maximum = maximum;
165        chart.ChartAreas[0].CursorX.Interval = residuals.Min() / 100;
166      }
[11093]167
168      UpdateSeries(residuals, solutionSeries);
169
[6642]170      solutionSeries.ToolTip = "Area over Curve: " + CalculateAreaOverCurve(solutionSeries);
[8105]171      solutionSeries.LegendToolTip = "Double-click to open model";
[6642]172      chart.Series.Add(solutionSeries);
173    }
[5417]174
[6642]175    protected void UpdateSeries(List<double> residuals, Series series) {
176      series.Points.Clear();
177      residuals.Sort();
[6982]178      if (!residuals.Any() || residuals.All(double.IsNaN)) return;
[4417]179
[6642]180      series.Points.AddXY(0, 0);
181      for (int i = 0; i < residuals.Count; i++) {
182        var point = new DataPoint();
183        if (residuals[i] > chart.ChartAreas[0].AxisX.Maximum) {
184          point.XValue = chart.ChartAreas[0].AxisX.Maximum;
[6750]185          point.YValues[0] = ((double)i) / residuals.Count;
[6642]186          point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
187          series.Points.Add(point);
188          break;
189        }
[4417]190
[6642]191        point.XValue = residuals[i];
[6982]192        point.YValues[0] = ((double)i + 1) / residuals.Count;
[6642]193        point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
194        series.Points.Add(point);
195      }
[4417]196
[6642]197      if (series.Points.Last().XValue < chart.ChartAreas[0].AxisX.Maximum) {
198        var point = new DataPoint();
199        point.XValue = chart.ChartAreas[0].AxisX.Maximum;
200        point.YValues[0] = 1;
201        point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
202        series.Points.Add(point);
203      }
204    }
[4417]205
[6642]206    protected IEnumerable<double> GetOriginalValues() {
207      IEnumerable<double> originalValues;
208      switch (cmbSamples.SelectedItem.ToString()) {
209        case TrainingSamples:
[8139]210          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices);
[6642]211          break;
212        case TestSamples:
[8139]213          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndices);
[6642]214          break;
215        case AllSamples:
[6740]216          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable);
[6642]217          break;
218        default:
219          throw new NotSupportedException();
220      }
221      return originalValues;
222    }
[4417]223
[6642]224    protected IEnumerable<double> GetEstimatedValues(IRegressionSolution solution) {
225      IEnumerable<double> estimatedValues;
226      switch (cmbSamples.SelectedItem.ToString()) {
227        case TrainingSamples:
228          estimatedValues = solution.EstimatedTrainingValues;
229          break;
230        case TestSamples:
231          estimatedValues = solution.EstimatedTestValues;
232          break;
233        case AllSamples:
234          estimatedValues = solution.EstimatedValues;
235          break;
236        default:
237          throw new NotSupportedException();
[4417]238      }
[6642]239      return estimatedValues;
[4417]240    }
241
[6642]242    protected virtual List<double> GetResiduals(IEnumerable<double> originalValues, IEnumerable<double> estimatedValues) {
[12493]243      switch (residualComboBox.SelectedItem.ToString()) {
[17976]244        case "Absolute error":
245          return originalValues.Zip(estimatedValues, (x, y) => Math.Abs(x - y))
[15810]246            .Where(r => !double.IsNaN(r) && !double.IsInfinity(r)).ToList();
[17976]247        case "Squared error":
248          return originalValues.Zip(estimatedValues, (x, y) => (x - y) * (x - y))
[15810]249            .Where(r => !double.IsNaN(r) && !double.IsInfinity(r)).ToList();
[15789]250        case "Relative error":
251          return originalValues.Zip(estimatedValues, (x, y) => x.IsAlmost(0.0) ? -1 : Math.Abs((x - y) / x))
[15810]252            .Where(r => r > 0 && !double.IsNaN(r) && !double.IsInfinity(r)) // remove entries where the original value is 0
253            .ToList();
[12577]254        default: throw new NotSupportedException();
[12493]255      }
[4417]256    }
257
[6642]258    private double CalculateAreaOverCurve(Series series) {
[6982]259      if (series.Points.Count < 1) return 0;
[4417]260
261      double auc = 0.0;
262      for (int i = 1; i < series.Points.Count; i++) {
263        double width = series.Points[i].XValue - series.Points[i - 1].XValue;
[6642]264        double y1 = 1 - series.Points[i - 1].YValues[0];
265        double y2 = 1 - series.Points[i].YValues[0];
[4417]266
267        auc += (y1 + y2) * width / 2;
268      }
269
270      return auc;
271    }
272
[6642]273    protected void cmbSamples_SelectedIndexChanged(object sender, EventArgs e) {
274      if (InvokeRequired) Invoke((Action<object, EventArgs>)cmbSamples_SelectedIndexChanged, sender, e);
275      else UpdateChart();
[4417]276    }
[7043]277
[7700]278    private void Chart_MouseDoubleClick(object sender, MouseEventArgs e) {
[7043]279      HitTestResult result = chart.HitTest(e.X, e.Y);
280      if (result.ChartElementType != ChartElementType.LegendItem) return;
281
282      MainFormManager.MainForm.ShowContent((IRegressionSolution)result.Series.Tag);
283    }
284
[13003]285    protected virtual IEnumerable<IRegressionSolution> CreateBaselineSolutions() {
[17976]286      var constantSolution = CreateConstantSolution();
287      if (constantSolution != null) yield return CreateConstantSolution();
288
289      var linearRegressionSolution = CreateLinearSolution();
290      if (linearRegressionSolution != null) yield return CreateLinearSolution();
[13003]291    }
292
293    private IRegressionSolution CreateConstantSolution() {
[17976]294      if (!ProblemData.TrainingIndices.Any()) return null;
295
[8139]296      double averageTrainingTarget = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices).Average();
[13992]297      var model = new ConstantModel(averageTrainingTarget, ProblemData.TargetVariable);
[13100]298      var solution = model.CreateRegressionSolution(ProblemData);
[13003]299      solution.Name = "Baseline (constant)";
[7043]300      return solution;
301    }
[13003]302    private IRegressionSolution CreateLinearSolution() {
[17976]303      try {
304        var solution = LinearRegression.CreateSolution((IRegressionProblemData)ProblemData.Clone(), out _, out _);
305        solution.Name = "Baseline (linear)";
306        return solution;
307      } catch (NotSupportedException e) {
308        ErrorHandling.ShowErrorDialog("Could not create a linear regression solution.", e);
309      } catch (ArgumentException e) {
310        ErrorHandling.ShowErrorDialog("Could not create a linear regression solution.", e);
311      }
312      return null;
[13003]313    }
[7700]314
[7701]315    private void chart_MouseMove(object sender, MouseEventArgs e) {
316      HitTestResult result = chart.HitTest(e.X, e.Y);
[8102]317      if (result.ChartElementType == ChartElementType.LegendItem) {
[7701]318        Cursor = Cursors.Hand;
[8102]319      } else {
[7701]320        Cursor = Cursors.Default;
[8102]321      }
[7700]322    }
[12493]323
[13002]324    private void chart_DragDrop(object sender, DragEventArgs e) {
[13003]325      if (e.Data.GetDataPresent(HeuristicLab.Common.Constants.DragDropDataFormat)) {
326
327        var data = e.Data.GetData(HeuristicLab.Common.Constants.DragDropDataFormat);
328        var dataAsRegressionSolution = data as IRegressionSolution;
329        var dataAsResult = data as IResult;
330
331        if (dataAsRegressionSolution != null) {
332          solutions.Add((IRegressionSolution)dataAsRegressionSolution.Clone());
333        } else if (dataAsResult != null && dataAsResult.Value is IRegressionSolution) {
334          solutions.Add((IRegressionSolution)dataAsResult.Value.Clone());
335        }
336
337        UpdateChart();
[13002]338      }
339    }
340
341    private void chart_DragEnter(object sender, DragEventArgs e) {
[13003]342      e.Effect = DragDropEffects.None;
343      if (!e.Data.GetDataPresent(HeuristicLab.Common.Constants.DragDropDataFormat)) return;
344
345      var data = e.Data.GetData(HeuristicLab.Common.Constants.DragDropDataFormat);
346      var dataAsRegressionSolution = data as IRegressionSolution;
347      var dataAsResult = data as IResult;
348
349      if (!ReadOnly &&
350        (dataAsRegressionSolution != null || (dataAsResult != null && dataAsResult.Value is IRegressionSolution))) {
351        e.Effect = DragDropEffects.Copy;
352      }
[13002]353    }
354
[12493]355    private void residualComboBox_SelectedIndexChanged(object sender, EventArgs e) {
356      UpdateChart();
357    }
[4417]358  }
359}
Note: See TracBrowser for help on using the repository browser.