Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/31/12 11:19:24 (12 years ago)
Author:
gkronber
Message:

#1902 added linear mean and covariance function

Location:
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4
Files:
2 added
7 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/CovarianceSum.cs

    r8323 r8366  
    1 using System.Collections.Generic;
     1using System;
    22using System.Linq;
     3using HeuristicLab.Common;
     4using HeuristicLab.Core;
     5using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    36
    47namespace HeuristicLab.Algorithms.DataAnalysis.GaussianProcess {
    5   public class CovarianceSum : ICovarianceFunction {
    6     private IList<ICovarianceFunction> covariances;
     8  [StorableClass]
     9  [Item(Name = "CovarianceSum",
     10    Description = "Sum covariance function for Gaussian processes.")]
     11  public class CovarianceSum : Item, ICovarianceFunction {
     12    [Storable]
     13    private ItemList<ICovarianceFunction> terms;
    714
    8     public int NumberOfParameters {
    9       get { return covariances.Sum(c => c.NumberOfParameters); }
     15    [Storable]
     16    private int numberOfVariables;
     17    public ItemList<ICovarianceFunction> Terms {
     18      get { return terms; }
    1019    }
    1120
    12     public CovarianceSum(IEnumerable<ICovarianceFunction> covariances) {
    13       this.covariances = covariances.ToList();
     21    [StorableConstructor]
     22    protected CovarianceSum(bool deserializing)
     23      : base(deserializing) {
    1424    }
    1525
    16     public void SetMatrix(double[,] x) {
    17       foreach (var covariance in covariances) {
    18         covariance.SetMatrix(x, x);
     26    protected CovarianceSum(CovarianceSum original, Cloner cloner)
     27      : base(original, cloner) {
     28      this.terms = cloner.Clone(terms);
     29    }
     30
     31    public CovarianceSum()
     32      : base() {
     33    }
     34
     35    public override IDeepCloneable Clone(Cloner cloner) {
     36      return new CovarianceSum(this, cloner);
     37    }
     38
     39    public int GetNumberOfParameters(int numberOfVariables) {
     40      this.numberOfVariables = numberOfVariables;
     41      return terms.Select(t => t.GetNumberOfParameters(numberOfVariables)).Sum();
     42    }
     43
     44    public void SetParameter(double[] hyp, double[,] x) {
     45      int offset = 0;
     46      foreach (var t in terms) {
     47        t.SetParameter(hyp.Skip(offset).Take(t.GetNumberOfParameters(numberOfVariables)), x);
     48        offset += numberOfVariables;
    1949      }
    2050    }
    2151
    22     public void SetMatrix(double[,] x, double[,] xt) {
    23       foreach (var covariance in covariances) {
    24         covariance.SetMatrix(x, xt);
    25       }
    26     }
    2752
    28     public void SetHyperparamter(double[] hyp) {
    29       int i = 0;
    30       foreach (var covariance in covariances) {
    31         int n = covariance.NumberOfParameters;
    32         covariance.SetHyperparamter(hyp.Skip(i).Take(n).ToArray());
    33         i += n;
    34       }
     53    public void SetParameter(double[] hyp, double[,] x, double[,] xt) {
     54      this.l = Math.Exp(hyp[0]);
     55      this.sf2 = Math.Exp(2 * hyp[1]);
     56
     57      this.symmetric = false;
     58      this.x = x;
     59      this.xt = xt;
     60      sd = null;
    3561    }
    3662
    3763    public double GetCovariance(int i, int j) {
    38       return covariances.Select(c => c.GetCovariance(i, j)).Sum();
     64      if (sd == null) CalculateSquaredDistances();
     65      return sf2 * Math.Exp(-sd[i, j] / 2.0);
    3966    }
    4067
    4168
    4269    public double[] GetDiagonalCovariances() {
    43       return covariances
    44         .Select(c => c.GetDiagonalCovariances())
    45         .Aggregate((s, d) => s.Zip(d, (a, b) => a + b).ToArray())
    46         .ToArray();
     70      if (x != xt) throw new InvalidOperationException();
     71      int rows = x.GetLength(0);
     72      var sd = new double[rows];
     73      for (int i = 0; i < rows; i++) {
     74        sd[i] = Util.SqrDist(Util.GetRow(x, i).Select(e => e / l), Util.GetRow(xt, i).Select(e => e / l));
     75      }
     76      return sd.Select(d => sf2 * Math.Exp(-d / 2.0)).ToArray();
    4777    }
    4878
    49     public double[] GetDerivatives(int i, int j) {
    50       return covariances
    51         .Select(c => c.GetDerivatives(i, j))
    52         .Aggregate(Enumerable.Empty<double>(), (h0, h1) => h0.Concat(h1))
    53         .ToArray();
     79
     80    public double[] GetGradient(int i, int j) {
     81      var res = new double[2];
     82      res[0] = sf2 * Math.Exp(-sd[i, j] / 2.0) * sd[i, j];
     83      res[1] = 2.0 * sf2 * Math.Exp(-sd[i, j] / 2.0);
     84      return res;
     85    }
     86
     87    private void CalculateSquaredDistances() {
     88      if (x.GetLength(1) != xt.GetLength(1)) throw new InvalidOperationException();
     89      int rows = x.GetLength(0);
     90      int cols = xt.GetLength(0);
     91      sd = new double[rows, cols];
     92      if (symmetric) {
     93        for (int i = 0; i < rows; i++) {
     94          for (int j = i; j < rows; j++) {
     95            sd[i, j] = Util.SqrDist(Util.GetRow(x, i).Select(e => e / l), Util.GetRow(xt, j).Select(e => e / l));
     96            sd[j, i] = sd[i, j];
     97          }
     98        }
     99      } else {
     100        for (int i = 0; i < rows; i++) {
     101          for (int j = 0; j < cols; j++) {
     102            sd[i, j] = Util.SqrDist(Util.GetRow(x, i).Select(e => e / l), Util.GetRow(xt, j).Select(e => e / l));
     103          }
     104        }
     105      }
    54106    }
    55107  }
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessModel.cs

    r8323 r8366  
    138138
    139139      // calculate means and covariances
    140       double[] m = meanFunction.GetMean();
     140      double[] m = meanFunction.GetMean(x);
    141141      for (int i = 0; i < n; i++) {
    142142
     
    188188      double[] meanGradients = new double[meanFunction.GetNumberOfParameters(nAllowedVariables)];
    189189      for (int i = 0; i < meanGradients.Length; i++) {
    190         var meanGrad = meanFunction.GetGradients(i);
     190        var meanGrad = meanFunction.GetGradients(i, x);
    191191        meanGradients[i] = -Util.ScalarProd(meanGrad, alpha);
    192192      }
    193193
    194194      double[] covGradients = new double[covarianceFunction.GetNumberOfParameters(nAllowedVariables)];
    195       for (int i = 0; i < n; i++) {
    196         for (int j = 0; j < n; j++) {
    197           var covDeriv = covarianceFunction.GetGradient(i, j);
    198           for (int k = 0; k < covGradients.Length; k++) {
    199             covGradients[k] += q[i, j] * covDeriv[k];
     195      if (covGradients.Length > 0) {
     196        for (int i = 0; i < n; i++) {
     197          for (int j = 0; j < n; j++) {
     198            var covDeriv = covarianceFunction.GetGradient(i, j);
     199            for (int k = 0; k < covGradients.Length; k++) {
     200              covGradients[k] += q[i, j] * covDeriv[k];
     201            }
    200202          }
    201203        }
    202       }
    203       covGradients = covGradients.Select(g => g / 2.0).ToArray();
     204        covGradients = covGradients.Select(g => g / 2.0).ToArray();
     205      }
    204206
    205207      return new double[] { noiseGradient }
     
    246248      covarianceFunction.SetParameter(covHyp, x, newX);
    247249      meanFunction.SetParameter(meanHyp, newX);
    248       var ms = meanFunction.GetMean();
     250      var ms = meanFunction.GetMean(newX);
    249251      for (int i = 0; i < newN; i++) {
    250252
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/IMeanFunction.cs

    r8323 r8366  
    66    int GetNumberOfParameters(int numberOfVariables);
    77    void SetParameter(double[] hyp, double[,] x);
    8     double[] GetMean();
    9     double[] GetGradients(int k);
     8    double[] GetMean(double[,] x);
     9    double[] GetGradients(int k, double[,] x);
    1010  }
    1111}
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/MeanConst.cs

    r8323 r8366  
    3434    }
    3535
    36     public double[] GetMean() {
     36    public double[] GetMean(double[,] x) {
    3737      return Enumerable.Repeat(c, n).ToArray();
    3838    }
    3939
    40     public double[] GetGradients(int k) {
     40    public double[] GetGradients(int k, double[,] x) {
    4141      if (k > 0) throw new ArgumentException();
    4242      return Enumerable.Repeat(1.0, n).ToArray();
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/MeanZero.cs

    r8323 r8366  
    2929    }
    3030
    31     public double[] GetMean() {
     31    public double[] GetMean(double[,] x) {
    3232      return Enumerable.Repeat(0.0, n).ToArray();
    3333    }
    3434
    35     public double[] GetGradients(int k) {
     35    public double[] GetGradients(int k, double[,] x) {
    3636      if (k > 0) throw new ArgumentException();
    3737      return Enumerable.Repeat(0.0, n).ToArray();
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/Util.cs

    r8323 r8366  
    2424      return Enumerable.Range(0, cols).Select(c => x[r, c]);
    2525    }
     26    public static IEnumerable<double> GetCol(double[,] x, int c) {
     27      int rows = x.GetLength(0);
     28      return Enumerable.Range(0, rows).Select(r => x[r, c]);
     29    }
    2630  }
    2731}
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/HeuristicLab.Algorithms.DataAnalysis-3.4.csproj

    r8324 r8366  
    122122    </Compile>
    123123    <Compile Include="FixedDataAnalysisAlgorithm.cs" />
     124    <Compile Include="GaussianProcess\CovarianceLinear.cs" />
     125    <Compile Include="GaussianProcess\MeanLinear.cs" />
    124126    <Compile Include="GaussianProcess\Util.cs" />
    125127    <Compile Include="GaussianProcess\MeanZero.cs" />
Note: See TracChangeset for help on using the changeset viewer.