Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GaussianProcessTuning/GaussianProcessDemo/Form1.cs @ 9779

Last change on this file since 9779 was 9387, checked in by gkronber, 12 years ago

#1967: added CovNN symbol and tree node

File size: 4.6 KB
RevLine 
[9124]1using System;
2using System.Collections.Generic;
3using System.ComponentModel;
4using System.Data;
5using System.Drawing;
6using System.Linq;
7using System.Text;
8using System.Threading.Tasks;
9using System.Windows.Forms;
10using HeuristicLab.Algorithms.DataAnalysis;
11using HeuristicLab.Core;
[9338]12using HeuristicLab.Data;
[9124]13using HeuristicLab.Problems.DataAnalysis;
14using HeuristicLab.Problems.Instances.DataAnalysis;
15using HeuristicLab.Random;
16
17namespace GaussianProcessDemo {
18  public partial class Form1 : Form {
19    private IRandom random;
20    private ICovarianceFunction covFunction;
21    private List<List<double>> data;
22    private double[] alpha;
23
24
25    public Form1() {
26      InitializeComponent();
27      this.random = new MersenneTwister();
28
29      var sum = new CovarianceSum();
[9387]30      var t = new CovarianceNeuralNetwork();
[9338]31      sum.Terms.Add(t);
[9124]32      sum.Terms.Add(new CovarianceNoise());
33      this.covFunction = sum;
34      UpdateSliders();
35
36      InitData();
37      UpdateChart();
38    }
39
40    private void UpdateSliders() {
41      flowLayoutPanel1.Controls.Clear();
[9338]42      flowLayoutPanel1.Controls.Add(dataButton);
[9124]43      for (int i = 0; i < covFunction.GetNumberOfParameters(1); i++) {
44        var sliderControl = new TrackBar();
45        sliderControl.Minimum = -50;
46        sliderControl.Maximum = 50;
47        sliderControl.Value = 0;
48        sliderControl.ValueChanged += (sender, args) => UpdateChart();
49        flowLayoutPanel1.Controls.Add(sliderControl);
50      }
51    }
52
53    private void InitData() {
[9338]54      int n = 200;
[9124]55      data = new List<List<double>>();
56      data.Add(ValueGenerator.GenerateSteps(0, 1, 1.0 / n).ToList());
57
58      // sample from GP
59      var normalRand = new NormalDistributedRandom(random, 0, 1);
60      alpha = (from i in Enumerable.Range(0, n + 1)
61               select normalRand.NextDouble()).ToArray();
62    }
63
64    private void UpdateChart() {
65      var hyp = GetSliderValues();
[9387]66      var cov = covFunction.GetParameterizedCovarianceFunction(hyp, Enumerable.Range(0, data.Count));
[9124]67      var y = Util.SampleGaussianProcess(random, cov, data, alpha);
68
69      chart1.Series[0].Points.Clear();
70      foreach (var p in y.Zip(data[0], (t, x) => new { t, x })) {
71        chart1.Series[0].Points.AddXY(p.x, p.t);
72      }
73
[9338]74      var trainingData = new List<List<double>>();
75      var trainingIndices = RandomEnumerable.SampleRandomWithoutRepetition(Enumerable.Range(0, y.Count), random, 10);
76      var trainingY = trainingIndices.Select(i => y[i]).ToList();
77      var trainingX = trainingIndices.Select(i => data[0][i]).ToList();
78      trainingData.Add(trainingY);
79      trainingData.Add(trainingX);
80     
81      //chart1.Series[2].Points.Clear();
82      //foreach (var p in trainingY.Zip(trainingX, (t, x) => new { t, x })) {
83      //  chart1.Series[2].Points.AddXY(p.x, p.t);
84      //}
85
[9124]86      var allData = new List<List<double>>();
87      allData.Add(y);
88      allData.Add(data[0]);
89      var variableNames = new string[] { "y", "x" };
[9338]90      var fullDataSet = new Dataset(variableNames, allData);
91      var trainingDataSet = new Dataset(variableNames, trainingData);
92      var trainingRows = Enumerable.Range(0, trainingIndices.Count());
93      var fullRows = Enumerable.Range(0, data[0].Count);
94      var correctModel = new GaussianProcessModel(fullDataSet, variableNames.First(), variableNames.Skip(1), fullRows, hyp, new MeanZero(),
[9124]95                                                (ICovarianceFunction)covFunction.Clone());
[9338]96      var yPred = correctModel.GetEstimatedValues(fullDataSet, fullRows);
[9124]97      chart1.Series[1].Points.Clear();
98      foreach (var p in yPred.Zip(data[0], (t, x) => new { t, x })) {
99        chart1.Series[1].Points.AddXY(p.x, p.t);
100      }
101    }
102
103    private double[] GetSliderValues() {
104      var hyp = new List<double>();
[9338]105      foreach (var slider in flowLayoutPanel1.Controls.OfType<TrackBar>()) {
[9124]106        Console.Write(slider.Value / 10.0 + " ");
107        hyp.Add(slider.Value / 10.0);
108      }
109      Console.WriteLine();
110
111      return hyp.ToArray();
112    }
[9338]113
114    private void dataButton_Click(object sender, EventArgs e) {
115      var dataForm = new Form();
116      var dataTextField = new TextBox();
117      dataTextField.Multiline = true;
118      dataTextField.Text = DataToText();
119      dataTextField.Dock = DockStyle.Fill;
120      dataForm.Controls.Add(dataTextField);
121      dataForm.ShowDialog();
122    }
123
124    private string DataToText() {
125      var str = new StringBuilder();
126      foreach (var p in chart1.Series[1].Points) {
127        str.AppendLine(p.XValue + "\t" + p.YValues.First());
128      }
129      return str.ToString();
130    }
[9124]131  }
132}
Note: See TracBrowser for help on using the repository browser.