Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessCovarianceOptimizationProblem.cs @ 12946

Last change on this file since 12946 was 12946, checked in by gkronber, 9 years ago

#1967: added synthesis of covariance functions for Gaussian Process regression as a BasicProblem<SymbolicExpressionTreeEncoding>

File size: 15.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Diagnostics.Contracts;
25using System.Drawing.Text;
26using System.Linq;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
31using HeuristicLab.Optimization;
32using HeuristicLab.Parameters;
33using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
34using HeuristicLab.Problems.DataAnalysis;
35using HeuristicLab.Problems.Instances;
36
37
38namespace HeuristicLab.Algorithms.DataAnalysis {
39  [Item("Gaussian Process Covariance Optimization Problem", "")]
40  [Creatable(CreatableAttribute.Categories.GeneticProgrammingProblems, Priority = 300)]
41  [StorableClass]
42  public sealed class GaussianProcessCovarianceOptimizationProblem : SymbolicExpressionTreeProblem, IRegressionProblem, IProblemInstanceConsumer<IRegressionProblemData>, IProblemInstanceExporter<IRegressionProblemData> {
43    #region static variables and ctor
44    private static readonly CovarianceMaternIso maternIso1;
45    private static readonly CovarianceMaternIso maternIso3;
46    private static readonly CovarianceMaternIso maternIso5;
47    private static readonly CovariancePiecewisePolynomial piecewisePoly0;
48    private static readonly CovariancePiecewisePolynomial piecewisePoly1;
49    private static readonly CovariancePiecewisePolynomial piecewisePoly2;
50    private static readonly CovariancePiecewisePolynomial piecewisePoly3;
51    private static readonly CovariancePolynomial poly2;
52    private static readonly CovariancePolynomial poly3;
53    private static readonly CovarianceSpectralMixture spectralMixture1;
54    private static readonly CovarianceSpectralMixture spectralMixture3;
55    private static readonly CovarianceSpectralMixture spectralMixture5;
56    private static readonly CovarianceLinear linear;
57    private static readonly CovarianceLinearArd linearArd;
58    private static readonly CovarianceNeuralNetwork neuralNetwork;
59    private static readonly CovariancePeriodic periodic;
60    private static readonly CovarianceRationalQuadraticIso ratQuadraticIso;
61    private static readonly CovarianceRationalQuadraticArd ratQuadraticArd;
62    private static readonly CovarianceSquaredExponentialArd sqrExpArd;
63    private static readonly CovarianceSquaredExponentialIso sqrExpIso;
64
65    static GaussianProcessCovarianceOptimizationProblem() {
66      // cumbersome initialization because of ConstrainedValueParameters
67      maternIso1 = new CovarianceMaternIso(); SetConstrainedValueParameter(maternIso1.DParameter, 1);
68      maternIso3 = new CovarianceMaternIso(); SetConstrainedValueParameter(maternIso3.DParameter, 3);
69      maternIso5 = new CovarianceMaternIso(); SetConstrainedValueParameter(maternIso5.DParameter, 5);
70
71      piecewisePoly0 = new CovariancePiecewisePolynomial(); SetConstrainedValueParameter(piecewisePoly0.VParameter, 0);
72      piecewisePoly1 = new CovariancePiecewisePolynomial(); SetConstrainedValueParameter(piecewisePoly1.VParameter, 1);
73      piecewisePoly2 = new CovariancePiecewisePolynomial(); SetConstrainedValueParameter(piecewisePoly2.VParameter, 2);
74      piecewisePoly3 = new CovariancePiecewisePolynomial(); SetConstrainedValueParameter(piecewisePoly3.VParameter, 3);
75
76      poly2 = new CovariancePolynomial(); poly2.DegreeParameter.Value.Value = 2;
77      poly3 = new CovariancePolynomial(); poly3.DegreeParameter.Value.Value = 3;
78
79      spectralMixture1 = new CovarianceSpectralMixture(); spectralMixture1.QParameter.Value.Value = 1;
80      spectralMixture3 = new CovarianceSpectralMixture(); spectralMixture3.QParameter.Value.Value = 3;
81      spectralMixture5 = new CovarianceSpectralMixture(); spectralMixture5.QParameter.Value.Value = 5;
82
83      linear = new CovarianceLinear();
84      linearArd = new CovarianceLinearArd();
85      neuralNetwork = new CovarianceNeuralNetwork();
86      periodic = new CovariancePeriodic();
87      ratQuadraticArd = new CovarianceRationalQuadraticArd();
88      ratQuadraticIso = new CovarianceRationalQuadraticIso();
89      sqrExpArd = new CovarianceSquaredExponentialArd();
90      sqrExpIso = new CovarianceSquaredExponentialIso();
91    }
92
93    private static void SetConstrainedValueParameter(IConstrainedValueParameter<IntValue> param, int val) {
94      param.Value = param.ValidValues.Single(v => v.Value == val);
95    }
96
97    #endregion
98
99
100    #region parameter names
101
102    private const string ProblemDataParameterName = "ProblemData";
103    private const string ConstantOptIterationsParameterName = "Constant optimization steps";
104    private const string RestartsParameterName = "Restarts";
105
106    #endregion
107
108    #region Parameter Properties
109    IParameter IDataAnalysisProblem.ProblemDataParameter { get { return ProblemDataParameter; } }
110
111    public IValueParameter<IRegressionProblemData> ProblemDataParameter {
112      get { return (IValueParameter<IRegressionProblemData>)Parameters[ProblemDataParameterName]; }
113    }
114    public IFixedValueParameter<IntValue> ConstantOptIterationsParameter {
115      get { return (IFixedValueParameter<IntValue>)Parameters[ConstantOptIterationsParameterName]; }
116    }
117    public IFixedValueParameter<IntValue> RestartsParameter {
118      get { return (IFixedValueParameter<IntValue>)Parameters[RestartsParameterName]; }
119    }
120    #endregion
121
122    #region Properties
123
124    public IRegressionProblemData ProblemData {
125      get { return ProblemDataParameter.Value; }
126      set { ProblemDataParameter.Value = value; }
127    }
128    IDataAnalysisProblemData IDataAnalysisProblem.ProblemData { get { return ProblemData; } }
129
130    public int ConstantOptIterations {
131      get { return ConstantOptIterationsParameter.Value.Value; }
132      set { ConstantOptIterationsParameter.Value.Value = value; }
133    }
134
135    public int Restarts {
136      get { return RestartsParameter.Value.Value; }
137      set { RestartsParameter.Value.Value = value; }
138    }
139    #endregion
140
141    public override bool Maximization {
142      get { return true; } // return log likelihood (instead of negative log likelihood as in GPR
143    }
144
145    public GaussianProcessCovarianceOptimizationProblem()
146      : base() {
147      Parameters.Add(new ValueParameter<IRegressionProblemData>(ProblemDataParameterName, "The data for the regression problem", new RegressionProblemData()));
148      Parameters.Add(new FixedValueParameter<IntValue>(ConstantOptIterationsParameterName, "Number of optimization steps for hyperparameter values", new IntValue(50)));
149      Parameters.Add(new FixedValueParameter<IntValue>(RestartsParameterName, "The number of random restarts for constant optimization.", new IntValue(10)));
150      Parameters["Restarts"].Hidden = true;
151      var g = new SimpleSymbolicExpressionGrammar();
152      g.AddSymbols(new string[] { "Sum", "Product" }, 2, 2);
153      g.AddTerminalSymbols(new string[]
154      {
155        "Linear",
156        "LinearArd",
157        "MaternIso1",
158        "MaternIso3",
159        "MaternIso5",
160        "NeuralNetwork",
161        "Periodic",
162        "PiecewisePolynomial0",
163        "PiecewisePolynomial1",
164        "PiecewisePolynomial2",
165        "PiecewisePolynomial3",
166        "Polynomial2",
167        "Polynomial3",
168        "RationalQuadraticArd",
169        "RationalQuadraticIso",
170        "SpectralMixture1",
171        "SpectralMixture3",
172        "SpectralMixture5",
173        "SquaredExponentialArd",
174        "SquaredExponentialIso"
175      });
176      base.Encoding = new SymbolicExpressionTreeEncoding(g, 10, 5);
177    }
178
179
180    public override double Evaluate(ISymbolicExpressionTree tree, IRandom random) {
181
182      var meanFunction = new MeanConst();
183      var problemData = ProblemData;
184      var ds = problemData.Dataset;
185      var targetVariable = problemData.TargetVariable;
186      var allowedInputVariables = problemData.AllowedInputVariables.ToArray();
187      var nVars = allowedInputVariables.Length;
188      var trainingRows = problemData.TrainingIndices.ToArray();
189
190      // use the same covariance function for each restart
191      var covarianceFunction = TreeToCovarianceFunction(tree);
192
193      // allocate hyperparameters
194      var hyperParameters = new double[meanFunction.GetNumberOfParameters(nVars) + covarianceFunction.GetNumberOfParameters(nVars) + 1]; // mean + cov + noise
195      double[] bestHyperParameters = new double[hyperParameters.Length];
196      var bestObjValue = new double[1] { double.MinValue };
197
198      // data that is necessary for the objective function
199      var data = Tuple.Create(ds, targetVariable, allowedInputVariables, trainingRows, (IMeanFunction)meanFunction, covarianceFunction, bestObjValue);
200
201      for (int t = 0; t < Restarts; t++) {
202        var prevBest = bestObjValue[0];
203        var prevBestHyperParameters = new double[hyperParameters.Length];
204        Array.Copy(bestHyperParameters, prevBestHyperParameters, bestHyperParameters.Length);
205
206        // initialize hyperparameters
207        hyperParameters[0] = ds.GetDoubleValues(targetVariable).Average(); // mean const
208
209        for (int i = 0; i < covarianceFunction.GetNumberOfParameters(nVars); i++) {
210          hyperParameters[1 + i] = random.NextDouble() * 2.0 - 1.0;
211        }
212        hyperParameters[hyperParameters.Length - 1] = 1.0; // s² = exp(2), TODO: other inits better?
213
214        // use alglib.bfgs for hyper-parameter optimization ...
215        double epsg = 0;
216        double epsf = 0.00001;
217        double epsx = 0;
218        double stpmax = 1;
219        int maxits = ConstantOptIterations;
220        alglib.mincgstate state;
221        alglib.mincgreport rep;
222
223        alglib.mincgcreate(hyperParameters, out state);
224        alglib.mincgsetcond(state, epsg, epsf, epsx, maxits);
225        alglib.mincgsetstpmax(state, stpmax);
226        alglib.mincgoptimize(state, ObjectiveFunction, null, data);
227
228        alglib.mincgresults(state, out bestHyperParameters, out rep);
229
230        if (rep.terminationtype < 0) {
231          // error -> restore previous best quality
232          bestObjValue[0] = prevBest;
233          Array.Copy(prevBestHyperParameters, bestHyperParameters, prevBestHyperParameters.Length);
234        }
235      }
236
237      return bestObjValue[0];
238    }
239
240    public void ObjectiveFunction(double[] x, ref double func, double[] grad, object obj) {
241      // we want to optimize the model likelihood by changing the hyperparameters and also return the gradient for each hyperparameter
242      var data = (Tuple<IDataset, string, string[], int[], IMeanFunction, ICovarianceFunction, double[]>)obj;
243      var ds = data.Item1;
244      var targetVariable = data.Item2;
245      var allowedInputVariables = data.Item3;
246      var trainingRows = data.Item4;
247      var meanFunction = data.Item5;
248      var covarianceFunction = data.Item6;
249      var bestObjValue = data.Item7;
250      var hyperParameters = x; // the decision variable vector
251
252      try {
253        var model = new GaussianProcessModel(ds, targetVariable, allowedInputVariables, trainingRows, hyperParameters, meanFunction, covarianceFunction);
254
255        func = model.NegativeLogLikelihood; // mincgoptimize, so we return negative likelihood
256        bestObjValue[0] = Math.Max(bestObjValue[0], -func); // problem itself is a maximization problem
257        var gradients = model.HyperparameterGradients;
258        Array.Copy(gradients, grad, gradients.Length);
259      } catch (Exception) {
260        // building the GaussianProcessModel might fail, in this case we return the worst possible objective value
261        func = 1.0E+300;
262        Array.Clear(grad, 0, grad.Length);
263      }
264    }
265
266    private ICovarianceFunction TreeToCovarianceFunction(ISymbolicExpressionTree tree) {
267      return TreeToCovarianceFunction(tree.Root.GetSubtree(0).GetSubtree(0)); // skip programroot and startsymbol
268    }
269
270    private ICovarianceFunction TreeToCovarianceFunction(ISymbolicExpressionTreeNode node) {
271      switch (node.Symbol.Name) {
272        case "Sum": {
273            var sum = new CovarianceSum();
274            sum.Terms.Add(TreeToCovarianceFunction(node.GetSubtree(0)));
275            sum.Terms.Add(TreeToCovarianceFunction(node.GetSubtree(1)));
276            return sum;
277          }
278        case "Product": {
279            var prod = new CovarianceProduct();
280            prod.Factors.Add(TreeToCovarianceFunction(node.GetSubtree(0)));
281            prod.Factors.Add(TreeToCovarianceFunction(node.GetSubtree(1)));
282            return prod;
283          }
284        // covFunction is cloned by the model so we can reuse instances of terminal covariance functions
285        case "Linear": return linear;
286        case "LinearArd": return linearArd;
287        case "MaternIso1": return maternIso1;
288        case "MaternIso3": return maternIso3;
289        case "MaternIso5": return maternIso5;
290        case "NeuralNetwork": return neuralNetwork;
291        case "Periodic": return periodic;
292        case "PiecewisePolynomial0": return piecewisePoly0;
293        case "PiecewisePolynomial1": return piecewisePoly1;
294        case "PiecewisePolynomial2": return piecewisePoly2;
295        case "PiecewisePolynomial3": return piecewisePoly3;
296        case "Polynomial2": return poly2;
297        case "Polynomial3": return poly3;
298        case "RationalQuadraticArd": return ratQuadraticArd;
299        case "RationalQuadraticIso": return ratQuadraticIso;
300        case "SpectralMixture1": return spectralMixture1;
301        case "SpectralMixture3": return spectralMixture3;
302        case "SpectralMixture5": return spectralMixture5;
303        case "SquaredExponentialArd": return sqrExpArd;
304        case "SquaredExponentialIso": return sqrExpIso;
305        default: throw new InvalidProgramException(string.Format("Found invalid symbol {0}", node.Symbol.Name));
306      }
307    }
308
309
310    // persistence
311    [StorableConstructor]
312    private GaussianProcessCovarianceOptimizationProblem(bool deserializing) : base(deserializing) { }
313    [StorableHook(HookType.AfterDeserialization)]
314    private void AfterDeserialization() {
315    }
316
317    // cloning
318    private GaussianProcessCovarianceOptimizationProblem(GaussianProcessCovarianceOptimizationProblem original, Cloner cloner)
319      : base(original, cloner) {
320    }
321    public override IDeepCloneable Clone(Cloner cloner) {
322      return new GaussianProcessCovarianceOptimizationProblem(this, cloner);
323    }
324
325    public void Load(IRegressionProblemData data) {
326      this.ProblemData = data;
327      OnProblemDataChanged();
328    }
329
330    public IRegressionProblemData Export() {
331      return ProblemData;
332    }
333
334    #region events
335    public event EventHandler ProblemDataChanged;
336
337
338    private void OnProblemDataChanged() {
339      var handler = ProblemDataChanged;
340      if (handler != null)
341        handler(this, EventArgs.Empty);
342    }
343    #endregion
344
345  }
346}
Note: See TracBrowser for help on using the repository browser.