Free cookie consent management tool by TermsFeed Policy Generator

source: branches/GaussianProcessEvolution/HeuristicLab.Problems.GaussianProcessTuning/Interpreter MATLAB.cs @ 8753

Last change on this file since 8753 was 8753, checked in by gkronber, 12 years ago

#1967 initial import of Gaussian process evolution plugin

File size: 20.6 KB
Line 
1using System;
2using System.Linq;
3using System.Reflection;
4using System.Text;
5using HeuristicLab.Common;
6using HeuristicLab.Core;
7using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
8using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
9using HeuristicLab.Problems.DataAnalysis;
10using MLApp;
11using System.Collections.Generic;
12
13namespace HeuristicLab.Problems.GaussianProcessTuning {
14  [StorableClass]
15  [Item("Interpreter", "An interpreter for Gaussian process configurations represented as trees.")]
16  public class Interpreter : Item {
17    [ThreadStatic]
18    private MLApp.MLApp ml;
19
20    private MLApp.MLApp MLApp {
21      get {
22        if (ml == null) {
23          ml = new MLApp.MLApp();
24        }
25        return ml;
26      }
27    }
28
29    [StorableConstructor]
30    protected Interpreter(bool deserializing) : base(deserializing) { }
31    protected Interpreter(Interpreter original, Cloner cloner)
32      : base(original, cloner) {
33    }
34    public Interpreter()
35      : base() { }
36    public override IDeepCloneable Clone(Cloner cloner) {
37      return new Interpreter(this, cloner);
38    }
39
40    public double EvaluateGaussianProcessConfiguration(ISymbolicExpressionTree tree, IRegressionProblemData problemData) {
41      string meanExpression, meanHyperParameter;
42      string covExpression, covHyperParameter;
43      string likFunction, likHyperParameter;
44      GetMeanFunction(tree, problemData.AllowedInputVariables.Count(), out meanExpression, out meanHyperParameter);
45      GetCovFunction(tree, problemData.AllowedInputVariables.Count(), out covExpression, out covHyperParameter);
46      GetLikelihoodFunction(tree, out likFunction, out likHyperParameter);
47
48      double[,] y = new double[problemData.TrainingIndizes.Count(), 1];
49      double[,] yImg = new double[problemData.TrainingIndizes.Count(), 1];
50
51      int r, c;
52      r = 0;
53      foreach (var e in problemData.Dataset.GetDoubleValues(problemData.TargetVariable, problemData.TrainingIndizes)) {
54        y[r++, 0] = e;
55      }
56      double[,] x = new double[y.Length, problemData.AllowedInputVariables.Count()];
57      double[,] xImg = new double[y.Length, problemData.AllowedInputVariables.Count()];
58      c = 0;
59      foreach (var allowedInput in problemData.AllowedInputVariables) {
60        r = 0;
61        foreach (var e in problemData.Dataset.GetDoubleValues(allowedInput, problemData.TrainingIndizes)) {
62          x[r++, c] = e;
63        }
64        c++;
65      }
66
67      object oldX = null;
68      try { oldX = MLApp.GetVariable("x", "base"); }
69      catch {
70      }
71      if (oldX == null || oldX is Missing || ((double[,])oldX).Length != x.Length) {
72        MLApp.PutFullMatrix("y", "base", y, yImg);
73        MLApp.PutFullMatrix("x", "base", x, xImg);
74      }
75      ExecuteMatlab("hyp0 = " + GetHyperParameterString(meanHyperParameter, covHyperParameter, likHyperParameter) + ";");
76      ExecuteMatlab("infFun =  " + GetInferenceMethodString(tree) + ";");
77      ExecuteMatlab("meanExpr = " + meanExpression + ";");
78      ExecuteMatlab("covExpr = " + covExpression + ";");
79      ExecuteMatlab("likExp = " + likFunction + ";");
80      try
81      {
82        ExecuteMatlab("hyp = minimize(hyp0,'gp', -50, infFun, meanExpr, covExpr, likExp, x, y);");
83        ExecuteMatlab("[nlZ, dnlZ] = gp(hyp, infFun, meanExpr, covExpr, likExp, x, y);");
84        ExecuteMatlab("if isnan(nlZ) nlz = " + double.MaxValue + "; end;");
85        var d = MLApp.GetVariable("nlZ", "base");
86        if (d is Missing) return double.PositiveInfinity;
87        else return (double)d;
88      }
89      catch {
90        return double.PositiveInfinity;
91      }
92    }
93
94    public void EvaluateGaussianProcessConfiguration(ISymbolicExpressionTree tree, IRegressionProblemData trainingData, Dataset testData, IEnumerable<int> testRows,
95      out double[] means, out double[] variances) {
96      string meanExpression, meanHyperParameter;
97      string covExpression, covHyperParameter;
98      string likFunction, likHyperParameter;
99      GetMeanFunction(tree, trainingData.AllowedInputVariables.Count(), out meanExpression, out meanHyperParameter);
100      GetCovFunction(tree, trainingData.AllowedInputVariables.Count(), out covExpression, out covHyperParameter);
101      GetLikelihoodFunction(tree, out likFunction, out likHyperParameter);
102
103      double[,] y = new double[trainingData.TrainingIndizes.Count(), 1];
104      double[,] yImg = new double[trainingData.TrainingIndizes.Count(), 1];
105
106      int r, c;
107      r = 0;
108      foreach (var e in trainingData.Dataset.GetDoubleValues(trainingData.TargetVariable, trainingData.TrainingIndizes)) {
109        y[r++, 0] = e;
110      }
111      double[,] x = new double[y.Length, trainingData.AllowedInputVariables.Count()];
112      double[,] xImg = new double[y.Length, trainingData.AllowedInputVariables.Count()];
113      c = 0;
114      foreach (var allowedInput in trainingData.AllowedInputVariables) {
115        r = 0;
116        foreach (var e in trainingData.Dataset.GetDoubleValues(allowedInput, trainingData.TrainingIndizes)) {
117          x[r++, c] = e;
118        }
119        c++;
120      }
121
122      double[,] xTest = new double[testRows.Count(), trainingData.AllowedInputVariables.Count()];
123      double[,] xTestImg = new double[testRows.Count(), trainingData.AllowedInputVariables.Count()];
124      c = 0;
125      foreach (var allowedInput in trainingData.AllowedInputVariables) {
126        r = 0;
127        foreach (var e in testData.GetDoubleValues(allowedInput, testRows)) {
128          xTest[r++, c] = e;
129        }
130        c++;
131      }
132
133      object oldX = null;
134      try {
135        oldX = MLApp.GetVariable("x", "base");
136      }
137      catch {
138      }
139      if (oldX == null || oldX is Missing || ((double[,])oldX).Length != x.Length) {
140        MLApp.PutFullMatrix("y", "base", y, yImg);
141        MLApp.PutFullMatrix("x", "base", x, xImg);
142      }
143      MLApp.PutFullMatrix("xTest", "base", xTest, xTestImg);
144      ExecuteMatlab("hyp0 = " + GetHyperParameterString(meanHyperParameter, covHyperParameter, likHyperParameter) + ";");
145      ExecuteMatlab("infFun =  " + GetInferenceMethodString(tree) + ";");
146      ExecuteMatlab("meanExpr = " + meanExpression + ";");
147      ExecuteMatlab("covExpr = " + covExpression + ";");
148      ExecuteMatlab("likExp = " + likFunction + ";");
149      ExecuteMatlab(
150        "try " +
151        "  hyp = minimize(hyp0,'gp', -50, infFun, meanExpr, covExpr, likExp, x, y);" +
152        "  [ymu, ys2, fmu, fs2] = gp(hyp, infFun, meanExpr, covExpr, likExp, x, y, xTest); " +
153        "catch " +
154        "  ymu = zeros(size(xTest, 1), 1); " +
155        "  ys2 = ones(size(xTest, 1), 1); " +
156        "end");
157      var meansMat = (double[,])MLApp.GetVariable("ymu", "base");
158      var variancesMat = (double[,])MLApp.GetVariable("ys2", "base");
159      means = new double[meansMat.GetLength(0)];
160      for (int i = 0; i < means.Length; i++)
161        means[i] = meansMat[i, 0];
162      variances = new double[variancesMat.GetLength(0)];
163      for (int i = 0; i < variances.Length; i++)
164        variances[i] = variancesMat[i, 0];
165    }
166
167    private string GetInferenceMethodString(ISymbolicExpressionTree tree) {
168      return "'infEP'";
169    }
170
171    private void GetMeanFunction(ISymbolicExpressionTree tree, int dimension, out string expression, out string hyperParameter) {
172      var expressionBuilder = new StringBuilder();
173      var hyperParameterBuilder = new StringBuilder();
174      // var mask = CalculateMask(tree.Root.GetSubtree(0).GetSubtree(0).GetSubtree(0)) ??
175      var mask = new bool[dimension];
176
177      hyperParameterBuilder.Append("[");
178
179      GetMeanFunction(tree.Root.GetSubtree(0).GetSubtree(0).GetSubtree(0), expressionBuilder, hyperParameterBuilder, mask);
180
181      hyperParameterBuilder.Append("]");
182
183      expression = expressionBuilder.ToString();
184      hyperParameter = hyperParameterBuilder.ToString();
185    }
186
187    private void GetCovFunction(ISymbolicExpressionTree tree, int dimension, out string expression, out string hyperParameter) {
188      var expressionBuilder = new StringBuilder();
189      var hyperParameterBuilder = new StringBuilder();
190      //expressionBuilder.Append("{");
191      hyperParameterBuilder.Append("[");
192
193      var mask = new bool[dimension];
194      GetCovFunction(tree.Root.GetSubtree(0).GetSubtree(0).GetSubtree(1), expressionBuilder, hyperParameterBuilder, mask);
195
196      hyperParameterBuilder.Append("]");
197      //expressionBuilder.Append("}");
198
199      expression = expressionBuilder.ToString();
200      hyperParameter = hyperParameterBuilder.ToString();
201    }
202
203
204    private void GetLikelihoodFunction(ISymbolicExpressionTree tree, out string expression, out string hyperParameter) {
205      var expressionBuilder = new StringBuilder();
206      var hyperParameterBuilder = new StringBuilder();
207      hyperParameterBuilder.Append("[");
208      GetLikelihoodFunction(tree.Root.GetSubtree(0).GetSubtree(0).GetSubtree(2), expressionBuilder,
209                                   hyperParameterBuilder);
210      hyperParameterBuilder.Append("]");
211      expression = expressionBuilder.ToString();
212      hyperParameter = hyperParameterBuilder.ToString();
213    }
214
215    private void GetLikelihoodFunction(ISymbolicExpressionTreeNode node, StringBuilder expressionBuilder, StringBuilder hyperParameterBuilder) {
216      if (node.Symbol is LikGauss) {
217        var likNode = node as LikGaussTreeNode;
218        expressionBuilder.Append("'likGauss'");
219        hyperParameterBuilder.Append(likNode.Sigma);
220      } else {
221        throw new ArgumentException("unknown likelihood function " + node.Symbol);
222      }
223    }
224
225    private string GetHyperParameterString(string meanHyp, string covHyp, string likHyp) {
226      return "struct('mean', " + meanHyp +
227             ", 'cov', " + covHyp +
228             ", 'lik', " + likHyp +
229             ")";
230      ;
231    }
232
233
234    private void GetCovFunction(ISymbolicExpressionTreeNode node, StringBuilder expressionStringBuilder, StringBuilder hyperParameterStringBuilder, bool[] mask) {
235      if (node.Symbol is CovConst) {
236        var constNode = node as CovConstTreeNode;
237        expressionStringBuilder.Append("{'covConst'}");
238        hyperParameterStringBuilder.Append(constNode.Sigma).Append("; ");
239      } else if (node.Symbol is CovLin) {
240        expressionStringBuilder.Append("{'covLIN'}");
241      } else if (node.Symbol is CovLinArd) {
242        var covNode = node as CovLinArdTreeNode;
243        expressionStringBuilder.Append("{'covLINard'}");
244        hyperParameterStringBuilder.Append(ToVectorString(covNode.Lambda, mask));
245      } else if (node.Symbol is CovSeArd) {
246        var covNode = node as CovSeArdTreeNode;
247        expressionStringBuilder.Append("{'covSEard'}");
248        hyperParameterStringBuilder.Append("[").Append(ToVectorString(covNode.Lambda, mask));
249        hyperParameterStringBuilder.AppendFormat("; {0} ]", covNode.Sigma);
250      } else if (node.Symbol is CovSeIso) {
251        var covNode = node as CovSeIsoTreeNode;
252        expressionStringBuilder.Append("{'covSEiso'}");
253        hyperParameterStringBuilder.AppendFormat("[{0}", covNode.L);
254        hyperParameterStringBuilder.AppendFormat("; {0}]", covNode.Sigma);
255      } else if (node.Symbol is CovSum) {
256        expressionStringBuilder.Append("{'covSum', {");
257        hyperParameterStringBuilder.Append("[");
258        GetCovFunction(node.GetSubtree(0), expressionStringBuilder, hyperParameterStringBuilder, mask);
259        foreach (var subTree in node.Subtrees.Skip(1)) {
260          expressionStringBuilder.Append(", ");
261          hyperParameterStringBuilder.Append("; ");
262          GetCovFunction(subTree, expressionStringBuilder, hyperParameterStringBuilder, mask);
263        }
264        hyperParameterStringBuilder.Append("]");
265        expressionStringBuilder.Append("}}");
266      } else if (node.Symbol is CovProd) {
267        expressionStringBuilder.Append("{'covProd', {");
268        hyperParameterStringBuilder.Append("[");
269        GetCovFunction(node.GetSubtree(0), expressionStringBuilder, hyperParameterStringBuilder, mask);
270        foreach (var subTree in node.Subtrees.Skip(1)) {
271          expressionStringBuilder.Append(", ");
272          hyperParameterStringBuilder.Append("; ");
273          GetCovFunction(subTree, expressionStringBuilder, hyperParameterStringBuilder, mask);
274        }
275        hyperParameterStringBuilder.Append("]");
276        expressionStringBuilder.Append("}}");
277      } else if (node.Symbol is CovScale) {
278        var covNode = node as CovScaleTreeNode;
279        expressionStringBuilder.Append("{'covScale', ");
280        hyperParameterStringBuilder.AppendFormat("[{0}; ", covNode.Alpha);
281        GetCovFunction(node.GetSubtree(0), expressionStringBuilder, hyperParameterStringBuilder, mask);
282        hyperParameterStringBuilder.Append("]");
283        expressionStringBuilder.Append("}");
284      } else if (node.Symbol is CovMask) {
285        var covNode = (CovMaskTreeNode)node;
286        // when nothing is masked then we can just return the child
287        if (!covNode.Mask.Any(t => t == false)) {
288          GetCovFunction(node.GetSubtree(0), expressionStringBuilder, hyperParameterStringBuilder, mask);
289        } else {
290          expressionStringBuilder.Append("{'covMask', {");
291          hyperParameterStringBuilder.Append("[");
292          expressionStringBuilder.Append(ToVectorString(covNode.Mask)).Append(", ");
293          int startIndex = expressionStringBuilder.Length;
294          GetCovFunction(node.GetSubtree(0), expressionStringBuilder, hyperParameterStringBuilder, covNode.Mask);
295          expressionStringBuilder.Remove(startIndex, 1);
296          expressionStringBuilder.Remove(expressionStringBuilder.Length - 1, 1);
297          hyperParameterStringBuilder.Append("]");
298          expressionStringBuilder.Append("}}");
299        }
300      } else {
301        throw new ArgumentException("unknown symbol " + node.Symbol);
302      }
303    }
304
305
306    private void GetMeanFunction(ISymbolicExpressionTreeNode node, StringBuilder expressionStringBuilder, StringBuilder hyperParameterStringBuilder, bool[] mask) {
307      if (node.Symbol is MeanConst) {
308        var constNode = node as MeanConstTreeNode;
309        expressionStringBuilder.Append("{'meanConst'}");
310        hyperParameterStringBuilder.Append(constNode.Value).Append("; ");
311      } else if (node.Symbol is MeanLinear) {
312        var meanLinNode = node as MeanLinearTreeNode;
313        expressionStringBuilder.Append("{'meanLinear'}");
314        hyperParameterStringBuilder.Append(ToVectorString(meanLinNode.Alpha, mask));
315      } else if (node.Symbol is MeanMask) {
316        var meanMaskNode = (MeanMaskTreeNode)node;
317        // when nothing is masked then we can just return the child
318        if (!meanMaskNode.Mask.Any(t => t == false)) {
319          GetMeanFunction(node.GetSubtree(0), expressionStringBuilder, hyperParameterStringBuilder, mask);
320        } else {
321          expressionStringBuilder.Append("{'meanMask', {");
322          hyperParameterStringBuilder.Append("[");
323          expressionStringBuilder.Append(ToVectorString(meanMaskNode.Mask)).Append(", ");
324          int startIndex = expressionStringBuilder.Length;
325          GetMeanFunction(node.GetSubtree(0), expressionStringBuilder, hyperParameterStringBuilder, meanMaskNode.Mask);
326          expressionStringBuilder.Remove(startIndex, 1);
327          expressionStringBuilder.Remove(expressionStringBuilder.Length - 1, 1);
328          hyperParameterStringBuilder.Append("]");
329          expressionStringBuilder.Append("}}");
330        }
331      } else if (node.Symbol is MeanOne) {
332        expressionStringBuilder.Append("{'meanOne'}");
333      } else if (node.Symbol is MeanPow) {
334        var meanPowSymbol = (MeanPow)node.Symbol;
335        expressionStringBuilder.Append("{'meanPow', {");
336        hyperParameterStringBuilder.Append("[");
337        expressionStringBuilder.Append(meanPowSymbol.Exponent).Append(", ");
338        GetMeanFunction(node.GetSubtree(0), expressionStringBuilder, hyperParameterStringBuilder, mask);
339        hyperParameterStringBuilder.Append("]");
340        expressionStringBuilder.Append("}}");
341      } else if (node.Symbol is MeanProd) {
342        expressionStringBuilder.Append("{'meanProd', {");
343        hyperParameterStringBuilder.Append("[");
344        GetMeanFunction(node.GetSubtree(0), expressionStringBuilder, hyperParameterStringBuilder, mask);
345        foreach (var subTree in node.Subtrees.Skip(1)) {
346          expressionStringBuilder.Append(", ");
347          hyperParameterStringBuilder.Append("; ");
348          GetMeanFunction(subTree, expressionStringBuilder, hyperParameterStringBuilder, mask);
349        }
350        hyperParameterStringBuilder.Append("]");
351        expressionStringBuilder.Append("}}");
352      } else if (node.Symbol is MeanScale) {
353        var meanScaleNode = node as MeanScaleTreeNode;
354        expressionStringBuilder.Append("{'meanScale', ");
355        hyperParameterStringBuilder.AppendFormat("[{0}; ", meanScaleNode.Alpha);
356        GetMeanFunction(node.GetSubtree(0), expressionStringBuilder, hyperParameterStringBuilder, mask);
357        hyperParameterStringBuilder.Append("]");
358        expressionStringBuilder.Append("}");
359      } else if (node.Symbol is MeanSum) {
360        expressionStringBuilder.Append("{'meanSum', {");
361        hyperParameterStringBuilder.Append("[");
362        GetMeanFunction(node.GetSubtree(0), expressionStringBuilder, hyperParameterStringBuilder, mask);
363        foreach (var subTree in node.Subtrees.Skip(1)) {
364          expressionStringBuilder.Append(", ");
365          hyperParameterStringBuilder.Append("; ");
366          GetMeanFunction(subTree, expressionStringBuilder, hyperParameterStringBuilder, mask);
367        }
368        hyperParameterStringBuilder.Append("]");
369        expressionStringBuilder.Append("}}");
370      } else if (node.Symbol is MeanZero) {
371        expressionStringBuilder.Append("{'meanZero'}");
372      } else {
373        throw new ArgumentException("Unknown mean function", "node");
374      }
375    }
376
377    //private bool[] CalculateMask(ISymbolicExpressionTreeNode node) {
378    //  var maskNode = node as MeanMaskTreeNode;
379    //  if (maskNode != null) {
380    //    bool[] newMask = CombineMasksProd(maskNode.Mask, CalculateMask(node.GetSubtree(0)));
381    //    return newMask;
382    //  } else if (node.Symbol is MeanProd) {
383    //    bool[] newMask = CalculateMask(node.GetSubtree(0));
384    //    foreach (var subTree in node.Subtrees.Skip(1)) {
385    //      newMask = CombineMasksProd(newMask, CalculateMask(subTree));
386    //    }
387    //    return newMask;
388    //  } else if (node.Symbol is MeanSum) {
389    //    bool[] newMask = CalculateMask(node.GetSubtree(0));
390    //    foreach (var subTree in node.Subtrees.Skip(1)) {
391    //      newMask = CombineMasksSum(newMask, CalculateMask(subTree));
392    //    }
393    //    return newMask;
394    //  } else if (node.SubtreeCount == 1) {
395    //    return CalculateMask(node.GetSubtree(0));
396    //  } else if (node is SymbolicExpressionTreeTerminalNode) {
397    //    return null;
398    //  } else {
399    //    throw new NotImplementedException();
400    //  }
401    //}
402
403    //private bool[] CombineMasksProd(bool[] m, bool[] n) {
404    //  if (m == null) return n;
405    //  if (n == null) return m;
406    //  if (m.Length != n.Length) throw new ArgumentException();
407    //  bool[] res = new bool[m.Length];
408    //  for (int i = 0; i < res.Length; i++)
409    //    res[i] = m[i] | n[i];
410    //  return res;
411    //}
412
413
414    //private bool[] CombineMasksSum(bool[] m, bool[] n) {
415    //  if (m == null) return n;
416    //  if (n == null) return m;
417    //  if (m.Length != n.Length) throw new ArgumentException();
418    //  bool[] res = new bool[m.Length];
419    //  for (int i = 0; i < res.Length; i++)
420    //    res[i] = m[i] & n[i];
421    //  return res;
422    //}
423
424    private string ToVectorString(bool[] b) {
425      var strBuilder = new StringBuilder();
426      strBuilder.Append("[");
427      if (b.Length == 1) // workaround for bug in GPML
428      {
429        if (!b[0]) strBuilder.Append("1");
430      } else {
431        for (int i = 0; i < b.Length; i++) {
432          if (i > 0) strBuilder.Append(", ");
433          strBuilder.Append(b[i] ? "0" : "1");
434        }
435      }
436      strBuilder.Append("]");
437      return strBuilder.ToString();
438    }
439    private string ToVectorString(double[] xs, bool[] mask) {
440      if (xs.Length != mask.Length) throw new ArgumentException();
441      var strBuilder = new StringBuilder();
442      strBuilder.Append("[");
443      for (int i = 0; i < xs.Length; i++)
444        if (!mask[i]) {
445          if (i > 0) strBuilder.Append("; ");
446          strBuilder.Append(xs[i]);
447        }
448      strBuilder.Append("]");
449      return strBuilder.ToString();
450    }
451
452
453    private void ExecuteMatlab(string command) {
454      var result = MLApp.Execute(command);
455      if (result.Contains("???")) throw new ArgumentException(command + " " + result, "command");
456    }
457  }
458}
Note: See TracBrowser for help on using the repository browser.