Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/14/12 13:25:17 (12 years ago)
Author:
gkronber
Message:

#1902 changed interface for covariance functions to improve readability, fixed several bugs in the covariance functions and in the line chart for Gaussian process models.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessModel.cs

    r8475 r8484  
    3939    public double NegativeLogLikelihood {
    4040      get { return negativeLogLikelihood; }
     41    }
     42
     43    [Storable]
     44    private double[] hyperparameterGradients;
     45    public double[] HyperparameterGradients {
     46      get {
     47        var copy = new double[hyperparameterGradients.Length];
     48        Array.Copy(hyperparameterGradients, copy, copy.Length);
     49        return copy;
     50      }
    4151    }
    4252
     
    125135
    126136      meanFunction.SetData(x);
    127       covarianceFunction.SetData(x);
    128137
    129138      // calculate means and covariances
     
    131140      for (int i = 0; i < n; i++) {
    132141        for (int j = i; j < n; j++) {
    133           l[j, i] = covarianceFunction.GetCovariance(i, j) / sqrSigmaNoise;
     142          l[j, i] = covarianceFunction.GetCovariance(x, i, j) / sqrSigmaNoise;
    134143          if (j == i) l[j, i] += 1.0;
    135144        }
     
    153162        alpha[i] = alpha[i] / sqrSigmaNoise;
    154163      negativeLogLikelihood = 0.5 * Util.ScalarProd(ym, alpha) + diagSum + (n / 2.0) * Math.Log(2.0 * Math.PI * sqrSigmaNoise);
    155     }
    156 
    157     public double[] GetHyperparameterGradients() {
     164
    158165      // derivatives
    159       int n = x.GetLength(0);
    160166      int nAllowedVariables = x.GetLength(1);
    161167
    162       int info;
    163168      alglib.matinvreport matInvRep;
    164169      double[,] lCopy = new double[l.GetLength(0), l.GetLength(1)];
     
    183188      if (covGradients.Length > 0) {
    184189        for (int i = 0; i < n; i++) {
     190          for (int j = 0; j < i; j++) {
     191            var g = covarianceFunction.GetGradient(x, i, j).ToArray();
     192            for (int k = 0; k < covGradients.Length; k++) {
     193              covGradients[k] += lCopy[i, j] * g[k];
     194            }
     195          }
     196
     197          var gDiag = covarianceFunction.GetGradient(x, i, i).ToArray();
    185198          for (int k = 0; k < covGradients.Length; k++) {
    186             for (int j = 0; j < i; j++) {
    187               covGradients[k] += lCopy[i, j] * covarianceFunction.GetGradient(i, j, k);
    188             }
    189             covGradients[k] += 0.5 * lCopy[i, i] * covarianceFunction.GetGradient(i, i, k);
     199            // diag
     200            covGradients[k] += 0.5 * lCopy[i, i] * gDiag[k];
    190201          }
    191202        }
    192203      }
    193204
    194       return
     205      hyperparameterGradients =
    195206        meanGradients
    196207        .Concat(covGradients)
    197208        .Concat(new double[] { noiseGradient }).ToArray();
     209
    198210    }
    199211
     
    219231      int newN = newX.GetLength(0);
    220232      int n = x.GetLength(0);
    221       // var predMean = new double[newN];
    222       // predVar = new double[newN];
    223 
    224 
    225 
    226       // var kss = new double[newN];
    227233      var Ks = new double[newN, n];
    228       //double[,] sWKs = new double[n, newN];
    229       // double[,] v;
    230 
    231 
    232       // for stddev
    233       //covarianceFunction.SetParameter(covHyp, newX);
    234       //kss = covarianceFunction.GetDiagonalCovariances();
    235 
    236       covarianceFunction.SetData(x, newX);
    237234      meanFunction.SetData(newX);
    238235      var ms = meanFunction.GetMean(newX);
    239236      for (int i = 0; i < newN; i++) {
    240237        for (int j = 0; j < n; j++) {
    241           Ks[i, j] = covarianceFunction.GetCovariance(j, i);
    242           //sWKs[j, i] = Ks[i, j] / Math.Sqrt(sqrSigmaNoise);
    243         }
    244       }
    245 
    246       // for stddev
    247       // alglib.rmatrixsolvem(l, n, sWKs, newN, true, out info, out denseSolveRep, out v);
     238          Ks[i, j] = covarianceFunction.GetCrossCovariance(x, newX, j, i);
     239        }
     240      }
    248241
    249242      return Enumerable.Range(0, newN)
    250243        .Select(i => ms[i] + Util.ScalarProd(Util.GetRow(Ks, i), alpha));
    251       //for (int i = 0; i < newN; i++) {
    252       //  // predMean[i] = ms[i] + prod(GetRow(Ks, i), alpha);
    253       //  // var sumV2 = prod(GetCol(v, i), GetCol(v, i));
    254       //  // predVar[i] = kss[i] - sumV2;
    255       //}
    256 
    257244    }
    258245
     
    266253
    267254      // for stddev
    268       covarianceFunction.SetData(newX);
    269255      for (int i = 0; i < newN; i++)
    270         kss[i] = covarianceFunction.GetCovariance(i, i);
    271 
    272       covarianceFunction.SetData(x, newX);
     256        kss[i] = covarianceFunction.GetCovariance(newX, i, i);
     257
    273258      for (int i = 0; i < newN; i++) {
    274259        for (int j = 0; j < n; j++) {
    275           sWKs[j, i] = covarianceFunction.GetCovariance(j, i) / Math.Sqrt(sqrSigmaNoise);
     260          sWKs[j, i] = covarianceFunction.GetCrossCovariance(x, newX, j, i) / Math.Sqrt(sqrSigmaNoise);
    276261        }
    277262      }
    278263
    279264      // for stddev
    280       int info;
    281       alglib.densesolverreport denseSolveRep;
    282       double[,] v;
    283 
    284       alglib.rmatrixsolvem(l, n, sWKs, newN, false, out info, out denseSolveRep, out v);
     265      alglib.ablas.rmatrixlefttrsm(n, newN, l, 0, 0, false, false, 0, ref sWKs, 0, 0);
    285266
    286267      for (int i = 0; i < newN; i++) {
    287         var sumV = Util.ScalarProd(Util.GetCol(v, i), Util.GetCol(v, i));
     268        var sumV = Util.ScalarProd(Util.GetCol(sWKs, i), Util.GetCol(sWKs, i));
    288269        kss[i] -= sumV;
    289270        if (kss[i] < 0) kss[i] = 0;
Note: See TracChangeset for help on using the changeset viewer.