source: branches/2521_ProblemRefactoring/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessCovarianceOptimizationProblem.cs @ 17270

Last change on this file since 17270 was 17270, checked in by abeham, 3 years ago

#2521: worked on removing virtual from Maximization for single-objective problems

File size: 19.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Linq;
24using HEAL.Attic;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Optimization;
30using HeuristicLab.Parameters;
31using HeuristicLab.Problems.DataAnalysis;
32using HeuristicLab.Problems.Instances;
33
34
35namespace HeuristicLab.Algorithms.DataAnalysis {
36  [Item("Gaussian Process Covariance Optimization Problem", "")]
37  [Creatable(CreatableAttribute.Categories.GeneticProgrammingProblems, Priority = 300)]
38  [StorableType("A3EA7CE7-78FA-48FF-9DD5-FBE5AB770A99")]
39  public sealed class GaussianProcessCovarianceOptimizationProblem : SymbolicExpressionTreeProblem, IStatefulItem, IRegressionProblem, IProblemInstanceConsumer<IRegressionProblemData>, IProblemInstanceExporter<IRegressionProblemData> {
40    #region static variables and ctor
41    private static readonly CovarianceMaternIso maternIso1;
42    private static readonly CovarianceMaternIso maternIso3;
43    private static readonly CovarianceMaternIso maternIso5;
44    private static readonly CovariancePiecewisePolynomial piecewisePoly0;
45    private static readonly CovariancePiecewisePolynomial piecewisePoly1;
46    private static readonly CovariancePiecewisePolynomial piecewisePoly2;
47    private static readonly CovariancePiecewisePolynomial piecewisePoly3;
48    private static readonly CovariancePolynomial poly2;
49    private static readonly CovariancePolynomial poly3;
50    private static readonly CovarianceSpectralMixture spectralMixture1;
51    private static readonly CovarianceSpectralMixture spectralMixture3;
52    private static readonly CovarianceSpectralMixture spectralMixture5;
53    private static readonly CovarianceLinear linear;
54    private static readonly CovarianceLinearArd linearArd;
55    private static readonly CovarianceNeuralNetwork neuralNetwork;
56    private static readonly CovariancePeriodic periodic;
57    private static readonly CovarianceRationalQuadraticIso ratQuadraticIso;
58    private static readonly CovarianceRationalQuadraticArd ratQuadraticArd;
59    private static readonly CovarianceSquaredExponentialArd sqrExpArd;
60    private static readonly CovarianceSquaredExponentialIso sqrExpIso;
61
62    static GaussianProcessCovarianceOptimizationProblem() {
63      // cumbersome initialization because of ConstrainedValueParameters
64      maternIso1 = new CovarianceMaternIso(); SetConstrainedValueParameter(maternIso1.DParameter, 1);
65      maternIso3 = new CovarianceMaternIso(); SetConstrainedValueParameter(maternIso3.DParameter, 3);
66      maternIso5 = new CovarianceMaternIso(); SetConstrainedValueParameter(maternIso5.DParameter, 5);
67
68      piecewisePoly0 = new CovariancePiecewisePolynomial(); SetConstrainedValueParameter(piecewisePoly0.VParameter, 0);
69      piecewisePoly1 = new CovariancePiecewisePolynomial(); SetConstrainedValueParameter(piecewisePoly1.VParameter, 1);
70      piecewisePoly2 = new CovariancePiecewisePolynomial(); SetConstrainedValueParameter(piecewisePoly2.VParameter, 2);
71      piecewisePoly3 = new CovariancePiecewisePolynomial(); SetConstrainedValueParameter(piecewisePoly3.VParameter, 3);
72
73      poly2 = new CovariancePolynomial(); poly2.DegreeParameter.Value.Value = 2;
74      poly3 = new CovariancePolynomial(); poly3.DegreeParameter.Value.Value = 3;
75
76      spectralMixture1 = new CovarianceSpectralMixture(); spectralMixture1.QParameter.Value.Value = 1;
77      spectralMixture3 = new CovarianceSpectralMixture(); spectralMixture3.QParameter.Value.Value = 3;
78      spectralMixture5 = new CovarianceSpectralMixture(); spectralMixture5.QParameter.Value.Value = 5;
79
80      linear = new CovarianceLinear();
81      linearArd = new CovarianceLinearArd();
82      neuralNetwork = new CovarianceNeuralNetwork();
83      periodic = new CovariancePeriodic();
84      ratQuadraticArd = new CovarianceRationalQuadraticArd();
85      ratQuadraticIso = new CovarianceRationalQuadraticIso();
86      sqrExpArd = new CovarianceSquaredExponentialArd();
87      sqrExpIso = new CovarianceSquaredExponentialIso();
88    }
89
90    private static void SetConstrainedValueParameter(IConstrainedValueParameter<IntValue> param, int val) {
91      param.Value = param.ValidValues.Single(v => v.Value == val);
92    }
93
94    #endregion
95
96    #region parameter names
97
98    private const string ProblemDataParameterName = "ProblemData";
99    private const string ConstantOptIterationsParameterName = "Constant optimization steps";
100    private const string RestartsParameterName = "Restarts";
101
102    #endregion
103
104    #region Parameter Properties
105    IParameter IDataAnalysisProblem.ProblemDataParameter { get { return ProblemDataParameter; } }
106
107    public IValueParameter<IRegressionProblemData> ProblemDataParameter {
108      get { return (IValueParameter<IRegressionProblemData>)Parameters[ProblemDataParameterName]; }
109    }
110    public IFixedValueParameter<IntValue> ConstantOptIterationsParameter {
111      get { return (IFixedValueParameter<IntValue>)Parameters[ConstantOptIterationsParameterName]; }
112    }
113    public IFixedValueParameter<IntValue> RestartsParameter {
114      get { return (IFixedValueParameter<IntValue>)Parameters[RestartsParameterName]; }
115    }
116    #endregion
117
118    #region Properties
119
120    public IRegressionProblemData ProblemData {
121      get { return ProblemDataParameter.Value; }
122      set { ProblemDataParameter.Value = value; }
123    }
124    IDataAnalysisProblemData IDataAnalysisProblem.ProblemData { get { return ProblemData; } }
125
126    public int ConstantOptIterations {
127      get { return ConstantOptIterationsParameter.Value.Value; }
128      set { ConstantOptIterationsParameter.Value.Value = value; }
129    }
130
131    public int Restarts {
132      get { return RestartsParameter.Value.Value; }
133      set { RestartsParameter.Value.Value = value; }
134    }
135    #endregion
136
137    // problem stores a few variables for information exchange from Evaluate() to Analyze()
138    private readonly object problemStateLocker = new object();
139    [Storable]
140    private double bestQ;
141    [Storable]
142    private double[] bestHyperParameters;
143    [Storable]
144    private IMeanFunction meanFunc;
145    [Storable]
146    private ICovarianceFunction covFunc;
147
148    public GaussianProcessCovarianceOptimizationProblem() : base(new SymbolicExpressionTreeEncoding()) {
149      Maximization = true; // return log likelihood (instead of negative log likelihood as in GPR)
150      Parameters.Add(new ValueParameter<IRegressionProblemData>(ProblemDataParameterName, "The data for the regression problem", new RegressionProblemData()));
151      Parameters.Add(new FixedValueParameter<IntValue>(ConstantOptIterationsParameterName, "Number of optimization steps for hyperparameter values", new IntValue(50)));
152      Parameters.Add(new FixedValueParameter<IntValue>(RestartsParameterName, "The number of random restarts for constant optimization.", new IntValue(10)));
153      Parameters["Restarts"].Hidden = true;
154
155
156      var g = new SimpleSymbolicExpressionGrammar();
157      g.AddSymbols(new string[] { "Sum", "Product" }, 2, 2);
158      g.AddTerminalSymbols(new string[]
159      {
160        "Linear",
161        "LinearArd",
162        "MaternIso1",
163        "MaternIso3",
164        "MaternIso5",
165        "NeuralNetwork",
166        "Periodic",
167        "PiecewisePolynomial0",
168        "PiecewisePolynomial1",
169        "PiecewisePolynomial2",
170        "PiecewisePolynomial3",
171        "Polynomial2",
172        "Polynomial3",
173        "RationalQuadraticArd",
174        "RationalQuadraticIso",
175        "SpectralMixture1",
176        "SpectralMixture3",
177        "SpectralMixture5",
178        "SquaredExponentialArd",
179        "SquaredExponentialIso"
180      });
181
182      Encoding.TreeLength = 10;
183      Encoding.TreeDepth = 5;
184      Encoding.GrammarParameter.ReadOnly = false;
185      Encoding.Grammar = g;
186      Encoding.GrammarParameter.ReadOnly = true;
187    }
188
189    public void InitializeState() { ClearState(); }
190    public void ClearState() {
191      meanFunc = null;
192      covFunc = null;
193      bestQ = double.NegativeInfinity;
194      bestHyperParameters = null;
195    }
196
197    private readonly object syncRoot = new object();
198    // Does not produce the same result for the same seed when using parallel engine (see below)!
199    public override double Evaluate(ISymbolicExpressionTree tree, IRandom random) {
200      var meanFunction = new MeanConst();
201      var problemData = ProblemData;
202      var ds = problemData.Dataset;
203      var targetVariable = problemData.TargetVariable;
204      var allowedInputVariables = problemData.AllowedInputVariables.ToArray();
205      var nVars = allowedInputVariables.Length;
206      var trainingRows = problemData.TrainingIndices.ToArray();
207
208      // use the same covariance function for each restart
209      var covarianceFunction = TreeToCovarianceFunction(tree);
210
211      // allocate hyperparameters
212      var hyperParameters = new double[meanFunction.GetNumberOfParameters(nVars) + covarianceFunction.GetNumberOfParameters(nVars) + 1]; // mean + cov + noise
213      double[] bestHyperParameters = new double[hyperParameters.Length];
214      var bestObjValue = new double[1] { double.MinValue };
215
216      // data that is necessary for the objective function
217      var data = Tuple.Create(ds, targetVariable, allowedInputVariables, trainingRows, (IMeanFunction)meanFunction, covarianceFunction, bestObjValue);
218
219      for (int t = 0; t < Restarts; t++) {
220        var prevBest = bestObjValue[0];
221        var prevBestHyperParameters = new double[hyperParameters.Length];
222        Array.Copy(bestHyperParameters, prevBestHyperParameters, bestHyperParameters.Length);
223
224        // initialize hyperparameters
225        hyperParameters[0] = ds.GetDoubleValues(targetVariable).Average(); // mean const
226
227        // Evaluate might be called concurrently therefore access to random has to be synchronized.
228        // However, results of multiple runs with the same seed will be different when using the parallel engine.
229        lock (syncRoot) {
230          for (int i = 0; i < covarianceFunction.GetNumberOfParameters(nVars); i++) {
231            hyperParameters[1 + i] = random.NextDouble() * 2.0 - 1.0;
232          }
233        }
234        hyperParameters[hyperParameters.Length - 1] = 1.0; // s² = exp(2), TODO: other inits better?
235
236        // use alglib.bfgs for hyper-parameter optimization ...
237        double epsg = 0;
238        double epsf = 0.00001;
239        double epsx = 0;
240        double stpmax = 1;
241        int maxits = ConstantOptIterations;
242        alglib.mincgstate state;
243        alglib.mincgreport rep;
244
245        alglib.mincgcreate(hyperParameters, out state);
246        alglib.mincgsetcond(state, epsg, epsf, epsx, maxits);
247        alglib.mincgsetstpmax(state, stpmax);
248        alglib.mincgoptimize(state, ObjectiveFunction, null, data);
249
250        alglib.mincgresults(state, out bestHyperParameters, out rep);
251
252        if (rep.terminationtype < 0) {
253          // error -> restore previous best quality
254          bestObjValue[0] = prevBest;
255          Array.Copy(prevBestHyperParameters, bestHyperParameters, prevBestHyperParameters.Length);
256        }
257      }
258
259      UpdateBestSoFar(bestObjValue[0], bestHyperParameters, meanFunction, covarianceFunction);
260
261      return bestObjValue[0];
262    }
263
264    // updates the overall best quality and overall best model for Analyze()
265    private void UpdateBestSoFar(double bestQ, double[] bestHyperParameters, IMeanFunction meanFunc, ICovarianceFunction covFunc) {
266      lock (problemStateLocker) {
267        if (bestQ > this.bestQ) {
268          this.bestQ = bestQ;
269          this.bestHyperParameters = new double[bestHyperParameters.Length];
270          Array.Copy(bestHyperParameters, this.bestHyperParameters, this.bestHyperParameters.Length);
271          this.meanFunc = meanFunc;
272          this.covFunc = covFunc;
273        }
274      }
275    }
276
277    public override void Analyze(ISymbolicExpressionTree[] trees, double[] qualities, ResultCollection results, IRandom random) {
278      if (!results.ContainsKey("Best Solution Quality")) {
279        results.Add(new Result("Best Solution Quality", typeof(DoubleValue)));
280      }
281      if (!results.ContainsKey("Best Tree")) {
282        results.Add(new Result("Best Tree", typeof(ISymbolicExpressionTree)));
283      }
284      if (!results.ContainsKey("Best Solution")) {
285        results.Add(new Result("Best Solution", typeof(GaussianProcessRegressionSolution)));
286      }
287
288      var bestQuality = qualities.Max();
289
290      if (results["Best Solution Quality"].Value == null || bestQuality > ((DoubleValue)results["Best Solution Quality"].Value).Value) {
291        var bestIdx = Array.IndexOf(qualities, bestQuality);
292        var bestClone = (ISymbolicExpressionTree)trees[bestIdx].Clone();
293        results["Best Tree"].Value = bestClone;
294        results["Best Solution Quality"].Value = new DoubleValue(bestQuality);
295        results["Best Solution"].Value = CreateSolution();
296      }
297    }
298
299    private IItem CreateSolution() {
300      var problemData = ProblemData;
301      var ds = problemData.Dataset;
302      var targetVariable = problemData.TargetVariable;
303      var allowedInputVariables = problemData.AllowedInputVariables.ToArray();
304      var trainingRows = problemData.TrainingIndices.ToArray();
305
306      lock (problemStateLocker) {
307        var model = new GaussianProcessModel(ds, targetVariable, allowedInputVariables, trainingRows, bestHyperParameters, (IMeanFunction)meanFunc.Clone(), (ICovarianceFunction)covFunc.Clone());
308        model.FixParameters();
309        return model.CreateRegressionSolution((IRegressionProblemData)ProblemData.Clone());
310      }
311    }
312
313    private void ObjectiveFunction(double[] x, ref double func, double[] grad, object obj) {
314      // we want to optimize the model likelihood by changing the hyperparameters and also return the gradient for each hyperparameter
315      var data = (Tuple<IDataset, string, string[], int[], IMeanFunction, ICovarianceFunction, double[]>)obj;
316      var ds = data.Item1;
317      var targetVariable = data.Item2;
318      var allowedInputVariables = data.Item3;
319      var trainingRows = data.Item4;
320      var meanFunction = data.Item5;
321      var covarianceFunction = data.Item6;
322      var bestObjValue = data.Item7;
323      var hyperParameters = x; // the decision variable vector
324
325      try {
326        var model = new GaussianProcessModel(ds, targetVariable, allowedInputVariables, trainingRows, hyperParameters, meanFunction, covarianceFunction);
327
328        func = model.NegativeLogLikelihood; // mincgoptimize, so we return negative likelihood
329        bestObjValue[0] = Math.Max(bestObjValue[0], -func); // problem itself is a maximization problem
330        var gradients = model.HyperparameterGradients;
331        Array.Copy(gradients, grad, gradients.Length);
332      } catch (ArgumentException) {
333        // building the GaussianProcessModel might fail, in this case we return the worst possible objective value
334        func = 1.0E+300;
335        Array.Clear(grad, 0, grad.Length);
336      }
337    }
338
339    private ICovarianceFunction TreeToCovarianceFunction(ISymbolicExpressionTree tree) {
340      return TreeToCovarianceFunction(tree.Root.GetSubtree(0).GetSubtree(0)); // skip programroot and startsymbol
341    }
342
343    private ICovarianceFunction TreeToCovarianceFunction(ISymbolicExpressionTreeNode node) {
344      switch (node.Symbol.Name) {
345        case "Sum": {
346            var sum = new CovarianceSum();
347            sum.Terms.Add(TreeToCovarianceFunction(node.GetSubtree(0)));
348            sum.Terms.Add(TreeToCovarianceFunction(node.GetSubtree(1)));
349            return sum;
350          }
351        case "Product": {
352            var prod = new CovarianceProduct();
353            prod.Factors.Add(TreeToCovarianceFunction(node.GetSubtree(0)));
354            prod.Factors.Add(TreeToCovarianceFunction(node.GetSubtree(1)));
355            return prod;
356          }
357        // covFunction is cloned by the model so we can reuse instances of terminal covariance functions
358        case "Linear": return linear;
359        case "LinearArd": return linearArd;
360        case "MaternIso1": return maternIso1;
361        case "MaternIso3": return maternIso3;
362        case "MaternIso5": return maternIso5;
363        case "NeuralNetwork": return neuralNetwork;
364        case "Periodic": return periodic;
365        case "PiecewisePolynomial0": return piecewisePoly0;
366        case "PiecewisePolynomial1": return piecewisePoly1;
367        case "PiecewisePolynomial2": return piecewisePoly2;
368        case "PiecewisePolynomial3": return piecewisePoly3;
369        case "Polynomial2": return poly2;
370        case "Polynomial3": return poly3;
371        case "RationalQuadraticArd": return ratQuadraticArd;
372        case "RationalQuadraticIso": return ratQuadraticIso;
373        case "SpectralMixture1": return spectralMixture1;
374        case "SpectralMixture3": return spectralMixture3;
375        case "SpectralMixture5": return spectralMixture5;
376        case "SquaredExponentialArd": return sqrExpArd;
377        case "SquaredExponentialIso": return sqrExpIso;
378        default: throw new InvalidProgramException(string.Format("Found invalid symbol {0}", node.Symbol.Name));
379      }
380    }
381
382
383    // persistence
384    [StorableConstructor]
385    private GaussianProcessCovarianceOptimizationProblem(StorableConstructorFlag _) : base(_) { }
386    [StorableHook(HookType.AfterDeserialization)]
387    private void AfterDeserialization() {
388    }
389
390    // cloning
391    private GaussianProcessCovarianceOptimizationProblem(GaussianProcessCovarianceOptimizationProblem original, Cloner cloner)
392      : base(original, cloner) {
393      bestQ = original.bestQ;
394      meanFunc = cloner.Clone(original.meanFunc);
395      covFunc = cloner.Clone(original.covFunc);
396      if (bestHyperParameters != null) {
397        bestHyperParameters = new double[original.bestHyperParameters.Length];
398        Array.Copy(original.bestHyperParameters, bestHyperParameters, bestHyperParameters.Length);
399      }
400    }
401    public override IDeepCloneable Clone(Cloner cloner) {
402      return new GaussianProcessCovarianceOptimizationProblem(this, cloner);
403    }
404
405    public void Load(IRegressionProblemData data) {
406      this.ProblemData = data;
407      OnProblemDataChanged();
408    }
409
410    public IRegressionProblemData Export() {
411      return ProblemData;
412    }
413
414    #region events
415    public event EventHandler ProblemDataChanged;
416
417
418    private void OnProblemDataChanged() {
419      var handler = ProblemDataChanged;
420      if (handler != null)
421        handler(this, EventArgs.Empty);
422    }
423    #endregion
424
425  }
426}
Note: See TracBrowser for help on using the repository browser.