Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
03/23/16 16:19:04 (9 years ago)
Author:
mkommend
Message:

#2591: Changed all GP covariance and mean functions to use int[] for column indices instead of IEnumerable<int>. Changed GP utils, GPModel and StudentTProcessModell as well to use fewer iterators and adapted unit tests to new interface.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/StudentTProcessModel.cs

    r13438 r13721  
    171171      try {
    172172        CalculateModel(ds, rows, scaleInputs);
    173       } catch (alglib.alglibexception ae) {
     173      }
     174      catch (alglib.alglibexception ae) {
    174175        // wrap exception so that calling code doesn't have to know about alglib implementation
    175176        throw new ArgumentException("There was a problem in the calculation of the Gaussian process model", ae);
     
    188189
    189190      int n = x.GetLength(0);
     191      var columns = Enumerable.Range(0, x.GetLength(1)).ToArray();
    190192
    191193      // calculate cholesky decomposed (lower triangular) covariance matrix
    192       var cov = covarianceFunction.GetParameterizedCovarianceFunction(covarianceParameter, Enumerable.Range(0, x.GetLength(1)));
     194      var cov = covarianceFunction.GetParameterizedCovarianceFunction(covarianceParameter, columns);
    193195      this.l = CalculateL(x, cov);
    194196
    195197      // calculate mean
    196       var mean = meanFunction.GetParameterizedMeanFunction(meanParameter, Enumerable.Range(0, x.GetLength(1)));
     198      var mean = meanFunction.GetParameterizedMeanFunction(meanParameter, columns);
    197199      double[] m = Enumerable.Range(0, x.GetLength(0))
    198200        .Select(r => mean.Mean(x, r))
     
    240242      double[] meanGradients = new double[meanFunction.GetNumberOfParameters(nAllowedVariables)];
    241243      for (int k = 0; k < meanGradients.Length; k++) {
    242         var meanGrad = Enumerable.Range(0, alpha.Length)
    243         .Select(r => mean.Gradient(x, r, k));
    244         meanGradients[k] = -Util.ScalarProd(meanGrad, alpha); //TODO not working yet, try to fix with gradient check
     244        var meanGrad = new double[alpha.Length];
     245        for (int g = 0; g < meanGrad.Length; g++)
     246          meanGrad[g] = mean.Gradient(x, g, k);
     247        meanGradients[k] = -Util.ScalarProd(meanGrad, alpha);//TODO not working yet, try to fix with gradient check
    245248      }
    246249
     
    336339        double[,] newX = GetData(dataset, allowedInputVariables, rows, inputScaling);
    337340        int newN = newX.GetLength(0);
    338 
    339         var Ks = new double[newN, n];
    340         var mean = meanFunction.GetParameterizedMeanFunction(meanParameter, Enumerable.Range(0, newX.GetLength(1)));
     341        var columns = Enumerable.Range(0, newX.GetLength(1)).ToArray();
     342
     343        var Ks = new double[newN][];
     344        var mean = meanFunction.GetParameterizedMeanFunction(meanParameter, columns);
    341345        var ms = Enumerable.Range(0, newX.GetLength(0))
    342346        .Select(r => mean.Mean(newX, r))
    343347        .ToArray();
    344         var cov = covarianceFunction.GetParameterizedCovarianceFunction(covarianceParameter, Enumerable.Range(0, newX.GetLength(1)));
     348        var cov = covarianceFunction.GetParameterizedCovarianceFunction(covarianceParameter, columns);
    345349        for (int i = 0; i < newN; i++) {
     350          Ks[i] = new double[n];
    346351          for (int j = 0; j < n; j++) {
    347             Ks[i, j] = cov.CrossCovariance(x, newX, j, i);
     352            Ks[i][j] = cov.CrossCovariance(x, newX, j, i);
    348353          }
    349354        }
    350355
    351356        return Enumerable.Range(0, newN)
    352           .Select(i => ms[i] + Util.ScalarProd(Util.GetRow(Ks, i), alpha));
    353       } catch (alglib.alglibexception ae) {
     357          .Select(i => ms[i] + Util.ScalarProd(Ks[i], alpha));
     358      }
     359      catch (alglib.alglibexception ae) {
    354360        // wrap exception so that calling code doesn't have to know about alglib implementation
    355361        throw new ArgumentException("There was a problem in the calculation of the Gaussian process model", ae);
     
    369375        var kss = new double[newN];
    370376        double[,] sWKs = new double[n, newN];
    371         var cov = covarianceFunction.GetParameterizedCovarianceFunction(covarianceParameter, Enumerable.Range(0, x.GetLength(1)));
    372        
     377        var cov = covarianceFunction.GetParameterizedCovarianceFunction(covarianceParameter, Enumerable.Range(0, x.GetLength(1)).ToArray());
     378
    373379        if (l == null) {
    374380          l = CalculateL(x, cov);
    375381        }
    376        
     382
    377383        // for stddev
    378384        for (int i = 0; i < newN; i++)
    379385          kss[i] = cov.Covariance(newX, i, i);
    380        
     386
    381387        for (int i = 0; i < newN; i++) {
    382388          for (int j = 0; j < n; j++) {
    383             sWKs[j, i] = cov.CrossCovariance(x, newX, j, i) ;
     389            sWKs[j, i] = cov.CrossCovariance(x, newX, j, i);
    384390          }
    385391        }
    386        
     392
    387393        // for stddev
    388394        alglib.ablas.rmatrixlefttrsm(n, newN, l, 0, 0, false, false, 0, ref sWKs, 0, 0);
    389        
     395
    390396        for (int i = 0; i < newN; i++) {
    391           var sumV = Util.ScalarProd(Util.GetCol(sWKs, i), Util.GetCol(sWKs, i));
     397          var col = Util.GetCol(sWKs, i).ToArray();
     398          var sumV = Util.ScalarProd(col, col);
    392399          kss[i] -= sumV;
    393           kss[i] *= (nu + beta -2) / (nu + n - 2);
     400          kss[i] *= (nu + beta - 2) / (nu + n - 2);
    394401          if (kss[i] < 0) kss[i] = 0;
    395402        }
    396403        return kss;
    397       } catch (alglib.alglibexception ae) {
     404      }
     405      catch (alglib.alglibexception ae) {
    398406        // wrap exception so that calling code doesn't have to know about alglib implementation
    399407        throw new ArgumentException("There was a problem in the calculation of the Gaussian process model", ae);
Note: See TracChangeset for help on using the changeset viewer.