Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/CovariancePeriodic.cs @ 8368

Last change on this file since 8368 was 8323, checked in by gkronber, 12 years ago

#1902 initial import of Gaussian process regression algorithm

File size: 2.8 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4
5namespace HeuristicLab.Algorithms.DataAnalysis.GaussianProcess {
6  public class CovariancePeriodic : ICovarianceFunction {
7    private double[,] x;
8    private double[,] xt;
9    private double sf2;
10    private double l;
11    private double[,] sd;
12    private double p;
13
14    public int NumberOfParameters {
15      get { return 2; }
16    }
17
18    public CovariancePeriodic(double p) {
19      this.p = p;
20    }
21
22    public void SetMatrix(double[,] x) {
23      SetMatrix(x, x);
24    }
25
26    public void SetMatrix(double[,] x, double[,] xt) {
27      this.x = x;
28      this.xt = xt;
29      sd = null;
30    }
31
32    public void SetHyperparamter(double[] hyp) {
33      if (hyp.Length != 2) throw new ArgumentException();
34      this.l = Math.Exp(hyp[0]);
35      this.sf2 = Math.Exp(2 * hyp[1]);
36      sd = null;
37    }
38
39    public double GetCovariance(int i, int j) {
40      if (sd == null) CalculateSquaredDistances();
41      double k = sd[i, j];
42      k = Math.PI * k / p;
43      k = Math.Sin(k) / l;
44      k = k * k;
45
46      return sf2 * Math.Exp(-2.0 * k);
47    }
48
49
50    public double[] GetDiagonalCovariances() {
51      if (x != xt) throw new InvalidOperationException();
52      int rows = x.GetLength(0);
53      var cov = new double[rows];
54      for (int i = 0; i < rows; i++) {
55        double k = Math.Sqrt(SqrDist(GetRow(x, i), GetRow(xt, i)));
56        k = Math.PI * k / p;
57        k = Math.Sin(k) / l;
58        k = k * k;
59        cov[i] = sf2 * Math.Exp(-2.0 * k);
60      }
61      return cov;
62    }
63
64    public double[] GetDerivatives(int i, int j) {
65
66      var res = new double[2];
67      double k = sd[i, j];
68      k = Math.PI * k / p;
69      k = Math.Sin(k) / l;
70      k = k * k;
71      res[0] = 4 * sf2 * Math.Exp(-2 * k) * k;
72      res[1] = 2 * sf2 * Math.Exp(-2 * k);
73      return res;
74    }
75
76    private void CalculateSquaredDistances() {
77      if (x.GetLength(1) != xt.GetLength(1)) throw new InvalidOperationException();
78      int rows = x.GetLength(0);
79      int cols = xt.GetLength(0);
80      sd = new double[rows, cols];
81      bool symmetric = x == xt;
82      for (int i = 0; i < rows; i++) {
83        for (int j = i; j < rows; j++) {
84          sd[i, j] = Math.Sqrt(SqrDist(GetRow(x, i), GetRow(xt, j)));
85          if (symmetric) {
86            sd[j, i] = sd[i, j];
87          } else {
88            sd[j, i] = Math.Sqrt(SqrDist(GetRow(x, j), GetRow(xt, i)));
89          }
90        }
91      }
92    }
93
94
95    private double SqrDist(IEnumerable<double> x, IEnumerable<double> y) {
96      var d0 = x.Zip(y, (a, b) => (a - b) * (a - b));
97      return Math.Max(0, d0.Sum());
98    }
99    private static IEnumerable<double> GetRow(double[,] x, int r) {
100      int cols = x.GetLength(1);
101      return Enumerable.Range(0, cols).Select(c => x[r, c]);
102    }
103  }
104}
Note: See TracBrowser for help on using the repository browser.