Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GaussianProcessTuning/HeuristicLab.Problems.GaussianProcessTuning/Interpreter.cs @ 10757

Last change on this file since 10757 was 10757, checked in by gkronber, 10 years ago

#1967 removed obsolete classes and cleaned up the implementation. Added the ProblemInstanceProvider interface to allow loading of CSV files.

File size: 8.7 KB
Line 
1using System;
2using System.Linq;
3using System.Text;
4using System.Threading;
5using HeuristicLab.Algorithms.DataAnalysis;
6using HeuristicLab.Common;
7using HeuristicLab.Core;
8using HeuristicLab.Data;
9using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
10using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
11using HeuristicLab.Problems.DataAnalysis;
12
13namespace HeuristicLab.Problems.GaussianProcessTuning {
14  [StorableClass]
15  [Item("Interpreter", "An interpreter for Gaussian process configurations represented as trees.")]
16  public class Interpreter : Item {
17    [StorableConstructor]
18    protected Interpreter(bool deserializing) : base(deserializing) { }
19    protected Interpreter(Interpreter original, Cloner cloner)
20      : base(original, cloner) {
21    }
22    public Interpreter()
23      : base() { }
24    public override IDeepCloneable Clone(Cloner cloner) {
25      return new Interpreter(this, cloner);
26    }
27
28    public void EvaluateGaussianProcessConfiguration(ISymbolicExpressionTree tree, IRegressionProblemData problemData, int iterations, out double negLogLikelihood, out IGaussianProcessSolution solution) {
29      var meanFunction = GetMeanFunction(tree);
30      var covFunction = GetCovFunction(tree);
31
32      var gprAlg = new GaussianProcessRegression();
33      gprAlg.Problem.ProblemDataParameter.Value = problemData;
34      gprAlg.CovarianceFunction = covFunction;
35      gprAlg.MeanFunction = meanFunction;
36      gprAlg.GaussianProcessModelCreatorParameter.Value =
37        gprAlg.GaussianProcessModelCreatorParameter.ValidValues.First(
38          v => v is GaussianProcessRegressionModelCreator);
39      gprAlg.MinimizationIterations = iterations;
40
41      var signal = new AutoResetEvent(false);
42      double result = double.MaxValue;
43      IGaussianProcessSolution regSolution = null;
44      gprAlg.Stopped += (sender, args) => {
45        result = ((DoubleValue)gprAlg.Results["NegativeLogLikelihood"].Value).Value;
46        if (gprAlg.Results.ContainsKey("Solution"))
47          regSolution = (IGaussianProcessSolution)gprAlg.Results["Solution"].Value;
48        signal.Set();
49      };
50      Exception ex = null;
51      gprAlg.ExceptionOccurred += (sender, args) => {
52        result = double.MaxValue;
53        regSolution = null;
54        ex = args.Value;
55        signal.Set();
56      };
57
58      gprAlg.Prepare();
59      gprAlg.Start();
60
61      signal.WaitOne();
62      if (ex != null) throw ex;
63
64      gprAlg.Prepare();
65      gprAlg.Problem = null;
66      solution = regSolution;
67      negLogLikelihood = result;
68    }
69
70    private IMeanFunction GetMeanFunction(ISymbolicExpressionTree tree) {
71      return GetMeanFunction(tree.Root.GetSubtree(0).GetSubtree(0).GetSubtree(0));
72    }
73
74    private ICovarianceFunction GetCovFunction(ISymbolicExpressionTree tree) {
75      return GetCovFunction(tree.Root.GetSubtree(0).GetSubtree(0).GetSubtree(1));
76    }
77
78    private ICovarianceFunction GetCovFunction(ISymbolicExpressionTreeNode node) {
79      if (node.Symbol is CovConst) {
80        return new CovarianceConst();
81      } else if (node.Symbol is CovScale) {
82        var cov = new CovarianceScale();
83        cov.CovarianceFunctionParameter.Value = GetCovFunction(node.GetSubtree(0));
84        return cov;
85      } else if (node.Symbol is CovMask) {
86        var maskNode = node as CovMaskTreeNode;
87        var covSymbol = node.Symbol as CovMask;
88        var cov = new CovarianceMask();
89        cov.SelectedDimensionsParameter.Value = new IntArray((from i in Enumerable.Range(0, covSymbol.Dimension)
90                                                              where maskNode.Mask[i]
91                                                              select i).ToArray());
92        cov.CovarianceFunctionParameter.Value = GetCovFunction(node.GetSubtree(0));
93        return cov;
94      } else if (node.Symbol is CovLin) {
95        return new CovarianceLinear();
96      } else if (node.Symbol is CovLinArd) {
97        return new CovarianceLinearArd();
98      } else if (node.Symbol is CovMatern) {
99        var covSymbol = node.Symbol as CovMatern;
100        var cov = new CovarianceMaternIso();
101        cov.DParameter.Value = cov.DParameter.ValidValues.Single(x => x.Value == covSymbol.D);
102        return cov;
103      } else if (node.Symbol is CovSeArd) {
104        return new CovarianceSquaredExponentialArd();
105      } else if (node.Symbol is CovSeIso) {
106        return new CovarianceSquaredExponentialIso();
107      } else if (node.Symbol is CovRQIso) {
108        return new CovarianceRationalQuadraticIso();
109      } else if (node.Symbol is CovRQArd) {
110        return new CovarianceRationalQuadraticArd();
111      } else if (node.Symbol is CovNn) {
112        return new CovarianceNeuralNetwork();
113      } else if (node.Symbol is CovPoly) {
114        return new CovariancePolynomial();
115      } else if (node.Symbol is CovPiecewisePoly) {
116        return new CovariancePiecewisePolynomial();
117      } else if (node.Symbol is CovPeriodic) {
118        return new CovariancePeriodic();
119      } else if (node.Symbol is CovNoise) {
120        return new CovarianceNoise();
121      } else if (node.Symbol is CovSum) {
122        var covSum = new Algorithms.DataAnalysis.CovarianceSum();
123        covSum.Terms.Add(GetCovFunction(node.GetSubtree(0)));
124        foreach (var subTree in node.Subtrees.Skip(1)) {
125          covSum.Terms.Add(GetCovFunction(subTree));
126        }
127        return covSum;
128      } else if (node.Symbol is CovProd) {
129        var covProd = new Algorithms.DataAnalysis.CovarianceProduct();
130        covProd.Factors.Add(GetCovFunction(node.GetSubtree(0)));
131        foreach (var subTree in node.Subtrees.Skip(1)) {
132          covProd.Factors.Add(GetCovFunction(subTree));
133        }
134        return covProd;
135      } else {
136        throw new ArgumentException("unknown symbol " + node.Symbol);
137      }
138    }
139
140
141    private IMeanFunction GetMeanFunction(ISymbolicExpressionTreeNode node) {
142      if (node.Symbol is MeanConst) {
143        return new Algorithms.DataAnalysis.MeanConst();
144      } else if (node.Symbol is MeanLinear) {
145        return new Algorithms.DataAnalysis.MeanLinear();
146      } else if (node.Symbol is MeanProd) {
147        var meanProd = new Algorithms.DataAnalysis.MeanProduct();
148        meanProd.Factors.Add(GetMeanFunction(node.GetSubtree(0)));
149        foreach (var subTree in node.Subtrees.Skip(1)) {
150          meanProd.Factors.Add(GetMeanFunction(subTree));
151        }
152        return meanProd;
153      } else if (node.Symbol is MeanSum) {
154        var meanSum = new Algorithms.DataAnalysis.MeanSum();
155        meanSum.Terms.Add(GetMeanFunction(node.GetSubtree(0)));
156        foreach (var subTree in node.Subtrees.Skip(1)) {
157          meanSum.Terms.Add(GetMeanFunction(subTree));
158        }
159        return meanSum;
160      } else if (node.Symbol is MeanZero) {
161        return new Algorithms.DataAnalysis.MeanZero();
162      } else {
163        throw new ArgumentException("Unknown mean function" + node.Symbol);
164      }
165    }
166
167    //private bool[] CalculateMask(ISymbolicExpressionTreeNode node) {
168    //  var maskNode = node as MeanMaskTreeNode;
169    //  if (maskNode != null) {
170    //    bool[] newMask = CombineMasksProd(maskNode.Mask, CalculateMask(node.GetSubtree(0)));
171    //    return newMask;
172    //  } else if (node.Symbol is MeanProd) {
173    //    bool[] newMask = CalculateMask(node.GetSubtree(0));
174    //    foreach (var subTree in node.Subtrees.Skip(1)) {
175    //      newMask = CombineMasksProd(newMask, CalculateMask(subTree));
176    //    }
177    //    return newMask;
178    //  } else if (node.Symbol is MeanSum) {
179    //    bool[] newMask = CalculateMask(node.GetSubtree(0));
180    //    foreach (var subTree in node.Subtrees.Skip(1)) {
181    //      newMask = CombineMasksSum(newMask, CalculateMask(subTree));
182    //    }
183    //    return newMask;
184    //  } else if (node.SubtreeCount == 1) {
185    //    return CalculateMask(node.GetSubtree(0));
186    //  } else if (node is SymbolicExpressionTreeTerminalNode) {
187    //    return null;
188    //  } else {
189    //    throw new NotImplementedException();
190    //  }
191    //}
192
193    //private bool[] CombineMasksProd(bool[] m, bool[] n) {
194    //  if (m == null) return n;
195    //  if (n == null) return m;
196    //  if (m.Length != n.Length) throw new ArgumentException();
197    //  bool[] res = new bool[m.Length];
198    //  for (int i = 0; i < res.Length; i++)
199    //    res[i] = m[i] | n[i];
200    //  return res;
201    //}
202
203
204    //private bool[] CombineMasksSum(bool[] m, bool[] n) {
205    //  if (m == null) return n;
206    //  if (n == null) return m;
207    //  if (m.Length != n.Length) throw new ArgumentException();
208    //  bool[] res = new bool[m.Length];
209    //  for (int i = 0; i < res.Length; i++)
210    //    res[i] = m[i] & n[i];
211    //  return res;
212    //}
213
214  }
215}
Note: See TracBrowser for help on using the repository browser.