Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Hive.Azure/HeuristicLab.Problems.DataAnalysis.Views/3.4/Regression/RegressionSolutionErrorCharacteristicsCurveView.cs @ 7669

Last change on this file since 7669 was 7270, checked in by spimming, 13 years ago

#1680:

  • merged changes from trunk into branch
File size: 9.8 KB
RevLine 
[4417]1#region License Information
2/* HeuristicLab
[7270]3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[4417]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Windows.Forms;
26using System.Windows.Forms.DataVisualization.Charting;
27using HeuristicLab.MainForm;
28using HeuristicLab.MainForm.WindowsForms;
[5829]29namespace HeuristicLab.Problems.DataAnalysis.Views {
[6642]30  [View("Error Characteristics Curve")]
31  [Content(typeof(IRegressionSolution))]
32  public partial class RegressionSolutionErrorCharacteristicsCurveView : DataAnalysisSolutionEvaluationView {
[7215]33    private IRegressionSolution constantModel;
[6642]34    protected const string TrainingSamples = "Training";
35    protected const string TestSamples = "Test";
36    protected const string AllSamples = "All Samples";
[4417]37
[6642]38    public RegressionSolutionErrorCharacteristicsCurveView()
39      : base() {
[4417]40      InitializeComponent();
41
42      cmbSamples.Items.Add(TrainingSamples);
43      cmbSamples.Items.Add(TestSamples);
[6642]44      cmbSamples.Items.Add(AllSamples);
45
[4417]46      cmbSamples.SelectedIndex = 0;
47
[4651]48      chart.CustomizeAllChartAreas();
[6642]49      chart.ChartAreas[0].AxisX.Title = "Absolute Error";
[4417]50      chart.ChartAreas[0].AxisX.Minimum = 0.0;
51      chart.ChartAreas[0].AxisX.Maximum = 1.0;
[6642]52      chart.ChartAreas[0].AxisX.IntervalAutoMode = IntervalAutoMode.VariableCount;
53      chart.ChartAreas[0].CursorX.Interval = 0.01;
54
55      chart.ChartAreas[0].AxisY.Title = "Number of Samples";
[4417]56      chart.ChartAreas[0].AxisY.Minimum = 0.0;
57      chart.ChartAreas[0].AxisY.Maximum = 1.0;
58      chart.ChartAreas[0].AxisY.MajorGrid.Interval = 0.2;
[6642]59      chart.ChartAreas[0].CursorY.Interval = 0.01;
[4417]60    }
61
[6642]62    public new IRegressionSolution Content {
63      get { return (IRegressionSolution)base.Content; }
[4417]64      set { base.Content = value; }
65    }
[6642]66    public IRegressionProblemData ProblemData {
67      get {
68        if (Content == null) return null;
69        return Content.ProblemData;
70      }
71    }
[4417]72
73    protected override void RegisterContentEvents() {
74      base.RegisterContentEvents();
[5664]75      Content.ModelChanged += new EventHandler(Content_ModelChanged);
[4417]76      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
77    }
78    protected override void DeregisterContentEvents() {
79      base.DeregisterContentEvents();
[5664]80      Content.ModelChanged -= new EventHandler(Content_ModelChanged);
[4417]81      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
82    }
83
[6642]84    protected virtual void Content_ModelChanged(object sender, EventArgs e) {
85      if (InvokeRequired) Invoke((Action<object, EventArgs>)Content_ModelChanged, sender, e);
86      else UpdateChart();
[4417]87    }
[6642]88    protected virtual void Content_ProblemDataChanged(object sender, EventArgs e) {
89      if (InvokeRequired) Invoke((Action<object, EventArgs>)Content_ProblemDataChanged, sender, e);
90      else {
91        UpdateChart();
92      }
[4417]93    }
94    protected override void OnContentChanged() {
95      base.OnContentChanged();
[6642]96      UpdateChart();
[4417]97    }
98
[6642]99    protected virtual void UpdateChart() {
100      chart.Series.Clear();
101      chart.Annotations.Clear();
102      if (Content == null) return;
[4417]103
[7215]104      var originalValues = GetOriginalValues().ToList();
105      constantModel = CreateConstantModel();
106      var meanModelEstimatedValues = GetEstimatedValues(constantModel);
[6642]107      var meanModelResiduals = GetResiduals(originalValues, meanModelEstimatedValues);
[4417]108
[6642]109      meanModelResiduals.Sort();
110      chart.ChartAreas[0].AxisX.Maximum = Math.Ceiling(meanModelResiduals.Last());
111      chart.ChartAreas[0].CursorX.Interval = meanModelResiduals.First() / 100;
[4417]112
[6642]113      Series meanModelSeries = new Series("Mean Model");
114      meanModelSeries.ChartType = SeriesChartType.FastLine;
115      UpdateSeries(meanModelResiduals, meanModelSeries);
116      meanModelSeries.ToolTip = "Area over Curve: " + CalculateAreaOverCurve(meanModelSeries);
[7215]117      meanModelSeries.Tag = constantModel;
[6642]118      chart.Series.Add(meanModelSeries);
[4417]119
[6642]120      AddRegressionSolution(Content);
121    }
[4417]122
[6642]123    protected void AddRegressionSolution(IRegressionSolution solution) {
124      if (chart.Series.Any(s => s.Name == solution.Name)) return;
[4417]125
[6642]126      Series solutionSeries = new Series(solution.Name);
127      solutionSeries.Tag = solution;
128      solutionSeries.ChartType = SeriesChartType.FastLine;
129      var estimatedValues = GetResiduals(GetOriginalValues(), GetEstimatedValues(solution));
130      UpdateSeries(estimatedValues, solutionSeries);
131      solutionSeries.ToolTip = "Area over Curve: " + CalculateAreaOverCurve(solutionSeries);
132      chart.Series.Add(solutionSeries);
133    }
[5417]134
[6642]135    protected void UpdateSeries(List<double> residuals, Series series) {
136      series.Points.Clear();
137      residuals.Sort();
[6982]138      if (!residuals.Any() || residuals.All(double.IsNaN)) return;
[4417]139
[6642]140      series.Points.AddXY(0, 0);
141      for (int i = 0; i < residuals.Count; i++) {
142        var point = new DataPoint();
143        if (residuals[i] > chart.ChartAreas[0].AxisX.Maximum) {
144          point.XValue = chart.ChartAreas[0].AxisX.Maximum;
[6750]145          point.YValues[0] = ((double)i) / residuals.Count;
[6642]146          point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
147          series.Points.Add(point);
148          break;
149        }
[4417]150
[6642]151        point.XValue = residuals[i];
[6982]152        point.YValues[0] = ((double)i + 1) / residuals.Count;
[6642]153        point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
154        series.Points.Add(point);
155      }
[4417]156
[6642]157      if (series.Points.Last().XValue < chart.ChartAreas[0].AxisX.Maximum) {
158        var point = new DataPoint();
159        point.XValue = chart.ChartAreas[0].AxisX.Maximum;
160        point.YValues[0] = 1;
161        point.ToolTip = "Error: " + point.XValue + "\n" + "Samples: " + point.YValues[0];
162        series.Points.Add(point);
163      }
164    }
[4417]165
[6642]166    protected IEnumerable<double> GetOriginalValues() {
167      IEnumerable<double> originalValues;
168      switch (cmbSamples.SelectedItem.ToString()) {
169        case TrainingSamples:
[6740]170          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes);
[6642]171          break;
172        case TestSamples:
[6740]173          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TestIndizes);
[6642]174          break;
175        case AllSamples:
[6740]176          originalValues = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable);
[6642]177          break;
178        default:
179          throw new NotSupportedException();
180      }
181      return originalValues;
182    }
[4417]183
[6642]184    protected IEnumerable<double> GetEstimatedValues(IRegressionSolution solution) {
185      IEnumerable<double> estimatedValues;
186      switch (cmbSamples.SelectedItem.ToString()) {
187        case TrainingSamples:
188          estimatedValues = solution.EstimatedTrainingValues;
189          break;
190        case TestSamples:
191          estimatedValues = solution.EstimatedTestValues;
192          break;
193        case AllSamples:
194          estimatedValues = solution.EstimatedValues;
195          break;
196        default:
197          throw new NotSupportedException();
[4417]198      }
[6642]199      return estimatedValues;
[4417]200    }
201
[6642]202    protected IEnumerable<double> GetMeanModelEstimatedValues(IEnumerable<double> originalValues) {
[6740]203      double averageTrainingTarget = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes).Average();
[6642]204      return Enumerable.Repeat(averageTrainingTarget, originalValues.Count());
205    }
[4417]206
[6642]207    protected virtual List<double> GetResiduals(IEnumerable<double> originalValues, IEnumerable<double> estimatedValues) {
208      return originalValues.Zip(estimatedValues, (x, y) => Math.Abs(x - y)).ToList();
[4417]209    }
210
[6642]211    private double CalculateAreaOverCurve(Series series) {
[6982]212      if (series.Points.Count < 1) return 0;
[4417]213
214      double auc = 0.0;
215      for (int i = 1; i < series.Points.Count; i++) {
216        double width = series.Points[i].XValue - series.Points[i - 1].XValue;
[6642]217        double y1 = 1 - series.Points[i - 1].YValues[0];
218        double y2 = 1 - series.Points[i].YValues[0];
[4417]219
220        auc += (y1 + y2) * width / 2;
221      }
222
223      return auc;
224    }
225
[6642]226    protected void cmbSamples_SelectedIndexChanged(object sender, EventArgs e) {
227      if (InvokeRequired) Invoke((Action<object, EventArgs>)cmbSamples_SelectedIndexChanged, sender, e);
228      else UpdateChart();
[4417]229    }
[7215]230
231    #region Mean Model
232    private void chart_MouseDown(object sender, MouseEventArgs e) {
233      if (e.Clicks < 2) return;
234      HitTestResult result = chart.HitTest(e.X, e.Y);
235      if (result.ChartElementType != ChartElementType.LegendItem) return;
236      if (result.Series.Name != constantModel.Name) return;
237
238      MainFormManager.MainForm.ShowContent((IRegressionSolution)result.Series.Tag);
239    }
240
241    private IRegressionSolution CreateConstantModel() {
242      double averageTrainingTarget = ProblemData.Dataset.GetDoubleValues(ProblemData.TargetVariable, ProblemData.TrainingIndizes).Average();
243      var solution = new ConstantRegressionModel(averageTrainingTarget).CreateRegressionSolution(ProblemData);
244      solution.Name = "Mean Model";
245      return solution;
246    }
247    #endregion
[4417]248  }
249}
Note: See TracBrowser for help on using the repository browser.