Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis.Views/3.4/GaussianProcessRegressionSolutionLineChartView.cs @ 13897

Last change on this file since 13897 was 13121, checked in by gkronber, 9 years ago

#2502: calculate the variance for the noisy test data instead V(y*) instead of the variance for the posterior GP function V(f*)

File size: 15.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21using System;
22using System.Collections.Generic;
23using System.Drawing;
24using System.Linq;
25using System.Windows.Forms;
26using System.Windows.Forms.DataVisualization.Charting;
27using HeuristicLab.MainForm;
28using HeuristicLab.Problems.DataAnalysis.Views;
29
30namespace HeuristicLab.Algorithms.DataAnalysis.Views {
31  [View("Line Chart (95% confidence interval)")]
32  [Content(typeof(GaussianProcessRegressionSolution))]
33  public partial class GaussianProcessRegressionSolutionLineChartView : DataAnalysisSolutionEvaluationView {
34    private const string TARGETVARIABLE_SERIES_NAME = "Target Variable";
35    private const string ESTIMATEDVALUES_TRAINING_SERIES_NAME = "Estimated Values (training)";
36    private const string ESTIMATEDVALUES_TEST_SERIES_NAME = "Estimated Values (test)";
37    private const string ESTIMATEDVALUES_ALL_SERIES_NAME = "Estimated Values (all samples)";
38
39    public new GaussianProcessRegressionSolution Content {
40      get { return (GaussianProcessRegressionSolution)base.Content; }
41      set { base.Content = value; }
42    }
43
44    public GaussianProcessRegressionSolutionLineChartView()
45      : base() {
46      InitializeComponent();
47      //configure axis
48      this.chart.CustomizeAllChartAreas();
49      this.chart.ChartAreas[0].CursorX.IsUserSelectionEnabled = true;
50      this.chart.ChartAreas[0].AxisX.ScaleView.Zoomable = true;
51      this.chart.ChartAreas[0].AxisX.IsStartedFromZero = true;
52      this.chart.ChartAreas[0].CursorX.Interval = 1;
53
54      this.chart.ChartAreas[0].CursorY.IsUserSelectionEnabled = true;
55      this.chart.ChartAreas[0].AxisY.ScaleView.Zoomable = true;
56      this.chart.ChartAreas[0].CursorY.Interval = 0;
57    }
58
59    private void RedrawChart() {
60      this.chart.Series.Clear();
61      if (Content != null) {
62        this.chart.ChartAreas[0].AxisX.Minimum = 0;
63        this.chart.ChartAreas[0].AxisX.Maximum = Content.ProblemData.Dataset.Rows - 1;
64
65        // training series
66        this.chart.Series.Add(ESTIMATEDVALUES_TRAINING_SERIES_NAME);
67        this.chart.Series[ESTIMATEDVALUES_TRAINING_SERIES_NAME].LegendText = ESTIMATEDVALUES_TRAINING_SERIES_NAME;
68        this.chart.Series[ESTIMATEDVALUES_TRAINING_SERIES_NAME].ChartType = SeriesChartType.Range;
69        this.chart.Series[ESTIMATEDVALUES_TRAINING_SERIES_NAME].EmptyPointStyle.Color = this.chart.Series[ESTIMATEDVALUES_TRAINING_SERIES_NAME].Color;
70        var mean = Content.EstimatedTrainingValues.ToArray();
71        var s2 = Content.EstimatedTrainingVariance.ToArray();
72        var lower = mean.Zip(s2, GetLowerConfBound).ToArray();
73        var upper = mean.Zip(s2, GetUpperConfBound).ToArray();
74        this.chart.Series[ESTIMATEDVALUES_TRAINING_SERIES_NAME].Points.DataBindXY(Content.ProblemData.TrainingIndices.ToArray(), lower, upper);
75        this.InsertEmptyPoints(this.chart.Series[ESTIMATEDVALUES_TRAINING_SERIES_NAME]);
76        this.chart.Series[ESTIMATEDVALUES_TRAINING_SERIES_NAME].Tag = Content;
77
78        // test series
79        this.chart.Series.Add(ESTIMATEDVALUES_TEST_SERIES_NAME);
80        this.chart.Series[ESTIMATEDVALUES_TEST_SERIES_NAME].LegendText = ESTIMATEDVALUES_TEST_SERIES_NAME;
81        this.chart.Series[ESTIMATEDVALUES_TEST_SERIES_NAME].ChartType = SeriesChartType.Range;
82
83        mean = Content.EstimatedTestValues.ToArray();
84        s2 = Content.EstimatedTestVariance.ToArray();
85        lower = mean.Zip(s2, GetLowerConfBound).ToArray();
86        upper = mean.Zip(s2, GetUpperConfBound).ToArray();
87        this.chart.Series[ESTIMATEDVALUES_TEST_SERIES_NAME].Points.DataBindXY(Content.ProblemData.TestIndices.ToArray(), lower, upper);
88        this.InsertEmptyPoints(this.chart.Series[ESTIMATEDVALUES_TEST_SERIES_NAME]);
89        this.chart.Series[ESTIMATEDVALUES_TEST_SERIES_NAME].Tag = Content;
90
91        // series of remaining points
92        int[] allIndices = Enumerable.Range(0, Content.ProblemData.Dataset.Rows).Except(Content.ProblemData.TrainingIndices).Except(Content.ProblemData.TestIndices).ToArray();
93        mean = Content.EstimatedValues.ToArray();
94        s2 = Content.EstimatedVariance.ToArray();
95        lower = mean.Zip(s2, GetLowerConfBound).ToArray();
96        upper = mean.Zip(s2, GetUpperConfBound).ToArray();
97        List<double> allLower = allIndices.Select(index => lower[index]).ToList();
98        List<double> allUpper = allIndices.Select(index => upper[index]).ToList();
99        this.chart.Series.Add(ESTIMATEDVALUES_ALL_SERIES_NAME);
100        this.chart.Series[ESTIMATEDVALUES_ALL_SERIES_NAME].LegendText = ESTIMATEDVALUES_ALL_SERIES_NAME;
101        this.chart.Series[ESTIMATEDVALUES_ALL_SERIES_NAME].ChartType = SeriesChartType.Range;
102        if (allIndices.Count() > 0) {
103          this.chart.Series[ESTIMATEDVALUES_ALL_SERIES_NAME].Points.DataBindXY(allIndices, allLower, allUpper);
104          this.InsertEmptyPoints(this.chart.Series[ESTIMATEDVALUES_ALL_SERIES_NAME]);
105        }
106        this.chart.Series[ESTIMATEDVALUES_ALL_SERIES_NAME].Tag = Content;
107
108        // target
109        this.chart.Series.Add(TARGETVARIABLE_SERIES_NAME);
110        this.chart.Series[TARGETVARIABLE_SERIES_NAME].LegendText = Content.ProblemData.TargetVariable;
111        this.chart.Series[TARGETVARIABLE_SERIES_NAME].ChartType = SeriesChartType.FastLine;
112        this.chart.Series[TARGETVARIABLE_SERIES_NAME].Points.DataBindXY(Enumerable.Range(0, Content.ProblemData.Dataset.Rows).ToArray(),
113          Content.ProblemData.Dataset.GetDoubleValues(Content.ProblemData.TargetVariable).ToArray());
114
115        this.ToggleSeriesData(this.chart.Series[ESTIMATEDVALUES_ALL_SERIES_NAME]);
116
117
118        // the series have been added in different order than in the normal line chart
119        // --> adapt coloring;
120        chart.ApplyPaletteColors();
121        this.chart.Palette = ChartColorPalette.None;
122        var s0Color = chart.Series[0].Color;
123        var s1Color = chart.Series[1].Color;
124        var s2Color = chart.Series[2].Color;
125        var s3Color = chart.Series[3].Color;
126        this.chart.PaletteCustomColors = new Color[] { s1Color, s2Color, s3Color, s0Color };
127
128        UpdateCursorInterval();
129        this.UpdateStripLines();
130      }
131    }
132
133    private void InsertEmptyPoints(Series series) {
134      int i = 0;
135      while (i < series.Points.Count - 1) {
136        if (series.Points[i].IsEmpty) {
137          ++i;
138          continue;
139        }
140
141        var p1 = series.Points[i];
142        var p2 = series.Points[i + 1];
143        // check for consecutive indices
144        if ((int)p2.XValue - (int)p1.XValue != 1) {
145          // insert an empty point between p1 and p2 so that the line will be invisible (transparent)
146          var p = new DataPoint((int)((p1.XValue + p2.XValue) / 2), new double[] { 0.0, 0.0 }) { IsEmpty = true };
147          // insert
148          series.Points.Insert(i + 1, p);
149        }
150        ++i;
151      }
152    }
153
154    private void UpdateCursorInterval() {
155      var estimatedValues = this.chart.Series[ESTIMATEDVALUES_TRAINING_SERIES_NAME].Points.Select(x => x.YValues[0]).DefaultIfEmpty(1.0);
156      var targetValues = this.chart.Series[TARGETVARIABLE_SERIES_NAME].Points.Select(x => x.YValues[0]).DefaultIfEmpty(1.0);
157      double estimatedValuesRange = estimatedValues.Max() - estimatedValues.Min();
158      double targetValuesRange = targetValues.Max() - targetValues.Min();
159      double interestingValuesRange = Math.Min(Math.Max(targetValuesRange, 1.0), Math.Max(estimatedValuesRange, 1.0));
160      double digits = (int)Math.Log10(interestingValuesRange) - 3;
161      double yZoomInterval = Math.Max(Math.Pow(10, digits), 10E-5);
162      this.chart.ChartAreas[0].CursorY.Interval = yZoomInterval;
163    }
164
165    #region events
166    protected override void RegisterContentEvents() {
167      base.RegisterContentEvents();
168      Content.ModelChanged += new EventHandler(Content_ModelChanged);
169      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
170    }
171    protected override void DeregisterContentEvents() {
172      base.DeregisterContentEvents();
173      Content.ModelChanged -= new EventHandler(Content_ModelChanged);
174      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
175    }
176
177    protected override void OnContentChanged() {
178      base.OnContentChanged();
179      RedrawChart();
180    }
181    private void Content_ProblemDataChanged(object sender, EventArgs e) {
182      RedrawChart();
183    }
184    private void Content_ModelChanged(object sender, EventArgs e) {
185      RedrawChart();
186    }
187
188
189
190    private void Chart_MouseDoubleClick(object sender, MouseEventArgs e) {
191      HitTestResult result = chart.HitTest(e.X, e.Y);
192      if (result.ChartArea != null && (result.ChartElementType == ChartElementType.PlottingArea ||
193                                       result.ChartElementType == ChartElementType.Gridlines) ||
194                                       result.ChartElementType == ChartElementType.StripLines) {
195        foreach (var axis in result.ChartArea.Axes)
196          axis.ScaleView.ZoomReset(int.MaxValue);
197      }
198    }
199    #endregion
200
201    private void UpdateStripLines() {
202      this.chart.ChartAreas[0].AxisX.StripLines.Clear();
203
204      int[] attr = new int[Content.ProblemData.Dataset.Rows + 1]; // add a virtual last row that is again empty to simplify loop further down
205      foreach (var row in Content.ProblemData.TrainingIndices) {
206        attr[row] += 1;
207      }
208      foreach (var row in Content.ProblemData.TestIndices) {
209        attr[row] += 2;
210      }
211      int start = 0;
212      int curAttr = attr[start];
213      for (int row = 0; row < attr.Length; row++) {
214        if (attr[row] != curAttr) {
215          switch (curAttr) {
216            case 0: break;
217            case 1:
218              this.CreateAndAddStripLine("Training", start, row, Color.FromArgb(40, Color.Green), Color.Transparent);
219              break;
220            case 2:
221              this.CreateAndAddStripLine("Test", start, row, Color.FromArgb(40, Color.Red), Color.Transparent);
222              break;
223            case 3:
224              this.CreateAndAddStripLine("Training and Test", start, row, Color.FromArgb(40, Color.Green), Color.FromArgb(40, Color.Red), ChartHatchStyle.WideUpwardDiagonal);
225              break;
226            default:
227              // should not happen
228              break;
229          }
230          curAttr = attr[row];
231          start = row;
232        }
233      }
234    }
235
236    private void CreateAndAddStripLine(string title, int start, int end, Color color, Color secondColor, ChartHatchStyle hatchStyle = ChartHatchStyle.None) {
237      StripLine stripLine = new StripLine();
238      stripLine.BackColor = color;
239      stripLine.BackSecondaryColor = secondColor;
240      stripLine.BackHatchStyle = hatchStyle;
241      stripLine.Text = title;
242      stripLine.Font = new Font("Times New Roman", 12, FontStyle.Bold);
243      // strip range is [start .. end] inclusive, but we evaluate [start..end[ (end is exclusive)
244      // the strip should be by one longer (starting at start - 0.5 and ending at end + 0.5)
245      stripLine.StripWidth = end - start;
246      stripLine.IntervalOffset = start - 0.5; // start slightly to the left of the first point to clearly indicate the first point in the partition
247      this.chart.ChartAreas[0].AxisX.StripLines.Add(stripLine);
248    }
249
250    private void ToggleSeriesData(Series series) {
251      if (series.Points.Count > 0) {  //checks if series is shown
252        if (this.chart.Series.Any(s => s != series && s.Points.Count > 0)) {
253          ClearPointsQuick(series.Points);
254        }
255      } else if (Content != null) {
256
257        IEnumerable<int> indices = null;
258        IEnumerable<double> mean = null;
259        IEnumerable<double> s2 = null;
260        double[] lower = null;
261        double[] upper = null;
262        switch (series.Name) {
263          case ESTIMATEDVALUES_ALL_SERIES_NAME:
264            indices = Enumerable.Range(0, Content.ProblemData.Dataset.Rows).Except(Content.ProblemData.TrainingIndices).Except(Content.ProblemData.TestIndices).ToArray();
265            mean = Content.EstimatedValues.ToArray();
266            s2 = Content.EstimatedVariance.ToArray();
267            lower = mean.Zip(s2, GetLowerConfBound).ToArray();
268            upper = mean.Zip(s2, GetUpperConfBound).ToArray();
269            lower = indices.Select(index => lower[index]).ToArray();
270            upper = indices.Select(index => upper[index]).ToArray();
271            break;
272          case ESTIMATEDVALUES_TRAINING_SERIES_NAME:
273            indices = Content.ProblemData.TrainingIndices.ToArray();
274            mean = Content.EstimatedTrainingValues.ToArray();
275            s2 = Content.EstimatedTrainingVariance.ToArray();
276            lower = mean.Zip(s2, GetLowerConfBound).ToArray();
277            upper = mean.Zip(s2, GetUpperConfBound).ToArray();
278            break;
279          case ESTIMATEDVALUES_TEST_SERIES_NAME:
280            indices = Content.ProblemData.TestIndices.ToArray();
281            mean = Content.EstimatedTestValues.ToArray();
282            s2 = Content.EstimatedTestVariance.ToArray();
283            lower = mean.Zip(s2, GetLowerConfBound).ToArray();
284            upper = mean.Zip(s2, GetUpperConfBound).ToArray();
285            break;
286        }
287        if (indices.Count() > 0) {
288          series.Points.DataBindXY(indices, lower, upper);
289          this.InsertEmptyPoints(series);
290          chart.Legends[series.Legend].ForeColor = Color.Black;
291          UpdateCursorInterval();
292          chart.Refresh();
293        }
294      }
295    }
296
297    private double GetLowerConfBound(double m, double s) {
298      return m - 1.96 * Math.Sqrt(s);
299    }
300
301
302    private double GetUpperConfBound(double m, double s) {
303      return m + 1.96 * Math.Sqrt(s);
304    }
305
306    // workaround as per http://stackoverflow.com/questions/5744930/datapointcollection-clear-performance
307    private static void ClearPointsQuick(DataPointCollection points) {
308      points.SuspendUpdates();
309      while (points.Count > 0)
310        points.RemoveAt(points.Count - 1);
311      points.ResumeUpdates();
312    }
313
314    private void chart_MouseMove(object sender, MouseEventArgs e) {
315      HitTestResult result = chart.HitTest(e.X, e.Y);
316      if (result.ChartElementType == ChartElementType.LegendItem && result.Series.Name != TARGETVARIABLE_SERIES_NAME)
317        Cursor = Cursors.Hand;
318      else
319        Cursor = Cursors.Default;
320    }
321    private void chart_MouseDown(object sender, MouseEventArgs e) {
322      HitTestResult result = chart.HitTest(e.X, e.Y);
323      if (result.ChartElementType == ChartElementType.LegendItem && result.Series.Name != TARGETVARIABLE_SERIES_NAME) {
324        ToggleSeriesData(result.Series);
325      }
326    }
327
328    private void chart_CustomizeLegend(object sender, CustomizeLegendEventArgs e) {
329      if (chart.Series.Count != 4) return;
330      e.LegendItems[0].Cells[1].ForeColor = this.chart.Series[ESTIMATEDVALUES_TRAINING_SERIES_NAME].Points.Count == 0 ? Color.Gray : Color.Black;
331      e.LegendItems[1].Cells[1].ForeColor = this.chart.Series[ESTIMATEDVALUES_TEST_SERIES_NAME].Points.Count == 0 ? Color.Gray : Color.Black;
332      e.LegendItems[2].Cells[1].ForeColor = this.chart.Series[ESTIMATEDVALUES_ALL_SERIES_NAME].Points.Count == 0 ? Color.Gray : Color.Black;
333      e.LegendItems[3].Cells[1].ForeColor = this.chart.Series[TARGETVARIABLE_SERIES_NAME].Points.Count == 0 ? Color.Gray : Color.Black;
334    }
335  }
336}
Note: See TracBrowser for help on using the repository browser.