Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Problems.DataAnalysis.Views/3.4/Regression/RegressionSolutionErrorCharacteristicsCurveView.cs @ 16189

Last change on this file since 16189 was 15810, checked in by gkronber, 7 years ago

#2383: made some changes while reviewing

File size: 14.1 KB
RevLine 
[4417]1#region License Information
2/* HeuristicLab
[15583]3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[4417]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Windows.Forms;
26using System.Windows.Forms.DataVisualization.Charting;
[13003]27using HeuristicLab.Algorithms.DataAnalysis;
[12493]28using HeuristicLab.Common;
[4417]29using HeuristicLab.MainForm;
[13003]30using HeuristicLab.Optimization;
[7701]31
[5829]32namespace HeuristicLab.Problems.DataAnalysis.Views {
[6642]33  [View("Error Characteristics Curve")]
34  [Content(typeof(IRegressionSolution))]
35  public partial class RegressionSolutionErrorCharacteristicsCurveView : DataAnalysisSolutionEvaluationView {
36    protected const string TrainingSamples = "Training";
37    protected const string TestSamples = "Test";
38    protected const string AllSamples = "All Samples";
[4417]39
[6642]40    public RegressionSolutionErrorCharacteristicsCurveView()
41      : base() {
[4417]42      InitializeComponent();
43
44      cmbSamples.Items.Add(TrainingSamples);
45      cmbSamples.Items.Add(TestSamples);
[6642]46      cmbSamples.Items.Add(AllSamples);
47
[4417]48      cmbSamples.SelectedIndex = 0;
49
[12493]50      residualComboBox.SelectedIndex = 0;
51
[4651]52      chart.CustomizeAllChartAreas();
[12493]53      chart.ChartAreas[0].AxisX.Title = residualComboBox.SelectedItem.ToString();
[4417]54      chart.ChartAreas[0].AxisX.Minimum = 0.0;
[12365]55      chart.ChartAreas[0].AxisX.Maximum = 0.0;
[6642]56      chart.ChartAreas[0].AxisX.IntervalAutoMode = IntervalAutoMode.VariableCount;
57      chart.ChartAreas[0].CursorX.Interval = 0.01;
58
[10500]59      chart.ChartAreas[0].AxisY.Title = "Ratio of Residuals";
[4417]60      chart.ChartAreas[0].AxisY.Minimum = 0.0;
61      chart.ChartAreas[0].AxisY.Maximum = 1.0;
62      chart.ChartAreas[0].AxisY.MajorGrid.Interval = 0.2;
[6642]63      chart.ChartAreas[0].CursorY.Interval = 0.01;
[4417]64    }
65
[13003]66    // the view holds one regression solution as content but also contains several other regression solutions for comparison
67    // the following invariants must hold
68    // (Solutions.IsEmpty && Content == null) ||
69    // (Solutions[0] == Content && Solutions.All(s => s.ProblemData.TargetVariable == Content.TargetVariable))
70
[6642]71    public new IRegressionSolution Content {
72      get { return (IRegressionSolution)base.Content; }
[4417]73      set { base.Content = value; }
74    }
[13003]75
76    private readonly IList<IRegressionSolution> solutions = new List<IRegressionSolution>();
77    public IEnumerable<IRegressionSolution> Solutions {
78      get { return solutions.AsEnumerable(); }
79    }
80
[6642]81    public IRegressionProblemData ProblemData {
82      get {
83        if (Content == null) return null;
84        return Content.ProblemData;
85      }
86    }
[4417]87
88    protected override void RegisterContentEvents() {
89      base.RegisterContentEvents();
[5664]90      Content.ModelChanged += new EventHandler(Content_ModelChanged);
[4417]91      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
92    }
93    protected override void DeregisterContentEvents() {
94      base.DeregisterContentEvents();
[5664]95      Content.ModelChanged -= new EventHandler(Content_ModelChanged);
[4417]96      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
97    }
98
[6642]99    protected virtual void Content_ModelChanged(object sender, EventArgs e) {
100      if (InvokeRequired) Invoke((Action<object, EventArgs>)Content_ModelChanged, sender, e);
[13003]101      else {
102        // recalculate baseline solutions (for symbolic regression models the features used in the model might have changed)
103        solutions.Clear(); // remove all
104        solutions.Add(Content); // re-add the first solution
105        // and recalculate all other solutions
106        foreach (var sol in CreateBaselineSolutions()) {
107          solutions.Add(sol);
108        }
109        UpdateChart();
110      }
[4417]111    }
[6642]112    protected virtual void Content_ProblemDataChanged(object sender, EventArgs e) {
113      if (InvokeRequired) Invoke((Action<object, EventArgs>)Content_ProblemDataChanged, sender, e);
114      else {
[13003]115        // recalculate baseline solutions
116        solutions.Clear(); // remove all
117        solutions.Add(Content); // re-add the first solution
118        // and recalculate all other solutions
119        foreach (var sol in CreateBaselineSolutions()) {
120          solutions.Add(sol);
121        }
[6642]122        UpdateChart();
123      }
[4417]124    }
125    protected override void OnContentChanged() {
126      base.OnContentChanged();
[13003]127      // the content object is always stored as the first element in solutions
128      solutions.Clear();
129      ReadOnly = Content == null;
130      if (Content != null) {
131        // recalculate all solutions
132        solutions.Add(Content);
133        if (ProblemData.TrainingIndices.Any()) {
134          foreach (var sol in CreateBaselineSolutions())
135            solutions.Add(sol);
136          // more solutions can be added by drag&drop
137        }
138      }
[6642]139      UpdateChart();
[4417]140    }
141
[6642]142    protected virtual void UpdateChart() {
143      chart.Series.Clear();
144      chart.Annotations.Clear();
[12642]145      chart.ChartAreas[0].AxisX.Maximum = 0.0;
146      chart.ChartAreas[0].CursorX.Interval = 0.01;
[11093]147
[6642]148      if (Content == null) return;
[11093]149      if (cmbSamples.SelectedItem.ToString() == TrainingSamples && !ProblemData.TrainingIndices.Any()) return;
150      if (cmbSamples.SelectedItem.ToString() == TestSamples && !ProblemData.TestIndices.Any()) return;
[4417]151
[13003]152      foreach (var sol in Solutions) {
153        AddSeries(sol);
[11093]154      }
[4417]155
[14255]156      chart.ChartAreas[0].AxisX.Title = string.Format("{0} ({1})", residualComboBox.SelectedItem, Content.ProblemData.TargetVariable);
[6642]157    }
[4417]158
[13003]159    protected void AddSeries(IRegressionSolution solution) {
[6642]160      if (chart.Series.Any(s => s.Name == solution.Name)) return;
[4417]161
[6642]162      Series solutionSeries = new Series(solution.Name);
163      solutionSeries.Tag = solution;
164      solutionSeries.ChartType = SeriesChartType.FastLine;
[11093]165      var residuals = GetResiduals(GetOriginalValues(), GetEstimatedValues(solution));
[12365]166
167      var maxValue = residuals.Max();
[12577]168      if (maxValue >= chart.ChartAreas[0].AxisX.Maximum) {
169        double scale = Math.Pow(10, Math.Floor(Math.Log10(maxValue)));
170        var maximum = scale * (1 + (int)(maxValue / scale));
171        chart.ChartAreas[0].AxisX.Maximum = maximum;
172        chart.ChartAreas[0].CursorX.Interval = residuals.Min() / 100;
173      }
[11093]174
175      UpdateSeries(residuals, solutionSeries);
176
[6642]177      solutionSeries.ToolTip = "Area over Curve: " + CalculateAreaOverCurve(solutionSeries);
[8105]178      solutionSeries.LegendToolTip = "Double-click to open model";
[6642]179      chart.Series.Add(solutionSeries);
180    }
[5417]181
[6642]182    protected void UpdateSeries(List<double> residuals, Series series) {
183      series.Points.Clear();
184      residuals.Sort();
[6982]185      if (!residuals.Any() || residuals.All(double.IsNaN)) return;
[4417]186
[6642]187      series.Points.AddXY(0, 0);
188      for (int i = 0; i < residuals.Count; i++) {
189        var point = new DataPoint();
190        if (residuals[i] > chart.ChartAreas[0].AxisX.Maximum) {
191          point.XValue = chart.ChartAreas[0].AxisX.Maximum;
[6750]192          point.YValues[0] = ((double)i) / residuals.Count;
[6642]193          point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
194          series.Points.Add(point);
195          break;
196        }
[4417]197
[6642]198        point.XValue = residuals[i];
[6982]199        point.YValues[0] = ((double)i + 1) / residuals.Count;
[6642]200        point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
201        series.Points.Add(point);
202      }
[4417]203
[6642]204      if (series.Points.Last().XValue < chart.ChartAreas[0].AxisX.Maximum) {
205        var point = new DataPoint();
206        point.XValue = chart.ChartAreas[0].AxisX.Maximum;
207        point.YValues[0] = 1;
208        point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
209        series.Points.Add(point);
210      }
211    }
[4417]212
[6642]213    protected IEnumerable<double> GetOriginalValues() {
214      IEnumerable<double> originalValues;
215      switch (cmbSamples.SelectedItem.ToString()) {
216        case TrainingSamples:
[8139]217          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices);
[6642]218          break;
219        case TestSamples:
[8139]220          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndices);
[6642]221          break;
222        case AllSamples:
[6740]223          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable);
[6642]224          break;
225        default:
226          throw new NotSupportedException();
227      }
228      return originalValues;
229    }
[4417]230
[6642]231    protected IEnumerable<double> GetEstimatedValues(IRegressionSolution solution) {
232      IEnumerable<double> estimatedValues;
233      switch (cmbSamples.SelectedItem.ToString()) {
234        case TrainingSamples:
235          estimatedValues = solution.EstimatedTrainingValues;
236          break;
237        case TestSamples:
238          estimatedValues = solution.EstimatedTestValues;
239          break;
240        case AllSamples:
241          estimatedValues = solution.EstimatedValues;
242          break;
243        default:
244          throw new NotSupportedException();
[4417]245      }
[6642]246      return estimatedValues;
[4417]247    }
248
[6642]249    protected virtual List<double> GetResiduals(IEnumerable<double> originalValues, IEnumerable<double> estimatedValues) {
[12493]250      switch (residualComboBox.SelectedItem.ToString()) {
[15810]251        case "Absolute error": return originalValues.Zip(estimatedValues, (x, y) => Math.Abs(x - y))
252            .Where(r => !double.IsNaN(r) && !double.IsInfinity(r)).ToList();
253        case "Squared error": return originalValues.Zip(estimatedValues, (x, y) => (x - y) * (x - y))
254            .Where(r => !double.IsNaN(r) && !double.IsInfinity(r)).ToList();
[15789]255        case "Relative error":
256          return originalValues.Zip(estimatedValues, (x, y) => x.IsAlmost(0.0) ? -1 : Math.Abs((x - y) / x))
[15810]257            .Where(r => r > 0 && !double.IsNaN(r) && !double.IsInfinity(r)) // remove entries where the original value is 0
258            .ToList();
[12577]259        default: throw new NotSupportedException();
[12493]260      }
[4417]261    }
262
[6642]263    private double CalculateAreaOverCurve(Series series) {
[6982]264      if (series.Points.Count < 1) return 0;
[4417]265
266      double auc = 0.0;
267      for (int i = 1; i < series.Points.Count; i++) {
268        double width = series.Points[i].XValue - series.Points[i - 1].XValue;
[6642]269        double y1 = 1 - series.Points[i - 1].YValues[0];
270        double y2 = 1 - series.Points[i].YValues[0];
[4417]271
272        auc += (y1 + y2) * width / 2;
273      }
274
275      return auc;
276    }
277
[6642]278    protected void cmbSamples_SelectedIndexChanged(object sender, EventArgs e) {
279      if (InvokeRequired) Invoke((Action<object, EventArgs>)cmbSamples_SelectedIndexChanged, sender, e);
280      else UpdateChart();
[4417]281    }
[7043]282
[7700]283    private void Chart_MouseDoubleClick(object sender, MouseEventArgs e) {
[7043]284      HitTestResult result = chart.HitTest(e.X, e.Y);
285      if (result.ChartElementType != ChartElementType.LegendItem) return;
286
287      MainFormManager.MainForm.ShowContent((IRegressionSolution)result.Series.Tag);
288    }
289
[13003]290    protected virtual IEnumerable<IRegressionSolution> CreateBaselineSolutions() {
291      yield return CreateConstantSolution();
292      yield return CreateLinearSolution();
293    }
294
295    private IRegressionSolution CreateConstantSolution() {
[8139]296      double averageTrainingTarget = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices).Average();
[13992]297      var model = new ConstantModel(averageTrainingTarget, ProblemData.TargetVariable);
[13100]298      var solution = model.CreateRegressionSolution(ProblemData);
[13003]299      solution.Name = "Baseline (constant)";
[7043]300      return solution;
301    }
[13003]302    private IRegressionSolution CreateLinearSolution() {
303      double rmsError, cvRmsError;
304      var solution = LinearRegression.CreateLinearRegressionSolution((IRegressionProblemData)ProblemData.Clone(), out rmsError, out cvRmsError);
305      solution.Name = "Baseline (linear)";
306      return solution;
307    }
[7700]308
[7701]309    private void chart_MouseMove(object sender, MouseEventArgs e) {
310      HitTestResult result = chart.HitTest(e.X, e.Y);
[8102]311      if (result.ChartElementType == ChartElementType.LegendItem) {
[7701]312        Cursor = Cursors.Hand;
[8102]313      } else {
[7701]314        Cursor = Cursors.Default;
[8102]315      }
[7700]316    }
[12493]317
[13002]318    private void chart_DragDrop(object sender, DragEventArgs e) {
[13003]319      if (e.Data.GetDataPresent(HeuristicLab.Common.Constants.DragDropDataFormat)) {
320
321        var data = e.Data.GetData(HeuristicLab.Common.Constants.DragDropDataFormat);
322        var dataAsRegressionSolution = data as IRegressionSolution;
323        var dataAsResult = data as IResult;
324
325        if (dataAsRegressionSolution != null) {
326          solutions.Add((IRegressionSolution)dataAsRegressionSolution.Clone());
327        } else if (dataAsResult != null && dataAsResult.Value is IRegressionSolution) {
328          solutions.Add((IRegressionSolution)dataAsResult.Value.Clone());
329        }
330
331        UpdateChart();
[13002]332      }
333    }
334
335    private void chart_DragEnter(object sender, DragEventArgs e) {
[13003]336      e.Effect = DragDropEffects.None;
337      if (!e.Data.GetDataPresent(HeuristicLab.Common.Constants.DragDropDataFormat)) return;
338
339      var data = e.Data.GetData(HeuristicLab.Common.Constants.DragDropDataFormat);
340      var dataAsRegressionSolution = data as IRegressionSolution;
341      var dataAsResult = data as IResult;
342
343      if (!ReadOnly &&
344        (dataAsRegressionSolution != null || (dataAsResult != null && dataAsResult.Value is IRegressionSolution))) {
345        e.Effect = DragDropEffects.Copy;
346      }
[13002]347    }
348
[12493]349    private void residualComboBox_SelectedIndexChanged(object sender, EventArgs e) {
350      UpdateChart();
351    }
[4417]352  }
353}
Note: See TracBrowser for help on using the repository browser.