Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessCovarianceOptimizationProblem.cs @ 13234

Last change on this file since 13234 was 13234, checked in by gkronber, 8 years ago

#1967: synchronized access to PRNG in GaussianProcessCovarianceOptimizationProblem.Evaluate() and added a comment

File size: 19.0 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Linq;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Data;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Optimization;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31using HeuristicLab.Problems.DataAnalysis;
32using HeuristicLab.Problems.Instances;
33
34
35namespace HeuristicLab.Algorithms.DataAnalysis {
36  [Item("Gaussian Process Covariance Optimization Problem", "")]
37  [Creatable(CreatableAttribute.Categories.GeneticProgrammingProblems, Priority = 300)]
38  [StorableClass]
39  public sealed class GaussianProcessCovarianceOptimizationProblem : SymbolicExpressionTreeProblem, IStatefulItem, IRegressionProblem, IProblemInstanceConsumer<IRegressionProblemData>, IProblemInstanceExporter<IRegressionProblemData> {
40    #region static variables and ctor
41    private static readonly CovarianceMaternIso maternIso1;
42    private static readonly CovarianceMaternIso maternIso3;
43    private static readonly CovarianceMaternIso maternIso5;
44    private static readonly CovariancePiecewisePolynomial piecewisePoly0;
45    private static readonly CovariancePiecewisePolynomial piecewisePoly1;
46    private static readonly CovariancePiecewisePolynomial piecewisePoly2;
47    private static readonly CovariancePiecewisePolynomial piecewisePoly3;
48    private static readonly CovariancePolynomial poly2;
49    private static readonly CovariancePolynomial poly3;
50    private static readonly CovarianceSpectralMixture spectralMixture1;
51    private static readonly CovarianceSpectralMixture spectralMixture3;
52    private static readonly CovarianceSpectralMixture spectralMixture5;
53    private static readonly CovarianceLinear linear;
54    private static readonly CovarianceLinearArd linearArd;
55    private static readonly CovarianceNeuralNetwork neuralNetwork;
56    private static readonly CovariancePeriodic periodic;
57    private static readonly CovarianceRationalQuadraticIso ratQuadraticIso;
58    private static readonly CovarianceRationalQuadraticArd ratQuadraticArd;
59    private static readonly CovarianceSquaredExponentialArd sqrExpArd;
60    private static readonly CovarianceSquaredExponentialIso sqrExpIso;
61
62    static GaussianProcessCovarianceOptimizationProblem() {
63      // cumbersome initialization because of ConstrainedValueParameters
64      maternIso1 = new CovarianceMaternIso(); SetConstrainedValueParameter(maternIso1.DParameter, 1);
65      maternIso3 = new CovarianceMaternIso(); SetConstrainedValueParameter(maternIso3.DParameter, 3);
66      maternIso5 = new CovarianceMaternIso(); SetConstrainedValueParameter(maternIso5.DParameter, 5);
67
68      piecewisePoly0 = new CovariancePiecewisePolynomial(); SetConstrainedValueParameter(piecewisePoly0.VParameter, 0);
69      piecewisePoly1 = new CovariancePiecewisePolynomial(); SetConstrainedValueParameter(piecewisePoly1.VParameter, 1);
70      piecewisePoly2 = new CovariancePiecewisePolynomial(); SetConstrainedValueParameter(piecewisePoly2.VParameter, 2);
71      piecewisePoly3 = new CovariancePiecewisePolynomial(); SetConstrainedValueParameter(piecewisePoly3.VParameter, 3);
72
73      poly2 = new CovariancePolynomial(); poly2.DegreeParameter.Value.Value = 2;
74      poly3 = new CovariancePolynomial(); poly3.DegreeParameter.Value.Value = 3;
75
76      spectralMixture1 = new CovarianceSpectralMixture(); spectralMixture1.QParameter.Value.Value = 1;
77      spectralMixture3 = new CovarianceSpectralMixture(); spectralMixture3.QParameter.Value.Value = 3;
78      spectralMixture5 = new CovarianceSpectralMixture(); spectralMixture5.QParameter.Value.Value = 5;
79
80      linear = new CovarianceLinear();
81      linearArd = new CovarianceLinearArd();
82      neuralNetwork = new CovarianceNeuralNetwork();
83      periodic = new CovariancePeriodic();
84      ratQuadraticArd = new CovarianceRationalQuadraticArd();
85      ratQuadraticIso = new CovarianceRationalQuadraticIso();
86      sqrExpArd = new CovarianceSquaredExponentialArd();
87      sqrExpIso = new CovarianceSquaredExponentialIso();
88    }
89
90    private static void SetConstrainedValueParameter(IConstrainedValueParameter<IntValue> param, int val) {
91      param.Value = param.ValidValues.Single(v => v.Value == val);
92    }
93
94    #endregion
95
96    #region parameter names
97
98    private const string ProblemDataParameterName = "ProblemData";
99    private const string ConstantOptIterationsParameterName = "Constant optimization steps";
100    private const string RestartsParameterName = "Restarts";
101
102    #endregion
103
104    #region Parameter Properties
105    IParameter IDataAnalysisProblem.ProblemDataParameter { get { return ProblemDataParameter; } }
106
107    public IValueParameter<IRegressionProblemData> ProblemDataParameter {
108      get { return (IValueParameter<IRegressionProblemData>)Parameters[ProblemDataParameterName]; }
109    }
110    public IFixedValueParameter<IntValue> ConstantOptIterationsParameter {
111      get { return (IFixedValueParameter<IntValue>)Parameters[ConstantOptIterationsParameterName]; }
112    }
113    public IFixedValueParameter<IntValue> RestartsParameter {
114      get { return (IFixedValueParameter<IntValue>)Parameters[RestartsParameterName]; }
115    }
116    #endregion
117
118    #region Properties
119
120    public IRegressionProblemData ProblemData {
121      get { return ProblemDataParameter.Value; }
122      set { ProblemDataParameter.Value = value; }
123    }
124    IDataAnalysisProblemData IDataAnalysisProblem.ProblemData { get { return ProblemData; } }
125
126    public int ConstantOptIterations {
127      get { return ConstantOptIterationsParameter.Value.Value; }
128      set { ConstantOptIterationsParameter.Value.Value = value; }
129    }
130
131    public int Restarts {
132      get { return RestartsParameter.Value.Value; }
133      set { RestartsParameter.Value.Value = value; }
134    }
135    #endregion
136
137    public override bool Maximization {
138      get { return true; } // return log likelihood (instead of negative log likelihood as in GPR
139    }
140
141    // problem stores a few variables for information exchange from Evaluate() to Analyze()
142    private readonly object problemStateLocker = new object();
143    [Storable]
144    private double bestQ;
145    [Storable]
146    private double[] bestHyperParameters;
147    [Storable]
148    private IMeanFunction meanFunc;
149    [Storable]
150    private ICovarianceFunction covFunc;
151
152    public GaussianProcessCovarianceOptimizationProblem()
153      : base() {
154      Parameters.Add(new ValueParameter<IRegressionProblemData>(ProblemDataParameterName, "The data for the regression problem", new RegressionProblemData()));
155      Parameters.Add(new FixedValueParameter<IntValue>(ConstantOptIterationsParameterName, "Number of optimization steps for hyperparameter values", new IntValue(50)));
156      Parameters.Add(new FixedValueParameter<IntValue>(RestartsParameterName, "The number of random restarts for constant optimization.", new IntValue(10)));
157      Parameters["Restarts"].Hidden = true;
158      var g = new SimpleSymbolicExpressionGrammar();
159      g.AddSymbols(new string[] { "Sum", "Product" }, 2, 2);
160      g.AddTerminalSymbols(new string[]
161      {
162        "Linear",
163        "LinearArd",
164        "MaternIso1",
165        "MaternIso3",
166        "MaternIso5",
167        "NeuralNetwork",
168        "Periodic",
169        "PiecewisePolynomial0",
170        "PiecewisePolynomial1",
171        "PiecewisePolynomial2",
172        "PiecewisePolynomial3",
173        "Polynomial2",
174        "Polynomial3",
175        "RationalQuadraticArd",
176        "RationalQuadraticIso",
177        "SpectralMixture1",
178        "SpectralMixture3",
179        "SpectralMixture5",
180        "SquaredExponentialArd",
181        "SquaredExponentialIso"
182      });
183      base.Encoding = new SymbolicExpressionTreeEncoding(g, 10, 5);
184    }
185
186    public void InitializeState() { ClearState(); }
187    public void ClearState() {
188      meanFunc = null;
189      covFunc = null;
190      bestQ = double.NegativeInfinity;
191      bestHyperParameters = null;
192    }
193
194    // Does not produce the same result for the same seed when using parallel engine (see below)!
195    public override double Evaluate(ISymbolicExpressionTree tree, IRandom random) {
196      var meanFunction = new MeanConst();
197      var problemData = ProblemData;
198      var ds = problemData.Dataset;
199      var targetVariable = problemData.TargetVariable;
200      var allowedInputVariables = problemData.AllowedInputVariables.ToArray();
201      var nVars = allowedInputVariables.Length;
202      var trainingRows = problemData.TrainingIndices.ToArray();
203
204      // use the same covariance function for each restart
205      var covarianceFunction = TreeToCovarianceFunction(tree);
206
207      // allocate hyperparameters
208      var hyperParameters = new double[meanFunction.GetNumberOfParameters(nVars) + covarianceFunction.GetNumberOfParameters(nVars) + 1]; // mean + cov + noise
209      double[] bestHyperParameters = new double[hyperParameters.Length];
210      var bestObjValue = new double[1] { double.MinValue };
211
212      // data that is necessary for the objective function
213      var data = Tuple.Create(ds, targetVariable, allowedInputVariables, trainingRows, (IMeanFunction)meanFunction, covarianceFunction, bestObjValue);
214
215      for (int t = 0; t < Restarts; t++) {
216        var prevBest = bestObjValue[0];
217        var prevBestHyperParameters = new double[hyperParameters.Length];
218        Array.Copy(bestHyperParameters, prevBestHyperParameters, bestHyperParameters.Length);
219
220        // initialize hyperparameters
221        hyperParameters[0] = ds.GetDoubleValues(targetVariable).Average(); // mean const
222
223        // Evaluate might be called concurrently therefore access to random has to be synchronized.
224        // However, results of multiple runs with the same seed will be different when using the parallel engine.
225        lock (random) {
226          for (int i = 0; i < covarianceFunction.GetNumberOfParameters(nVars); i++) {
227            hyperParameters[1 + i] = random.NextDouble() * 2.0 - 1.0;
228          }
229        }
230        hyperParameters[hyperParameters.Length - 1] = 1.0; // s² = exp(2), TODO: other inits better?
231
232        // use alglib.bfgs for hyper-parameter optimization ...
233        double epsg = 0;
234        double epsf = 0.00001;
235        double epsx = 0;
236        double stpmax = 1;
237        int maxits = ConstantOptIterations;
238        alglib.mincgstate state;
239        alglib.mincgreport rep;
240
241        alglib.mincgcreate(hyperParameters, out state);
242        alglib.mincgsetcond(state, epsg, epsf, epsx, maxits);
243        alglib.mincgsetstpmax(state, stpmax);
244        alglib.mincgoptimize(state, ObjectiveFunction, null, data);
245
246        alglib.mincgresults(state, out bestHyperParameters, out rep);
247
248        if (rep.terminationtype < 0) {
249          // error -> restore previous best quality
250          bestObjValue[0] = prevBest;
251          Array.Copy(prevBestHyperParameters, bestHyperParameters, prevBestHyperParameters.Length);
252        }
253      }
254
255      UpdateBestSoFar(bestObjValue[0], bestHyperParameters, meanFunction, covarianceFunction);
256
257      return bestObjValue[0];
258    }
259
260    // updates the overall best quality and overall best model for Analyze()
261    private void UpdateBestSoFar(double bestQ, double[] bestHyperParameters, IMeanFunction meanFunc, ICovarianceFunction covFunc) {
262      lock (problemStateLocker) {
263        if (bestQ > this.bestQ) {
264          this.bestQ = bestQ;
265          this.bestHyperParameters = new double[bestHyperParameters.Length];
266          Array.Copy(bestHyperParameters, this.bestHyperParameters, this.bestHyperParameters.Length);
267          this.meanFunc = meanFunc;
268          this.covFunc = covFunc;
269        }
270      }
271    }
272
273    public override void Analyze(ISymbolicExpressionTree[] trees, double[] qualities, ResultCollection results, IRandom random) {
274      if (!results.ContainsKey("Best Solution Quality")) {
275        results.Add(new Result("Best Solution Quality", typeof(DoubleValue)));
276      }
277      if (!results.ContainsKey("Best Tree")) {
278        results.Add(new Result("Best Tree", typeof(ISymbolicExpressionTree)));
279      }
280      if (!results.ContainsKey("Best Solution")) {
281        results.Add(new Result("Best Solution", typeof(GaussianProcessRegressionSolution)));
282      }
283
284      var bestQuality = qualities.Max();
285
286      if (results["Best Solution Quality"].Value == null || bestQuality > ((DoubleValue)results["Best Solution Quality"].Value).Value) {
287        var bestIdx = Array.IndexOf(qualities, bestQuality);
288        var bestClone = (ISymbolicExpressionTree)trees[bestIdx].Clone();
289        results["Best Tree"].Value = bestClone;
290        results["Best Solution Quality"].Value = new DoubleValue(bestQuality);
291        results["Best Solution"].Value = CreateSolution();
292      }
293    }
294
295    private IItem CreateSolution() {
296      var problemData = ProblemData;
297      var ds = problemData.Dataset;
298      var targetVariable = problemData.TargetVariable;
299      var allowedInputVariables = problemData.AllowedInputVariables.ToArray();
300      var trainingRows = problemData.TrainingIndices.ToArray();
301
302      lock (problemStateLocker) {
303        var model = new GaussianProcessModel(ds, targetVariable, allowedInputVariables, trainingRows, bestHyperParameters, (IMeanFunction)meanFunc.Clone(), (ICovarianceFunction)covFunc.Clone());
304        model.FixParameters();
305        return model.CreateRegressionSolution((IRegressionProblemData)ProblemData.Clone());
306      }
307    }
308
309    private void ObjectiveFunction(double[] x, ref double func, double[] grad, object obj) {
310      // we want to optimize the model likelihood by changing the hyperparameters and also return the gradient for each hyperparameter
311      var data = (Tuple<IDataset, string, string[], int[], IMeanFunction, ICovarianceFunction, double[]>)obj;
312      var ds = data.Item1;
313      var targetVariable = data.Item2;
314      var allowedInputVariables = data.Item3;
315      var trainingRows = data.Item4;
316      var meanFunction = data.Item5;
317      var covarianceFunction = data.Item6;
318      var bestObjValue = data.Item7;
319      var hyperParameters = x; // the decision variable vector
320
321      try {
322        var model = new GaussianProcessModel(ds, targetVariable, allowedInputVariables, trainingRows, hyperParameters, meanFunction, covarianceFunction);
323
324        func = model.NegativeLogLikelihood; // mincgoptimize, so we return negative likelihood
325        bestObjValue[0] = Math.Max(bestObjValue[0], -func); // problem itself is a maximization problem
326        var gradients = model.HyperparameterGradients;
327        Array.Copy(gradients, grad, gradients.Length);
328      }
329      catch (ArgumentException) {
330        // building the GaussianProcessModel might fail, in this case we return the worst possible objective value
331        func = 1.0E+300;
332        Array.Clear(grad, 0, grad.Length);
333      }
334    }
335
336    private ICovarianceFunction TreeToCovarianceFunction(ISymbolicExpressionTree tree) {
337      return TreeToCovarianceFunction(tree.Root.GetSubtree(0).GetSubtree(0)); // skip programroot and startsymbol
338    }
339
340    private ICovarianceFunction TreeToCovarianceFunction(ISymbolicExpressionTreeNode node) {
341      switch (node.Symbol.Name) {
342        case "Sum": {
343            var sum = new CovarianceSum();
344            sum.Terms.Add(TreeToCovarianceFunction(node.GetSubtree(0)));
345            sum.Terms.Add(TreeToCovarianceFunction(node.GetSubtree(1)));
346            return sum;
347          }
348        case "Product": {
349            var prod = new CovarianceProduct();
350            prod.Factors.Add(TreeToCovarianceFunction(node.GetSubtree(0)));
351            prod.Factors.Add(TreeToCovarianceFunction(node.GetSubtree(1)));
352            return prod;
353          }
354        // covFunction is cloned by the model so we can reuse instances of terminal covariance functions
355        case "Linear": return linear;
356        case "LinearArd": return linearArd;
357        case "MaternIso1": return maternIso1;
358        case "MaternIso3": return maternIso3;
359        case "MaternIso5": return maternIso5;
360        case "NeuralNetwork": return neuralNetwork;
361        case "Periodic": return periodic;
362        case "PiecewisePolynomial0": return piecewisePoly0;
363        case "PiecewisePolynomial1": return piecewisePoly1;
364        case "PiecewisePolynomial2": return piecewisePoly2;
365        case "PiecewisePolynomial3": return piecewisePoly3;
366        case "Polynomial2": return poly2;
367        case "Polynomial3": return poly3;
368        case "RationalQuadraticArd": return ratQuadraticArd;
369        case "RationalQuadraticIso": return ratQuadraticIso;
370        case "SpectralMixture1": return spectralMixture1;
371        case "SpectralMixture3": return spectralMixture3;
372        case "SpectralMixture5": return spectralMixture5;
373        case "SquaredExponentialArd": return sqrExpArd;
374        case "SquaredExponentialIso": return sqrExpIso;
375        default: throw new InvalidProgramException(string.Format("Found invalid symbol {0}", node.Symbol.Name));
376      }
377    }
378
379
380    // persistence
381    [StorableConstructor]
382    private GaussianProcessCovarianceOptimizationProblem(bool deserializing) : base(deserializing) { }
383    [StorableHook(HookType.AfterDeserialization)]
384    private void AfterDeserialization() {
385    }
386
387    // cloning
388    private GaussianProcessCovarianceOptimizationProblem(GaussianProcessCovarianceOptimizationProblem original, Cloner cloner)
389      : base(original, cloner) {
390      bestQ = original.bestQ;
391      meanFunc = cloner.Clone(original.meanFunc);
392      covFunc = cloner.Clone(original.covFunc);
393      if (bestHyperParameters != null) {
394        bestHyperParameters = new double[original.bestHyperParameters.Length];
395        Array.Copy(original.bestHyperParameters, bestHyperParameters, bestHyperParameters.Length);
396      }
397    }
398    public override IDeepCloneable Clone(Cloner cloner) {
399      return new GaussianProcessCovarianceOptimizationProblem(this, cloner);
400    }
401
402    public void Load(IRegressionProblemData data) {
403      this.ProblemData = data;
404      OnProblemDataChanged();
405    }
406
407    public IRegressionProblemData Export() {
408      return ProblemData;
409    }
410
411    #region events
412    public event EventHandler ProblemDataChanged;
413
414
415    private void OnProblemDataChanged() {
416      var handler = ProblemDataChanged;
417      if (handler != null)
418        handler(this, EventArgs.Empty);
419    }
420    #endregion
421
422  }
423}
Note: See TracBrowser for help on using the repository browser.