source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/StudentTProcessModel.cs @ 13784

Last change on this file since 13784 was 13784, checked in by pfleck, 5 years ago

#2591 Made the creation of a GaussianProcessModel faster by avoiding additional iterators during calculation of the hyperparameter gradients.
The gradients of the hyperparameters are now calculated in one sweep and returned as IList, instead of returning an iterator (with yield return).
This avoids a large amount of Move-calls of the iterator, especially for covariance functions with a lot of hyperparameters.
Besides, the signature of the CovarianceGradientFunctionDelegate is changed, to return an IList instead of an IEnumerable to avoid unnececary ToList or ToArray calls.

File size: 15.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.DataAnalysis {
31  /// <summary>
32  /// Represents a Gaussian process model.
33  /// </summary>
34  [StorableClass]
35  [Item("StudentTProcessModel", "Represents a Student-t process posterior.")]
36  public sealed class StudentTProcessModel : NamedItem, IGaussianProcessModel {
37    [Storable]
38    private double negativeLogLikelihood;
39    public double NegativeLogLikelihood {
40      get { return negativeLogLikelihood; }
41    }
42
43    [Storable]
44    private double[] hyperparameterGradients;
45    public double[] HyperparameterGradients {
46      get {
47        var copy = new double[hyperparameterGradients.Length];
48        Array.Copy(hyperparameterGradients, copy, copy.Length);
49        return copy;
50      }
51    }
52
53    [Storable]
54    private ICovarianceFunction covarianceFunction;
55    public ICovarianceFunction CovarianceFunction {
56      get { return covarianceFunction; }
57    }
58    [Storable]
59    private IMeanFunction meanFunction;
60    public IMeanFunction MeanFunction {
61      get { return meanFunction; }
62    }
63    [Storable]
64    private string targetVariable;
65    public string TargetVariable {
66      get { return targetVariable; }
67    }
68    [Storable]
69    private string[] allowedInputVariables;
70    public string[] AllowedInputVariables {
71      get { return allowedInputVariables; }
72    }
73
74    [Storable]
75    private double[] alpha;
76    [Storable]
77    private double beta;
78
79    public double SigmaNoise {
80      get { return 0; }
81    }
82
83    [Storable]
84    private double[] meanParameter;
85    [Storable]
86    private double[] covarianceParameter;
87    [Storable]
88    private double nu;
89
90    private double[,] l; // used to be storable in previous versions (is calculated lazily now)
91    private double[,] x; // scaled training dataset, used to be storable in previous versions (is calculated lazily now)
92
93    // BackwardsCompatibility3.4
94    #region Backwards compatible code, remove with 3.5
95    [Storable(Name = "l")] // restore if available but don't store anymore
96    private double[,] l_storable {
97      set { this.l = value; }
98      get {
99        if (trainingDataset == null) return l; // this model has been created with an old version
100        else return null; // if the training dataset is available l should not be serialized
101      }
102    }
103    [Storable(Name = "x")] // restore if available but don't store anymore
104    private double[,] x_storable {
105      set { this.x = value; }
106      get {
107        if (trainingDataset == null) return x; // this model has been created with an old version
108        else return null; // if the training dataset is available x should not be serialized
109      }
110    }
111    #endregion
112
113
114    [Storable]
115    private IDataset trainingDataset; // it is better to store the original training dataset completely because this is more efficient in persistence
116    [Storable]
117    private int[] trainingRows;
118
119    [Storable]
120    private Scaling inputScaling;
121
122
123    [StorableConstructor]
124    private StudentTProcessModel(bool deserializing) : base(deserializing) { }
125    private StudentTProcessModel(StudentTProcessModel original, Cloner cloner)
126      : base(original, cloner) {
127      this.meanFunction = cloner.Clone(original.meanFunction);
128      this.covarianceFunction = cloner.Clone(original.covarianceFunction);
129      if (original.inputScaling != null)
130        this.inputScaling = cloner.Clone(original.inputScaling);
131      this.trainingDataset = cloner.Clone(original.trainingDataset);
132      this.negativeLogLikelihood = original.negativeLogLikelihood;
133      this.targetVariable = original.targetVariable;
134      if (original.meanParameter != null) {
135        this.meanParameter = (double[])original.meanParameter.Clone();
136      }
137      if (original.covarianceParameter != null) {
138        this.covarianceParameter = (double[])original.covarianceParameter.Clone();
139      }
140      nu = original.nu;
141
142      // shallow copies of arrays because they cannot be modified
143      this.trainingRows = original.trainingRows;
144      this.allowedInputVariables = original.allowedInputVariables;
145      this.alpha = original.alpha;
146      this.beta = original.beta;
147      this.l = original.l;
148      this.x = original.x;
149    }
150    public StudentTProcessModel(IDataset ds, string targetVariable, IEnumerable<string> allowedInputVariables, IEnumerable<int> rows,
151      IEnumerable<double> hyp, IMeanFunction meanFunction, ICovarianceFunction covarianceFunction,
152      bool scaleInputs = true)
153      : base() {
154      this.name = ItemName;
155      this.description = ItemDescription;
156      this.meanFunction = (IMeanFunction)meanFunction.Clone();
157      this.covarianceFunction = (ICovarianceFunction)covarianceFunction.Clone();
158      this.targetVariable = targetVariable;
159      this.allowedInputVariables = allowedInputVariables.ToArray();
160
161
162      int nVariables = this.allowedInputVariables.Length;
163      meanParameter = hyp
164        .Take(this.meanFunction.GetNumberOfParameters(nVariables))
165        .ToArray();
166
167      covarianceParameter = hyp.Skip(meanParameter.Length)
168                                             .Take(this.covarianceFunction.GetNumberOfParameters(nVariables))
169                                             .ToArray();
170      nu = Math.Exp(hyp.Skip(meanParameter.Length + covarianceParameter.Length).First()) + 2; //TODO check gradient
171      try {
172        CalculateModel(ds, rows, scaleInputs);
173      }
174      catch (alglib.alglibexception ae) {
175        // wrap exception so that calling code doesn't have to know about alglib implementation
176        throw new ArgumentException("There was a problem in the calculation of the Gaussian process model", ae);
177      }
178    }
179
180    private void CalculateModel(IDataset ds, IEnumerable<int> rows, bool scaleInputs = true) {
181      this.trainingDataset = (IDataset)ds.Clone();
182      this.trainingRows = rows.ToArray();
183      this.inputScaling = scaleInputs ? new Scaling(ds, allowedInputVariables, rows) : null;
184
185      x = GetData(ds, this.allowedInputVariables, this.trainingRows, this.inputScaling);
186
187      IEnumerable<double> y;
188      y = ds.GetDoubleValues(targetVariable, rows);
189
190      int n = x.GetLength(0);
191      var columns = Enumerable.Range(0, x.GetLength(1)).ToArray();
192
193      // calculate cholesky decomposed (lower triangular) covariance matrix
194      var cov = covarianceFunction.GetParameterizedCovarianceFunction(covarianceParameter, columns);
195      this.l = CalculateL(x, cov);
196
197      // calculate mean
198      var mean = meanFunction.GetParameterizedMeanFunction(meanParameter, columns);
199      double[] m = Enumerable.Range(0, x.GetLength(0))
200        .Select(r => mean.Mean(x, r))
201        .ToArray();
202
203      // calculate sum of diagonal elements for likelihood
204      double diagSum = Enumerable.Range(0, n).Select(i => Math.Log(l[i, i])).Sum();
205
206      // solve for alpha
207      double[] ym = y.Zip(m, (a, b) => a - b).ToArray();
208
209      int info;
210      alglib.densesolverreport denseSolveRep;
211
212      alglib.spdmatrixcholeskysolve(l, n, false, ym, out info, out denseSolveRep, out alpha);
213
214      beta = Util.ScalarProd(ym, alpha);
215      double sign0, sign1;
216      double lngamma0 = alglib.lngamma(0.5 * (nu + n), out sign0);
217      lngamma0 *= sign0;
218      double lngamma1 = alglib.lngamma(0.5 * nu, out sign1);
219      lngamma1 *= sign1;
220      negativeLogLikelihood =
221        0.5 * n * Math.Log((nu - 2) * Math.PI) +
222        diagSum +
223        -lngamma0 + lngamma1 +
224        //-Math.Log(alglib.gammafunction((n + nu) / 2) / alglib.gammafunction(nu / 2)) +
225        0.5 * (nu + n) * Math.Log(1 + beta / (nu - 2));
226
227      // derivatives
228      int nAllowedVariables = x.GetLength(1);
229
230      alglib.matinvreport matInvRep;
231      double[,] lCopy = new double[l.GetLength(0), l.GetLength(1)];
232      Array.Copy(l, lCopy, lCopy.Length);
233
234      alglib.spdmatrixcholeskyinverse(ref lCopy, n, false, out info, out matInvRep);
235      double c = (nu + n) / (nu + beta - 2);
236      if (info != 1) throw new ArgumentException("Can't invert matrix to calculate gradients.");
237      for (int i = 0; i < n; i++) {
238        for (int j = 0; j <= i; j++)
239          lCopy[i, j] = lCopy[i, j] - c * alpha[i] * alpha[j];
240      }
241
242      double[] meanGradients = new double[meanFunction.GetNumberOfParameters(nAllowedVariables)];
243      for (int k = 0; k < meanGradients.Length; k++) {
244        var meanGrad = new double[alpha.Length];
245        for (int g = 0; g < meanGrad.Length; g++)
246          meanGrad[g] = mean.Gradient(x, g, k);
247        meanGradients[k] = -Util.ScalarProd(meanGrad, alpha);//TODO not working yet, try to fix with gradient check
248      }
249
250      double[] covGradients = new double[covarianceFunction.GetNumberOfParameters(nAllowedVariables)];
251      if (covGradients.Length > 0) {
252        for (int i = 0; i < n; i++) {
253          for (int j = 0; j < i; j++) {
254            var g = cov.CovarianceGradient(x, i, j);
255            for (int k = 0; k < covGradients.Length; k++) {
256              covGradients[k] += lCopy[i, j] * g[k];
257            }
258          }
259
260          var gDiag = cov.CovarianceGradient(x, i, i);
261          for (int k = 0; k < covGradients.Length; k++) {
262            // diag
263            covGradients[k] += 0.5 * lCopy[i, i] * gDiag[k];
264          }
265        }
266      }
267
268      double nuGradient = 0.5 * n
269        - 0.5 * (nu - 2) * alglib.psi((n + nu) / 2) + 0.5 * (nu - 2) * alglib.psi(nu / 2)
270        + 0.5 * (nu - 2) * Math.Log(1 + beta / (nu - 2)) - beta * (n + nu) / (2 * (beta + (nu - 2)));
271
272      //nuGradient = (nu-2) * nuGradient;
273      hyperparameterGradients =
274        meanGradients
275        .Concat(covGradients)
276        .Concat(new double[] { nuGradient }).ToArray();
277
278    }
279
280    private static double[,] GetData(IDataset ds, IEnumerable<string> allowedInputs, IEnumerable<int> rows, Scaling scaling) {
281      if (scaling != null) {
282        return AlglibUtil.PrepareAndScaleInputMatrix(ds, allowedInputs, rows, scaling);
283      } else {
284        return AlglibUtil.PrepareInputMatrix(ds, allowedInputs, rows);
285      }
286    }
287
288    private static double[,] CalculateL(double[,] x, ParameterizedCovarianceFunction cov) {
289      int n = x.GetLength(0);
290      var l = new double[n, n];
291
292      // calculate covariances
293      for (int i = 0; i < n; i++) {
294        for (int j = i; j < n; j++) {
295          l[j, i] = cov.Covariance(x, i, j);
296        }
297      }
298
299      // cholesky decomposition
300      var res = alglib.trfac.spdmatrixcholesky(ref l, n, false);
301      if (!res) throw new ArgumentException("Matrix is not positive semidefinite");
302      return l;
303    }
304
305
306    public override IDeepCloneable Clone(Cloner cloner) {
307      return new StudentTProcessModel(this, cloner);
308    }
309
310    // is called by the solution creator to set all parameter values of the covariance and mean function
311    // to the optimized values (necessary to make the values visible in the GUI)
312    public void FixParameters() {
313      covarianceFunction.SetParameter(covarianceParameter);
314      meanFunction.SetParameter(meanParameter);
315      covarianceParameter = new double[0];
316      meanParameter = new double[0];
317    }
318
319    #region IRegressionModel Members
320    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
321      return GetEstimatedValuesHelper(dataset, rows);
322    }
323    public GaussianProcessRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
324      return new GaussianProcessRegressionSolution(this, new RegressionProblemData(problemData));
325    }
326    IRegressionSolution IRegressionModel.CreateRegressionSolution(IRegressionProblemData problemData) {
327      return CreateRegressionSolution(problemData);
328    }
329    #endregion
330
331
332    private IEnumerable<double> GetEstimatedValuesHelper(IDataset dataset, IEnumerable<int> rows) {
333      try {
334        if (x == null) {
335          x = GetData(trainingDataset, allowedInputVariables, trainingRows, inputScaling);
336        }
337        int n = x.GetLength(0);
338
339        double[,] newX = GetData(dataset, allowedInputVariables, rows, inputScaling);
340        int newN = newX.GetLength(0);
341        var columns = Enumerable.Range(0, newX.GetLength(1)).ToArray();
342
343        var Ks = new double[newN][];
344        var mean = meanFunction.GetParameterizedMeanFunction(meanParameter, columns);
345        var ms = Enumerable.Range(0, newX.GetLength(0))
346        .Select(r => mean.Mean(newX, r))
347        .ToArray();
348        var cov = covarianceFunction.GetParameterizedCovarianceFunction(covarianceParameter, columns);
349        for (int i = 0; i < newN; i++) {
350          Ks[i] = new double[n];
351          for (int j = 0; j < n; j++) {
352            Ks[i][j] = cov.CrossCovariance(x, newX, j, i);
353          }
354        }
355
356        return Enumerable.Range(0, newN)
357          .Select(i => ms[i] + Util.ScalarProd(Ks[i], alpha));
358      }
359      catch (alglib.alglibexception ae) {
360        // wrap exception so that calling code doesn't have to know about alglib implementation
361        throw new ArgumentException("There was a problem in the calculation of the Gaussian process model", ae);
362      }
363    }
364
365    public IEnumerable<double> GetEstimatedVariance(IDataset dataset, IEnumerable<int> rows) {
366      try {
367        if (x == null) {
368          x = GetData(trainingDataset, allowedInputVariables, trainingRows, inputScaling);
369        }
370        int n = x.GetLength(0);
371
372        var newX = GetData(dataset, allowedInputVariables, rows, inputScaling);
373        int newN = newX.GetLength(0);
374
375        var kss = new double[newN];
376        double[,] sWKs = new double[n, newN];
377        var cov = covarianceFunction.GetParameterizedCovarianceFunction(covarianceParameter, Enumerable.Range(0, x.GetLength(1)).ToArray());
378
379        if (l == null) {
380          l = CalculateL(x, cov);
381        }
382
383        // for stddev
384        for (int i = 0; i < newN; i++)
385          kss[i] = cov.Covariance(newX, i, i);
386
387        for (int i = 0; i < newN; i++) {
388          for (int j = 0; j < n; j++) {
389            sWKs[j, i] = cov.CrossCovariance(x, newX, j, i);
390          }
391        }
392
393        // for stddev
394        alglib.ablas.rmatrixlefttrsm(n, newN, l, 0, 0, false, false, 0, ref sWKs, 0, 0);
395
396        for (int i = 0; i < newN; i++) {
397          var col = Util.GetCol(sWKs, i).ToArray();
398          var sumV = Util.ScalarProd(col, col);
399          kss[i] -= sumV;
400          kss[i] *= (nu + beta - 2) / (nu + n - 2);
401          if (kss[i] < 0) kss[i] = 0;
402        }
403        return kss;
404      }
405      catch (alglib.alglibexception ae) {
406        // wrap exception so that calling code doesn't have to know about alglib implementation
407        throw new ArgumentException("There was a problem in the calculation of the Gaussian process model", ae);
408      }
409    }
410  }
411}
Note: See TracBrowser for help on using the repository browser.