source: branches/HeuristicLab.TimeSeries/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/CovarianceSEiso.cs @ 8477

Last change on this file since 8477 was 8477, checked in by mkommend, 9 years ago

#1081:

  • Added autoregressive target variable Symbol
  • Merged trunk changes into the branch.
File size: 4.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Linq;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
27
28namespace HeuristicLab.Algorithms.DataAnalysis {
29  [StorableClass]
30  [Item(Name = "CovarianceSEiso",
31    Description = "Isotropic squared exponential covariance function for Gaussian processes.")]
32  public class CovarianceSEiso : Item, ICovarianceFunction {
33    [Storable]
34    private double[,] x;
35    [Storable]
36    private double[,] xt;
37    [Storable]
38    private double sf2;
39    public double Scale { get { return sf2; } }
40    [Storable]
41    private double l;
42    public double Length { get { return l; } }
43    [Storable]
44    private bool symmetric;
45    private double[,] sd;
46
47    [StorableConstructor]
48    protected CovarianceSEiso(bool deserializing)
49      : base(deserializing) {
50    }
51
52    protected CovarianceSEiso(CovarianceSEiso original, Cloner cloner)
53      : base(original, cloner) {
54      if (original.x != null) {
55        this.x = new double[original.x.GetLength(0), original.x.GetLength(1)];
56        Array.Copy(original.x, this.x, x.Length);
57
58        this.xt = new double[original.xt.GetLength(0), original.xt.GetLength(1)];
59        Array.Copy(original.xt, this.xt, xt.Length);
60
61        this.sd = new double[original.sd.GetLength(0), original.sd.GetLength(1)];
62        Array.Copy(original.sd, this.sd, sd.Length);
63        this.sf2 = original.sf2;
64      }
65      this.sf2 = original.sf2;
66      this.l = original.l;
67      this.symmetric = original.symmetric;
68    }
69
70    public CovarianceSEiso()
71      : base() {
72    }
73
74    public override IDeepCloneable Clone(Cloner cloner) {
75      return new CovarianceSEiso(this, cloner);
76    }
77
78    public int GetNumberOfParameters(int numberOfVariables) {
79      return 2;
80    }
81
82    public void SetParameter(double[] hyp) {
83      this.l = Math.Exp(hyp[0]);
84      this.sf2 = Math.Exp(2 * hyp[1]);
85      sd = null;
86    }
87    public void SetData(double[,] x) {
88      SetData(x, x);
89      this.symmetric = true;
90    }
91
92
93    public void SetData(double[,] x, double[,] xt) {
94      this.symmetric = false;
95      this.x = x;
96      this.xt = xt;
97      sd = null;
98    }
99
100    public double GetCovariance(int i, int j) {
101      if (sd == null) CalculateSquaredDistances();
102      return sf2 * Math.Exp(-sd[i, j] / 2.0);
103    }
104
105    public double GetGradient(int i, int j, int k) {
106      switch (k) {
107        case 0: return sf2 * Math.Exp(-sd[i, j] / 2.0) * sd[i, j];
108        case 1: return 2.0 * sf2 * Math.Exp(-sd[i, j] / 2.0);
109        default: throw new ArgumentException("CovarianceSEiso has two hyperparameters", "k");
110      }
111    }
112
113    private void CalculateSquaredDistances() {
114      if (x.GetLength(1) != xt.GetLength(1)) throw new InvalidOperationException();
115      int rows = x.GetLength(0);
116      int cols = xt.GetLength(0);
117      sd = new double[rows, cols];
118      double lInv = 1.0 / l;
119      if (symmetric) {
120        for (int i = 0; i < rows; i++) {
121          for (int j = i; j < rows; j++) {
122            sd[i, j] = Util.SqrDist(Util.GetRow(x, i).Select(e => e * lInv), Util.GetRow(xt, j).Select(e => e * lInv));
123            sd[j, i] = sd[i, j];
124          }
125        }
126      } else {
127        for (int i = 0; i < rows; i++) {
128          for (int j = 0; j < cols; j++) {
129            sd[i, j] = Util.SqrDist(Util.GetRow(x, i).Select(e => e * lInv), Util.GetRow(xt, j).Select(e => e * lInv));
130          }
131        }
132      }
133    }
134  }
135}
Note: See TracBrowser for help on using the repository browser.