Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GaussianProcessTuning/HeuristicLab.Problems.GaussianProcessTuning/Interpreter.cs @ 9155

Last change on this file since 9155 was 9112, checked in by gkronber, 11 years ago

#1967: worked on tuned GP model and benchmark instances

File size: 13.9 KB
Line 
1using System;
2using System.Linq;
3using System.Text;
4using System.Threading;
5using HeuristicLab.Algorithms.DataAnalysis;
6using HeuristicLab.Common;
7using HeuristicLab.Core;
8using HeuristicLab.Data;
9using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
10using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
11using HeuristicLab.Problems.DataAnalysis;
12
13namespace HeuristicLab.Problems.GaussianProcessTuning {
14  [StorableClass]
15  [Item("Interpreter", "An interpreter for Gaussian process configurations represented as trees.")]
16  public class Interpreter : Item {
17    [StorableConstructor]
18    protected Interpreter(bool deserializing) : base(deserializing) { }
19    protected Interpreter(Interpreter original, Cloner cloner)
20      : base(original, cloner) {
21    }
22    public Interpreter()
23      : base() { }
24    public override IDeepCloneable Clone(Cloner cloner) {
25      return new Interpreter(this, cloner);
26    }
27
28    public void EvaluateGaussianProcessConfiguration(ISymbolicExpressionTree tree, IRegressionProblemData problemData, out double negLogLikelihood, out IGaussianProcessSolution solution) {
29      var meanFunction = GetMeanFunction(tree);
30      var covFunction = GetCovFunction(tree);
31
32      var gprAlg = new GaussianProcessRegression();
33      gprAlg.Problem.ProblemDataParameter.Value = problemData;
34      gprAlg.CovarianceFunction = covFunction;
35      gprAlg.MeanFunction = meanFunction;
36      gprAlg.GaussianProcessModelCreatorParameter.Value =
37        gprAlg.GaussianProcessModelCreatorParameter.ValidValues.First(
38          v => v is TunedGaussianProcessRegressionModelCreator);
39      gprAlg.MinimizationIterations = 50;
40
41      var signal = new AutoResetEvent(false);
42      double result = double.MaxValue;
43      IGaussianProcessSolution regSolution = null;
44      gprAlg.Stopped += (sender, args) => {
45        result = ((DoubleValue)gprAlg.Results["NegativeLogLikelihood"].Value).Value;
46        if (gprAlg.Results.ContainsKey("Solution"))
47          regSolution = (IGaussianProcessSolution)gprAlg.Results["Solution"].Value;
48        signal.Set();
49      };
50      Exception ex = null;
51      gprAlg.ExceptionOccurred += (sender, args) => {
52        result = double.MaxValue;
53        regSolution = null;
54        ex = args.Value;
55        signal.Set();
56      };
57
58      gprAlg.Prepare();
59      gprAlg.Start();
60
61      signal.WaitOne();
62      if (ex != null) throw ex;
63
64      gprAlg.Prepare();
65      gprAlg.Problem = null;
66      solution = regSolution;
67      negLogLikelihood = result;
68    }
69
70    /*
71    public void EvaluateGaussianProcessConfiguration(ISymbolicExpressionTree tree, IRegressionProblemData trainingData, Dataset testData, IEnumerable<int> testRows,
72      out double[] means, out double[] variances) {
73      string meanExpression, meanHyperParameter;
74      string covExpression, covHyperParameter;
75      string likFunction, likHyperParameter;
76      GetMeanFunction(tree, trainingData.AllowedInputVariables.Count(), out meanExpression, out meanHyperParameter);
77      GetCovFunction(tree, trainingData.AllowedInputVariables.Count(), out covExpression, out covHyperParameter);
78      GetLikelihoodFunction(tree, out likFunction, out likHyperParameter);
79
80      double[,] y = new double[trainingData.TrainingIndices.Count(), 1];
81      double[,] yImg = new double[trainingData.TrainingIndices.Count(), 1];
82
83      int r, c;
84      r = 0;
85      foreach (var e in trainingData.Dataset.GetDoubleValues(trainingData.TargetVariable, trainingData.TrainingIndices)) {
86        y[r++, 0] = e;
87      }
88      double[,] x = new double[y.Length, trainingData.AllowedInputVariables.Count()];
89      double[,] xImg = new double[y.Length, trainingData.AllowedInputVariables.Count()];
90      c = 0;
91      foreach (var allowedInput in trainingData.AllowedInputVariables) {
92        r = 0;
93        foreach (var e in trainingData.Dataset.GetDoubleValues(allowedInput, trainingData.TrainingIndices)) {
94          x[r++, c] = e;
95        }
96        c++;
97      }
98
99      double[,] xTest = new double[testRows.Count(), trainingData.AllowedInputVariables.Count()];
100      double[,] xTestImg = new double[testRows.Count(), trainingData.AllowedInputVariables.Count()];
101      c = 0;
102      foreach (var allowedInput in trainingData.AllowedInputVariables) {
103        r = 0;
104        foreach (var e in testData.GetDoubleValues(allowedInput, testRows)) {
105          xTest[r++, c] = e;
106        }
107        c++;
108      }
109
110      object oldX = null;
111      try {
112        oldX = MLApp.GetVariable("x", "base");
113      }
114      catch {
115      }
116      if (oldX == null || oldX is Missing || ((double[,])oldX).Length != x.Length) {
117        MLApp.PutFullMatrix("y", "base", y, yImg);
118        MLApp.PutFullMatrix("x", "base", x, xImg);
119      }
120      MLApp.PutFullMatrix("xTest", "base", xTest, xTestImg);
121      ExecuteMatlab("hyp0 = " + GetHyperParameterString(meanHyperParameter, covHyperParameter, likHyperParameter) + ";");
122      ExecuteMatlab("infFun =  " + GetInferenceMethodString(tree) + ";");
123      ExecuteMatlab("meanExpr = " + meanExpression + ";");
124      ExecuteMatlab("covExpr = " + covExpression + ";");
125      ExecuteMatlab("likExp = " + likFunction + ";");
126      ExecuteMatlab(
127        "try " +
128        "  hyp = minimize(hyp0,'gp', -50, infFun, meanExpr, covExpr, likExp, x, y);" +
129        "  [ymu, ys2, fmu, fs2] = gp(hyp, infFun, meanExpr, covExpr, likExp, x, y, xTest); " +
130        "catch " +
131        "  ymu = zeros(size(xTest, 1), 1); " +
132        "  ys2 = ones(size(xTest, 1), 1); " +
133        "end");
134      var meansMat = (double[,])MLApp.GetVariable("ymu", "base");
135      var variancesMat = (double[,])MLApp.GetVariable("ys2", "base");
136      means = new double[meansMat.GetLength(0)];
137      for (int i = 0; i < means.Length; i++)
138        means[i] = meansMat[i, 0];
139      variances = new double[variancesMat.GetLength(0)];
140      for (int i = 0; i < variances.Length; i++)
141        variances[i] = variancesMat[i, 0];
142    }
143    */
144    private string GetInferenceMethodString(ISymbolicExpressionTree tree) {
145      return "'infEP'";
146    }
147
148    private IMeanFunction GetMeanFunction(ISymbolicExpressionTree tree) {
149      return GetMeanFunction(tree.Root.GetSubtree(0).GetSubtree(0).GetSubtree(0));
150    }
151
152    private ICovarianceFunction GetCovFunction(ISymbolicExpressionTree tree) {
153      return GetCovFunction(tree.Root.GetSubtree(0).GetSubtree(0).GetSubtree(1));
154    }
155
156
157    private void GetLikelihoodFunction(ISymbolicExpressionTree tree, out string expression, out string hyperParameter) {
158      var expressionBuilder = new StringBuilder();
159      var hyperParameterBuilder = new StringBuilder();
160      hyperParameterBuilder.Append("[");
161      GetLikelihoodFunction(tree.Root.GetSubtree(0).GetSubtree(0).GetSubtree(2), expressionBuilder,
162                                   hyperParameterBuilder);
163      hyperParameterBuilder.Append("]");
164      expression = expressionBuilder.ToString();
165      hyperParameter = hyperParameterBuilder.ToString();
166    }
167
168    private void GetLikelihoodFunction(ISymbolicExpressionTreeNode node, StringBuilder expressionBuilder, StringBuilder hyperParameterBuilder) {
169      if (node.Symbol is LikGauss) {
170        var likNode = node as LikGaussTreeNode;
171        expressionBuilder.Append("'likGauss'");
172        hyperParameterBuilder.Append(likNode.Sigma);
173      } else {
174        throw new ArgumentException("unknown likelihood function " + node.Symbol);
175      }
176    }
177
178    private string GetHyperParameterString(string meanHyp, string covHyp, string likHyp) {
179      return "struct('mean', " + meanHyp +
180             ", 'cov', " + covHyp +
181             ", 'lik', " + likHyp +
182             ")";
183      ;
184    }
185
186
187    private ICovarianceFunction GetCovFunction(ISymbolicExpressionTreeNode node) {
188      if (node.Symbol is CovConst) {
189        return new CovarianceConst();
190      } else if (node.Symbol is CovScale) {
191        return new CovarianceScale();
192      } else if (node.Symbol is CovMask) {
193        var maskNode = node as CovMaskTreeNode;
194        var covSymbol = node.Symbol as CovMask;
195        var cov = new CovarianceMask();
196        cov.SelectedDimensionsParameter.Value = new IntArray((from i in Enumerable.Range(0, covSymbol.Dimension)
197                                                              where maskNode.Mask[i]
198                                                              select i).ToArray());
199        return cov;
200      } else if (node.Symbol is CovLin) {
201        return new CovarianceLinear();
202      } else if (node.Symbol is CovLinArd) {
203        return new CovarianceLinearArd();
204      } else if (node.Symbol is CovMatern) {
205        var covSymbol = node.Symbol as CovMatern;
206        var cov = new CovarianceMaternIso();
207        cov.DParameter.Value = cov.DParameter.ValidValues.Single(x => x.Value == covSymbol.D);
208        return cov;
209      } else if (node.Symbol is CovSeArd) {
210        return new CovarianceSquaredExponentialArd();
211      } else if (node.Symbol is CovSeIso) {
212        return new CovarianceSquaredExponentialIso();
213      } else if (node.Symbol is CovRQIso) {
214        return new CovarianceRationalQuadraticIso();
215      } else if (node.Symbol is CovRQArd) {
216        return new CovarianceRationalQuadraticArd();
217      } else if (node.Symbol is CovPeriodic) {
218        return new CovariancePeriodic();
219      } else if (node.Symbol is CovNoise) {
220        return new CovarianceNoise();
221      } else if (node.Symbol is CovSum) {
222        var covSum = new Algorithms.DataAnalysis.CovarianceSum();
223        covSum.Terms.Add(GetCovFunction(node.GetSubtree(0)));
224        foreach (var subTree in node.Subtrees.Skip(1)) {
225          covSum.Terms.Add(GetCovFunction(subTree));
226        }
227        return covSum;
228      } else if (node.Symbol is CovProd) {
229        var covProd = new Algorithms.DataAnalysis.CovarianceProduct();
230        covProd.Factors.Add(GetCovFunction(node.GetSubtree(0)));
231        foreach (var subTree in node.Subtrees.Skip(1)) {
232          covProd.Factors.Add(GetCovFunction(subTree));
233        }
234        return covProd;
235      } else {
236        throw new ArgumentException("unknown symbol " + node.Symbol);
237      }
238    }
239
240
241    private IMeanFunction GetMeanFunction(ISymbolicExpressionTreeNode node) {
242      if (node.Symbol is MeanConst) {
243        return new Algorithms.DataAnalysis.MeanConst();
244      } else if (node.Symbol is MeanLinear) {
245        return new Algorithms.DataAnalysis.MeanLinear();
246      } else if (node.Symbol is MeanProd) {
247        var meanProd = new Algorithms.DataAnalysis.MeanProduct();
248        meanProd.Factors.Add(GetMeanFunction(node.GetSubtree(0)));
249        foreach (var subTree in node.Subtrees.Skip(1)) {
250          meanProd.Factors.Add(GetMeanFunction(subTree));
251        }
252        return meanProd;
253      } else if (node.Symbol is MeanSum) {
254        var meanSum = new Algorithms.DataAnalysis.MeanSum();
255        meanSum.Terms.Add(GetMeanFunction(node.GetSubtree(0)));
256        foreach (var subTree in node.Subtrees.Skip(1)) {
257          meanSum.Terms.Add(GetMeanFunction(subTree));
258        }
259        return meanSum;
260      } else if (node.Symbol is MeanZero) {
261        return new Algorithms.DataAnalysis.MeanZero();
262      } else {
263        throw new ArgumentException("Unknown mean function" + node.Symbol);
264      }
265    }
266
267    //private bool[] CalculateMask(ISymbolicExpressionTreeNode node) {
268    //  var maskNode = node as MeanMaskTreeNode;
269    //  if (maskNode != null) {
270    //    bool[] newMask = CombineMasksProd(maskNode.Mask, CalculateMask(node.GetSubtree(0)));
271    //    return newMask;
272    //  } else if (node.Symbol is MeanProd) {
273    //    bool[] newMask = CalculateMask(node.GetSubtree(0));
274    //    foreach (var subTree in node.Subtrees.Skip(1)) {
275    //      newMask = CombineMasksProd(newMask, CalculateMask(subTree));
276    //    }
277    //    return newMask;
278    //  } else if (node.Symbol is MeanSum) {
279    //    bool[] newMask = CalculateMask(node.GetSubtree(0));
280    //    foreach (var subTree in node.Subtrees.Skip(1)) {
281    //      newMask = CombineMasksSum(newMask, CalculateMask(subTree));
282    //    }
283    //    return newMask;
284    //  } else if (node.SubtreeCount == 1) {
285    //    return CalculateMask(node.GetSubtree(0));
286    //  } else if (node is SymbolicExpressionTreeTerminalNode) {
287    //    return null;
288    //  } else {
289    //    throw new NotImplementedException();
290    //  }
291    //}
292
293    //private bool[] CombineMasksProd(bool[] m, bool[] n) {
294    //  if (m == null) return n;
295    //  if (n == null) return m;
296    //  if (m.Length != n.Length) throw new ArgumentException();
297    //  bool[] res = new bool[m.Length];
298    //  for (int i = 0; i < res.Length; i++)
299    //    res[i] = m[i] | n[i];
300    //  return res;
301    //}
302
303
304    //private bool[] CombineMasksSum(bool[] m, bool[] n) {
305    //  if (m == null) return n;
306    //  if (n == null) return m;
307    //  if (m.Length != n.Length) throw new ArgumentException();
308    //  bool[] res = new bool[m.Length];
309    //  for (int i = 0; i < res.Length; i++)
310    //    res[i] = m[i] & n[i];
311    //  return res;
312    //}
313
314    private string ToVectorString(bool[] b) {
315      var strBuilder = new StringBuilder();
316      strBuilder.Append("[");
317      if (b.Length == 1) // workaround for bug in GPML
318      {
319        if (!b[0]) strBuilder.Append("1");
320      } else {
321        for (int i = 0; i < b.Length; i++) {
322          if (i > 0) strBuilder.Append(", ");
323          strBuilder.Append(b[i] ? "0" : "1");
324        }
325      }
326      strBuilder.Append("]");
327      return strBuilder.ToString();
328    }
329    private string ToVectorString(double[] xs, bool[] mask) {
330      if (xs.Length != mask.Length) throw new ArgumentException();
331      var strBuilder = new StringBuilder();
332      strBuilder.Append("[");
333      for (int i = 0; i < xs.Length; i++)
334        if (!mask[i]) {
335          if (i > 0) strBuilder.Append("; ");
336          strBuilder.Append(xs[i]);
337        }
338      strBuilder.Append("]");
339      return strBuilder.ToString();
340    }
341  }
342}
Note: See TracBrowser for help on using the repository browser.