Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/CovarianceFunctions/CovarianceSpectralMixture.cs @ 10480

Last change on this file since 10480 was 10480, checked in by gkronber, 10 years ago

#2124 reintegrated branch for spectral mixture kernel

File size: 9.3 KB
RevLine 
[10205]1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Linq.Expressions;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31
32namespace HeuristicLab.Algorithms.DataAnalysis {
33  [StorableClass]
34  [Item(Name = "CovarianceSpectralMixture",
35    Description = "The spectral mixture kernel described in Wilson A. G. and Adams R.P., Gaussian Process Kernels for Pattern Discovery and Exptrapolation, ICML 2013.")]
36  public sealed class CovarianceSpectralMixture : ParameterizedNamedItem, ICovarianceFunction {
37    public const string QParameterName = "Number of components (Q)";
38    public const string WeightParameterName = "Weight";
39    public const string FrequencyParameterName = "Component frequency (mu)";
40    public const string LengthScaleParameterName = "Length scale (nu)";
41    public IValueParameter<IntValue> QParameter {
42      get { return (IValueParameter<IntValue>)Parameters[QParameterName]; }
43    }
44
45    public IValueParameter<DoubleArray> WeightParameter {
46      get { return (IValueParameter<DoubleArray>)Parameters[WeightParameterName]; }
47    }
48    public IValueParameter<DoubleArray> FrequencyParameter {
49      get { return (IValueParameter<DoubleArray>)Parameters[FrequencyParameterName]; }
50    }
51
52    public IValueParameter<DoubleArray> LengthScaleParameter {
53      get { return (IValueParameter<DoubleArray>)Parameters[LengthScaleParameterName]; }
54    }
55
56    [StorableConstructor]
57    private CovarianceSpectralMixture(bool deserializing)
58      : base(deserializing) {
59    }
60
61    private CovarianceSpectralMixture(CovarianceSpectralMixture original, Cloner cloner)
62      : base(original, cloner) {
63    }
64
65    public CovarianceSpectralMixture()
66      : base() {
67      Name = ItemName;
68      Description = ItemDescription;
69      Parameters.Add(new ValueParameter<IntValue>(QParameterName, "The number of Gaussians (Q) to use for the spectral mixture.", new IntValue(10)));
70      Parameters.Add(new OptionalValueParameter<DoubleArray>(WeightParameterName, "The weight of the component w (peak height of the Gaussian in spectrum)."));
71      Parameters.Add(new OptionalValueParameter<DoubleArray>(FrequencyParameterName, "The inverse component period parameter mu_q (location of the Gaussian in spectrum)."));
72      Parameters.Add(new OptionalValueParameter<DoubleArray>(LengthScaleParameterName, "The length scale parameter (nu_q) (variance of the Gaussian in the spectrum)."));
73    }
74
75    public override IDeepCloneable Clone(Cloner cloner) {
76      return new CovarianceSpectralMixture(this, cloner);
77    }
78
79    public int GetNumberOfParameters(int numberOfVariables) {
80      var q = QParameter.Value.Value;
81      return
82        (WeightParameter.Value != null ? 0 : q) +
83        (FrequencyParameter.Value != null ? 0 : q * numberOfVariables) +
84        (LengthScaleParameter.Value != null ? 0 : q * numberOfVariables);
85    }
86
87    public void SetParameter(double[] p) {
88      double[] weight, frequency, lengthScale;
89      GetParameterValues(p, out weight, out frequency, out lengthScale);
90      WeightParameter.Value = new DoubleArray(weight);
91      FrequencyParameter.Value = new DoubleArray(frequency);
92      LengthScaleParameter.Value = new DoubleArray(lengthScale);
93    }
94
95
96    private void GetParameterValues(double[] p, out double[] weight, out double[] frequency, out double[] lengthScale) {
97      // gather parameter values
98      int c = 0;
99      int q = QParameter.Value.Value;
100      // guess number of elements for frequency and length (=q * numberOfVariables)
101      int n = WeightParameter.Value == null ? ((p.Length - q) / 2) : (p.Length / 2);
102      if (WeightParameter.Value != null) {
103        weight = WeightParameter.Value.ToArray();
104      } else {
105        weight = p.Skip(c).Select(Math.Exp).Take(q).ToArray();
106        c += q;
107      }
108      if (FrequencyParameter.Value != null) {
109        frequency = FrequencyParameter.Value.ToArray();
110      } else {
111        frequency = p.Skip(c).Select(Math.Exp).Take(n).ToArray();
112        c += n;
113      }
114      if (LengthScaleParameter.Value != null) {
115        lengthScale = LengthScaleParameter.Value.ToArray();
116      } else {
117        lengthScale = p.Skip(c).Select(Math.Exp).Take(n).ToArray();
118        c += n;
119      }
120      if (p.Length != c) throw new ArgumentException("The length of the parameter vector does not match the number of free parameters for CovarianceSpectralMixture", "p");
121    }
122
123    public ParameterizedCovarianceFunction GetParameterizedCovarianceFunction(double[] p, IEnumerable<int> columnIndices) {
124      double[] weight, frequency, lengthScale;
125      GetParameterValues(p, out weight, out frequency, out lengthScale);
126      // create functions
127      var cov = new ParameterizedCovarianceFunction();
128      cov.Covariance = (x, i, j) => {
129        return GetCovariance(x, x, i, j, QParameter.Value.Value, weight, frequency,
130                             lengthScale, columnIndices);
131      };
132      cov.CrossCovariance = (x, xt, i, j) => {
133        return GetCovariance(x, xt, i, j, QParameter.Value.Value, weight, frequency,
134                             lengthScale, columnIndices);
135      };
136      cov.CovarianceGradient = (x, i, j) => GetGradient(x, i, j, QParameter.Value.Value, weight, frequency,
137                             lengthScale, columnIndices);
138      return cov;
139    }
140
141    private static double GetCovariance(double[,] x, double[,] xt, int i, int j, int maxQ, double[] weight, double[] frequency, double[] lengthScale, IEnumerable<int> columnIndices) {
142      // tau = x - x' (only for selected variables)
143      double[] tau =
144        Util.GetRow(x, i, columnIndices).Zip(Util.GetRow(xt, j, columnIndices), (xi, xj) => xi - xj).ToArray();
145      int numberOfVariables = lengthScale.Length / maxQ;
146      double k = 0;
147      // for each component
148      for (int q = 0; q < maxQ; q++) {
149        double kc = weight[q]; // weighted kernel component
150
151        int idx = 0; // helper index for tau
152        // for each selected variable
153        foreach (var c in columnIndices) {
154          kc *= f1(tau[idx], lengthScale[q * numberOfVariables + c]) * f2(tau[idx], frequency[q * numberOfVariables + c]);
155          idx++;
156        }
157        k += kc;
158      }
159      return k;
160    }
161
162    public static double f1(double tau, double lengthScale) {
163      return Math.Exp(-2 * Math.PI * Math.PI * tau * tau * lengthScale);
164    }
165    public static double f2(double tau, double frequency) {
166      return Math.Cos(2 * Math.PI * tau * frequency);
167    }
168
169    // order of returned gradients must match the order in GetParameterValues!
170    private static IEnumerable<double> GetGradient(double[,] x, int i, int j, int maxQ, double[] weight, double[] frequency, double[] lengthScale, IEnumerable<int> columnIndices) {
171      double[] tau = Util.GetRow(x, i, columnIndices).Zip(Util.GetRow(x, j, columnIndices), (xi, xj) => xi - xj).ToArray();
172      int numberOfVariables = lengthScale.Length / maxQ;
173
174      // weight
175      // for each component
176      for (int q = 0; q < maxQ; q++) {
[10473]177        double k = weight[q];
[10205]178        int idx = 0; // helper index for tau
179        // for each selected variable
180        foreach (var c in columnIndices) {
181          k *= f1(tau[idx], lengthScale[q * numberOfVariables + c]) * f2(tau[idx], frequency[q * numberOfVariables + c]);
182          idx++;
183        }
184        yield return k;
185      }
186
187      // frequency
188      // for each component
189      for (int q = 0; q < maxQ; q++) {
190        int idx = 0; // helper index for tau
191        // for each selected variable
192        foreach (var c in columnIndices) {
193          double k = f1(tau[idx], lengthScale[q * numberOfVariables + c]) *
194            -2 * Math.PI * tau[idx] * frequency[q * numberOfVariables + c] * Math.Sin(2 * Math.PI * tau[idx] * frequency[q * numberOfVariables + c]);
195          idx++;
196          yield return weight[q] * k;
197        }
198      }
199
200      // length scale
201      // for each component
202      for (int q = 0; q < maxQ; q++) {
203        int idx = 0; // helper index for tau
204        // for each selected variable
205        foreach (var c in columnIndices) {
[10473]206          double k = -2 * Math.PI * Math.PI * tau[idx] * tau[idx] * lengthScale[q * numberOfVariables + c] *
207             f1(tau[idx], lengthScale[q * numberOfVariables + c]) * f2(tau[idx], frequency[q * numberOfVariables + c]);
[10205]208          idx++;
209          yield return weight[q] * k;
210        }
211      }
212
213    }
214  }
215}
Note: See TracBrowser for help on using the repository browser.