Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2521_ProblemRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessCovarianceOptimizationProblem.cs @ 17655

Last change on this file since 17655 was 17655, checked in by abeham, 4 years ago

#2521: adapted readonly of reference parameters

File size: 19.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Linq;
24using System.Threading;
25using HEAL.Attic;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Optimization;
31using HeuristicLab.Parameters;
32using HeuristicLab.Problems.DataAnalysis;
33using HeuristicLab.Problems.Instances;
34
35
36namespace HeuristicLab.Algorithms.DataAnalysis {
37  [Item("Gaussian Process Covariance Optimization Problem", "")]
38  [Creatable(CreatableAttribute.Categories.GeneticProgrammingProblems, Priority = 300)]
39  [StorableType("A3EA7CE7-78FA-48FF-9DD5-FBE5AB770A99")]
40  public sealed class GaussianProcessCovarianceOptimizationProblem : SymbolicExpressionTreeProblem, IStatefulItem, IRegressionProblem, IProblemInstanceConsumer<IRegressionProblemData>, IProblemInstanceExporter<IRegressionProblemData> {
41    #region static variables and ctor
42    private static readonly CovarianceMaternIso maternIso1;
43    private static readonly CovarianceMaternIso maternIso3;
44    private static readonly CovarianceMaternIso maternIso5;
45    private static readonly CovariancePiecewisePolynomial piecewisePoly0;
46    private static readonly CovariancePiecewisePolynomial piecewisePoly1;
47    private static readonly CovariancePiecewisePolynomial piecewisePoly2;
48    private static readonly CovariancePiecewisePolynomial piecewisePoly3;
49    private static readonly CovariancePolynomial poly2;
50    private static readonly CovariancePolynomial poly3;
51    private static readonly CovarianceSpectralMixture spectralMixture1;
52    private static readonly CovarianceSpectralMixture spectralMixture3;
53    private static readonly CovarianceSpectralMixture spectralMixture5;
54    private static readonly CovarianceLinear linear;
55    private static readonly CovarianceLinearArd linearArd;
56    private static readonly CovarianceNeuralNetwork neuralNetwork;
57    private static readonly CovariancePeriodic periodic;
58    private static readonly CovarianceRationalQuadraticIso ratQuadraticIso;
59    private static readonly CovarianceRationalQuadraticArd ratQuadraticArd;
60    private static readonly CovarianceSquaredExponentialArd sqrExpArd;
61    private static readonly CovarianceSquaredExponentialIso sqrExpIso;
62
63    static GaussianProcessCovarianceOptimizationProblem() {
64      // cumbersome initialization because of ConstrainedValueParameters
65      maternIso1 = new CovarianceMaternIso(); SetConstrainedValueParameter(maternIso1.DParameter, 1);
66      maternIso3 = new CovarianceMaternIso(); SetConstrainedValueParameter(maternIso3.DParameter, 3);
67      maternIso5 = new CovarianceMaternIso(); SetConstrainedValueParameter(maternIso5.DParameter, 5);
68
69      piecewisePoly0 = new CovariancePiecewisePolynomial(); SetConstrainedValueParameter(piecewisePoly0.VParameter, 0);
70      piecewisePoly1 = new CovariancePiecewisePolynomial(); SetConstrainedValueParameter(piecewisePoly1.VParameter, 1);
71      piecewisePoly2 = new CovariancePiecewisePolynomial(); SetConstrainedValueParameter(piecewisePoly2.VParameter, 2);
72      piecewisePoly3 = new CovariancePiecewisePolynomial(); SetConstrainedValueParameter(piecewisePoly3.VParameter, 3);
73
74      poly2 = new CovariancePolynomial(); poly2.DegreeParameter.Value.Value = 2;
75      poly3 = new CovariancePolynomial(); poly3.DegreeParameter.Value.Value = 3;
76
77      spectralMixture1 = new CovarianceSpectralMixture(); spectralMixture1.QParameter.Value.Value = 1;
78      spectralMixture3 = new CovarianceSpectralMixture(); spectralMixture3.QParameter.Value.Value = 3;
79      spectralMixture5 = new CovarianceSpectralMixture(); spectralMixture5.QParameter.Value.Value = 5;
80
81      linear = new CovarianceLinear();
82      linearArd = new CovarianceLinearArd();
83      neuralNetwork = new CovarianceNeuralNetwork();
84      periodic = new CovariancePeriodic();
85      ratQuadraticArd = new CovarianceRationalQuadraticArd();
86      ratQuadraticIso = new CovarianceRationalQuadraticIso();
87      sqrExpArd = new CovarianceSquaredExponentialArd();
88      sqrExpIso = new CovarianceSquaredExponentialIso();
89    }
90
91    private static void SetConstrainedValueParameter(IConstrainedValueParameter<IntValue> param, int val) {
92      param.Value = param.ValidValues.Single(v => v.Value == val);
93    }
94
95    #endregion
96
97    #region parameter names
98
99    private const string ProblemDataParameterName = "ProblemData";
100    private const string ConstantOptIterationsParameterName = "Constant optimization steps";
101    private const string RestartsParameterName = "Restarts";
102
103    #endregion
104
105    #region Parameter Properties
106    public IValueParameter<IRegressionProblemData> ProblemDataParameter {
107      get { return (IValueParameter<IRegressionProblemData>)Parameters[ProblemDataParameterName]; }
108    }
109    public IFixedValueParameter<IntValue> ConstantOptIterationsParameter {
110      get { return (IFixedValueParameter<IntValue>)Parameters[ConstantOptIterationsParameterName]; }
111    }
112    public IFixedValueParameter<IntValue> RestartsParameter {
113      get { return (IFixedValueParameter<IntValue>)Parameters[RestartsParameterName]; }
114    }
115    #endregion
116
117    #region Properties
118
119    public IRegressionProblemData ProblemData {
120      get { return ProblemDataParameter.Value; }
121      set { ProblemDataParameter.Value = value; }
122    }
123    IDataAnalysisProblemData IDataAnalysisProblem.ProblemData { get { return ProblemData; } }
124
125    public int ConstantOptIterations {
126      get { return ConstantOptIterationsParameter.Value.Value; }
127      set { ConstantOptIterationsParameter.Value.Value = value; }
128    }
129
130    public int Restarts {
131      get { return RestartsParameter.Value.Value; }
132      set { RestartsParameter.Value.Value = value; }
133    }
134    #endregion
135
136    // problem stores a few variables for information exchange from Evaluate() to Analyze()
137    private readonly object problemStateLocker = new object();
138    [Storable]
139    private double bestQ;
140    [Storable]
141    private double[] bestHyperParameters;
142    [Storable]
143    private IMeanFunction meanFunc;
144    [Storable]
145    private ICovarianceFunction covFunc;
146
147    public GaussianProcessCovarianceOptimizationProblem() : base(new SymbolicExpressionTreeEncoding()) {
148      Maximization = true; // return log likelihood (instead of negative log likelihood as in GPR)
149      Parameters.Add(new ValueParameter<IRegressionProblemData>(ProblemDataParameterName, "The data for the regression problem", new RegressionProblemData()));
150      Parameters.Add(new FixedValueParameter<IntValue>(ConstantOptIterationsParameterName, "Number of optimization steps for hyperparameter values", new IntValue(50)));
151      Parameters.Add(new FixedValueParameter<IntValue>(RestartsParameterName, "The number of random restarts for constant optimization.", new IntValue(10)));
152      Parameters["Restarts"].Hidden = true;
153
154
155      var g = new SimpleSymbolicExpressionGrammar();
156      g.AddSymbols(new string[] { "Sum", "Product" }, 2, 2);
157      g.AddTerminalSymbols(new string[]
158      {
159        "Linear",
160        "LinearArd",
161        "MaternIso1",
162        "MaternIso3",
163        "MaternIso5",
164        "NeuralNetwork",
165        "Periodic",
166        "PiecewisePolynomial0",
167        "PiecewisePolynomial1",
168        "PiecewisePolynomial2",
169        "PiecewisePolynomial3",
170        "Polynomial2",
171        "Polynomial3",
172        "RationalQuadraticArd",
173        "RationalQuadraticIso",
174        "SpectralMixture1",
175        "SpectralMixture3",
176        "SpectralMixture5",
177        "SquaredExponentialArd",
178        "SquaredExponentialIso"
179      });
180
181      Encoding.TreeLength = 10;
182      Encoding.TreeDepth = 5;
183      Encoding.Grammar = g;
184      Encoding.GrammarParameter.ReadOnly = GrammarRefParameter.ReadOnly = true;
185    }
186
187    public void InitializeState() { ClearState(); }
188    public void ClearState() {
189      meanFunc = null;
190      covFunc = null;
191      bestQ = double.NegativeInfinity;
192      bestHyperParameters = null;
193    }
194
195    private readonly object syncRoot = new object();
196    // Does not produce the same result for the same seed when using parallel engine (see below)!
197    public override ISingleObjectiveEvaluationResult Evaluate(ISymbolicExpressionTree tree, IRandom random, CancellationToken cancellationToken) {
198      var meanFunction = new MeanConst();
199      var problemData = ProblemData;
200      var ds = problemData.Dataset;
201      var targetVariable = problemData.TargetVariable;
202      var allowedInputVariables = problemData.AllowedInputVariables.ToArray();
203      var nVars = allowedInputVariables.Length;
204      var trainingRows = problemData.TrainingIndices.ToArray();
205
206      // use the same covariance function for each restart
207      var covarianceFunction = TreeToCovarianceFunction(tree);
208
209      // allocate hyperparameters
210      var hyperParameters = new double[meanFunction.GetNumberOfParameters(nVars) + covarianceFunction.GetNumberOfParameters(nVars) + 1]; // mean + cov + noise
211      double[] bestHyperParameters = new double[hyperParameters.Length];
212      var bestObjValue = new double[1] { double.MinValue };
213
214      // data that is necessary for the objective function
215      var data = Tuple.Create(ds, targetVariable, allowedInputVariables, trainingRows, (IMeanFunction)meanFunction, covarianceFunction, bestObjValue);
216
217      for (int t = 0; t < Restarts; t++) {
218        var prevBest = bestObjValue[0];
219        var prevBestHyperParameters = new double[hyperParameters.Length];
220        Array.Copy(bestHyperParameters, prevBestHyperParameters, bestHyperParameters.Length);
221
222        // initialize hyperparameters
223        hyperParameters[0] = ds.GetDoubleValues(targetVariable).Average(); // mean const
224
225        // Evaluate might be called concurrently therefore access to random has to be synchronized.
226        // However, results of multiple runs with the same seed will be different when using the parallel engine.
227        lock (syncRoot) {
228          for (int i = 0; i < covarianceFunction.GetNumberOfParameters(nVars); i++) {
229            hyperParameters[1 + i] = random.NextDouble() * 2.0 - 1.0;
230          }
231        }
232        hyperParameters[hyperParameters.Length - 1] = 1.0; // s² = exp(2), TODO: other inits better?
233
234        // use alglib.bfgs for hyper-parameter optimization ...
235        double epsg = 0;
236        double epsf = 0.00001;
237        double epsx = 0;
238        double stpmax = 1;
239        int maxits = ConstantOptIterations;
240        alglib.mincgstate state;
241        alglib.mincgreport rep;
242
243        alglib.mincgcreate(hyperParameters, out state);
244        alglib.mincgsetcond(state, epsg, epsf, epsx, maxits);
245        alglib.mincgsetstpmax(state, stpmax);
246        alglib.mincgoptimize(state, ObjectiveFunction, null, data);
247
248        alglib.mincgresults(state, out bestHyperParameters, out rep);
249
250        if (rep.terminationtype < 0) {
251          // error -> restore previous best quality
252          bestObjValue[0] = prevBest;
253          Array.Copy(prevBestHyperParameters, bestHyperParameters, prevBestHyperParameters.Length);
254        }
255      }
256
257      UpdateBestSoFar(bestObjValue[0], bestHyperParameters, meanFunction, covarianceFunction);
258
259      var quality = bestObjValue[0];
260      return new SingleObjectiveEvaluationResult(quality);
261    }
262
263    // updates the overall best quality and overall best model for Analyze()
264    private void UpdateBestSoFar(double bestQ, double[] bestHyperParameters, IMeanFunction meanFunc, ICovarianceFunction covFunc) {
265      lock (problemStateLocker) {
266        if (bestQ > this.bestQ) {
267          this.bestQ = bestQ;
268          this.bestHyperParameters = new double[bestHyperParameters.Length];
269          Array.Copy(bestHyperParameters, this.bestHyperParameters, this.bestHyperParameters.Length);
270          this.meanFunc = meanFunc;
271          this.covFunc = covFunc;
272        }
273      }
274    }
275
276    public override void Analyze(ISymbolicExpressionTree[] trees, double[] qualities, ResultCollection results, IRandom random) {
277      if (!results.ContainsKey("Best Solution Quality")) {
278        results.Add(new Result("Best Solution Quality", typeof(DoubleValue)));
279      }
280      if (!results.ContainsKey("Best Tree")) {
281        results.Add(new Result("Best Tree", typeof(ISymbolicExpressionTree)));
282      }
283      if (!results.ContainsKey("Best Solution")) {
284        results.Add(new Result("Best Solution", typeof(GaussianProcessRegressionSolution)));
285      }
286
287      var bestQuality = qualities.Max();
288
289      if (results["Best Solution Quality"].Value == null || bestQuality > ((DoubleValue)results["Best Solution Quality"].Value).Value) {
290        var bestIdx = Array.IndexOf(qualities, bestQuality);
291        var bestClone = (ISymbolicExpressionTree)trees[bestIdx].Clone();
292        results["Best Tree"].Value = bestClone;
293        results["Best Solution Quality"].Value = new DoubleValue(bestQuality);
294        results["Best Solution"].Value = CreateSolution();
295      }
296    }
297
298    private IItem CreateSolution() {
299      var problemData = ProblemData;
300      var ds = problemData.Dataset;
301      var targetVariable = problemData.TargetVariable;
302      var allowedInputVariables = problemData.AllowedInputVariables.ToArray();
303      var trainingRows = problemData.TrainingIndices.ToArray();
304
305      lock (problemStateLocker) {
306        var model = new GaussianProcessModel(ds, targetVariable, allowedInputVariables, trainingRows, bestHyperParameters, (IMeanFunction)meanFunc.Clone(), (ICovarianceFunction)covFunc.Clone());
307        model.FixParameters();
308        return model.CreateRegressionSolution((IRegressionProblemData)ProblemData.Clone());
309      }
310    }
311
312    private void ObjectiveFunction(double[] x, ref double func, double[] grad, object obj) {
313      // we want to optimize the model likelihood by changing the hyperparameters and also return the gradient for each hyperparameter
314      var data = (Tuple<IDataset, string, string[], int[], IMeanFunction, ICovarianceFunction, double[]>)obj;
315      var ds = data.Item1;
316      var targetVariable = data.Item2;
317      var allowedInputVariables = data.Item3;
318      var trainingRows = data.Item4;
319      var meanFunction = data.Item5;
320      var covarianceFunction = data.Item6;
321      var bestObjValue = data.Item7;
322      var hyperParameters = x; // the decision variable vector
323
324      try {
325        var model = new GaussianProcessModel(ds, targetVariable, allowedInputVariables, trainingRows, hyperParameters, meanFunction, covarianceFunction);
326
327        func = model.NegativeLogLikelihood; // mincgoptimize, so we return negative likelihood
328        bestObjValue[0] = Math.Max(bestObjValue[0], -func); // problem itself is a maximization problem
329        var gradients = model.HyperparameterGradients;
330        Array.Copy(gradients, grad, gradients.Length);
331      } catch (ArgumentException) {
332        // building the GaussianProcessModel might fail, in this case we return the worst possible objective value
333        func = 1.0E+300;
334        Array.Clear(grad, 0, grad.Length);
335      }
336    }
337
338    private ICovarianceFunction TreeToCovarianceFunction(ISymbolicExpressionTree tree) {
339      return TreeToCovarianceFunction(tree.Root.GetSubtree(0).GetSubtree(0)); // skip programroot and startsymbol
340    }
341
342    private ICovarianceFunction TreeToCovarianceFunction(ISymbolicExpressionTreeNode node) {
343      switch (node.Symbol.Name) {
344        case "Sum": {
345            var sum = new CovarianceSum();
346            sum.Terms.Add(TreeToCovarianceFunction(node.GetSubtree(0)));
347            sum.Terms.Add(TreeToCovarianceFunction(node.GetSubtree(1)));
348            return sum;
349          }
350        case "Product": {
351            var prod = new CovarianceProduct();
352            prod.Factors.Add(TreeToCovarianceFunction(node.GetSubtree(0)));
353            prod.Factors.Add(TreeToCovarianceFunction(node.GetSubtree(1)));
354            return prod;
355          }
356        // covFunction is cloned by the model so we can reuse instances of terminal covariance functions
357        case "Linear": return linear;
358        case "LinearArd": return linearArd;
359        case "MaternIso1": return maternIso1;
360        case "MaternIso3": return maternIso3;
361        case "MaternIso5": return maternIso5;
362        case "NeuralNetwork": return neuralNetwork;
363        case "Periodic": return periodic;
364        case "PiecewisePolynomial0": return piecewisePoly0;
365        case "PiecewisePolynomial1": return piecewisePoly1;
366        case "PiecewisePolynomial2": return piecewisePoly2;
367        case "PiecewisePolynomial3": return piecewisePoly3;
368        case "Polynomial2": return poly2;
369        case "Polynomial3": return poly3;
370        case "RationalQuadraticArd": return ratQuadraticArd;
371        case "RationalQuadraticIso": return ratQuadraticIso;
372        case "SpectralMixture1": return spectralMixture1;
373        case "SpectralMixture3": return spectralMixture3;
374        case "SpectralMixture5": return spectralMixture5;
375        case "SquaredExponentialArd": return sqrExpArd;
376        case "SquaredExponentialIso": return sqrExpIso;
377        default: throw new InvalidProgramException(string.Format("Found invalid symbol {0}", node.Symbol.Name));
378      }
379    }
380
381
382    // persistence
383    [StorableConstructor]
384    private GaussianProcessCovarianceOptimizationProblem(StorableConstructorFlag _) : base(_) { }
385    [StorableHook(HookType.AfterDeserialization)]
386    private void AfterDeserialization() {
387    }
388
389    // cloning
390    private GaussianProcessCovarianceOptimizationProblem(GaussianProcessCovarianceOptimizationProblem original, Cloner cloner)
391      : base(original, cloner) {
392      bestQ = original.bestQ;
393      meanFunc = cloner.Clone(original.meanFunc);
394      covFunc = cloner.Clone(original.covFunc);
395      if (bestHyperParameters != null) {
396        bestHyperParameters = new double[original.bestHyperParameters.Length];
397        Array.Copy(original.bestHyperParameters, bestHyperParameters, bestHyperParameters.Length);
398      }
399    }
400    public override IDeepCloneable Clone(Cloner cloner) {
401      return new GaussianProcessCovarianceOptimizationProblem(this, cloner);
402    }
403
404    public void Load(IRegressionProblemData data) {
405      this.ProblemData = data;
406      OnProblemDataChanged();
407    }
408
409    public IRegressionProblemData Export() {
410      return ProblemData;
411    }
412
413    #region events
414    public event EventHandler ProblemDataChanged;
415
416
417    private void OnProblemDataChanged() {
418      var handler = ProblemDataChanged;
419      if (handler != null)
420        handler(this, EventArgs.Empty);
421    }
422    #endregion
423
424  }
425}
Note: See TracBrowser for help on using the repository browser.