1 | using System;
|
---|
2 | using System.Collections.Generic;
|
---|
3 | using System.Linq;
|
---|
4 |
|
---|
5 | namespace HeuristicLab.Algorithms.DataAnalysis.GaussianProcess {
|
---|
6 | public class CovarianceNNOne : ICovarianceFunction {
|
---|
7 | private double[,] x;
|
---|
8 | private double[,] xt;
|
---|
9 | private double sf2;
|
---|
10 | private double l2;
|
---|
11 | private double[,] S;
|
---|
12 | private double[] sx;
|
---|
13 | private double[] sz;
|
---|
14 | private double sxsx;
|
---|
15 | private double sxsz;
|
---|
16 |
|
---|
17 | public int NumberOfParameters {
|
---|
18 | get { return 2; }
|
---|
19 | }
|
---|
20 |
|
---|
21 | public void SetMatrix(double[,] x) {
|
---|
22 | SetMatrix(x, x);
|
---|
23 | }
|
---|
24 |
|
---|
25 | public void SetMatrix(double[,] x, double[,] xt) {
|
---|
26 | this.x = x;
|
---|
27 | this.xt = xt;
|
---|
28 | S = null;
|
---|
29 | sx = null;
|
---|
30 | sz = null;
|
---|
31 | }
|
---|
32 |
|
---|
33 | public void SetHyperparamter(double[] hyp) {
|
---|
34 | if (hyp.Length != 2) throw new ArgumentException();
|
---|
35 | this.l2 = Math.Exp(2 * hyp[0]);
|
---|
36 | this.sf2 = Math.Exp(2 * hyp[1]);
|
---|
37 | S = null;
|
---|
38 | sx = null;
|
---|
39 | sz = null;
|
---|
40 | }
|
---|
41 |
|
---|
42 | public double GetCovariance(int i, int j) {
|
---|
43 | if (S == null) CalculateVectorProducts();
|
---|
44 | if (sx == null) CalculateSx();
|
---|
45 | bool symmetric = x == xt;
|
---|
46 | double k;
|
---|
47 | if (symmetric) {
|
---|
48 | k = S[i, j] / sxsx;
|
---|
49 | } else {
|
---|
50 | k = S[i, j] / sxsz;
|
---|
51 | }
|
---|
52 | return sf2 * Math.Asin(k);
|
---|
53 | }
|
---|
54 |
|
---|
55 |
|
---|
56 | public double[] GetDiagonalCovariances() {
|
---|
57 | if (x != xt) throw new InvalidOperationException();
|
---|
58 | if (sx == null) CalculateSx();
|
---|
59 | int rows = x.GetLength(0);
|
---|
60 | var k = new double[rows];
|
---|
61 | for (int i = 0; i < rows; i++) {
|
---|
62 | k[i] = sx[i] / (sx[i] + l2);
|
---|
63 | k[i] = sf2 * Math.Asin(k[i]);
|
---|
64 | }
|
---|
65 | return k;
|
---|
66 | }
|
---|
67 |
|
---|
68 | public double[] GetDerivatives(int i, int j) {
|
---|
69 | double[] dhyp = new double[NumberOfParameters];
|
---|
70 | double[] vx = sx.Select(e => e / (l2 + e) / 2).ToArray();
|
---|
71 |
|
---|
72 | double k;
|
---|
73 | double v;
|
---|
74 | if (x == xt) {
|
---|
75 | k = S[i, j] / sxsx;
|
---|
76 | v = vx[i] + vx[j];
|
---|
77 | } else {
|
---|
78 | double[] vz = sz.Select(e => e / (l2 + e) / 2).ToArray();
|
---|
79 | v = vx[i] + vz[j];
|
---|
80 | k = S[i, j] / sxsz;
|
---|
81 | }
|
---|
82 | dhyp[0] = -2 * sf2 * (k - k * v) / Math.Sqrt(1 - k * k);
|
---|
83 | dhyp[1] = 2.0 * sf2 * Math.Asin(k);
|
---|
84 | return dhyp;
|
---|
85 | }
|
---|
86 |
|
---|
87 | private void CalculateSx() {
|
---|
88 | this.sx = new double[x.GetLength(0)];
|
---|
89 | for (int i = 0; i < sx.Length; i++) {
|
---|
90 | sx[i] = 1 + Product(GetRow(x, i), GetRow(x, i));
|
---|
91 | }
|
---|
92 | this.sz = new double[xt.GetLength(0)];
|
---|
93 | for (int i = 0; i < sz.Length; i++) {
|
---|
94 | sz[i] = 1 + Product(GetRow(xt, i), GetRow(xt, i));
|
---|
95 | }
|
---|
96 |
|
---|
97 | sxsx = Product(sx.Select(e => Math.Sqrt(l2 + e)), sx.Select(e => Math.Sqrt(l2 + e)));
|
---|
98 | sxsz = Product(sx.Select(e => Math.Sqrt(l2 + e)), sz.Select(e => Math.Sqrt(l2 + e)));
|
---|
99 | }
|
---|
100 |
|
---|
101 | private void CalculateVectorProducts() {
|
---|
102 | if (x.GetLength(1) != xt.GetLength(1)) throw new InvalidOperationException();
|
---|
103 | int rows = x.GetLength(0);
|
---|
104 | int cols = xt.GetLength(0);
|
---|
105 | S = new double[rows, cols];
|
---|
106 | bool symmetric = x == xt;
|
---|
107 | for (int i = 0; i < rows; i++) {
|
---|
108 | for (int j = i; j < rows; j++) {
|
---|
109 | S[i, j] = 1 + Product(GetRow(x, i), GetRow(xt, j));
|
---|
110 | if (symmetric) {
|
---|
111 | S[j, i] = S[i, j];
|
---|
112 | } else {
|
---|
113 | S[j, i] = 1 + Product(GetRow(x, j), GetRow(xt, i));
|
---|
114 | }
|
---|
115 | }
|
---|
116 | }
|
---|
117 | }
|
---|
118 |
|
---|
119 |
|
---|
120 | private double Product(IEnumerable<double> x, IEnumerable<double> y) {
|
---|
121 | return x.Zip(y, (a, b) => a * b).Sum();
|
---|
122 | }
|
---|
123 | private static IEnumerable<double> GetRow(double[,] x, int r) {
|
---|
124 | int cols = x.GetLength(1);
|
---|
125 | return Enumerable.Range(0, cols).Select(c => x[r, c]);
|
---|
126 | }
|
---|
127 | }
|
---|
128 | }
|
---|