Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/CovarianceFunctions/CovariancePiecewisePolynomial.cs @ 13784

Last change on this file since 13784 was 13784, checked in by pfleck, 8 years ago

#2591 Made the creation of a GaussianProcessModel faster by avoiding additional iterators during calculation of the hyperparameter gradients.
The gradients of the hyperparameters are now calculated in one sweep and returned as IList, instead of returning an iterator (with yield return).
This avoids a large amount of Move-calls of the iterator, especially for covariance functions with a lot of hyperparameters.
Besides, the signature of the CovarianceGradientFunctionDelegate is changed, to return an IList instead of an IEnumerable to avoid unnececary ToList or ToArray calls.

File size: 6.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Parameters;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30
31namespace HeuristicLab.Algorithms.DataAnalysis {
32  [StorableClass]
33  [Item(Name = "CovariancePiecewisePolynomial",
34    Description = "Piecewise polynomial covariance function with compact support for Gaussian processes.")]
35  public sealed class CovariancePiecewisePolynomial : ParameterizedNamedItem, ICovarianceFunction {
36    public IValueParameter<DoubleValue> LengthParameter {
37      get { return (IValueParameter<DoubleValue>)Parameters["Length"]; }
38    }
39
40    public IValueParameter<DoubleValue> ScaleParameter {
41      get { return (IValueParameter<DoubleValue>)Parameters["Scale"]; }
42    }
43
44    public IConstrainedValueParameter<IntValue> VParameter {
45      get { return (IConstrainedValueParameter<IntValue>)Parameters["V"]; }
46    }
47    private bool HasFixedLengthParameter {
48      get { return LengthParameter.Value != null; }
49    }
50    private bool HasFixedScaleParameter {
51      get { return ScaleParameter.Value != null; }
52    }
53
54    [StorableConstructor]
55    private CovariancePiecewisePolynomial(bool deserializing)
56      : base(deserializing) {
57    }
58
59    private CovariancePiecewisePolynomial(CovariancePiecewisePolynomial original, Cloner cloner)
60      : base(original, cloner) {
61    }
62
63    public CovariancePiecewisePolynomial()
64      : base() {
65      Name = ItemName;
66      Description = ItemDescription;
67
68      Parameters.Add(new OptionalValueParameter<DoubleValue>("Length", "The length parameter of the isometric piecewise polynomial covariance function."));
69      Parameters.Add(new OptionalValueParameter<DoubleValue>("Scale", "The scale parameter of the piecewise polynomial covariance function."));
70
71      var validValues = new ItemSet<IntValue>(new IntValue[] {
72        (IntValue)(new IntValue().AsReadOnly()),
73        (IntValue)(new IntValue(1).AsReadOnly()),
74        (IntValue)(new IntValue(2).AsReadOnly()),
75        (IntValue)(new IntValue(3).AsReadOnly()) });
76      Parameters.Add(new ConstrainedValueParameter<IntValue>("V", "The v parameter of the piecewise polynomial function (allowed values 0, 1, 2, 3).", validValues, validValues.First()));
77    }
78
79    public override IDeepCloneable Clone(Cloner cloner) {
80      return new CovariancePiecewisePolynomial(this, cloner);
81    }
82
83    public int GetNumberOfParameters(int numberOfVariables) {
84      return
85        (HasFixedLengthParameter ? 0 : 1) +
86        (HasFixedScaleParameter ? 0 : 1);
87    }
88
89    public void SetParameter(double[] p) {
90      double @const, scale;
91      GetParameterValues(p, out @const, out scale);
92      LengthParameter.Value = new DoubleValue(@const);
93      ScaleParameter.Value = new DoubleValue(scale);
94    }
95
96    private void GetParameterValues(double[] p, out double length, out double scale) {
97      // gather parameter values
98      int n = 0;
99      if (HasFixedLengthParameter) {
100        length = LengthParameter.Value.Value;
101      } else {
102        length = Math.Exp(p[n]);
103        n++;
104      }
105
106      if (HasFixedScaleParameter) {
107        scale = ScaleParameter.Value.Value;
108      } else {
109        scale = Math.Exp(2 * p[n]);
110        n++;
111      }
112      if (p.Length != n) throw new ArgumentException("The length of the parameter vector does not match the number of free parameters for CovariancePiecewisePolynomial", "p");
113    }
114
115    public ParameterizedCovarianceFunction GetParameterizedCovarianceFunction(double[] p, int[] columnIndices) {
116      double length, scale;
117      int v = VParameter.Value.Value;
118      GetParameterValues(p, out length, out scale);
119      var fixedLength = HasFixedLengthParameter;
120      var fixedScale = HasFixedScaleParameter;
121      int exp = (int)Math.Floor(columnIndices.Count() / 2.0) + v + 1;
122
123      Func<double, double> f;
124      Func<double, double> df;
125      switch (v) {
126        case 0:
127          f = (r) => 1.0;
128          df = (r) => 0.0;
129          break;
130        case 1:
131          f = (r) => 1 + (exp + 1) * r;
132          df = (r) => exp + 1;
133          break;
134        case 2:
135          f = (r) => 1 + (exp + 2) * r + (exp * exp + 4.0 * exp + 3) / 3.0 * r * r;
136          df = (r) => (exp + 2) + 2 * (exp * exp + 4.0 * exp + 3) / 3.0 * r;
137          break;
138        case 3:
139          f = (r) => 1 + (exp + 3) * r + (6.0 * exp * exp + 36 * exp + 45) / 15.0 * r * r +
140                     (exp * exp * exp + 9 * exp * exp + 23 * exp + 45) / 15.0 * r * r * r;
141          df = (r) => (exp + 3) + 2 * (6.0 * exp * exp + 36 * exp + 45) / 15.0 * r +
142                      (exp * exp * exp + 9 * exp * exp + 23 * exp + 45) / 5.0 * r * r;
143          break;
144        default: throw new ArgumentException();
145      }
146
147      // create functions
148      var cov = new ParameterizedCovarianceFunction();
149      cov.Covariance = (x, i, j) => {
150        double k = Math.Sqrt(Util.SqrDist(x, i, x, j, columnIndices, 1.0 / length));
151        return scale * Math.Pow(Math.Max(1 - k, 0), exp + v) * f(k);
152      };
153      cov.CrossCovariance = (x, xt, i, j) => {
154        double k = Math.Sqrt(Util.SqrDist(x, i, xt, j, columnIndices, 1.0 / length));
155        return scale * Math.Pow(Math.Max(1 - k, 0), exp + v) * f(k);
156      };
157      cov.CovarianceGradient = (x, i, j) => GetGradient(x, i, j, length, scale, v, exp, f, df, columnIndices, fixedLength, fixedScale);
158      return cov;
159    }
160
161    private static IList<double> GetGradient(double[,] x, int i, int j, double length, double scale, int v, double exp, Func<double, double> f, Func<double, double> df, int[] columnIndices,
162      bool fixedLength, bool fixedScale) {
163      double k = Math.Sqrt(Util.SqrDist(x, i, x, j, columnIndices, 1.0 / length));
164      var g = new List<double>(2);
165      if (!fixedLength) g.Add(scale * Math.Pow(Math.Max(1.0 - k, 0), exp + v - 1) * k * ((exp + v) * f(k) - Math.Max(1 - k, 0) * df(k)));
166      if (!fixedScale) g.Add(2.0 * scale * Math.Pow(Math.Max(1 - k, 0), exp + v) * f(k));
167      return g;
168    }
169  }
170}
Note: See TracBrowser for help on using the repository browser.