Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/14/12 13:25:17 (12 years ago)
Author:
gkronber
Message:

#1902 changed interface for covariance functions to improve readability, fixed several bugs in the covariance functions and in the line chart for Gaussian process models.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/CovarianceRQiso.cs

    r8473 r8484  
    2121
    2222using System;
     23using System.Collections.Generic;
    2324using System.Linq;
    2425using HeuristicLab.Common;
     
    3233  public class CovarianceRQiso : Item, ICovarianceFunction {
    3334    [Storable]
    34     private double[,] x;
    35     [Storable]
    36     private double[,] xt;
    37     [Storable]
    3835    private double sf2;
    3936    public double Scale { get { return sf2; } }
     
    4441    private double alpha;
    4542    public double Shape { get { return alpha; } }
    46     [Storable]
    47     private bool symmetric;
    48     private double[,] d2;
    4943
    5044    [StorableConstructor]
     
    5549    protected CovarianceRQiso(CovarianceRQiso original, Cloner cloner)
    5650      : base(original, cloner) {
    57       if (original.x != null) {
    58         this.x = new double[original.x.GetLength(0), original.x.GetLength(1)];
    59         Array.Copy(original.x, this.x, x.Length);
    60 
    61         this.xt = new double[original.xt.GetLength(0), original.xt.GetLength(1)];
    62         Array.Copy(original.xt, this.xt, xt.Length);
    63 
    64         this.d2 = new double[original.d2.GetLength(0), original.d2.GetLength(1)];
    65         Array.Copy(original.d2, this.d2, d2.Length);
    66         this.sf2 = original.sf2;
    67       }
    6851      this.sf2 = original.sf2;
    6952      this.l = original.l;
    7053      this.alpha = original.alpha;
    71       this.symmetric = original.symmetric;
    7254    }
    7355
     
    8567
    8668    public void SetParameter(double[] hyp) {
     69      if (hyp.Length != 3) throw new ArgumentException("CovarianceRQiso has three hyperparameters", "k");
    8770      this.l = Math.Exp(hyp[0]);
    8871      this.sf2 = Math.Exp(2 * hyp[1]);
    8972      this.alpha = Math.Exp(hyp[2]);
    90       d2 = null;
    91     }
    92     public void SetData(double[,] x) {
    93       SetData(x, x);
    94       this.symmetric = true;
    9573    }
    9674
    9775
    98     public void SetData(double[,] x, double[,] xt) {
    99       this.symmetric = false;
    100       this.x = x;
    101       this.xt = xt;
    102       d2 = null;
     76    public double GetCovariance(double[,] x, int i, int j) {
     77      double lInv = 1.0 / l;
     78      double d = i == j
     79                   ? 0.0
     80                   : Util.SqrDist(Util.GetRow(x, i).Select(e => e * lInv), Util.GetRow(x, j).Select(e => e * lInv));
     81      return sf2 * Math.Pow(1 + 0.5 * d / alpha, -alpha);
    10382    }
    10483
    105     public double GetCovariance(int i, int j) {
    106       if (d2 == null) CalculateSquaredDistances();
    107       return sf2 * Math.Pow(1 + 0.5 * d2[i, j] / alpha, -alpha);
     84    public IEnumerable<double> GetGradient(double[,] x, int i, int j) {
     85      double lInv = 1.0 / l;
     86      double d = i == j
     87                   ? 0.0
     88                   : Util.SqrDist(Util.GetRow(x, i).Select(e => e * lInv), Util.GetRow(x, j).Select(e => e * lInv));
     89
     90      double b = 1 + 0.5 * d / alpha;
     91      yield return sf2 * Math.Pow(b, -alpha - 1) * d;
     92      yield return 2 * sf2 * Math.Pow(b, -alpha);
     93      yield return sf2 * Math.Pow(b, -alpha) * (0.5 * d / b - alpha * Math.Log(b));
    10894    }
    10995
    110     public double GetGradient(int i, int j, int k) {
    111       switch (k) {
    112         case 0: return sf2 * Math.Pow(1 + 0.5 * d2[i, j] / alpha, -alpha - 1) * d2[i, j];
    113         case 1: return 2 * sf2 * Math.Pow((1 + 0.5 * d2[i, j] / alpha), (-alpha));
    114         case 2: {
    115             double g = (1 + 0.5 * d2[i, j] / alpha);
    116             g = sf2 * Math.Pow(g, -alpha) * (0.5 * d2[i, j] / g - alpha * Math.Log(g));
    117             return g;
    118           }
    119         default: throw new ArgumentException("CovarianceRQiso has three hyperparameters", "k");
    120       }
    121     }
    122 
    123     private void CalculateSquaredDistances() {
    124       if (x.GetLength(1) != xt.GetLength(1)) throw new InvalidOperationException();
    125       int rows = x.GetLength(0);
    126       int cols = xt.GetLength(0);
    127       d2 = new double[rows, cols];
     96    public double GetCrossCovariance(double[,] x, double[,] xt, int i, int j) {
    12897      double lInv = 1.0 / l;
    129       if (symmetric) {
    130         for (int i = 0; i < rows; i++) {
    131           for (int j = i; j < rows; j++) {
    132             d2[i, j] = Util.SqrDist(Util.GetRow(x, i).Select(e => e * lInv), Util.GetRow(xt, j).Select(e => e * lInv));
    133             d2[j, i] = d2[i, j];
    134           }
    135         }
    136       } else {
    137         for (int i = 0; i < rows; i++) {
    138           for (int j = 0; j < cols; j++) {
    139             d2[i, j] = Util.SqrDist(Util.GetRow(x, i).Select(e => e * lInv), Util.GetRow(xt, j).Select(e => e * lInv));
    140           }
    141         }
    142       }
     98      double d = Util.SqrDist(Util.GetRow(x, i).Select(e => e * lInv), Util.GetRow(xt, j).Select(e => e * lInv));
     99      return sf2 * Math.Pow(1 + 0.5 * d / alpha, -alpha);
    143100    }
    144101  }
Note: See TracChangeset for help on using the changeset viewer.