source: branches/2839_HiveProjectManagement/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessCovarianceOptimizationProblem.cs @ 16057

Last change on this file since 16057 was 16057, checked in by jkarder, 15 months ago

#2839:

File size: 19.0 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Linq;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Data;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Optimization;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31using HeuristicLab.Problems.DataAnalysis;
32using HeuristicLab.Problems.Instances;
33
34
35namespace HeuristicLab.Algorithms.DataAnalysis {
36  [Item("Gaussian Process Covariance Optimization Problem", "")]
37  [Creatable(CreatableAttribute.Categories.GeneticProgrammingProblems, Priority = 300)]
38  [StorableClass]
39  public sealed class GaussianProcessCovarianceOptimizationProblem : SymbolicExpressionTreeProblem, IStatefulItem, IRegressionProblem, IProblemInstanceConsumer<IRegressionProblemData>, IProblemInstanceExporter<IRegressionProblemData> {
40    #region static variables and ctor
41    private static readonly CovarianceMaternIso maternIso1;
42    private static readonly CovarianceMaternIso maternIso3;
43    private static readonly CovarianceMaternIso maternIso5;
44    private static readonly CovariancePiecewisePolynomial piecewisePoly0;
45    private static readonly CovariancePiecewisePolynomial piecewisePoly1;
46    private static readonly CovariancePiecewisePolynomial piecewisePoly2;
47    private static readonly CovariancePiecewisePolynomial piecewisePoly3;
48    private static readonly CovariancePolynomial poly2;
49    private static readonly CovariancePolynomial poly3;
50    private static readonly CovarianceSpectralMixture spectralMixture1;
51    private static readonly CovarianceSpectralMixture spectralMixture3;
52    private static readonly CovarianceSpectralMixture spectralMixture5;
53    private static readonly CovarianceLinear linear;
54    private static readonly CovarianceLinearArd linearArd;
55    private static readonly CovarianceNeuralNetwork neuralNetwork;
56    private static readonly CovariancePeriodic periodic;
57    private static readonly CovarianceRationalQuadraticIso ratQuadraticIso;
58    private static readonly CovarianceRationalQuadraticArd ratQuadraticArd;
59    private static readonly CovarianceSquaredExponentialArd sqrExpArd;
60    private static readonly CovarianceSquaredExponentialIso sqrExpIso;
61
62    static GaussianProcessCovarianceOptimizationProblem() {
63      // cumbersome initialization because of ConstrainedValueParameters
64      maternIso1 = new CovarianceMaternIso(); SetConstrainedValueParameter(maternIso1.DParameter, 1);
65      maternIso3 = new CovarianceMaternIso(); SetConstrainedValueParameter(maternIso3.DParameter, 3);
66      maternIso5 = new CovarianceMaternIso(); SetConstrainedValueParameter(maternIso5.DParameter, 5);
67
68      piecewisePoly0 = new CovariancePiecewisePolynomial(); SetConstrainedValueParameter(piecewisePoly0.VParameter, 0);
69      piecewisePoly1 = new CovariancePiecewisePolynomial(); SetConstrainedValueParameter(piecewisePoly1.VParameter, 1);
70      piecewisePoly2 = new CovariancePiecewisePolynomial(); SetConstrainedValueParameter(piecewisePoly2.VParameter, 2);
71      piecewisePoly3 = new CovariancePiecewisePolynomial(); SetConstrainedValueParameter(piecewisePoly3.VParameter, 3);
72
73      poly2 = new CovariancePolynomial(); poly2.DegreeParameter.Value.Value = 2;
74      poly3 = new CovariancePolynomial(); poly3.DegreeParameter.Value.Value = 3;
75
76      spectralMixture1 = new CovarianceSpectralMixture(); spectralMixture1.QParameter.Value.Value = 1;
77      spectralMixture3 = new CovarianceSpectralMixture(); spectralMixture3.QParameter.Value.Value = 3;
78      spectralMixture5 = new CovarianceSpectralMixture(); spectralMixture5.QParameter.Value.Value = 5;
79
80      linear = new CovarianceLinear();
81      linearArd = new CovarianceLinearArd();
82      neuralNetwork = new CovarianceNeuralNetwork();
83      periodic = new CovariancePeriodic();
84      ratQuadraticArd = new CovarianceRationalQuadraticArd();
85      ratQuadraticIso = new CovarianceRationalQuadraticIso();
86      sqrExpArd = new CovarianceSquaredExponentialArd();
87      sqrExpIso = new CovarianceSquaredExponentialIso();
88    }
89
90    private static void SetConstrainedValueParameter(IConstrainedValueParameter<IntValue> param, int val) {
91      param.Value = param.ValidValues.Single(v => v.Value == val);
92    }
93
94    #endregion
95
96    #region parameter names
97
98    private const string ProblemDataParameterName = "ProblemData";
99    private const string ConstantOptIterationsParameterName = "Constant optimization steps";
100    private const string RestartsParameterName = "Restarts";
101
102    #endregion
103
104    #region Parameter Properties
105    IParameter IDataAnalysisProblem.ProblemDataParameter { get { return ProblemDataParameter; } }
106
107    public IValueParameter<IRegressionProblemData> ProblemDataParameter {
108      get { return (IValueParameter<IRegressionProblemData>)Parameters[ProblemDataParameterName]; }
109    }
110    public IFixedValueParameter<IntValue> ConstantOptIterationsParameter {
111      get { return (IFixedValueParameter<IntValue>)Parameters[ConstantOptIterationsParameterName]; }
112    }
113    public IFixedValueParameter<IntValue> RestartsParameter {
114      get { return (IFixedValueParameter<IntValue>)Parameters[RestartsParameterName]; }
115    }
116    #endregion
117
118    #region Properties
119
120    public IRegressionProblemData ProblemData {
121      get { return ProblemDataParameter.Value; }
122      set { ProblemDataParameter.Value = value; }
123    }
124    IDataAnalysisProblemData IDataAnalysisProblem.ProblemData { get { return ProblemData; } }
125
126    public int ConstantOptIterations {
127      get { return ConstantOptIterationsParameter.Value.Value; }
128      set { ConstantOptIterationsParameter.Value.Value = value; }
129    }
130
131    public int Restarts {
132      get { return RestartsParameter.Value.Value; }
133      set { RestartsParameter.Value.Value = value; }
134    }
135    #endregion
136
137    public override bool Maximization {
138      get { return true; } // return log likelihood (instead of negative log likelihood as in GPR
139    }
140
141    // problem stores a few variables for information exchange from Evaluate() to Analyze()
142    private readonly object problemStateLocker = new object();
143    [Storable]
144    private double bestQ;
145    [Storable]
146    private double[] bestHyperParameters;
147    [Storable]
148    private IMeanFunction meanFunc;
149    [Storable]
150    private ICovarianceFunction covFunc;
151
152    public GaussianProcessCovarianceOptimizationProblem()
153      : base() {
154      Parameters.Add(new ValueParameter<IRegressionProblemData>(ProblemDataParameterName, "The data for the regression problem", new RegressionProblemData()));
155      Parameters.Add(new FixedValueParameter<IntValue>(ConstantOptIterationsParameterName, "Number of optimization steps for hyperparameter values", new IntValue(50)));
156      Parameters.Add(new FixedValueParameter<IntValue>(RestartsParameterName, "The number of random restarts for constant optimization.", new IntValue(10)));
157      Parameters["Restarts"].Hidden = true;
158      var g = new SimpleSymbolicExpressionGrammar();
159      g.AddSymbols(new string[] { "Sum", "Product" }, 2, 2);
160      g.AddTerminalSymbols(new string[]
161      {
162        "Linear",
163        "LinearArd",
164        "MaternIso1",
165        "MaternIso3",
166        "MaternIso5",
167        "NeuralNetwork",
168        "Periodic",
169        "PiecewisePolynomial0",
170        "PiecewisePolynomial1",
171        "PiecewisePolynomial2",
172        "PiecewisePolynomial3",
173        "Polynomial2",
174        "Polynomial3",
175        "RationalQuadraticArd",
176        "RationalQuadraticIso",
177        "SpectralMixture1",
178        "SpectralMixture3",
179        "SpectralMixture5",
180        "SquaredExponentialArd",
181        "SquaredExponentialIso"
182      });
183      base.Encoding = new SymbolicExpressionTreeEncoding(g, 10, 5);
184    }
185
186    public void InitializeState() { ClearState(); }
187    public void ClearState() {
188      meanFunc = null;
189      covFunc = null;
190      bestQ = double.NegativeInfinity;
191      bestHyperParameters = null;
192    }
193
194    private readonly object syncRoot = new object();
195    // Does not produce the same result for the same seed when using parallel engine (see below)!
196    public override double Evaluate(ISymbolicExpressionTree tree, IRandom random) {
197      var meanFunction = new MeanConst();
198      var problemData = ProblemData;
199      var ds = problemData.Dataset;
200      var targetVariable = problemData.TargetVariable;
201      var allowedInputVariables = problemData.AllowedInputVariables.ToArray();
202      var nVars = allowedInputVariables.Length;
203      var trainingRows = problemData.TrainingIndices.ToArray();
204
205      // use the same covariance function for each restart
206      var covarianceFunction = TreeToCovarianceFunction(tree);
207
208      // allocate hyperparameters
209      var hyperParameters = new double[meanFunction.GetNumberOfParameters(nVars) + covarianceFunction.GetNumberOfParameters(nVars) + 1]; // mean + cov + noise
210      double[] bestHyperParameters = new double[hyperParameters.Length];
211      var bestObjValue = new double[1] { double.MinValue };
212
213      // data that is necessary for the objective function
214      var data = Tuple.Create(ds, targetVariable, allowedInputVariables, trainingRows, (IMeanFunction)meanFunction, covarianceFunction, bestObjValue);
215
216      for (int t = 0; t < Restarts; t++) {
217        var prevBest = bestObjValue[0];
218        var prevBestHyperParameters = new double[hyperParameters.Length];
219        Array.Copy(bestHyperParameters, prevBestHyperParameters, bestHyperParameters.Length);
220
221        // initialize hyperparameters
222        hyperParameters[0] = ds.GetDoubleValues(targetVariable).Average(); // mean const
223
224        // Evaluate might be called concurrently therefore access to random has to be synchronized.
225        // However, results of multiple runs with the same seed will be different when using the parallel engine.
226        lock (syncRoot) {
227          for (int i = 0; i < covarianceFunction.GetNumberOfParameters(nVars); i++) {
228            hyperParameters[1 + i] = random.NextDouble() * 2.0 - 1.0;
229          }
230        }
231        hyperParameters[hyperParameters.Length - 1] = 1.0; // s² = exp(2), TODO: other inits better?
232
233        // use alglib.bfgs for hyper-parameter optimization ...
234        double epsg = 0;
235        double epsf = 0.00001;
236        double epsx = 0;
237        double stpmax = 1;
238        int maxits = ConstantOptIterations;
239        alglib.mincgstate state;
240        alglib.mincgreport rep;
241
242        alglib.mincgcreate(hyperParameters, out state);
243        alglib.mincgsetcond(state, epsg, epsf, epsx, maxits);
244        alglib.mincgsetstpmax(state, stpmax);
245        alglib.mincgoptimize(state, ObjectiveFunction, null, data);
246
247        alglib.mincgresults(state, out bestHyperParameters, out rep);
248
249        if (rep.terminationtype < 0) {
250          // error -> restore previous best quality
251          bestObjValue[0] = prevBest;
252          Array.Copy(prevBestHyperParameters, bestHyperParameters, prevBestHyperParameters.Length);
253        }
254      }
255
256      UpdateBestSoFar(bestObjValue[0], bestHyperParameters, meanFunction, covarianceFunction);
257
258      return bestObjValue[0];
259    }
260
261    // updates the overall best quality and overall best model for Analyze()
262    private void UpdateBestSoFar(double bestQ, double[] bestHyperParameters, IMeanFunction meanFunc, ICovarianceFunction covFunc) {
263      lock (problemStateLocker) {
264        if (bestQ > this.bestQ) {
265          this.bestQ = bestQ;
266          this.bestHyperParameters = new double[bestHyperParameters.Length];
267          Array.Copy(bestHyperParameters, this.bestHyperParameters, this.bestHyperParameters.Length);
268          this.meanFunc = meanFunc;
269          this.covFunc = covFunc;
270        }
271      }
272    }
273
274    public override void Analyze(ISymbolicExpressionTree[] trees, double[] qualities, ResultCollection results, IRandom random) {
275      if (!results.ContainsKey("Best Solution Quality")) {
276        results.Add(new Result("Best Solution Quality", typeof(DoubleValue)));
277      }
278      if (!results.ContainsKey("Best Tree")) {
279        results.Add(new Result("Best Tree", typeof(ISymbolicExpressionTree)));
280      }
281      if (!results.ContainsKey("Best Solution")) {
282        results.Add(new Result("Best Solution", typeof(GaussianProcessRegressionSolution)));
283      }
284
285      var bestQuality = qualities.Max();
286
287      if (results["Best Solution Quality"].Value == null || bestQuality > ((DoubleValue)results["Best Solution Quality"].Value).Value) {
288        var bestIdx = Array.IndexOf(qualities, bestQuality);
289        var bestClone = (ISymbolicExpressionTree)trees[bestIdx].Clone();
290        results["Best Tree"].Value = bestClone;
291        results["Best Solution Quality"].Value = new DoubleValue(bestQuality);
292        results["Best Solution"].Value = CreateSolution();
293      }
294    }
295
296    private IItem CreateSolution() {
297      var problemData = ProblemData;
298      var ds = problemData.Dataset;
299      var targetVariable = problemData.TargetVariable;
300      var allowedInputVariables = problemData.AllowedInputVariables.ToArray();
301      var trainingRows = problemData.TrainingIndices.ToArray();
302
303      lock (problemStateLocker) {
304        var model = new GaussianProcessModel(ds, targetVariable, allowedInputVariables, trainingRows, bestHyperParameters, (IMeanFunction)meanFunc.Clone(), (ICovarianceFunction)covFunc.Clone());
305        model.FixParameters();
306        return model.CreateRegressionSolution((IRegressionProblemData)ProblemData.Clone());
307      }
308    }
309
310    private void ObjectiveFunction(double[] x, ref double func, double[] grad, object obj) {
311      // we want to optimize the model likelihood by changing the hyperparameters and also return the gradient for each hyperparameter
312      var data = (Tuple<IDataset, string, string[], int[], IMeanFunction, ICovarianceFunction, double[]>)obj;
313      var ds = data.Item1;
314      var targetVariable = data.Item2;
315      var allowedInputVariables = data.Item3;
316      var trainingRows = data.Item4;
317      var meanFunction = data.Item5;
318      var covarianceFunction = data.Item6;
319      var bestObjValue = data.Item7;
320      var hyperParameters = x; // the decision variable vector
321
322      try {
323        var model = new GaussianProcessModel(ds, targetVariable, allowedInputVariables, trainingRows, hyperParameters, meanFunction, covarianceFunction);
324
325        func = model.NegativeLogLikelihood; // mincgoptimize, so we return negative likelihood
326        bestObjValue[0] = Math.Max(bestObjValue[0], -func); // problem itself is a maximization problem
327        var gradients = model.HyperparameterGradients;
328        Array.Copy(gradients, grad, gradients.Length);
329      }
330      catch (ArgumentException) {
331        // building the GaussianProcessModel might fail, in this case we return the worst possible objective value
332        func = 1.0E+300;
333        Array.Clear(grad, 0, grad.Length);
334      }
335    }
336
337    private ICovarianceFunction TreeToCovarianceFunction(ISymbolicExpressionTree tree) {
338      return TreeToCovarianceFunction(tree.Root.GetSubtree(0).GetSubtree(0)); // skip programroot and startsymbol
339    }
340
341    private ICovarianceFunction TreeToCovarianceFunction(ISymbolicExpressionTreeNode node) {
342      switch (node.Symbol.Name) {
343        case "Sum": {
344            var sum = new CovarianceSum();
345            sum.Terms.Add(TreeToCovarianceFunction(node.GetSubtree(0)));
346            sum.Terms.Add(TreeToCovarianceFunction(node.GetSubtree(1)));
347            return sum;
348          }
349        case "Product": {
350            var prod = new CovarianceProduct();
351            prod.Factors.Add(TreeToCovarianceFunction(node.GetSubtree(0)));
352            prod.Factors.Add(TreeToCovarianceFunction(node.GetSubtree(1)));
353            return prod;
354          }
355        // covFunction is cloned by the model so we can reuse instances of terminal covariance functions
356        case "Linear": return linear;
357        case "LinearArd": return linearArd;
358        case "MaternIso1": return maternIso1;
359        case "MaternIso3": return maternIso3;
360        case "MaternIso5": return maternIso5;
361        case "NeuralNetwork": return neuralNetwork;
362        case "Periodic": return periodic;
363        case "PiecewisePolynomial0": return piecewisePoly0;
364        case "PiecewisePolynomial1": return piecewisePoly1;
365        case "PiecewisePolynomial2": return piecewisePoly2;
366        case "PiecewisePolynomial3": return piecewisePoly3;
367        case "Polynomial2": return poly2;
368        case "Polynomial3": return poly3;
369        case "RationalQuadraticArd": return ratQuadraticArd;
370        case "RationalQuadraticIso": return ratQuadraticIso;
371        case "SpectralMixture1": return spectralMixture1;
372        case "SpectralMixture3": return spectralMixture3;
373        case "SpectralMixture5": return spectralMixture5;
374        case "SquaredExponentialArd": return sqrExpArd;
375        case "SquaredExponentialIso": return sqrExpIso;
376        default: throw new InvalidProgramException(string.Format("Found invalid symbol {0}", node.Symbol.Name));
377      }
378    }
379
380
381    // persistence
382    [StorableConstructor]
383    private GaussianProcessCovarianceOptimizationProblem(bool deserializing) : base(deserializing) { }
384    [StorableHook(HookType.AfterDeserialization)]
385    private void AfterDeserialization() {
386    }
387
388    // cloning
389    private GaussianProcessCovarianceOptimizationProblem(GaussianProcessCovarianceOptimizationProblem original, Cloner cloner)
390      : base(original, cloner) {
391      bestQ = original.bestQ;
392      meanFunc = cloner.Clone(original.meanFunc);
393      covFunc = cloner.Clone(original.covFunc);
394      if (bestHyperParameters != null) {
395        bestHyperParameters = new double[original.bestHyperParameters.Length];
396        Array.Copy(original.bestHyperParameters, bestHyperParameters, bestHyperParameters.Length);
397      }
398    }
399    public override IDeepCloneable Clone(Cloner cloner) {
400      return new GaussianProcessCovarianceOptimizationProblem(this, cloner);
401    }
402
403    public void Load(IRegressionProblemData data) {
404      this.ProblemData = data;
405      OnProblemDataChanged();
406    }
407
408    public IRegressionProblemData Export() {
409      return ProblemData;
410    }
411
412    #region events
413    public event EventHandler ProblemDataChanged;
414
415
416    private void OnProblemDataChanged() {
417      var handler = ProblemDataChanged;
418      if (handler != null)
419        handler(this, EventArgs.Empty);
420    }
421    #endregion
422
423  }
424}
Note: See TracBrowser for help on using the repository browser.