source: branches/3087_Ceres_Integration/HeuristicLab.Problems.DataAnalysis.Views/3.4/Regression/RegressionSolutionErrorCharacteristicsCurveView.cs @ 17991

Last change on this file since 17991 was 17991, checked in by gkronber, 6 months ago

#3128: first dump of exploratory work-in-progress code to make sure the working copy is not lost.

File size: 14.5 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Windows.Forms;
26using System.Windows.Forms.DataVisualization.Charting;
27using HeuristicLab.Algorithms.DataAnalysis;
28using HeuristicLab.Common;
29using HeuristicLab.MainForm;
30using HeuristicLab.Optimization;
31using HeuristicLab.PluginInfrastructure;
32
33namespace HeuristicLab.Problems.DataAnalysis.Views {
34  [View("Error Characteristics Curve")]
35  [Content(typeof(IRegressionSolution))]
36  public partial class RegressionSolutionErrorCharacteristicsCurveView : DataAnalysisSolutionEvaluationView {
37    protected const string TrainingSamples = "Training";
38    protected const string TestSamples = "Test";
39    protected const string AllSamples = "All Samples";
40
41    public RegressionSolutionErrorCharacteristicsCurveView()
42      : base() {
43      InitializeComponent();
44
45      cmbSamples.Items.Add(TrainingSamples);
46      cmbSamples.Items.Add(TestSamples);
47      cmbSamples.Items.Add(AllSamples);
48
49      cmbSamples.SelectedIndex = 0;
50
51      residualComboBox.SelectedIndex = 0;
52
53      chart.CustomizeAllChartAreas();
54      chart.ChartAreas[0].AxisX.Title = residualComboBox.SelectedItem.ToString();
55      chart.ChartAreas[0].AxisX.Minimum = 0.0;
56      chart.ChartAreas[0].AxisX.Maximum = 0.0;
57      chart.ChartAreas[0].AxisX.IntervalAutoMode = IntervalAutoMode.VariableCount;
58      chart.ChartAreas[0].CursorX.Interval = 0.01;
59
60      chart.ChartAreas[0].AxisY.Title = "Ratio of Residuals";
61      chart.ChartAreas[0].AxisY.Minimum = 0.0;
62      chart.ChartAreas[0].AxisY.Maximum = 1.0;
63      chart.ChartAreas[0].AxisY.MajorGrid.Interval = 0.2;
64      chart.ChartAreas[0].CursorY.Interval = 0.01;
65    }
66
67    // the view holds one regression solution as content but also contains several other regression solutions for comparison
68    // the following invariants must hold
69    // (Solutions.IsEmpty && Content == null) ||
70    // (Solutions[0] == Content && Solutions.All(s => s.ProblemData.TargetVariable == Content.TargetVariable))
71
72    public new IRegressionSolution Content {
73      get { return (IRegressionSolution)base.Content; }
74      set { base.Content = value; }
75    }
76
77    private readonly List<IRegressionSolution> solutions = new List<IRegressionSolution>();
78    public IEnumerable<IRegressionSolution> Solutions {
79      get { return solutions.AsEnumerable(); }
80    }
81
82    public IRegressionProblemData ProblemData {
83      get {
84        if (Content == null) return null;
85        return Content.ProblemData;
86      }
87    }
88
89    protected override void RegisterContentEvents() {
90      base.RegisterContentEvents();
91      Content.ModelChanged += new EventHandler(Content_ModelChanged);
92      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
93    }
94    protected override void DeregisterContentEvents() {
95      base.DeregisterContentEvents();
96      Content.ModelChanged -= new EventHandler(Content_ModelChanged);
97      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
98    }
99
100    protected virtual void Content_ModelChanged(object sender, EventArgs e) {
101      if (InvokeRequired) Invoke((Action<object, EventArgs>)Content_ModelChanged, sender, e);
102      else {
103        // recalculate baseline solutions (for symbolic regression models the features used in the model might have changed)
104        solutions.Clear(); // remove all
105        solutions.Add(Content); // re-add the first solution
106        solutions.AddRange(CreateBaselineSolutions());
107        UpdateChart();
108      }
109    }
110    protected virtual void Content_ProblemDataChanged(object sender, EventArgs e) {
111      if (InvokeRequired) Invoke((Action<object, EventArgs>)Content_ProblemDataChanged, sender, e);
112      else {
113        // recalculate baseline solutions
114        solutions.Clear(); // remove all
115        solutions.Add(Content); // re-add the first solution
116        solutions.AddRange(CreateBaselineSolutions());
117        UpdateChart();
118      }
119    }
120    protected override void OnContentChanged() {
121      base.OnContentChanged();
122      // the content object is always stored as the first element in solutions
123      solutions.Clear();
124      ReadOnly = Content == null;
125      if (Content != null) {
126        // recalculate all solutions
127        solutions.Add(Content);
128        if (ProblemData.TrainingIndices.Any()) {
129          foreach (var sol in CreateBaselineSolutions())
130            solutions.Add(sol);
131          // more solutions can be added by drag&drop
132        }
133      }
134      UpdateChart();
135    }
136
137    protected virtual void UpdateChart() {
138      chart.Series.Clear();
139      chart.Annotations.Clear();
140      chart.ChartAreas[0].AxisX.Maximum = 0.0;
141      chart.ChartAreas[0].CursorX.Interval = 0.01;
142
143      if (Content == null) return;
144      if (cmbSamples.SelectedItem.ToString() == TrainingSamples && !ProblemData.TrainingIndices.Any()) return;
145      if (cmbSamples.SelectedItem.ToString() == TestSamples && !ProblemData.TestIndices.Any()) return;
146
147      foreach (var sol in Solutions) {
148        AddSeries(sol);
149      }
150
151      chart.ChartAreas[0].AxisX.Title = string.Format("{0} ({1})", residualComboBox.SelectedItem, Content.ProblemData.TargetVariable);
152    }
153
154    protected void AddSeries(IRegressionSolution solution) {
155      if (chart.Series.Any(s => s.Name == solution.Name)) return;
156
157      Series solutionSeries = new Series(solution.Name);
158      solutionSeries.Tag = solution;
159      solutionSeries.ChartType = SeriesChartType.FastLine;
160      var residuals = GetResiduals(GetOriginalValues(), GetEstimatedValues(solution));
161
162      var maxValue = residuals.Max();
163      if (maxValue >= chart.ChartAreas[0].AxisX.Maximum) {
164        double scale = Math.Pow(10, Math.Floor(Math.Log10(maxValue)));
165        var maximum = scale * (1 + (int)(maxValue / scale));
166        chart.ChartAreas[0].AxisX.Maximum = maximum;
167        chart.ChartAreas[0].CursorX.Interval = residuals.Min() / 100;
168      }
169
170      UpdateSeries(residuals, solutionSeries);
171
172      solutionSeries.ToolTip = "Area over Curve: " + CalculateAreaOverCurve(solutionSeries);
173      solutionSeries.LegendToolTip = "Double-click to open model";
174      chart.Series.Add(solutionSeries);
175    }
176
177    protected void UpdateSeries(List<double> residuals, Series series) {
178      series.Points.Clear();
179      residuals.Sort();
180      if (!residuals.Any() || residuals.All(double.IsNaN)) return;
181
182      series.Points.AddXY(0, 0);
183      for (int i = 0; i < residuals.Count; i++) {
184        var point = new DataPoint();
185        if (residuals[i] > chart.ChartAreas[0].AxisX.Maximum) {
186          point.XValue = chart.ChartAreas[0].AxisX.Maximum;
187          point.YValues[0] = ((double)i) / residuals.Count;
188          point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
189          series.Points.Add(point);
190          break;
191        }
192
193        point.XValue = residuals[i];
194        point.YValues[0] = ((double)i + 1) / residuals.Count;
195        point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
196        series.Points.Add(point);
197      }
198
199      if (series.Points.Last().XValue < chart.ChartAreas[0].AxisX.Maximum) {
200        var point = new DataPoint();
201        point.XValue = chart.ChartAreas[0].AxisX.Maximum;
202        point.YValues[0] = 1;
203        point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
204        series.Points.Add(point);
205      }
206    }
207
208    protected IEnumerable<double> GetOriginalValues() {
209      IEnumerable<double> originalValues;
210      switch (cmbSamples.SelectedItem.ToString()) {
211        case TrainingSamples:
212          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices);
213          break;
214        case TestSamples:
215          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndices);
216          break;
217        case AllSamples:
218          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable);
219          break;
220        default:
221          throw new NotSupportedException();
222      }
223      return originalValues;
224    }
225
226    protected IEnumerable<double> GetEstimatedValues(IRegressionSolution solution) {
227      IEnumerable<double> estimatedValues;
228      switch (cmbSamples.SelectedItem.ToString()) {
229        case TrainingSamples:
230          estimatedValues = solution.EstimatedTrainingValues;
231          break;
232        case TestSamples:
233          estimatedValues = solution.EstimatedTestValues;
234          break;
235        case AllSamples:
236          estimatedValues = solution.EstimatedValues;
237          break;
238        default:
239          throw new NotSupportedException();
240      }
241      return estimatedValues;
242    }
243
244    protected virtual List<double> GetResiduals(IEnumerable<double> originalValues, IEnumerable<double> estimatedValues) {
245      switch (residualComboBox.SelectedItem.ToString()) {
246        case "Absolute error":
247          return originalValues.Zip(estimatedValues, (x, y) => Math.Abs(x - y))
248            .Where(r => !double.IsNaN(r) && !double.IsInfinity(r)).ToList();
249        case "Squared error":
250          return originalValues.Zip(estimatedValues, (x, y) => (x - y) * (x - y))
251            .Where(r => !double.IsNaN(r) && !double.IsInfinity(r)).ToList();
252        case "Relative error":
253          return originalValues.Zip(estimatedValues, (x, y) => x.IsAlmost(0.0) ? -1 : Math.Abs((x - y) / x))
254            .Where(r => r > 0 && !double.IsNaN(r) && !double.IsInfinity(r)) // remove entries where the original value is 0
255            .ToList();
256        default: throw new NotSupportedException();
257      }
258    }
259
260    private double CalculateAreaOverCurve(Series series) {
261      if (series.Points.Count < 1) return 0;
262
263      double auc = 0.0;
264      for (int i = 1; i < series.Points.Count; i++) {
265        double width = series.Points[i].XValue - series.Points[i - 1].XValue;
266        double y1 = 1 - series.Points[i - 1].YValues[0];
267        double y2 = 1 - series.Points[i].YValues[0];
268
269        auc += (y1 + y2) * width / 2;
270      }
271
272      return auc;
273    }
274
275    protected void cmbSamples_SelectedIndexChanged(object sender, EventArgs e) {
276      if (InvokeRequired) Invoke((Action<object, EventArgs>)cmbSamples_SelectedIndexChanged, sender, e);
277      else UpdateChart();
278    }
279
280    private void Chart_MouseDoubleClick(object sender, MouseEventArgs e) {
281      HitTestResult result = chart.HitTest(e.X, e.Y);
282      if (result.ChartElementType != ChartElementType.LegendItem) return;
283
284      MainFormManager.MainForm.ShowContent((IRegressionSolution)result.Series.Tag);
285    }
286
287    protected virtual IEnumerable<IRegressionSolution> CreateBaselineSolutions() {
288      var constantSolution = CreateConstantSolution();
289      if (constantSolution != null) yield return CreateConstantSolution();
290
291      var linearRegressionSolution = CreateLinearSolution();
292      if (linearRegressionSolution != null) yield return CreateLinearSolution();
293    }
294
295    private IRegressionSolution CreateConstantSolution() {
296      if (!ProblemData.TrainingIndices.Any()) return null;
297
298      double averageTrainingTarget = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndices).Average();
299      var model = new ConstantModel(averageTrainingTarget, ProblemData.TargetVariable);
300      var solution = model.CreateRegressionSolution(ProblemData);
301      solution.Name = "Baseline (constant)";
302      return solution;
303    }
304    private IRegressionSolution CreateLinearSolution() {
305      try {
306        var solution = LinearRegression.CreateSolution((IRegressionProblemData)ProblemData.Clone(), out _, out _, out _);
307        solution.Name = "Baseline (linear)";
308        return solution;
309      } catch (NotSupportedException e) {
310        ErrorHandling.ShowErrorDialog("Could not create a linear regression solution.", e);
311      } catch (ArgumentException e) {
312        ErrorHandling.ShowErrorDialog("Could not create a linear regression solution.", e);
313      }
314      return null;
315    }
316
317    private void chart_MouseMove(object sender, MouseEventArgs e) {
318      HitTestResult result = chart.HitTest(e.X, e.Y);
319      if (result.ChartElementType == ChartElementType.LegendItem) {
320        Cursor = Cursors.Hand;
321      } else {
322        Cursor = Cursors.Default;
323      }
324    }
325
326    private void chart_DragDrop(object sender, DragEventArgs e) {
327      if (e.Data.GetDataPresent(HeuristicLab.Common.Constants.DragDropDataFormat)) {
328
329        var data = e.Data.GetData(HeuristicLab.Common.Constants.DragDropDataFormat);
330        var dataAsRegressionSolution = data as IRegressionSolution;
331        var dataAsResult = data as IResult;
332
333        if (dataAsRegressionSolution != null) {
334          solutions.Add((IRegressionSolution)dataAsRegressionSolution.Clone());
335        } else if (dataAsResult != null && dataAsResult.Value is IRegressionSolution) {
336          solutions.Add((IRegressionSolution)dataAsResult.Value.Clone());
337        }
338
339        UpdateChart();
340      }
341    }
342
343    private void chart_DragEnter(object sender, DragEventArgs e) {
344      e.Effect = DragDropEffects.None;
345      if (!e.Data.GetDataPresent(HeuristicLab.Common.Constants.DragDropDataFormat)) return;
346
347      var data = e.Data.GetData(HeuristicLab.Common.Constants.DragDropDataFormat);
348      var dataAsRegressionSolution = data as IRegressionSolution;
349      var dataAsResult = data as IResult;
350
351      if (!ReadOnly &&
352        (dataAsRegressionSolution != null || (dataAsResult != null && dataAsResult.Value is IRegressionSolution))) {
353        e.Effect = DragDropEffects.Copy;
354      }
355    }
356
357    private void residualComboBox_SelectedIndexChanged(object sender, EventArgs e) {
358      UpdateChart();
359    }
360  }
361}
Note: See TracBrowser for help on using the repository browser.