Free cookie consent management tool by TermsFeed Policy Generator

source: branches/gteufl/HeuristicLab.Algorithms.DataAnalysis/3.4/GaussianProcess/GaussianProcessCovarianceOptimizationProblem.cs @ 13042

Last change on this file since 13042 was 12969, checked in by gkronber, 9 years ago

#2478 merged all changes from trunk to branch before trunk-reintegration

File size: 15.2 KB
RevLine 
[12946]1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Linq;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Data;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Parameters;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30using HeuristicLab.Problems.DataAnalysis;
31using HeuristicLab.Problems.Instances;
32
33
34namespace HeuristicLab.Algorithms.DataAnalysis {
35  [Item("Gaussian Process Covariance Optimization Problem", "")]
36  [Creatable(CreatableAttribute.Categories.GeneticProgrammingProblems, Priority = 300)]
37  [StorableClass]
38  public sealed class GaussianProcessCovarianceOptimizationProblem : SymbolicExpressionTreeProblem, IRegressionProblem, IProblemInstanceConsumer<IRegressionProblemData>, IProblemInstanceExporter<IRegressionProblemData> {
39    #region static variables and ctor
40    private static readonly CovarianceMaternIso maternIso1;
41    private static readonly CovarianceMaternIso maternIso3;
42    private static readonly CovarianceMaternIso maternIso5;
43    private static readonly CovariancePiecewisePolynomial piecewisePoly0;
44    private static readonly CovariancePiecewisePolynomial piecewisePoly1;
45    private static readonly CovariancePiecewisePolynomial piecewisePoly2;
46    private static readonly CovariancePiecewisePolynomial piecewisePoly3;
47    private static readonly CovariancePolynomial poly2;
48    private static readonly CovariancePolynomial poly3;
49    private static readonly CovarianceSpectralMixture spectralMixture1;
50    private static readonly CovarianceSpectralMixture spectralMixture3;
51    private static readonly CovarianceSpectralMixture spectralMixture5;
52    private static readonly CovarianceLinear linear;
53    private static readonly CovarianceLinearArd linearArd;
54    private static readonly CovarianceNeuralNetwork neuralNetwork;
55    private static readonly CovariancePeriodic periodic;
56    private static readonly CovarianceRationalQuadraticIso ratQuadraticIso;
57    private static readonly CovarianceRationalQuadraticArd ratQuadraticArd;
58    private static readonly CovarianceSquaredExponentialArd sqrExpArd;
59    private static readonly CovarianceSquaredExponentialIso sqrExpIso;
60
61    static GaussianProcessCovarianceOptimizationProblem() {
62      // cumbersome initialization because of ConstrainedValueParameters
63      maternIso1 = new CovarianceMaternIso(); SetConstrainedValueParameter(maternIso1.DParameter, 1);
64      maternIso3 = new CovarianceMaternIso(); SetConstrainedValueParameter(maternIso3.DParameter, 3);
65      maternIso5 = new CovarianceMaternIso(); SetConstrainedValueParameter(maternIso5.DParameter, 5);
66
67      piecewisePoly0 = new CovariancePiecewisePolynomial(); SetConstrainedValueParameter(piecewisePoly0.VParameter, 0);
68      piecewisePoly1 = new CovariancePiecewisePolynomial(); SetConstrainedValueParameter(piecewisePoly1.VParameter, 1);
69      piecewisePoly2 = new CovariancePiecewisePolynomial(); SetConstrainedValueParameter(piecewisePoly2.VParameter, 2);
70      piecewisePoly3 = new CovariancePiecewisePolynomial(); SetConstrainedValueParameter(piecewisePoly3.VParameter, 3);
71
72      poly2 = new CovariancePolynomial(); poly2.DegreeParameter.Value.Value = 2;
73      poly3 = new CovariancePolynomial(); poly3.DegreeParameter.Value.Value = 3;
74
75      spectralMixture1 = new CovarianceSpectralMixture(); spectralMixture1.QParameter.Value.Value = 1;
76      spectralMixture3 = new CovarianceSpectralMixture(); spectralMixture3.QParameter.Value.Value = 3;
77      spectralMixture5 = new CovarianceSpectralMixture(); spectralMixture5.QParameter.Value.Value = 5;
78
79      linear = new CovarianceLinear();
80      linearArd = new CovarianceLinearArd();
81      neuralNetwork = new CovarianceNeuralNetwork();
82      periodic = new CovariancePeriodic();
83      ratQuadraticArd = new CovarianceRationalQuadraticArd();
84      ratQuadraticIso = new CovarianceRationalQuadraticIso();
85      sqrExpArd = new CovarianceSquaredExponentialArd();
86      sqrExpIso = new CovarianceSquaredExponentialIso();
87    }
88
89    private static void SetConstrainedValueParameter(IConstrainedValueParameter<IntValue> param, int val) {
90      param.Value = param.ValidValues.Single(v => v.Value == val);
91    }
92
93    #endregion
94
95    #region parameter names
96
97    private const string ProblemDataParameterName = "ProblemData";
98    private const string ConstantOptIterationsParameterName = "Constant optimization steps";
99    private const string RestartsParameterName = "Restarts";
100
101    #endregion
102
103    #region Parameter Properties
104    IParameter IDataAnalysisProblem.ProblemDataParameter { get { return ProblemDataParameter; } }
105
106    public IValueParameter<IRegressionProblemData> ProblemDataParameter {
107      get { return (IValueParameter<IRegressionProblemData>)Parameters[ProblemDataParameterName]; }
108    }
109    public IFixedValueParameter<IntValue> ConstantOptIterationsParameter {
110      get { return (IFixedValueParameter<IntValue>)Parameters[ConstantOptIterationsParameterName]; }
111    }
112    public IFixedValueParameter<IntValue> RestartsParameter {
113      get { return (IFixedValueParameter<IntValue>)Parameters[RestartsParameterName]; }
114    }
115    #endregion
116
117    #region Properties
118
119    public IRegressionProblemData ProblemData {
120      get { return ProblemDataParameter.Value; }
121      set { ProblemDataParameter.Value = value; }
122    }
123    IDataAnalysisProblemData IDataAnalysisProblem.ProblemData { get { return ProblemData; } }
124
125    public int ConstantOptIterations {
126      get { return ConstantOptIterationsParameter.Value.Value; }
127      set { ConstantOptIterationsParameter.Value.Value = value; }
128    }
129
130    public int Restarts {
131      get { return RestartsParameter.Value.Value; }
132      set { RestartsParameter.Value.Value = value; }
133    }
134    #endregion
135
136    public override bool Maximization {
137      get { return true; } // return log likelihood (instead of negative log likelihood as in GPR
138    }
139
140    public GaussianProcessCovarianceOptimizationProblem()
141      : base() {
142      Parameters.Add(new ValueParameter<IRegressionProblemData>(ProblemDataParameterName, "The data for the regression problem", new RegressionProblemData()));
143      Parameters.Add(new FixedValueParameter<IntValue>(ConstantOptIterationsParameterName, "Number of optimization steps for hyperparameter values", new IntValue(50)));
144      Parameters.Add(new FixedValueParameter<IntValue>(RestartsParameterName, "The number of random restarts for constant optimization.", new IntValue(10)));
145      Parameters["Restarts"].Hidden = true;
146      var g = new SimpleSymbolicExpressionGrammar();
147      g.AddSymbols(new string[] { "Sum", "Product" }, 2, 2);
148      g.AddTerminalSymbols(new string[]
149      {
150        "Linear",
151        "LinearArd",
152        "MaternIso1",
153        "MaternIso3",
154        "MaternIso5",
155        "NeuralNetwork",
156        "Periodic",
157        "PiecewisePolynomial0",
158        "PiecewisePolynomial1",
159        "PiecewisePolynomial2",
160        "PiecewisePolynomial3",
161        "Polynomial2",
162        "Polynomial3",
163        "RationalQuadraticArd",
164        "RationalQuadraticIso",
165        "SpectralMixture1",
166        "SpectralMixture3",
167        "SpectralMixture5",
168        "SquaredExponentialArd",
169        "SquaredExponentialIso"
170      });
171      base.Encoding = new SymbolicExpressionTreeEncoding(g, 10, 5);
172    }
173
174
175    public override double Evaluate(ISymbolicExpressionTree tree, IRandom random) {
176
177      var meanFunction = new MeanConst();
178      var problemData = ProblemData;
179      var ds = problemData.Dataset;
180      var targetVariable = problemData.TargetVariable;
181      var allowedInputVariables = problemData.AllowedInputVariables.ToArray();
182      var nVars = allowedInputVariables.Length;
183      var trainingRows = problemData.TrainingIndices.ToArray();
184
185      // use the same covariance function for each restart
186      var covarianceFunction = TreeToCovarianceFunction(tree);
187
188      // allocate hyperparameters
189      var hyperParameters = new double[meanFunction.GetNumberOfParameters(nVars) + covarianceFunction.GetNumberOfParameters(nVars) + 1]; // mean + cov + noise
190      double[] bestHyperParameters = new double[hyperParameters.Length];
191      var bestObjValue = new double[1] { double.MinValue };
192
193      // data that is necessary for the objective function
194      var data = Tuple.Create(ds, targetVariable, allowedInputVariables, trainingRows, (IMeanFunction)meanFunction, covarianceFunction, bestObjValue);
195
196      for (int t = 0; t < Restarts; t++) {
197        var prevBest = bestObjValue[0];
198        var prevBestHyperParameters = new double[hyperParameters.Length];
199        Array.Copy(bestHyperParameters, prevBestHyperParameters, bestHyperParameters.Length);
200
201        // initialize hyperparameters
202        hyperParameters[0] = ds.GetDoubleValues(targetVariable).Average(); // mean const
203
204        for (int i = 0; i < covarianceFunction.GetNumberOfParameters(nVars); i++) {
205          hyperParameters[1 + i] = random.NextDouble() * 2.0 - 1.0;
206        }
207        hyperParameters[hyperParameters.Length - 1] = 1.0; // s² = exp(2), TODO: other inits better?
208
209        // use alglib.bfgs for hyper-parameter optimization ...
210        double epsg = 0;
211        double epsf = 0.00001;
212        double epsx = 0;
213        double stpmax = 1;
214        int maxits = ConstantOptIterations;
215        alglib.mincgstate state;
216        alglib.mincgreport rep;
217
218        alglib.mincgcreate(hyperParameters, out state);
219        alglib.mincgsetcond(state, epsg, epsf, epsx, maxits);
220        alglib.mincgsetstpmax(state, stpmax);
221        alglib.mincgoptimize(state, ObjectiveFunction, null, data);
222
223        alglib.mincgresults(state, out bestHyperParameters, out rep);
224
225        if (rep.terminationtype < 0) {
226          // error -> restore previous best quality
227          bestObjValue[0] = prevBest;
228          Array.Copy(prevBestHyperParameters, bestHyperParameters, prevBestHyperParameters.Length);
229        }
230      }
231
232      return bestObjValue[0];
233    }
234
235    public void ObjectiveFunction(double[] x, ref double func, double[] grad, object obj) {
236      // we want to optimize the model likelihood by changing the hyperparameters and also return the gradient for each hyperparameter
237      var data = (Tuple<IDataset, string, string[], int[], IMeanFunction, ICovarianceFunction, double[]>)obj;
238      var ds = data.Item1;
239      var targetVariable = data.Item2;
240      var allowedInputVariables = data.Item3;
241      var trainingRows = data.Item4;
242      var meanFunction = data.Item5;
243      var covarianceFunction = data.Item6;
244      var bestObjValue = data.Item7;
245      var hyperParameters = x; // the decision variable vector
246
247      try {
248        var model = new GaussianProcessModel(ds, targetVariable, allowedInputVariables, trainingRows, hyperParameters, meanFunction, covarianceFunction);
249
250        func = model.NegativeLogLikelihood; // mincgoptimize, so we return negative likelihood
251        bestObjValue[0] = Math.Max(bestObjValue[0], -func); // problem itself is a maximization problem
252        var gradients = model.HyperparameterGradients;
253        Array.Copy(gradients, grad, gradients.Length);
254      } catch (Exception) {
255        // building the GaussianProcessModel might fail, in this case we return the worst possible objective value
256        func = 1.0E+300;
257        Array.Clear(grad, 0, grad.Length);
258      }
259    }
260
261    private ICovarianceFunction TreeToCovarianceFunction(ISymbolicExpressionTree tree) {
262      return TreeToCovarianceFunction(tree.Root.GetSubtree(0).GetSubtree(0)); // skip programroot and startsymbol
263    }
264
265    private ICovarianceFunction TreeToCovarianceFunction(ISymbolicExpressionTreeNode node) {
266      switch (node.Symbol.Name) {
267        case "Sum": {
268            var sum = new CovarianceSum();
269            sum.Terms.Add(TreeToCovarianceFunction(node.GetSubtree(0)));
270            sum.Terms.Add(TreeToCovarianceFunction(node.GetSubtree(1)));
271            return sum;
272          }
273        case "Product": {
274            var prod = new CovarianceProduct();
275            prod.Factors.Add(TreeToCovarianceFunction(node.GetSubtree(0)));
276            prod.Factors.Add(TreeToCovarianceFunction(node.GetSubtree(1)));
277            return prod;
278          }
279        // covFunction is cloned by the model so we can reuse instances of terminal covariance functions
280        case "Linear": return linear;
281        case "LinearArd": return linearArd;
282        case "MaternIso1": return maternIso1;
283        case "MaternIso3": return maternIso3;
284        case "MaternIso5": return maternIso5;
285        case "NeuralNetwork": return neuralNetwork;
286        case "Periodic": return periodic;
287        case "PiecewisePolynomial0": return piecewisePoly0;
288        case "PiecewisePolynomial1": return piecewisePoly1;
289        case "PiecewisePolynomial2": return piecewisePoly2;
290        case "PiecewisePolynomial3": return piecewisePoly3;
291        case "Polynomial2": return poly2;
292        case "Polynomial3": return poly3;
293        case "RationalQuadraticArd": return ratQuadraticArd;
294        case "RationalQuadraticIso": return ratQuadraticIso;
295        case "SpectralMixture1": return spectralMixture1;
296        case "SpectralMixture3": return spectralMixture3;
297        case "SpectralMixture5": return spectralMixture5;
298        case "SquaredExponentialArd": return sqrExpArd;
299        case "SquaredExponentialIso": return sqrExpIso;
300        default: throw new InvalidProgramException(string.Format("Found invalid symbol {0}", node.Symbol.Name));
301      }
302    }
303
304
305    // persistence
306    [StorableConstructor]
307    private GaussianProcessCovarianceOptimizationProblem(bool deserializing) : base(deserializing) { }
308    [StorableHook(HookType.AfterDeserialization)]
309    private void AfterDeserialization() {
310    }
311
312    // cloning
313    private GaussianProcessCovarianceOptimizationProblem(GaussianProcessCovarianceOptimizationProblem original, Cloner cloner)
314      : base(original, cloner) {
315    }
316    public override IDeepCloneable Clone(Cloner cloner) {
317      return new GaussianProcessCovarianceOptimizationProblem(this, cloner);
318    }
319
320    public void Load(IRegressionProblemData data) {
321      this.ProblemData = data;
322      OnProblemDataChanged();
323    }
324
325    public IRegressionProblemData Export() {
326      return ProblemData;
327    }
328
329    #region events
330    public event EventHandler ProblemDataChanged;
331
332
333    private void OnProblemDataChanged() {
334      var handler = ProblemDataChanged;
335      if (handler != null)
336        handler(this, EventArgs.Empty);
337    }
338    #endregion
339
340  }
341}
Note: See TracBrowser for help on using the repository browser.