Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/CovarianceLinear.cs @ 8471

Last change on this file since 8471 was 8455, checked in by gkronber, 12 years ago

#1902 changed calculation of gradients for covariance functions to reduce allocations of arrays

File size: 3.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using HeuristicLab.Common;
24using HeuristicLab.Core;
25using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
26
27namespace HeuristicLab.Algorithms.DataAnalysis {
28  [StorableClass]
29  [Item(Name = "CovarianceLinear", Description = "Linear covariance function for Gaussian processes.")]
30  public class CovarianceLinear : Item, ICovarianceFunction {
31    [Storable]
32    private double[,] x;
33    [Storable]
34    private double[,] xt;
35
36    private double[,] k;
37    private bool symmetric;
38
39    public int GetNumberOfParameters(int numberOfVariables) {
40      return 0;
41    }
42    [StorableConstructor]
43    protected CovarianceLinear(bool deserializing) : base(deserializing) { }
44    protected CovarianceLinear(CovarianceLinear original, Cloner cloner)
45      : base(original, cloner) {
46      if (original.x != null) {
47        this.x = new double[original.x.GetLength(0), original.x.GetLength(1)];
48        Array.Copy(original.x, this.x, x.Length);
49
50        this.xt = new double[original.xt.GetLength(0), original.xt.GetLength(1)];
51        Array.Copy(original.xt, this.xt, xt.Length);
52
53        this.k = new double[original.k.GetLength(0), original.k.GetLength(1)];
54        Array.Copy(original.k, this.k, k.Length);
55      }
56      this.symmetric = original.symmetric;
57    }
58    public CovarianceLinear()
59      : base() {
60    }
61
62    public override IDeepCloneable Clone(Cloner cloner) {
63      return new CovarianceLinear(this, cloner);
64    }
65
66    public void SetParameter(double[] hyp) {
67      if (hyp.Length > 0) throw new ArgumentException("No hyperparameters are allowed for the linear covariance function.");
68      k = null;
69    }
70
71    public void SetData(double[,] x) {
72      SetData(x, x);
73      this.symmetric = true;
74    }
75
76    public void SetData(double[,] x, double[,] xt) {
77      this.x = x;
78      this.xt = xt;
79      this.symmetric = false;
80
81      k = null;
82    }
83
84    public double GetCovariance(int i, int j) {
85      if (k == null) CalculateInnerProduct();
86      return k[i, j];
87    }
88
89    public double GetGradient(int i, int j, int k) {
90      throw new NotSupportedException("CovarianceLinear does not have hyperparameters.");
91    }
92
93
94    private void CalculateInnerProduct() {
95      if (x.GetLength(1) != xt.GetLength(1)) throw new InvalidOperationException();
96      int rows = x.GetLength(0);
97      int cols = xt.GetLength(0);
98      k = new double[rows, cols];
99      if (symmetric) {
100        for (int i = 0; i < rows; i++) {
101          for (int j = i; j < cols; j++) {
102            k[i, j] = Util.ScalarProd(Util.GetRow(x, i),
103                                      Util.GetRow(x, j));
104            k[j, i] = k[i, j];
105          }
106        }
107      } else {
108        for (int i = 0; i < rows; i++) {
109          for (int j = 0; j < cols; j++) {
110            k[i, j] = Util.ScalarProd(Util.GetRow(x, i),
111                                      Util.GetRow(xt, j));
112          }
113        }
114      }
115    }
116  }
117}
Note: See TracBrowser for help on using the repository browser.