Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis.Views/3.4/GaussianProcessRegressionSolutionLineChartView.cs @ 8473

Last change on this file since 8473 was 8473, checked in by gkronber, 12 years ago

#1902 worked on GPR: added line chart, made parameters of mean and covariance functions readable, removed target variable scaling, moved noise hyperparameter for likelihood function to the end of the parameter list, added methods to calculate the predicted variance, removed limits for scale of covariance functions and introduced exception handling to catch non-spd or singular cov matrixes, implemented rational quadratic covariance function, added unit test case from GBML book (however it does not work as the book seemingly uses a noise-less likelihood function)

File size: 14.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21using System;
22using System.Collections.Generic;
23using System.Drawing;
24using System.Linq;
25using System.Windows.Forms;
26using System.Windows.Forms.DataVisualization.Charting;
27using HeuristicLab.MainForm;
28using HeuristicLab.Problems.DataAnalysis.Views;
29
30namespace HeuristicLab.Algorithms.DataAnalysis.Views {
31  [View("Line Chart 2")]
32  [Content(typeof(GaussianProcessRegressionSolution))]
33  public partial class GaussianProcessRegressionSolutionLineChartView : DataAnalysisSolutionEvaluationView {
34    private const string TARGETVARIABLE_SERIES_NAME = "Target Variable";
35    private const string ESTIMATEDVALUES_TRAINING_SERIES_NAME = "Estimated Values (training)";
36    private const string ESTIMATEDVALUES_TEST_SERIES_NAME = "Estimated Values (test)";
37    private const string ESTIMATEDVALUES_ALL_SERIES_NAME = "Estimated Values (all samples)";
38
39    public new GaussianProcessRegressionSolution Content {
40      get { return (GaussianProcessRegressionSolution)base.Content; }
41      set { base.Content = value; }
42    }
43
44    public GaussianProcessRegressionSolutionLineChartView()
45      : base() {
46      InitializeComponent();
47      //configure axis
48      this.chart.CustomizeAllChartAreas();
49      this.chart.ChartAreas[0].CursorX.IsUserSelectionEnabled = true;
50      this.chart.ChartAreas[0].AxisX.ScaleView.Zoomable = true;
51      this.chart.ChartAreas[0].AxisX.IsStartedFromZero = true;
52      this.chart.ChartAreas[0].CursorX.Interval = 1;
53
54      this.chart.ChartAreas[0].CursorY.IsUserSelectionEnabled = true;
55      this.chart.ChartAreas[0].AxisY.ScaleView.Zoomable = true;
56      this.chart.ChartAreas[0].CursorY.Interval = 0;
57    }
58
59    private void RedrawChart() {
60      this.chart.Series.Clear();
61      if (Content != null) {
62        this.chart.ChartAreas[0].AxisX.Minimum = 0;
63        this.chart.ChartAreas[0].AxisX.Maximum = Content.ProblemData.Dataset.Rows - 1;
64
65        this.chart.Series.Add(TARGETVARIABLE_SERIES_NAME);
66        this.chart.Series[TARGETVARIABLE_SERIES_NAME].LegendText = Content.ProblemData.TargetVariable;
67        this.chart.Series[TARGETVARIABLE_SERIES_NAME].ChartType = SeriesChartType.FastLine;
68        this.chart.Series[TARGETVARIABLE_SERIES_NAME].Points.DataBindXY(Enumerable.Range(0, Content.ProblemData.Dataset.Rows).ToArray(),
69          Content.ProblemData.Dataset.GetDoubleValues(Content.ProblemData.TargetVariable).ToArray());
70        // training series
71        this.chart.Series.Add(ESTIMATEDVALUES_TRAINING_SERIES_NAME);
72        this.chart.Series[ESTIMATEDVALUES_TRAINING_SERIES_NAME].LegendText = ESTIMATEDVALUES_TRAINING_SERIES_NAME;
73        this.chart.Series[ESTIMATEDVALUES_TRAINING_SERIES_NAME].ChartType = SeriesChartType.Range;
74        this.chart.Series[ESTIMATEDVALUES_TRAINING_SERIES_NAME].EmptyPointStyle.Color = this.chart.Series[ESTIMATEDVALUES_TRAINING_SERIES_NAME].Color;
75        var mean = Content.EstimatedTrainingValues.ToArray();
76        var s2 = Content.EstimatedTrainingVariance.ToArray();
77        var lower = mean.Zip(s2, (m, s) => m - s).ToArray();
78        var upper = mean.Zip(s2, (m, s) => m + s).ToArray();
79        this.chart.Series[ESTIMATEDVALUES_TRAINING_SERIES_NAME].Points.DataBindXY(Content.ProblemData.TrainingIndices.ToArray(), lower, upper);
80        this.InsertEmptyPoints(this.chart.Series[ESTIMATEDVALUES_TRAINING_SERIES_NAME]);
81        this.chart.Series[ESTIMATEDVALUES_TRAINING_SERIES_NAME].Tag = Content;
82        // test series
83        this.chart.Series.Add(ESTIMATEDVALUES_TEST_SERIES_NAME);
84        this.chart.Series[ESTIMATEDVALUES_TEST_SERIES_NAME].LegendText = ESTIMATEDVALUES_TEST_SERIES_NAME;
85        this.chart.Series[ESTIMATEDVALUES_TEST_SERIES_NAME].ChartType = SeriesChartType.Range;
86
87        mean = Content.EstimatedTestValues.ToArray();
88        s2 = Content.EstimatedTestVariance.ToArray();
89        lower = mean.Zip(s2, (m, s) => m - s).ToArray();
90        upper = mean.Zip(s2, (m, s) => m + s).ToArray();
91        this.chart.Series[ESTIMATEDVALUES_TEST_SERIES_NAME].Points.DataBindXY(Content.ProblemData.TestIndices.ToArray(), lower, upper);
92        this.InsertEmptyPoints(this.chart.Series[ESTIMATEDVALUES_TEST_SERIES_NAME]);
93        this.chart.Series[ESTIMATEDVALUES_TEST_SERIES_NAME].Tag = Content;
94        // series of remaining points
95        int[] allIndices = Enumerable.Range(0, Content.ProblemData.Dataset.Rows).Except(Content.ProblemData.TrainingIndices).Except(Content.ProblemData.TestIndices).ToArray();
96        var estimatedValues = Content.EstimatedValues.ToArray();
97        List<double> allEstimatedValues = allIndices.Select(index => estimatedValues[index]).ToList();
98        this.chart.Series.Add(ESTIMATEDVALUES_ALL_SERIES_NAME);
99        this.chart.Series[ESTIMATEDVALUES_ALL_SERIES_NAME].LegendText = ESTIMATEDVALUES_ALL_SERIES_NAME;
100        this.chart.Series[ESTIMATEDVALUES_ALL_SERIES_NAME].ChartType = SeriesChartType.Range;
101        this.chart.Series[ESTIMATEDVALUES_ALL_SERIES_NAME].Points.DataBindXY(allIndices, allEstimatedValues);
102        this.InsertEmptyPoints(this.chart.Series[ESTIMATEDVALUES_ALL_SERIES_NAME]);
103        this.chart.Series[ESTIMATEDVALUES_ALL_SERIES_NAME].Tag = Content;
104        this.ToggleSeriesData(this.chart.Series[ESTIMATEDVALUES_ALL_SERIES_NAME]);
105
106        UpdateCursorInterval();
107        this.UpdateStripLines();
108      }
109    }
110
111    private void InsertEmptyPoints(Series series) {
112      int i = 0;
113      while (i < series.Points.Count - 1) {
114        if (series.Points[i].IsEmpty) {
115          ++i;
116          continue;
117        }
118
119        var p1 = series.Points[i];
120        var p2 = series.Points[i + 1];
121        // check for consecutive indices
122        if ((int)p2.XValue - (int)p1.XValue != 1) {
123          // insert an empty point between p1 and p2 so that the line will be invisible (transparent)
124          var p = new DataPoint((int)((p1.XValue + p2.XValue) / 2), 0.0) { IsEmpty = true };
125          series.Points.Insert(i + 1, p);
126        }
127        ++i;
128      }
129    }
130
131    private void UpdateCursorInterval() {
132      var estimatedValues = this.chart.Series[ESTIMATEDVALUES_TRAINING_SERIES_NAME].Points.Select(x => x.YValues[0]).DefaultIfEmpty(1.0);
133      var targetValues = this.chart.Series[TARGETVARIABLE_SERIES_NAME].Points.Select(x => x.YValues[0]).DefaultIfEmpty(1.0);
134      double estimatedValuesRange = estimatedValues.Max() - estimatedValues.Min();
135      double targetValuesRange = targetValues.Max() - targetValues.Min();
136      double interestingValuesRange = Math.Min(Math.Max(targetValuesRange, 1.0), Math.Max(estimatedValuesRange, 1.0));
137      double digits = (int)Math.Log10(interestingValuesRange) - 3;
138      double yZoomInterval = Math.Max(Math.Pow(10, digits), 10E-5);
139      this.chart.ChartAreas[0].CursorY.Interval = yZoomInterval;
140    }
141
142    #region events
143    protected override void RegisterContentEvents() {
144      base.RegisterContentEvents();
145      Content.ModelChanged += new EventHandler(Content_ModelChanged);
146      Content.ProblemDataChanged += new EventHandler(Content_ProblemDataChanged);
147    }
148    protected override void DeregisterContentEvents() {
149      base.DeregisterContentEvents();
150      Content.ModelChanged -= new EventHandler(Content_ModelChanged);
151      Content.ProblemDataChanged -= new EventHandler(Content_ProblemDataChanged);
152    }
153
154    protected override void OnContentChanged() {
155      base.OnContentChanged();
156      RedrawChart();
157    }
158    private void Content_ProblemDataChanged(object sender, EventArgs e) {
159      RedrawChart();
160    }
161    private void Content_ModelChanged(object sender, EventArgs e) {
162      RedrawChart();
163    }
164
165
166
167    private void Chart_MouseDoubleClick(object sender, MouseEventArgs e) {
168      HitTestResult result = chart.HitTest(e.X, e.Y);
169      if (result.ChartArea != null && (result.ChartElementType == ChartElementType.PlottingArea ||
170                                       result.ChartElementType == ChartElementType.Gridlines) ||
171                                       result.ChartElementType == ChartElementType.StripLines) {
172        foreach (var axis in result.ChartArea.Axes)
173          axis.ScaleView.ZoomReset(int.MaxValue);
174      }
175    }
176    #endregion
177
178    private void UpdateStripLines() {
179      this.chart.ChartAreas[0].AxisX.StripLines.Clear();
180
181      int[] attr = new int[Content.ProblemData.Dataset.Rows + 1]; // add a virtual last row that is again empty to simplify loop further down
182      foreach (var row in Content.ProblemData.TrainingIndices) {
183        attr[row] += 1;
184      }
185      foreach (var row in Content.ProblemData.TestIndices) {
186        attr[row] += 2;
187      }
188      int start = 0;
189      int curAttr = attr[start];
190      for (int row = 0; row < attr.Length; row++) {
191        if (attr[row] != curAttr) {
192          switch (curAttr) {
193            case 0: break;
194            case 1:
195              this.CreateAndAddStripLine("Training", start, row, Color.FromArgb(40, Color.Green), Color.Transparent);
196              break;
197            case 2:
198              this.CreateAndAddStripLine("Test", start, row, Color.FromArgb(40, Color.Red), Color.Transparent);
199              break;
200            case 3:
201              this.CreateAndAddStripLine("Training and Test", start, row, Color.FromArgb(40, Color.Green), Color.FromArgb(40, Color.Red), ChartHatchStyle.WideUpwardDiagonal);
202              break;
203            default:
204              // should not happen
205              break;
206          }
207          curAttr = attr[row];
208          start = row;
209        }
210      }
211    }
212
213    private void CreateAndAddStripLine(string title, int start, int end, Color color, Color secondColor, ChartHatchStyle hatchStyle = ChartHatchStyle.None) {
214      StripLine stripLine = new StripLine();
215      stripLine.BackColor = color;
216      stripLine.BackSecondaryColor = secondColor;
217      stripLine.BackHatchStyle = hatchStyle;
218      stripLine.Text = title;
219      stripLine.Font = new Font("Times New Roman", 12, FontStyle.Bold);
220      // strip range is [start .. end] inclusive, but we evaluate [start..end[ (end is exclusive)
221      // the strip should be by one longer (starting at start - 0.5 and ending at end + 0.5)
222      stripLine.StripWidth = end - start;
223      stripLine.IntervalOffset = start - 0.5; // start slightly to the left of the first point to clearly indicate the first point in the partition
224      this.chart.ChartAreas[0].AxisX.StripLines.Add(stripLine);
225    }
226
227    private void ToggleSeriesData(Series series) {
228      if (series.Points.Count > 0) {  //checks if series is shown
229        if (this.chart.Series.Any(s => s != series && s.Points.Count > 0)) {
230          ClearPointsQuick(series.Points);
231        }
232      } else if (Content != null) {
233        string targetVariableName = Content.ProblemData.TargetVariable;
234
235        IEnumerable<int> indices = null;
236        IEnumerable<double> mean = null;
237        IEnumerable<double> s2 = null;
238        double[] lower = null;
239        double[] upper = null;
240        switch (series.Name) {
241          case ESTIMATEDVALUES_ALL_SERIES_NAME:
242            indices = Enumerable.Range(0, Content.ProblemData.Dataset.Rows).Except(Content.ProblemData.TrainingIndices).Except(Content.ProblemData.TestIndices).ToArray();
243            mean = Content.EstimatedValues.ToArray();
244            s2 = Content.EstimatedVariance.ToArray();
245            lower = mean.Zip(s2, (m, s) => m - s).ToArray();
246            upper = mean.Zip(s2, (m, s) => m + s).ToArray();
247            lower = indices.Select(index => lower[index]).ToArray();
248            upper = indices.Select(index => upper[index]).ToArray();
249            break;
250          case ESTIMATEDVALUES_TRAINING_SERIES_NAME:
251            indices = Content.ProblemData.TrainingIndices.ToArray();
252            mean = Content.EstimatedTrainingValues.ToArray();
253            s2 = Content.EstimatedTrainingVariance.ToArray();
254            lower = mean.Zip(s2, (m, s) => m - s).ToArray();
255            upper = mean.Zip(s2, (m, s) => m + s).ToArray();
256            break;
257          case ESTIMATEDVALUES_TEST_SERIES_NAME:
258            indices = Content.ProblemData.TestIndices.ToArray();
259            mean = Content.EstimatedTestValues.ToArray();
260            s2 = Content.EstimatedTestVariance.ToArray();
261            lower = mean.Zip(s2, (m, s) => m - s).ToArray();
262            upper = mean.Zip(s2, (m, s) => m + s).ToArray();
263            break;
264        }
265        series.Points.DataBindXY(indices, lower, upper);
266        this.InsertEmptyPoints(series);
267        chart.Legends[series.Legend].ForeColor = Color.Black;
268        UpdateCursorInterval();
269        chart.Refresh();
270      }
271    }
272
273    // workaround as per http://stackoverflow.com/questions/5744930/datapointcollection-clear-performance
274    private static void ClearPointsQuick(DataPointCollection points) {
275      points.SuspendUpdates();
276      while (points.Count > 0)
277        points.RemoveAt(points.Count - 1);
278      points.ResumeUpdates();
279    }
280
281    private void chart_MouseMove(object sender, MouseEventArgs e) {
282      HitTestResult result = chart.HitTest(e.X, e.Y);
283      if (result.ChartElementType == ChartElementType.LegendItem && result.Series.Name != TARGETVARIABLE_SERIES_NAME)
284        Cursor = Cursors.Hand;
285      else
286        Cursor = Cursors.Default;
287    }
288    private void chart_MouseDown(object sender, MouseEventArgs e) {
289      HitTestResult result = chart.HitTest(e.X, e.Y);
290      if (result.ChartElementType == ChartElementType.LegendItem && result.Series.Name != TARGETVARIABLE_SERIES_NAME) {
291        ToggleSeriesData(result.Series);
292      }
293    }
294
295    private void chart_CustomizeLegend(object sender, CustomizeLegendEventArgs e) {
296      if (chart.Series.Count != 4) return;
297      e.LegendItems[0].Cells[1].ForeColor = this.chart.Series[TARGETVARIABLE_SERIES_NAME].Points.Count == 0 ? Color.Gray : Color.Black;
298      e.LegendItems[1].Cells[1].ForeColor = this.chart.Series[ESTIMATEDVALUES_TRAINING_SERIES_NAME].Points.Count == 0 ? Color.Gray : Color.Black;
299      e.LegendItems[2].Cells[1].ForeColor = this.chart.Series[ESTIMATEDVALUES_TEST_SERIES_NAME].Points.Count == 0 ? Color.Gray : Color.Black;
300      e.LegendItems[3].Cells[1].ForeColor = this.chart.Series[ESTIMATEDVALUES_ALL_SERIES_NAME].Points.Count == 0 ? Color.Gray : Color.Black;
301    }
302  }
303}
Note: See TracBrowser for help on using the repository browser.