Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/06/12 16:16:28 (12 years ago)
Author:
gkronber
Message:

#1902 added periodic covariance function

Location:
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/CovarianceLinear.cs

    r8416 r8417  
    2727namespace HeuristicLab.Algorithms.DataAnalysis {
    2828  [StorableClass]
    29   [Item(Name = "CovarianceLinear", Description = "Linear covariance function with for Gaussian processes.")]
     29  [Item(Name = "CovarianceLinear", Description = "Linear covariance function for Gaussian processes.")]
    3030  public class CovarianceLinear : Item, ICovarianceFunction {
    3131    private static readonly double[] emptyArray = new double[0];
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/CovariancePeriodic.cs

    r8323 r8417  
    1 using System;
    2 using System.Collections.Generic;
    3 using System.Linq;
     1#region License Information
     2/* HeuristicLab
     3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     4 *
     5 * This file is part of HeuristicLab.
     6 *
     7 * HeuristicLab is free software: you can redistribute it and/or modify
     8 * it under the terms of the GNU General Public License as published by
     9 * the Free Software Foundation, either version 3 of the License, or
     10 * (at your option) any later version.
     11 *
     12 * HeuristicLab is distributed in the hope that it will be useful,
     13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
     14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     15 * GNU General Public License for more details.
     16 *
     17 * You should have received a copy of the GNU General Public License
     18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
     19 */
     20#endregion
    421
    5 namespace HeuristicLab.Algorithms.DataAnalysis.GaussianProcess {
    6   public class CovariancePeriodic : ICovarianceFunction {
     22using System;
     23using HeuristicLab.Common;
     24using HeuristicLab.Core;
     25using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
     26
     27namespace HeuristicLab.Algorithms.DataAnalysis {
     28  [StorableClass]
     29  [Item(Name = "CovariancePeriodic", Description = "Periodic covariance function for Gaussian processes.")]
     30  public class CovariancePeriodic : Item, ICovarianceFunction {
     31    [Storable]
    732    private double[,] x;
     33    [Storable]
    834    private double[,] xt;
     35    [Storable]
    936    private double sf2;
     37    [Storable]
    1038    private double l;
    11     private double[,] sd;
     39    [Storable]
    1240    private double p;
    1341
    14     public int NumberOfParameters {
    15       get { return 2; }
     42    private bool symmetric;
     43
     44    private double[,] sd;
     45    public int GetNumberOfParameters(int numberOfVariables) {
     46      return 3;
     47    }
     48    [StorableConstructor]
     49    protected CovariancePeriodic(bool deserializing) : base(deserializing) { }
     50    protected CovariancePeriodic(CovariancePeriodic original, Cloner cloner)
     51      : base(original, cloner) {
     52      if (original.x != null) {
     53        x = new double[original.x.GetLength(0), original.x.GetLength(1)];
     54        Array.Copy(original.x, x, x.Length);
     55        xt = new double[original.xt.GetLength(0), original.xt.GetLength(1)];
     56        Array.Copy(original.xt, xt, xt.Length);
     57      }
     58      sf2 = original.sf2;
     59      l = original.l;
     60      p = original.p;
     61      symmetric = original.symmetric;
     62    }
     63    public CovariancePeriodic()
     64      : base() {
    1665    }
    1766
    18     public CovariancePeriodic(double p) {
    19       this.p = p;
     67    public override IDeepCloneable Clone(Cloner cloner) {
     68      return new CovariancePeriodic(this, cloner);
    2069    }
    2170
    22     public void SetMatrix(double[,] x) {
    23       SetMatrix(x, x);
     71    public void SetParameter(double[] hyp) {
     72      if (hyp.Length != 3) throw new ArgumentException();
     73      this.l = Math.Exp(hyp[0]);
     74      this.p = Math.Exp(hyp[1]);
     75      this.sf2 = Math.Exp(2 * hyp[2]);
     76
     77      sf2 = Math.Min(10E6, sf2); // upper limit for the scale
     78
     79      sd = null;
     80    }
     81    public void SetData(double[,] x) {
     82      SetData(x, x);
     83      this.symmetric = true;
    2484    }
    2585
    26     public void SetMatrix(double[,] x, double[,] xt) {
     86    public void SetData(double[,] x, double[,] xt) {
    2787      this.x = x;
    2888      this.xt = xt;
    29       sd = null;
    30     }
     89      this.symmetric = false;
    3190
    32     public void SetHyperparamter(double[] hyp) {
    33       if (hyp.Length != 2) throw new ArgumentException();
    34       this.l = Math.Exp(hyp[0]);
    35       this.sf2 = Math.Exp(2 * hyp[1]);
    3691      sd = null;
    3792    }
     
    53108      var cov = new double[rows];
    54109      for (int i = 0; i < rows; i++) {
    55         double k = Math.Sqrt(SqrDist(GetRow(x, i), GetRow(xt, i)));
     110        double k = Math.Sqrt(Util.SqrDist(Util.GetRow(x, i), Util.GetRow(xt, i)));
    56111        k = Math.PI * k / p;
    57112        k = Math.Sin(k) / l;
     
    62117    }
    63118
    64     public double[] GetDerivatives(int i, int j) {
     119    public double[] GetGradient(int i, int j) {
    65120
    66       var res = new double[2];
     121      var res = new double[3];
    67122      double k = sd[i, j];
    68123      k = Math.PI * k / p;
    69       k = Math.Sin(k) / l;
    70       k = k * k;
    71       res[0] = 4 * sf2 * Math.Exp(-2 * k) * k;
    72       res[1] = 2 * sf2 * Math.Exp(-2 * k);
     124      {
     125        double newK = Math.Sin(k) / l;
     126        newK = newK * newK;
     127        res[0] = 4 * sf2 * Math.Exp(-2 * newK) * newK;
     128      }
     129      {
     130        double r = Math.Sin(k) / l;
     131        res[1] = 4 * sf2 / l * Math.Exp(-2 * r * r) * r * Math.Cos(k) * k;
     132      }
     133      {
     134        double newK = Math.Sin(k) / l;
     135        newK = newK * newK;
     136        res[2] = 2 * sf2 * Math.Exp(-2 * newK);
     137      }
     138
    73139      return res;
    74140    }
     
    79145      int cols = xt.GetLength(0);
    80146      sd = new double[rows, cols];
    81       bool symmetric = x == xt;
    82       for (int i = 0; i < rows; i++) {
    83         for (int j = i; j < rows; j++) {
    84           sd[i, j] = Math.Sqrt(SqrDist(GetRow(x, i), GetRow(xt, j)));
    85           if (symmetric) {
     147
     148      if (symmetric) {
     149        for (int i = 0; i < rows; i++) {
     150          for (int j = i; j < cols; j++) {
     151            sd[i, j] = Math.Sqrt(Util.SqrDist(Util.GetRow(x, i), Util.GetRow(x, j)));
    86152            sd[j, i] = sd[i, j];
    87           } else {
    88             sd[j, i] = Math.Sqrt(SqrDist(GetRow(x, j), GetRow(xt, i)));
     153          }
     154        }
     155      } else {
     156        for (int i = 0; i < rows; i++) {
     157          for (int j = 0; j < cols; j++) {
     158            sd[i, j] = Math.Sqrt(Util.SqrDist(Util.GetRow(x, i), Util.GetRow(xt, j)));
    89159          }
    90160        }
    91161      }
    92162    }
    93 
    94 
    95     private double SqrDist(IEnumerable<double> x, IEnumerable<double> y) {
    96       var d0 = x.Zip(y, (a, b) => (a - b) * (a - b));
    97       return Math.Max(0, d0.Sum());
    98     }
    99     private static IEnumerable<double> GetRow(double[,] x, int r) {
    100       int cols = x.GetLength(1);
    101       return Enumerable.Range(0, cols).Select(c => x[r, c]);
    102     }
    103163  }
    104164}
Note: See TracChangeset for help on using the changeset viewer.