Free cookie consent management tool by TermsFeed Policy Generator

source: branches/ChangeDatasetOfRegressionModel/HeuristicLab.Problems.DataAnalysis.Views/3.4/Regression/RegressionSolutionErrorCharacteristicsCurveView.cs @ 8032

Last change on this file since 8032 was 8032, checked in by sforsten, 12 years ago

#1758:

  • merged trunk r7330:8022 into branch
  • importButton shows a message, if clicked and it can't be used
File size: 10.0 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Windows.Forms;
26using System.Windows.Forms.DataVisualization.Charting;
27using HeuristicLab.MainForm;
28using HeuristicLab.MainForm.WindowsForms;
29
30namespace HeuristicLab.Problems.DataAnalysis.Views {
31  [View("Error Characteristics Curve")]
32  [Content(typeof(IRegressionSolution))]
33  public partial class RegressionSolutionErrorCharacteristicsCurveView : DataAnalysisSolutionEvaluationView {
34    private IRegressionSolution constantModel;
35    protected const string TrainingSamples = "Training";
36    protected const string TestSamples = "Test";
37    protected const string AllSamples = "All Samples";
38
39    public RegressionSolutionErrorCharacteristicsCurveView()
40      : base() {
41      InitializeComponent();
42
43      cmbSamples.Items.Add(TrainingSamples);
44      cmbSamples.Items.Add(TestSamples);
45      cmbSamples.Items.Add(AllSamples);
46
47      cmbSamples.SelectedIndex = 0;
48
49      chart.CustomizeAllChartAreas();
50      chart.ChartAreas[0].AxisX.Title = "Absolute Error";
51      chart.ChartAreas[0].AxisX.Minimum = 0.0;
52      chart.ChartAreas[0].AxisX.Maximum = 1.0;
53      chart.ChartAreas[0].AxisX.IntervalAutoMode = IntervalAutoMode.VariableCount;
54      chart.ChartAreas[0].CursorX.Interval = 0.01;
55
56      chart.ChartAreas[0].AxisY.Title = "Number of Samples";
57      chart.ChartAreas[0].AxisY.Minimum = 0.0;
58      chart.ChartAreas[0].AxisY.Maximum = 1.0;
59      chart.ChartAreas[0].AxisY.MajorGrid.Interval = 0.2;
60      chart.ChartAreas[0].CursorY.Interval = 0.01;
61    }
62
63    public new IRegressionSolution Content {
64      get { return (IRegressionSolution)base.Content; }
65      set { base.Content = value; }
66    }
67    public IRegressionProblemData ProblemData {
68      get {
69        if (Content == null) return null;
70        return Content.ProblemData;
71      }
72    }
73
74    protected override void RegisterContentEvents() {
75      base.RegisterContentEvents();
76      Content.ModelChanged += new EventHandler(Content_ModelChanged);
77      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
78    }
79    protected override void DeregisterContentEvents() {
80      base.DeregisterContentEvents();
81      Content.ModelChanged -= new EventHandler(Content_ModelChanged);
82      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
83    }
84
85    protected virtual void Content_ModelChanged(object sender, EventArgs e) {
86      if (InvokeRequired) Invoke((Action<object, EventArgs>)Content_ModelChanged, sender, e);
87      else UpdateChart();
88    }
89    protected virtual void Content_ProblemDataChanged(object sender, EventArgs e) {
90      if (InvokeRequired) Invoke((Action<object, EventArgs>)Content_ProblemDataChanged, sender, e);
91      else {
92        UpdateChart();
93      }
94    }
95    protected override void OnContentChanged() {
96      base.OnContentChanged();
97      UpdateChart();
98    }
99
100    protected virtual void UpdateChart() {
101      chart.Series.Clear();
102      chart.Annotations.Clear();
103      if (Content == null) return;
104
105      var originalValues = GetOriginalValues().ToList();
106      constantModel = CreateConstantModel();
107      var baselineEstimatedValues = GetEstimatedValues(constantModel);
108      var baselineResiduals = GetResiduals(originalValues, baselineEstimatedValues);
109
110      baselineResiduals.Sort();
111      chart.ChartAreas[0].AxisX.Maximum = Math.Ceiling(baselineResiduals.Last());
112      chart.ChartAreas[0].CursorX.Interval = baselineResiduals.First() / 100;
113
114      Series baselineSeries = new Series("Baseline");
115      baselineSeries.ChartType = SeriesChartType.FastLine;
116      UpdateSeries(baselineResiduals, baselineSeries);
117      baselineSeries.ToolTip = "Area over Curve: " + CalculateAreaOverCurve(baselineSeries);
118      baselineSeries.Tag = constantModel;
119      chart.Series.Add(baselineSeries);
120
121      AddRegressionSolution(Content);
122    }
123
124    protected void AddRegressionSolution(IRegressionSolution solution) {
125      if (chart.Series.Any(s => s.Name == solution.Name)) return;
126
127      Series solutionSeries = new Series(solution.Name);
128      solutionSeries.Tag = solution;
129      solutionSeries.ChartType = SeriesChartType.FastLine;
130      var estimatedValues = GetResiduals(GetOriginalValues(), GetEstimatedValues(solution));
131      UpdateSeries(estimatedValues, solutionSeries);
132      solutionSeries.ToolTip = "Area over Curve: " + CalculateAreaOverCurve(solutionSeries);
133      chart.Series.Add(solutionSeries);
134    }
135
136    protected void UpdateSeries(List<double> residuals, Series series) {
137      series.Points.Clear();
138      residuals.Sort();
139      if (!residuals.Any() || residuals.All(double.IsNaN)) return;
140
141      series.Points.AddXY(0, 0);
142      for (int i = 0; i < residuals.Count; i++) {
143        var point = new DataPoint();
144        if (residuals[i] > chart.ChartAreas[0].AxisX.Maximum) {
145          point.XValue = chart.ChartAreas[0].AxisX.Maximum;
146          point.YValues[0] = ((double)i) / residuals.Count;
147          point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
148          series.Points.Add(point);
149          break;
150        }
151
152        point.XValue = residuals[i];
153        point.YValues[0] = ((double)i + 1) / residuals.Count;
154        point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
155        series.Points.Add(point);
156      }
157
158      if (series.Points.Last().XValue < chart.ChartAreas[0].AxisX.Maximum) {
159        var point = new DataPoint();
160        point.XValue = chart.ChartAreas[0].AxisX.Maximum;
161        point.YValues[0] = 1;
162        point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
163        series.Points.Add(point);
164      }
165    }
166
167    protected IEnumerable<double> GetOriginalValues() {
168      IEnumerable<double> originalValues;
169      switch (cmbSamples.SelectedItem.ToString()) {
170        case TrainingSamples:
171          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
172          break;
173        case TestSamples:
174          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
175          break;
176        case AllSamples:
177          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable);
178          break;
179        default:
180          throw new NotSupportedException();
181      }
182      return originalValues;
183    }
184
185    protected IEnumerable<double> GetEstimatedValues(IRegressionSolution solution) {
186      IEnumerable<double> estimatedValues;
187      switch (cmbSamples.SelectedItem.ToString()) {
188        case TrainingSamples:
189          estimatedValues = solution.EstimatedTrainingValues;
190          break;
191        case TestSamples:
192          estimatedValues = solution.EstimatedTestValues;
193          break;
194        case AllSamples:
195          estimatedValues = solution.EstimatedValues;
196          break;
197        default:
198          throw new NotSupportedException();
199      }
200      return estimatedValues;
201    }
202
203    protected IEnumerable<double> GetbaselineEstimatedValues(IEnumerable<double> originalValues) {
204      double averageTrainingTarget = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes).Average();
205      return Enumerable.Repeat(averageTrainingTarget, originalValues.Count());
206    }
207
208    protected virtual List<double> GetResiduals(IEnumerable<double> originalValues, IEnumerable<double> estimatedValues) {
209      return originalValues.Zip(estimatedValues, (x, y) => Math.Abs(x - y)).ToList();
210    }
211
212    private double CalculateAreaOverCurve(Series series) {
213      if (series.Points.Count < 1) return 0;
214
215      double auc = 0.0;
216      for (int i = 1; i < series.Points.Count; i++) {
217        double width = series.Points[i].XValue - series.Points[i - 1].XValue;
218        double y1 = 1 - series.Points[i - 1].YValues[0];
219        double y2 = 1 - series.Points[i].YValues[0];
220
221        auc += (y1 + y2) * width / 2;
222      }
223
224      return auc;
225    }
226
227    protected void cmbSamples_SelectedIndexChanged(object sender, EventArgs e) {
228      if (InvokeRequired) Invoke((Action<object, EventArgs>)cmbSamples_SelectedIndexChanged, sender, e);
229      else UpdateChart();
230    }
231
232    #region Baseline
233    private void Chart_MouseDoubleClick(object sender, MouseEventArgs e) {
234      HitTestResult result = chart.HitTest(e.X, e.Y);
235      if (result.ChartElementType != ChartElementType.LegendItem) return;
236
237      MainFormManager.MainForm.ShowContent((IRegressionSolution)result.Series.Tag);
238    }
239
240    private IRegressionSolution CreateConstantModel() {
241      double averageTrainingTarget = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes).Average();
242      var solution = new ConstantRegressionModel(averageTrainingTarget).CreateRegressionSolution(ProblemData);
243      solution.Name = "Baseline";
244      return solution;
245    }
246    #endregion
247
248    private void chart_MouseMove(object sender, MouseEventArgs e) {
249      HitTestResult result = chart.HitTest(e.X, e.Y);
250      if (result.ChartElementType == ChartElementType.LegendItem)
251        Cursor = Cursors.Hand;
252      else
253        Cursor = Cursors.Default;
254    }
255  }
256}
Note: See TracBrowser for help on using the repository browser.