Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/CovarianceNNOne.cs @ 8368

Last change on this file since 8368 was 8323, checked in by gkronber, 12 years ago

#1902 initial import of Gaussian process regression algorithm

File size: 3.6 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4
5namespace HeuristicLab.Algorithms.DataAnalysis.GaussianProcess {
6  public class CovarianceNNOne : ICovarianceFunction {
7    private double[,] x;
8    private double[,] xt;
9    private double sf2;
10    private double l2;
11    private double[,] S;
12    private double[] sx;
13    private double[] sz;
14    private double sxsx;
15    private double sxsz;
16
17    public int NumberOfParameters {
18      get { return 2; }
19    }
20
21    public void SetMatrix(double[,] x) {
22      SetMatrix(x, x);
23    }
24
25    public void SetMatrix(double[,] x, double[,] xt) {
26      this.x = x;
27      this.xt = xt;
28      S = null;
29      sx = null;
30      sz = null;
31    }
32
33    public void SetHyperparamter(double[] hyp) {
34      if (hyp.Length != 2) throw new ArgumentException();
35      this.l2 = Math.Exp(2 * hyp[0]);
36      this.sf2 = Math.Exp(2 * hyp[1]);
37      S = null;
38      sx = null;
39      sz = null;
40    }
41
42    public double GetCovariance(int i, int j) {
43      if (S == null) CalculateVectorProducts();
44      if (sx == null) CalculateSx();
45      bool symmetric = x == xt;
46      double k;
47      if (symmetric) {
48        k = S[i, j] / sxsx;
49      } else {
50        k = S[i, j] / sxsz;
51      }
52      return sf2 * Math.Asin(k);
53    }
54
55
56    public double[] GetDiagonalCovariances() {
57      if (x != xt) throw new InvalidOperationException();
58      if (sx == null) CalculateSx();
59      int rows = x.GetLength(0);
60      var k = new double[rows];
61      for (int i = 0; i < rows; i++) {
62        k[i] = sx[i] / (sx[i] + l2);
63        k[i] = sf2 * Math.Asin(k[i]);
64      }
65      return k;
66    }
67
68    public double[] GetDerivatives(int i, int j) {
69      double[] dhyp = new double[NumberOfParameters];
70      double[] vx = sx.Select(e => e / (l2 + e) / 2).ToArray();
71
72      double k;
73      double v;
74      if (x == xt) {
75        k = S[i, j] / sxsx;
76        v = vx[i] + vx[j];
77      } else {
78        double[] vz = sz.Select(e => e / (l2 + e) / 2).ToArray();
79        v = vx[i] + vz[j];
80        k = S[i, j] / sxsz;
81      }
82      dhyp[0] = -2 * sf2 * (k - k * v) / Math.Sqrt(1 - k * k);
83      dhyp[1] = 2.0 * sf2 * Math.Asin(k);
84      return dhyp;
85    }
86
87    private void CalculateSx() {
88      this.sx = new double[x.GetLength(0)];
89      for (int i = 0; i < sx.Length; i++) {
90        sx[i] = 1 + Product(GetRow(x, i), GetRow(x, i));
91      }
92      this.sz = new double[xt.GetLength(0)];
93      for (int i = 0; i < sz.Length; i++) {
94        sz[i] = 1 + Product(GetRow(xt, i), GetRow(xt, i));
95      }
96
97      sxsx = Product(sx.Select(e => Math.Sqrt(l2 + e)), sx.Select(e => Math.Sqrt(l2 + e)));
98      sxsz = Product(sx.Select(e => Math.Sqrt(l2 + e)), sz.Select(e => Math.Sqrt(l2 + e)));
99    }
100
101    private void CalculateVectorProducts() {
102      if (x.GetLength(1) != xt.GetLength(1)) throw new InvalidOperationException();
103      int rows = x.GetLength(0);
104      int cols = xt.GetLength(0);
105      S = new double[rows, cols];
106      bool symmetric = x == xt;
107      for (int i = 0; i < rows; i++) {
108        for (int j = i; j < rows; j++) {
109          S[i, j] = 1 + Product(GetRow(x, i), GetRow(xt, j));
110          if (symmetric) {
111            S[j, i] = S[i, j];
112          } else {
113            S[j, i] = 1 + Product(GetRow(x, j), GetRow(xt, i));
114          }
115        }
116      }
117    }
118
119
120    private double Product(IEnumerable<double> x, IEnumerable<double> y) {
121      return x.Zip(y, (a, b) => a * b).Sum();
122    }
123    private static IEnumerable<double> GetRow(double[,] x, int r) {
124      int cols = x.GetLength(1);
125      return Enumerable.Range(0, cols).Select(c => x[r, c]);
126    }
127  }
128}
Note: See TracBrowser for help on using the repository browser.