1 | using System;
|
---|
2 | using System.Linq;
|
---|
3 | using System.Text;
|
---|
4 | using System.Threading;
|
---|
5 | using HeuristicLab.Algorithms.DataAnalysis;
|
---|
6 | using HeuristicLab.Common;
|
---|
7 | using HeuristicLab.Core;
|
---|
8 | using HeuristicLab.Data;
|
---|
9 | using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
|
---|
10 | using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
|
---|
11 | using HeuristicLab.Problems.DataAnalysis;
|
---|
12 |
|
---|
13 | namespace HeuristicLab.Problems.GaussianProcessTuning {
|
---|
14 | [StorableClass]
|
---|
15 | [Item("Interpreter", "An interpreter for Gaussian process configurations represented as trees.")]
|
---|
16 | public class Interpreter : Item {
|
---|
17 | [StorableConstructor]
|
---|
18 | protected Interpreter(bool deserializing) : base(deserializing) { }
|
---|
19 | protected Interpreter(Interpreter original, Cloner cloner)
|
---|
20 | : base(original, cloner) {
|
---|
21 | }
|
---|
22 | public Interpreter()
|
---|
23 | : base() { }
|
---|
24 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
25 | return new Interpreter(this, cloner);
|
---|
26 | }
|
---|
27 |
|
---|
28 | public void EvaluateGaussianProcessConfiguration(ISymbolicExpressionTree tree, IRegressionProblemData problemData, out double negLogLikelihood, out IGaussianProcessSolution solution) {
|
---|
29 | var meanFunction = GetMeanFunction(tree);
|
---|
30 | var covFunction = GetCovFunction(tree);
|
---|
31 |
|
---|
32 | var gprAlg = new GaussianProcessRegression();
|
---|
33 | gprAlg.Problem.ProblemDataParameter.Value = problemData;
|
---|
34 | gprAlg.CovarianceFunction = covFunction;
|
---|
35 | gprAlg.MeanFunction = meanFunction;
|
---|
36 | gprAlg.GaussianProcessModelCreatorParameter.Value =
|
---|
37 | gprAlg.GaussianProcessModelCreatorParameter.ValidValues.First(
|
---|
38 | v => v is TunedGaussianProcessRegressionModelCreator);
|
---|
39 | gprAlg.MinimizationIterations = 50;
|
---|
40 |
|
---|
41 | var signal = new AutoResetEvent(false);
|
---|
42 | double result = double.MaxValue;
|
---|
43 | IGaussianProcessSolution regSolution = null;
|
---|
44 | gprAlg.Stopped += (sender, args) => {
|
---|
45 | result = ((DoubleValue)gprAlg.Results["NegativeLogLikelihood"].Value).Value;
|
---|
46 | if (gprAlg.Results.ContainsKey("Solution"))
|
---|
47 | regSolution = (IGaussianProcessSolution)gprAlg.Results["Solution"].Value;
|
---|
48 | signal.Set();
|
---|
49 | };
|
---|
50 | Exception ex = null;
|
---|
51 | gprAlg.ExceptionOccurred += (sender, args) => {
|
---|
52 | result = double.MaxValue;
|
---|
53 | regSolution = null;
|
---|
54 | ex = args.Value;
|
---|
55 | signal.Set();
|
---|
56 | };
|
---|
57 |
|
---|
58 | gprAlg.Prepare();
|
---|
59 | gprAlg.Start();
|
---|
60 |
|
---|
61 | signal.WaitOne();
|
---|
62 | if (ex != null) throw ex;
|
---|
63 |
|
---|
64 | gprAlg.Prepare();
|
---|
65 | gprAlg.Problem = null;
|
---|
66 | solution = regSolution;
|
---|
67 | negLogLikelihood = result;
|
---|
68 | }
|
---|
69 |
|
---|
70 | /*
|
---|
71 | public void EvaluateGaussianProcessConfiguration(ISymbolicExpressionTree tree, IRegressionProblemData trainingData, Dataset testData, IEnumerable<int> testRows,
|
---|
72 | out double[] means, out double[] variances) {
|
---|
73 | string meanExpression, meanHyperParameter;
|
---|
74 | string covExpression, covHyperParameter;
|
---|
75 | string likFunction, likHyperParameter;
|
---|
76 | GetMeanFunction(tree, trainingData.AllowedInputVariables.Count(), out meanExpression, out meanHyperParameter);
|
---|
77 | GetCovFunction(tree, trainingData.AllowedInputVariables.Count(), out covExpression, out covHyperParameter);
|
---|
78 | GetLikelihoodFunction(tree, out likFunction, out likHyperParameter);
|
---|
79 |
|
---|
80 | double[,] y = new double[trainingData.TrainingIndices.Count(), 1];
|
---|
81 | double[,] yImg = new double[trainingData.TrainingIndices.Count(), 1];
|
---|
82 |
|
---|
83 | int r, c;
|
---|
84 | r = 0;
|
---|
85 | foreach (var e in trainingData.Dataset.GetDoubleValues(trainingData.TargetVariable, trainingData.TrainingIndices)) {
|
---|
86 | y[r++, 0] = e;
|
---|
87 | }
|
---|
88 | double[,] x = new double[y.Length, trainingData.AllowedInputVariables.Count()];
|
---|
89 | double[,] xImg = new double[y.Length, trainingData.AllowedInputVariables.Count()];
|
---|
90 | c = 0;
|
---|
91 | foreach (var allowedInput in trainingData.AllowedInputVariables) {
|
---|
92 | r = 0;
|
---|
93 | foreach (var e in trainingData.Dataset.GetDoubleValues(allowedInput, trainingData.TrainingIndices)) {
|
---|
94 | x[r++, c] = e;
|
---|
95 | }
|
---|
96 | c++;
|
---|
97 | }
|
---|
98 |
|
---|
99 | double[,] xTest = new double[testRows.Count(), trainingData.AllowedInputVariables.Count()];
|
---|
100 | double[,] xTestImg = new double[testRows.Count(), trainingData.AllowedInputVariables.Count()];
|
---|
101 | c = 0;
|
---|
102 | foreach (var allowedInput in trainingData.AllowedInputVariables) {
|
---|
103 | r = 0;
|
---|
104 | foreach (var e in testData.GetDoubleValues(allowedInput, testRows)) {
|
---|
105 | xTest[r++, c] = e;
|
---|
106 | }
|
---|
107 | c++;
|
---|
108 | }
|
---|
109 |
|
---|
110 | object oldX = null;
|
---|
111 | try {
|
---|
112 | oldX = MLApp.GetVariable("x", "base");
|
---|
113 | }
|
---|
114 | catch {
|
---|
115 | }
|
---|
116 | if (oldX == null || oldX is Missing || ((double[,])oldX).Length != x.Length) {
|
---|
117 | MLApp.PutFullMatrix("y", "base", y, yImg);
|
---|
118 | MLApp.PutFullMatrix("x", "base", x, xImg);
|
---|
119 | }
|
---|
120 | MLApp.PutFullMatrix("xTest", "base", xTest, xTestImg);
|
---|
121 | ExecuteMatlab("hyp0 = " + GetHyperParameterString(meanHyperParameter, covHyperParameter, likHyperParameter) + ";");
|
---|
122 | ExecuteMatlab("infFun = " + GetInferenceMethodString(tree) + ";");
|
---|
123 | ExecuteMatlab("meanExpr = " + meanExpression + ";");
|
---|
124 | ExecuteMatlab("covExpr = " + covExpression + ";");
|
---|
125 | ExecuteMatlab("likExp = " + likFunction + ";");
|
---|
126 | ExecuteMatlab(
|
---|
127 | "try " +
|
---|
128 | " hyp = minimize(hyp0,'gp', -50, infFun, meanExpr, covExpr, likExp, x, y);" +
|
---|
129 | " [ymu, ys2, fmu, fs2] = gp(hyp, infFun, meanExpr, covExpr, likExp, x, y, xTest); " +
|
---|
130 | "catch " +
|
---|
131 | " ymu = zeros(size(xTest, 1), 1); " +
|
---|
132 | " ys2 = ones(size(xTest, 1), 1); " +
|
---|
133 | "end");
|
---|
134 | var meansMat = (double[,])MLApp.GetVariable("ymu", "base");
|
---|
135 | var variancesMat = (double[,])MLApp.GetVariable("ys2", "base");
|
---|
136 | means = new double[meansMat.GetLength(0)];
|
---|
137 | for (int i = 0; i < means.Length; i++)
|
---|
138 | means[i] = meansMat[i, 0];
|
---|
139 | variances = new double[variancesMat.GetLength(0)];
|
---|
140 | for (int i = 0; i < variances.Length; i++)
|
---|
141 | variances[i] = variancesMat[i, 0];
|
---|
142 | }
|
---|
143 | */
|
---|
144 | private string GetInferenceMethodString(ISymbolicExpressionTree tree) {
|
---|
145 | return "'infEP'";
|
---|
146 | }
|
---|
147 |
|
---|
148 | private IMeanFunction GetMeanFunction(ISymbolicExpressionTree tree) {
|
---|
149 | return GetMeanFunction(tree.Root.GetSubtree(0).GetSubtree(0).GetSubtree(0));
|
---|
150 | }
|
---|
151 |
|
---|
152 | private ICovarianceFunction GetCovFunction(ISymbolicExpressionTree tree) {
|
---|
153 | return GetCovFunction(tree.Root.GetSubtree(0).GetSubtree(0).GetSubtree(1));
|
---|
154 | }
|
---|
155 |
|
---|
156 |
|
---|
157 | private void GetLikelihoodFunction(ISymbolicExpressionTree tree, out string expression, out string hyperParameter) {
|
---|
158 | var expressionBuilder = new StringBuilder();
|
---|
159 | var hyperParameterBuilder = new StringBuilder();
|
---|
160 | hyperParameterBuilder.Append("[");
|
---|
161 | GetLikelihoodFunction(tree.Root.GetSubtree(0).GetSubtree(0).GetSubtree(2), expressionBuilder,
|
---|
162 | hyperParameterBuilder);
|
---|
163 | hyperParameterBuilder.Append("]");
|
---|
164 | expression = expressionBuilder.ToString();
|
---|
165 | hyperParameter = hyperParameterBuilder.ToString();
|
---|
166 | }
|
---|
167 |
|
---|
168 | private void GetLikelihoodFunction(ISymbolicExpressionTreeNode node, StringBuilder expressionBuilder, StringBuilder hyperParameterBuilder) {
|
---|
169 | if (node.Symbol is LikGauss) {
|
---|
170 | var likNode = node as LikGaussTreeNode;
|
---|
171 | expressionBuilder.Append("'likGauss'");
|
---|
172 | hyperParameterBuilder.Append(likNode.Sigma);
|
---|
173 | } else {
|
---|
174 | throw new ArgumentException("unknown likelihood function " + node.Symbol);
|
---|
175 | }
|
---|
176 | }
|
---|
177 |
|
---|
178 | private string GetHyperParameterString(string meanHyp, string covHyp, string likHyp) {
|
---|
179 | return "struct('mean', " + meanHyp +
|
---|
180 | ", 'cov', " + covHyp +
|
---|
181 | ", 'lik', " + likHyp +
|
---|
182 | ")";
|
---|
183 | ;
|
---|
184 | }
|
---|
185 |
|
---|
186 |
|
---|
187 | private ICovarianceFunction GetCovFunction(ISymbolicExpressionTreeNode node) {
|
---|
188 | if (node.Symbol is CovConst) {
|
---|
189 | return new CovarianceConst();
|
---|
190 | } else if (node.Symbol is CovScale) {
|
---|
191 | return new CovarianceScale();
|
---|
192 | } else if (node.Symbol is CovMask) {
|
---|
193 | var maskNode = node as CovMaskTreeNode;
|
---|
194 | var covSymbol = node.Symbol as CovMask;
|
---|
195 | var cov = new CovarianceMask();
|
---|
196 | cov.SelectedDimensionsParameter.Value = new IntArray((from i in Enumerable.Range(0, covSymbol.Dimension)
|
---|
197 | where maskNode.Mask[i]
|
---|
198 | select i).ToArray());
|
---|
199 | return cov;
|
---|
200 | } else if (node.Symbol is CovLin) {
|
---|
201 | return new CovarianceLinear();
|
---|
202 | } else if (node.Symbol is CovLinArd) {
|
---|
203 | return new CovarianceLinearArd();
|
---|
204 | } else if (node.Symbol is CovMatern) {
|
---|
205 | var covSymbol = node.Symbol as CovMatern;
|
---|
206 | var cov = new CovarianceMaternIso();
|
---|
207 | cov.DParameter.Value = cov.DParameter.ValidValues.Single(x => x.Value == covSymbol.D);
|
---|
208 | return cov;
|
---|
209 | } else if (node.Symbol is CovSeArd) {
|
---|
210 | return new CovarianceSquaredExponentialArd();
|
---|
211 | } else if (node.Symbol is CovSeIso) {
|
---|
212 | return new CovarianceSquaredExponentialIso();
|
---|
213 | } else if (node.Symbol is CovRQIso) {
|
---|
214 | return new CovarianceRationalQuadraticIso();
|
---|
215 | } else if (node.Symbol is CovRQArd) {
|
---|
216 | return new CovarianceRationalQuadraticArd();
|
---|
217 | } else if (node.Symbol is CovPeriodic) {
|
---|
218 | return new CovariancePeriodic();
|
---|
219 | } else if (node.Symbol is CovNoise) {
|
---|
220 | return new CovarianceNoise();
|
---|
221 | } else if (node.Symbol is CovSum) {
|
---|
222 | var covSum = new Algorithms.DataAnalysis.CovarianceSum();
|
---|
223 | covSum.Terms.Add(GetCovFunction(node.GetSubtree(0)));
|
---|
224 | foreach (var subTree in node.Subtrees.Skip(1)) {
|
---|
225 | covSum.Terms.Add(GetCovFunction(subTree));
|
---|
226 | }
|
---|
227 | return covSum;
|
---|
228 | } else if (node.Symbol is CovProd) {
|
---|
229 | var covProd = new Algorithms.DataAnalysis.CovarianceProduct();
|
---|
230 | covProd.Factors.Add(GetCovFunction(node.GetSubtree(0)));
|
---|
231 | foreach (var subTree in node.Subtrees.Skip(1)) {
|
---|
232 | covProd.Factors.Add(GetCovFunction(subTree));
|
---|
233 | }
|
---|
234 | return covProd;
|
---|
235 | } else {
|
---|
236 | throw new ArgumentException("unknown symbol " + node.Symbol);
|
---|
237 | }
|
---|
238 | }
|
---|
239 |
|
---|
240 |
|
---|
241 | private IMeanFunction GetMeanFunction(ISymbolicExpressionTreeNode node) {
|
---|
242 | if (node.Symbol is MeanConst) {
|
---|
243 | return new Algorithms.DataAnalysis.MeanConst();
|
---|
244 | } else if (node.Symbol is MeanLinear) {
|
---|
245 | return new Algorithms.DataAnalysis.MeanLinear();
|
---|
246 | } else if (node.Symbol is MeanProd) {
|
---|
247 | var meanProd = new Algorithms.DataAnalysis.MeanProduct();
|
---|
248 | meanProd.Factors.Add(GetMeanFunction(node.GetSubtree(0)));
|
---|
249 | foreach (var subTree in node.Subtrees.Skip(1)) {
|
---|
250 | meanProd.Factors.Add(GetMeanFunction(subTree));
|
---|
251 | }
|
---|
252 | return meanProd;
|
---|
253 | } else if (node.Symbol is MeanSum) {
|
---|
254 | var meanSum = new Algorithms.DataAnalysis.MeanSum();
|
---|
255 | meanSum.Terms.Add(GetMeanFunction(node.GetSubtree(0)));
|
---|
256 | foreach (var subTree in node.Subtrees.Skip(1)) {
|
---|
257 | meanSum.Terms.Add(GetMeanFunction(subTree));
|
---|
258 | }
|
---|
259 | return meanSum;
|
---|
260 | } else if (node.Symbol is MeanZero) {
|
---|
261 | return new Algorithms.DataAnalysis.MeanZero();
|
---|
262 | } else {
|
---|
263 | throw new ArgumentException("Unknown mean function" + node.Symbol);
|
---|
264 | }
|
---|
265 | }
|
---|
266 |
|
---|
267 | //private bool[] CalculateMask(ISymbolicExpressionTreeNode node) {
|
---|
268 | // var maskNode = node as MeanMaskTreeNode;
|
---|
269 | // if (maskNode != null) {
|
---|
270 | // bool[] newMask = CombineMasksProd(maskNode.Mask, CalculateMask(node.GetSubtree(0)));
|
---|
271 | // return newMask;
|
---|
272 | // } else if (node.Symbol is MeanProd) {
|
---|
273 | // bool[] newMask = CalculateMask(node.GetSubtree(0));
|
---|
274 | // foreach (var subTree in node.Subtrees.Skip(1)) {
|
---|
275 | // newMask = CombineMasksProd(newMask, CalculateMask(subTree));
|
---|
276 | // }
|
---|
277 | // return newMask;
|
---|
278 | // } else if (node.Symbol is MeanSum) {
|
---|
279 | // bool[] newMask = CalculateMask(node.GetSubtree(0));
|
---|
280 | // foreach (var subTree in node.Subtrees.Skip(1)) {
|
---|
281 | // newMask = CombineMasksSum(newMask, CalculateMask(subTree));
|
---|
282 | // }
|
---|
283 | // return newMask;
|
---|
284 | // } else if (node.SubtreeCount == 1) {
|
---|
285 | // return CalculateMask(node.GetSubtree(0));
|
---|
286 | // } else if (node is SymbolicExpressionTreeTerminalNode) {
|
---|
287 | // return null;
|
---|
288 | // } else {
|
---|
289 | // throw new NotImplementedException();
|
---|
290 | // }
|
---|
291 | //}
|
---|
292 |
|
---|
293 | //private bool[] CombineMasksProd(bool[] m, bool[] n) {
|
---|
294 | // if (m == null) return n;
|
---|
295 | // if (n == null) return m;
|
---|
296 | // if (m.Length != n.Length) throw new ArgumentException();
|
---|
297 | // bool[] res = new bool[m.Length];
|
---|
298 | // for (int i = 0; i < res.Length; i++)
|
---|
299 | // res[i] = m[i] | n[i];
|
---|
300 | // return res;
|
---|
301 | //}
|
---|
302 |
|
---|
303 |
|
---|
304 | //private bool[] CombineMasksSum(bool[] m, bool[] n) {
|
---|
305 | // if (m == null) return n;
|
---|
306 | // if (n == null) return m;
|
---|
307 | // if (m.Length != n.Length) throw new ArgumentException();
|
---|
308 | // bool[] res = new bool[m.Length];
|
---|
309 | // for (int i = 0; i < res.Length; i++)
|
---|
310 | // res[i] = m[i] & n[i];
|
---|
311 | // return res;
|
---|
312 | //}
|
---|
313 |
|
---|
314 | private string ToVectorString(bool[] b) {
|
---|
315 | var strBuilder = new StringBuilder();
|
---|
316 | strBuilder.Append("[");
|
---|
317 | if (b.Length == 1) // workaround for bug in GPML
|
---|
318 | {
|
---|
319 | if (!b[0]) strBuilder.Append("1");
|
---|
320 | } else {
|
---|
321 | for (int i = 0; i < b.Length; i++) {
|
---|
322 | if (i > 0) strBuilder.Append(", ");
|
---|
323 | strBuilder.Append(b[i] ? "0" : "1");
|
---|
324 | }
|
---|
325 | }
|
---|
326 | strBuilder.Append("]");
|
---|
327 | return strBuilder.ToString();
|
---|
328 | }
|
---|
329 | private string ToVectorString(double[] xs, bool[] mask) {
|
---|
330 | if (xs.Length != mask.Length) throw new ArgumentException();
|
---|
331 | var strBuilder = new StringBuilder();
|
---|
332 | strBuilder.Append("[");
|
---|
333 | for (int i = 0; i < xs.Length; i++)
|
---|
334 | if (!mask[i]) {
|
---|
335 | if (i > 0) strBuilder.Append("; ");
|
---|
336 | strBuilder.Append(xs[i]);
|
---|
337 | }
|
---|
338 | strBuilder.Append("]");
|
---|
339 | return strBuilder.ToString();
|
---|
340 | }
|
---|
341 | }
|
---|
342 | }
|
---|