Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Views/3.4/Regression/RegressionSolutionErrorCharacteristicsCurveView.cs @ 7700

Last change on this file since 7700 was 7700, checked in by sforsten, 12 years ago

#1811:

  • added toggle of series
  • changed existing MouseDown event to MouseDoubleClick
  • renamed "Mean Model" to "Baseline". For consistence some variable names also have been renamed
File size: 11.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Drawing;
25using System.Linq;
26using System.Windows.Forms;
27using System.Windows.Forms.DataVisualization.Charting;
28using HeuristicLab.MainForm;
29using HeuristicLab.MainForm.WindowsForms;
30namespace HeuristicLab.Problems.DataAnalysis.Views {
31  [View("Error Characteristics Curve")]
32  [Content(typeof(IRegressionSolution))]
33  public partial class RegressionSolutionErrorCharacteristicsCurveView : DataAnalysisSolutionEvaluationView {
34    private IRegressionSolution constantModel;
35    protected const string TrainingSamples = "Training";
36    protected const string TestSamples = "Test";
37    protected const string AllSamples = "All Samples";
38
39    protected Dictionary<string, List<double>> seriesResiduals = new Dictionary<string, List<double>>();
40
41    public RegressionSolutionErrorCharacteristicsCurveView()
42      : base() {
43      InitializeComponent();
44
45      cmbSamples.Items.Add(TrainingSamples);
46      cmbSamples.Items.Add(TestSamples);
47      cmbSamples.Items.Add(AllSamples);
48
49      cmbSamples.SelectedIndex = 0;
50
51      chart.CustomizeAllChartAreas();
52      chart.ChartAreas[0].AxisX.Title = "Absolute Error";
53      chart.ChartAreas[0].AxisX.Minimum = 0.0;
54      chart.ChartAreas[0].AxisX.Maximum = 1.0;
55      chart.ChartAreas[0].AxisX.IntervalAutoMode = IntervalAutoMode.VariableCount;
56      chart.ChartAreas[0].CursorX.Interval = 0.01;
57
58      chart.ChartAreas[0].AxisY.Title = "Number of Samples";
59      chart.ChartAreas[0].AxisY.Minimum = 0.0;
60      chart.ChartAreas[0].AxisY.Maximum = 1.0;
61      chart.ChartAreas[0].AxisY.MajorGrid.Interval = 0.2;
62      chart.ChartAreas[0].CursorY.Interval = 0.01;
63    }
64
65    public new IRegressionSolution Content {
66      get { return (IRegressionSolution)base.Content; }
67      set { base.Content = value; }
68    }
69    public IRegressionProblemData ProblemData {
70      get {
71        if (Content == null) return null;
72        return Content.ProblemData;
73      }
74    }
75
76    protected override void RegisterContentEvents() {
77      base.RegisterContentEvents();
78      Content.ModelChanged += new EventHandler(Content_ModelChanged);
79      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
80    }
81    protected override void DeregisterContentEvents() {
82      base.DeregisterContentEvents();
83      Content.ModelChanged -= new EventHandler(Content_ModelChanged);
84      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
85    }
86
87    protected virtual void Content_ModelChanged(object sender, EventArgs e) {
88      if (InvokeRequired) Invoke((Action<object, EventArgs>)Content_ModelChanged, sender, e);
89      else UpdateChart();
90    }
91    protected virtual void Content_ProblemDataChanged(object sender, EventArgs e) {
92      if (InvokeRequired) Invoke((Action<object, EventArgs>)Content_ProblemDataChanged, sender, e);
93      else {
94        UpdateChart();
95      }
96    }
97    protected override void OnContentChanged() {
98      base.OnContentChanged();
99      UpdateChart();
100    }
101
102    protected virtual void UpdateChart() {
103      chart.Series.Clear();
104      seriesResiduals.Clear();
105      chart.Annotations.Clear();
106      if (Content == null) return;
107
108      var originalValues = GetOriginalValues().ToList();
109      constantModel = CreateConstantModel();
110      var baselineEstimatedValues = GetEstimatedValues(constantModel);
111      var baselineResiduals = GetResiduals(originalValues, baselineEstimatedValues);
112
113      baselineResiduals.Sort();
114      chart.ChartAreas[0].AxisX.Maximum = Math.Ceiling(baselineResiduals.Last());
115      chart.ChartAreas[0].CursorX.Interval = baselineResiduals.First() / 100;
116
117      Series baselineSeries = new Series("Baseline");
118      baselineSeries.ChartType = SeriesChartType.FastLine;
119      seriesResiduals[baselineSeries.Name] = baselineResiduals;
120      UpdateSeries(baselineResiduals, baselineSeries);
121      baselineSeries.ToolTip = "Area over Curve: " + CalculateAreaOverCurve(baselineSeries);
122      baselineSeries.Tag = constantModel;
123      chart.Series.Add(baselineSeries);
124
125      AddRegressionSolution(Content);
126    }
127
128    protected void AddRegressionSolution(IRegressionSolution solution) {
129      if (chart.Series.Any(s => s.Name == solution.Name)) return;
130
131      Series solutionSeries = new Series(solution.Name);
132      solutionSeries.Tag = solution;
133      solutionSeries.ChartType = SeriesChartType.FastLine;
134      var estimatedValues = GetResiduals(GetOriginalValues(), GetEstimatedValues(solution));
135      seriesResiduals[solutionSeries.Name] = estimatedValues;
136      UpdateSeries(estimatedValues, solutionSeries);
137      solutionSeries.ToolTip = "Area over Curve: " + CalculateAreaOverCurve(solutionSeries);
138      chart.Series.Add(solutionSeries);
139    }
140
141    protected void UpdateSeries(List<double> residuals, Series series) {
142      series.Points.Clear();
143      residuals.Sort();
144      if (!residuals.Any() || residuals.All(double.IsNaN)) return;
145
146      series.Points.AddXY(0, 0);
147      for (int i = 0; i < residuals.Count; i++) {
148        var point = new DataPoint();
149        if (residuals[i] > chart.ChartAreas[0].AxisX.Maximum) {
150          point.XValue = chart.ChartAreas[0].AxisX.Maximum;
151          point.YValues[0] = ((double)i) / residuals.Count;
152          point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
153          series.Points.Add(point);
154          break;
155        }
156
157        point.XValue = residuals[i];
158        point.YValues[0] = ((double)i + 1) / residuals.Count;
159        point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
160        series.Points.Add(point);
161      }
162
163      if (series.Points.Last().XValue < chart.ChartAreas[0].AxisX.Maximum) {
164        var point = new DataPoint();
165        point.XValue = chart.ChartAreas[0].AxisX.Maximum;
166        point.YValues[0] = 1;
167        point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
168        series.Points.Add(point);
169      }
170    }
171
172    protected IEnumerable<double> GetOriginalValues() {
173      IEnumerable<double> originalValues;
174      switch (cmbSamples.SelectedItem.ToString()) {
175        case TrainingSamples:
176          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
177          break;
178        case TestSamples:
179          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
180          break;
181        case AllSamples:
182          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable);
183          break;
184        default:
185          throw new NotSupportedException();
186      }
187      return originalValues;
188    }
189
190    protected IEnumerable<double> GetEstimatedValues(IRegressionSolution solution) {
191      IEnumerable<double> estimatedValues;
192      switch (cmbSamples.SelectedItem.ToString()) {
193        case TrainingSamples:
194          estimatedValues = solution.EstimatedTrainingValues;
195          break;
196        case TestSamples:
197          estimatedValues = solution.EstimatedTestValues;
198          break;
199        case AllSamples:
200          estimatedValues = solution.EstimatedValues;
201          break;
202        default:
203          throw new NotSupportedException();
204      }
205      return estimatedValues;
206    }
207
208    protected IEnumerable<double> GetbaselineEstimatedValues(IEnumerable<double> originalValues) {
209      double averageTrainingTarget = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes).Average();
210      return Enumerable.Repeat(averageTrainingTarget, originalValues.Count());
211    }
212
213    protected virtual List<double> GetResiduals(IEnumerable<double> originalValues, IEnumerable<double> estimatedValues) {
214      return originalValues.Zip(estimatedValues, (x, y) => Math.Abs(x - y)).ToList();
215    }
216
217    private double CalculateAreaOverCurve(Series series) {
218      if (series.Points.Count < 1) return 0;
219
220      double auc = 0.0;
221      for (int i = 1; i < series.Points.Count; i++) {
222        double width = series.Points[i].XValue - series.Points[i - 1].XValue;
223        double y1 = 1 - series.Points[i - 1].YValues[0];
224        double y2 = 1 - series.Points[i].YValues[0];
225
226        auc += (y1 + y2) * width / 2;
227      }
228
229      return auc;
230    }
231
232    protected void cmbSamples_SelectedIndexChanged(object sender, EventArgs e) {
233      if (InvokeRequired) Invoke((Action<object, EventArgs>)cmbSamples_SelectedIndexChanged, sender, e);
234      else UpdateChart();
235    }
236
237    #region events
238    private void Chart_MouseDoubleClick(object sender, MouseEventArgs e) {
239      HitTestResult result = chart.HitTest(e.X, e.Y);
240      if (result.ChartElementType != ChartElementType.LegendItem) return;
241      if (result.Series.Name != constantModel.Name) return;
242
243      MainFormManager.MainForm.ShowContent((IRegressionSolution)result.Series.Tag);
244    }
245    private void chart_MouseMove(object sender, MouseEventArgs e) {
246      HitTestResult result = chart.HitTest(e.X, e.Y);
247      if (result.ChartElementType == ChartElementType.LegendItem)
248        Cursor = Cursors.Hand;
249      else
250        Cursor = Cursors.Default;
251    }
252    private void chart_MouseDown(object sender, MouseEventArgs e) {
253      HitTestResult result = chart.HitTest(e.X, e.Y);
254      if (result.ChartElementType == ChartElementType.LegendItem) {
255        ToggleSeriesData(result.Series);
256      }
257    }
258    private void chart_CustomizeLegend(object sender, CustomizeLegendEventArgs e) {
259      foreach (LegendItem legend in e.LegendItems) {
260        legend.Cells[1].ForeColor = this.chart.Series[legend.SeriesName].Points.Count == 0 ? Color.Gray : Color.Black;
261      }
262    }
263    #endregion
264
265    private IRegressionSolution CreateConstantModel() {
266      double averageTrainingTarget = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes).Average();
267      var solution = new ConstantRegressionModel(averageTrainingTarget).CreateRegressionSolution(ProblemData);
268      solution.Name = "Baseline";
269      return solution;
270    }
271
272    private void ToggleSeriesData(Series series) {
273      if (series.Points.Count > 0) {  //checks if series is shown
274        if (this.chart.Series.Any(s => s != series && s.Points.Count > 0)) {
275          ClearPointsQuick(series.Points);
276        }
277      } else if (Content != null) {
278        List<double> residuals;
279        if (seriesResiduals.TryGetValue(series.Name, out residuals)) {
280          UpdateSeries(residuals, series);
281          chart.Legends[series.Legend].ForeColor = Color.Black;
282          chart.Refresh();
283        }
284      }
285    }
286
287    // workaround as per http://stackoverflow.com/questions/5744930/datapointcollection-clear-performance
288    private static void ClearPointsQuick(DataPointCollection points) {
289      points.SuspendUpdates();
290      while (points.Count > 0)
291        points.RemoveAt(points.Count - 1);
292      points.ResumeUpdates();
293    }
294  }
295}
Note: See TracBrowser for help on using the repository browser.