Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/CovariancePeriodic.cs @ 8463

Last change on this file since 8463 was 8463, checked in by gkronber, 12 years ago

#1902 improved GPR implementation

File size: 4.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using HeuristicLab.Common;
24using HeuristicLab.Core;
25using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
26
27namespace HeuristicLab.Algorithms.DataAnalysis {
28  [StorableClass]
29  [Item(Name = "CovariancePeriodic", Description = "Periodic covariance function for Gaussian processes.")]
30  public class CovariancePeriodic : Item, ICovarianceFunction {
31    [Storable]
32    private double[,] x;
33    [Storable]
34    private double[,] xt;
35    [Storable]
36    private double sf2;
37    [Storable]
38    private double l;
39    [Storable]
40    private double p;
41
42    private bool symmetric;
43
44    private double[,] sd;
45    public int GetNumberOfParameters(int numberOfVariables) {
46      return 3;
47    }
48    [StorableConstructor]
49    protected CovariancePeriodic(bool deserializing) : base(deserializing) { }
50    protected CovariancePeriodic(CovariancePeriodic original, Cloner cloner)
51      : base(original, cloner) {
52      if (original.x != null) {
53        x = new double[original.x.GetLength(0), original.x.GetLength(1)];
54        Array.Copy(original.x, x, x.Length);
55        xt = new double[original.xt.GetLength(0), original.xt.GetLength(1)];
56        Array.Copy(original.xt, xt, xt.Length);
57      }
58      sf2 = original.sf2;
59      l = original.l;
60      p = original.p;
61      symmetric = original.symmetric;
62    }
63    public CovariancePeriodic()
64      : base() {
65    }
66
67    public override IDeepCloneable Clone(Cloner cloner) {
68      return new CovariancePeriodic(this, cloner);
69    }
70
71    public void SetParameter(double[] hyp) {
72      if (hyp.Length != 3) throw new ArgumentException();
73      this.l = Math.Exp(hyp[0]);
74      this.p = Math.Exp(hyp[1]);
75      this.sf2 = Math.Exp(2 * hyp[2]);
76
77      sf2 = Math.Min(10E6, sf2); // upper limit for the scale
78
79      sd = null;
80    }
81    public void SetData(double[,] x) {
82      SetData(x, x);
83      this.symmetric = true;
84    }
85
86    public void SetData(double[,] x, double[,] xt) {
87      this.x = x;
88      this.xt = xt;
89      this.symmetric = false;
90
91      sd = null;
92    }
93
94    public double GetCovariance(int i, int j) {
95      if (sd == null) CalculateSquaredDistances();
96      double k = sd[i, j];
97      k = Math.PI * k / p;
98      k = Math.Sin(k) / l;
99      k = k * k;
100
101      return sf2 * Math.Exp(-2.0 * k);
102    }
103
104    public double GetGradient(int i, int j, int k) {
105      double v = Math.PI * sd[i, j] / p;
106      switch (k) {
107        case 0: {
108            double newK = Math.Sin(v) / l;
109            newK = newK * newK;
110            return 4 * sf2 * Math.Exp(-2 * newK) * newK;
111          }
112        case 1: {
113            double r = Math.Sin(v) / l;
114            return 4 * sf2 / l * Math.Exp(-2 * r * r) * r * Math.Cos(v) * v;
115          }
116        case 2: {
117            double newK = Math.Sin(v) / l;
118            newK = newK * newK;
119            return 2 * sf2 * Math.Exp(-2 * newK);
120
121          }
122        default: {
123            throw new ArgumentException("CovariancePeriodic only has three hyperparameters.", "k");
124          }
125      }
126    }
127
128    private void CalculateSquaredDistances() {
129      if (x.GetLength(1) != xt.GetLength(1)) throw new InvalidOperationException();
130      int rows = x.GetLength(0);
131      int cols = xt.GetLength(0);
132      sd = new double[rows, cols];
133
134      if (symmetric) {
135        for (int i = 0; i < rows; i++) {
136          for (int j = i; j < cols; j++) {
137            sd[i, j] = Math.Sqrt(Util.SqrDist(Util.GetRow(x, i), Util.GetRow(x, j)));
138            sd[j, i] = sd[i, j];
139          }
140        }
141      } else {
142        for (int i = 0; i < rows; i++) {
143          for (int j = 0; j < cols; j++) {
144            sd[i, j] = Math.Sqrt(Util.SqrDist(Util.GetRow(x, i), Util.GetRow(xt, j)));
145          }
146        }
147      }
148    }
149  }
150}
Note: See TracBrowser for help on using the repository browser.