Free cookie consent management tool by TermsFeed Policy Generator

source: branches/GaussianProcessEvolution/HeuristicLab.Problems.GaussianProcessTuning/Interpreter.cs @ 8753

Last change on this file since 8753 was 8753, checked in by gkronber, 12 years ago

#1967 initial import of Gaussian process evolution plugin

File size: 13.6 KB
Line 
1using System;
2using System.Linq;
3using System.Text;
4using System.Threading;
5using HeuristicLab.Algorithms.DataAnalysis;
6using HeuristicLab.Common;
7using HeuristicLab.Core;
8using HeuristicLab.Data;
9using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
10using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
11using HeuristicLab.Problems.DataAnalysis;
12
13namespace HeuristicLab.Problems.GaussianProcessTuning {
14  [StorableClass]
15  [Item("Interpreter", "An interpreter for Gaussian process configurations represented as trees.")]
16  public class Interpreter : Item {
17    [StorableConstructor]
18    protected Interpreter(bool deserializing) : base(deserializing) { }
19    protected Interpreter(Interpreter original, Cloner cloner)
20      : base(original, cloner) {
21    }
22    public Interpreter()
23      : base() { }
24    public override IDeepCloneable Clone(Cloner cloner) {
25      return new Interpreter(this, cloner);
26    }
27
28    public void EvaluateGaussianProcessConfiguration(ISymbolicExpressionTree tree, IRegressionProblemData problemData, out double negLogLikelihood, out IGaussianProcessSolution solution) {
29      var meanFunction = GetMeanFunction(tree);
30      var covFunction = GetCovFunction(tree);
31
32      var gprAlg = new GaussianProcessRegression();
33      gprAlg.Problem.ProblemDataParameter.Value = problemData;
34      gprAlg.CovarianceFunction = covFunction;
35      gprAlg.MeanFunction = meanFunction;
36
37      var signal = new AutoResetEvent(false);
38      double result = double.MaxValue;
39      IGaussianProcessSolution regSolution = null;
40      gprAlg.Stopped += (sender, args) => {
41        result = ((DoubleValue)gprAlg.Results["NegativeLogLikelihood"].Value).Value;
42        if (gprAlg.Results.ContainsKey("Solution"))
43          regSolution = (IGaussianProcessSolution)gprAlg.Results["Solution"].Value;
44        signal.Set();
45      };
46      Exception ex = null;
47      gprAlg.ExceptionOccurred += (sender, args) => {
48        result = double.MaxValue;
49        regSolution = null;
50        ex = args.Value;
51        signal.Set();
52      };
53
54      gprAlg.Prepare();
55      gprAlg.Start();
56
57      signal.WaitOne();
58      if (ex != null) throw ex;
59
60      gprAlg.Prepare();
61      gprAlg.Problem = null;
62      solution = regSolution;
63      negLogLikelihood = result;
64    }
65
66    /*
67    public void EvaluateGaussianProcessConfiguration(ISymbolicExpressionTree tree, IRegressionProblemData trainingData, Dataset testData, IEnumerable<int> testRows,
68      out double[] means, out double[] variances) {
69      string meanExpression, meanHyperParameter;
70      string covExpression, covHyperParameter;
71      string likFunction, likHyperParameter;
72      GetMeanFunction(tree, trainingData.AllowedInputVariables.Count(), out meanExpression, out meanHyperParameter);
73      GetCovFunction(tree, trainingData.AllowedInputVariables.Count(), out covExpression, out covHyperParameter);
74      GetLikelihoodFunction(tree, out likFunction, out likHyperParameter);
75
76      double[,] y = new double[trainingData.TrainingIndices.Count(), 1];
77      double[,] yImg = new double[trainingData.TrainingIndices.Count(), 1];
78
79      int r, c;
80      r = 0;
81      foreach (var e in trainingData.Dataset.GetDoubleValues(trainingData.TargetVariable, trainingData.TrainingIndices)) {
82        y[r++, 0] = e;
83      }
84      double[,] x = new double[y.Length, trainingData.AllowedInputVariables.Count()];
85      double[,] xImg = new double[y.Length, trainingData.AllowedInputVariables.Count()];
86      c = 0;
87      foreach (var allowedInput in trainingData.AllowedInputVariables) {
88        r = 0;
89        foreach (var e in trainingData.Dataset.GetDoubleValues(allowedInput, trainingData.TrainingIndices)) {
90          x[r++, c] = e;
91        }
92        c++;
93      }
94
95      double[,] xTest = new double[testRows.Count(), trainingData.AllowedInputVariables.Count()];
96      double[,] xTestImg = new double[testRows.Count(), trainingData.AllowedInputVariables.Count()];
97      c = 0;
98      foreach (var allowedInput in trainingData.AllowedInputVariables) {
99        r = 0;
100        foreach (var e in testData.GetDoubleValues(allowedInput, testRows)) {
101          xTest[r++, c] = e;
102        }
103        c++;
104      }
105
106      object oldX = null;
107      try {
108        oldX = MLApp.GetVariable("x", "base");
109      }
110      catch {
111      }
112      if (oldX == null || oldX is Missing || ((double[,])oldX).Length != x.Length) {
113        MLApp.PutFullMatrix("y", "base", y, yImg);
114        MLApp.PutFullMatrix("x", "base", x, xImg);
115      }
116      MLApp.PutFullMatrix("xTest", "base", xTest, xTestImg);
117      ExecuteMatlab("hyp0 = " + GetHyperParameterString(meanHyperParameter, covHyperParameter, likHyperParameter) + ";");
118      ExecuteMatlab("infFun =  " + GetInferenceMethodString(tree) + ";");
119      ExecuteMatlab("meanExpr = " + meanExpression + ";");
120      ExecuteMatlab("covExpr = " + covExpression + ";");
121      ExecuteMatlab("likExp = " + likFunction + ";");
122      ExecuteMatlab(
123        "try " +
124        "  hyp = minimize(hyp0,'gp', -50, infFun, meanExpr, covExpr, likExp, x, y);" +
125        "  [ymu, ys2, fmu, fs2] = gp(hyp, infFun, meanExpr, covExpr, likExp, x, y, xTest); " +
126        "catch " +
127        "  ymu = zeros(size(xTest, 1), 1); " +
128        "  ys2 = ones(size(xTest, 1), 1); " +
129        "end");
130      var meansMat = (double[,])MLApp.GetVariable("ymu", "base");
131      var variancesMat = (double[,])MLApp.GetVariable("ys2", "base");
132      means = new double[meansMat.GetLength(0)];
133      for (int i = 0; i < means.Length; i++)
134        means[i] = meansMat[i, 0];
135      variances = new double[variancesMat.GetLength(0)];
136      for (int i = 0; i < variances.Length; i++)
137        variances[i] = variancesMat[i, 0];
138    }
139    */
140    private string GetInferenceMethodString(ISymbolicExpressionTree tree) {
141      return "'infEP'";
142    }
143
144    private IMeanFunction GetMeanFunction(ISymbolicExpressionTree tree) {
145      return GetMeanFunction(tree.Root.GetSubtree(0).GetSubtree(0).GetSubtree(0));
146    }
147
148    private ICovarianceFunction GetCovFunction(ISymbolicExpressionTree tree) {
149      return GetCovFunction(tree.Root.GetSubtree(0).GetSubtree(0).GetSubtree(1));
150    }
151
152
153    private void GetLikelihoodFunction(ISymbolicExpressionTree tree, out string expression, out string hyperParameter) {
154      var expressionBuilder = new StringBuilder();
155      var hyperParameterBuilder = new StringBuilder();
156      hyperParameterBuilder.Append("[");
157      GetLikelihoodFunction(tree.Root.GetSubtree(0).GetSubtree(0).GetSubtree(2), expressionBuilder,
158                                   hyperParameterBuilder);
159      hyperParameterBuilder.Append("]");
160      expression = expressionBuilder.ToString();
161      hyperParameter = hyperParameterBuilder.ToString();
162    }
163
164    private void GetLikelihoodFunction(ISymbolicExpressionTreeNode node, StringBuilder expressionBuilder, StringBuilder hyperParameterBuilder) {
165      if (node.Symbol is LikGauss) {
166        var likNode = node as LikGaussTreeNode;
167        expressionBuilder.Append("'likGauss'");
168        hyperParameterBuilder.Append(likNode.Sigma);
169      } else {
170        throw new ArgumentException("unknown likelihood function " + node.Symbol);
171      }
172    }
173
174    private string GetHyperParameterString(string meanHyp, string covHyp, string likHyp) {
175      return "struct('mean', " + meanHyp +
176             ", 'cov', " + covHyp +
177             ", 'lik', " + likHyp +
178             ")";
179      ;
180    }
181
182
183    private ICovarianceFunction GetCovFunction(ISymbolicExpressionTreeNode node) {
184      if (node.Symbol is CovConst) {
185        return new CovarianceConst();
186      } else if (node.Symbol is CovScale) {
187        return new CovarianceScale();
188      } else if (node.Symbol is CovMask) {
189        var maskNode = node as CovMaskTreeNode;
190        var covSymbol = node.Symbol as CovMask;
191        var cov = new CovarianceMask();
192        cov.SelectedDimensionsParameter.Value = new IntArray((from i in Enumerable.Range(0, covSymbol.Dimension)
193                                                              where maskNode.Mask[i]
194                                                              select i).ToArray());
195        return cov;
196      } else if (node.Symbol is CovLin) {
197        return new CovarianceLinear();
198      } else if (node.Symbol is CovLinArd) {
199        return new CovarianceLinearArd();
200      } else if (node.Symbol is CovMatern) {
201        var covSymbol = node.Symbol as CovMatern;
202        var cov = new CovarianceMaternIso();
203        cov.DParameter.Value = cov.DParameter.ValidValues.Single(x => x.Value == covSymbol.D);
204        return cov;
205      } else if (node.Symbol is CovSeArd) {
206        return new CovarianceSquaredExponentialArd();
207      } else if (node.Symbol is CovSeIso) {
208        return new CovarianceSquaredExponentialIso();
209      } else if (node.Symbol is CovRQIso) {
210        return new CovarianceRationalQuadraticIso();
211      } else if (node.Symbol is CovRQArd) {
212        return new CovarianceRationalQuadraticArd();
213      } else if (node.Symbol is CovPeriodic) {
214        return new CovariancePeriodic();
215      } else if (node.Symbol is CovNoise) {
216        return new CovarianceNoise();
217      } else if (node.Symbol is CovSum) {
218        var covSum = new Algorithms.DataAnalysis.CovarianceSum();
219        covSum.Terms.Add(GetCovFunction(node.GetSubtree(0)));
220        foreach (var subTree in node.Subtrees.Skip(1)) {
221          covSum.Terms.Add(GetCovFunction(subTree));
222        }
223        return covSum;
224      } else if (node.Symbol is CovProd) {
225        var covProd = new Algorithms.DataAnalysis.CovarianceProduct();
226        covProd.Factors.Add(GetCovFunction(node.GetSubtree(0)));
227        foreach (var subTree in node.Subtrees.Skip(1)) {
228          covProd.Factors.Add(GetCovFunction(subTree));
229        }
230        return covProd;
231      } else {
232        throw new ArgumentException("unknown symbol " + node.Symbol);
233      }
234    }
235
236
237    private IMeanFunction GetMeanFunction(ISymbolicExpressionTreeNode node) {
238      if (node.Symbol is MeanConst) {
239        return new Algorithms.DataAnalysis.MeanConst();
240      } else if (node.Symbol is MeanLinear) {
241        return new Algorithms.DataAnalysis.MeanLinear();
242      } else if (node.Symbol is MeanProd) {
243        var meanProd = new Algorithms.DataAnalysis.MeanProduct();
244        meanProd.Factors.Add(GetMeanFunction(node.GetSubtree(0)));
245        foreach (var subTree in node.Subtrees.Skip(1)) {
246          meanProd.Factors.Add(GetMeanFunction(subTree));
247        }
248        return meanProd;
249      } else if (node.Symbol is MeanSum) {
250        var meanSum = new Algorithms.DataAnalysis.MeanSum();
251        meanSum.Terms.Add(GetMeanFunction(node.GetSubtree(0)));
252        foreach (var subTree in node.Subtrees.Skip(1)) {
253          meanSum.Terms.Add(GetMeanFunction(subTree));
254        }
255        return meanSum;
256      } else if (node.Symbol is MeanZero) {
257        return new Algorithms.DataAnalysis.MeanZero();
258      } else {
259        throw new ArgumentException("Unknown mean function" + node.Symbol);
260      }
261    }
262
263    //private bool[] CalculateMask(ISymbolicExpressionTreeNode node) {
264    //  var maskNode = node as MeanMaskTreeNode;
265    //  if (maskNode != null) {
266    //    bool[] newMask = CombineMasksProd(maskNode.Mask, CalculateMask(node.GetSubtree(0)));
267    //    return newMask;
268    //  } else if (node.Symbol is MeanProd) {
269    //    bool[] newMask = CalculateMask(node.GetSubtree(0));
270    //    foreach (var subTree in node.Subtrees.Skip(1)) {
271    //      newMask = CombineMasksProd(newMask, CalculateMask(subTree));
272    //    }
273    //    return newMask;
274    //  } else if (node.Symbol is MeanSum) {
275    //    bool[] newMask = CalculateMask(node.GetSubtree(0));
276    //    foreach (var subTree in node.Subtrees.Skip(1)) {
277    //      newMask = CombineMasksSum(newMask, CalculateMask(subTree));
278    //    }
279    //    return newMask;
280    //  } else if (node.SubtreeCount == 1) {
281    //    return CalculateMask(node.GetSubtree(0));
282    //  } else if (node is SymbolicExpressionTreeTerminalNode) {
283    //    return null;
284    //  } else {
285    //    throw new NotImplementedException();
286    //  }
287    //}
288
289    //private bool[] CombineMasksProd(bool[] m, bool[] n) {
290    //  if (m == null) return n;
291    //  if (n == null) return m;
292    //  if (m.Length != n.Length) throw new ArgumentException();
293    //  bool[] res = new bool[m.Length];
294    //  for (int i = 0; i < res.Length; i++)
295    //    res[i] = m[i] | n[i];
296    //  return res;
297    //}
298
299
300    //private bool[] CombineMasksSum(bool[] m, bool[] n) {
301    //  if (m == null) return n;
302    //  if (n == null) return m;
303    //  if (m.Length != n.Length) throw new ArgumentException();
304    //  bool[] res = new bool[m.Length];
305    //  for (int i = 0; i < res.Length; i++)
306    //    res[i] = m[i] & n[i];
307    //  return res;
308    //}
309
310    private string ToVectorString(bool[] b) {
311      var strBuilder = new StringBuilder();
312      strBuilder.Append("[");
313      if (b.Length == 1) // workaround for bug in GPML
314      {
315        if (!b[0]) strBuilder.Append("1");
316      } else {
317        for (int i = 0; i < b.Length; i++) {
318          if (i > 0) strBuilder.Append(", ");
319          strBuilder.Append(b[i] ? "0" : "1");
320        }
321      }
322      strBuilder.Append("]");
323      return strBuilder.ToString();
324    }
325    private string ToVectorString(double[] xs, bool[] mask) {
326      if (xs.Length != mask.Length) throw new ArgumentException();
327      var strBuilder = new StringBuilder();
328      strBuilder.Append("[");
329      for (int i = 0; i < xs.Length; i++)
330        if (!mask[i]) {
331          if (i > 0) strBuilder.Append("; ");
332          strBuilder.Append(xs[i]);
333        }
334      strBuilder.Append("]");
335      return strBuilder.ToString();
336    }
337  }
338}
Note: See TracBrowser for help on using the repository browser.