Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GaussianProcessTuning/HeuristicLab.Problems.GaussianProcessTuning/Interpreter.cs @ 9387

Last change on this file since 9387 was 9387, checked in by gkronber, 11 years ago

#1967: added CovNN symbol and tree node

File size: 8.5 KB
Line 
1using System;
2using System.Linq;
3using System.Text;
4using System.Threading;
5using HeuristicLab.Algorithms.DataAnalysis;
6using HeuristicLab.Common;
7using HeuristicLab.Core;
8using HeuristicLab.Data;
9using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
10using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
11using HeuristicLab.Problems.DataAnalysis;
12
13namespace HeuristicLab.Problems.GaussianProcessTuning {
14  [StorableClass]
15  [Item("Interpreter", "An interpreter for Gaussian process configurations represented as trees.")]
16  public class Interpreter : Item {
17    [StorableConstructor]
18    protected Interpreter(bool deserializing) : base(deserializing) { }
19    protected Interpreter(Interpreter original, Cloner cloner)
20      : base(original, cloner) {
21    }
22    public Interpreter()
23      : base() { }
24    public override IDeepCloneable Clone(Cloner cloner) {
25      return new Interpreter(this, cloner);
26    }
27
28    public void EvaluateGaussianProcessConfiguration(ISymbolicExpressionTree tree, IRegressionProblemData problemData, out double negLogLikelihood, out IGaussianProcessSolution solution) {
29      var meanFunction = GetMeanFunction(tree);
30      var covFunction = GetCovFunction(tree);
31
32      var gprAlg = new GaussianProcessRegression();
33      gprAlg.Problem.ProblemDataParameter.Value = problemData;
34      gprAlg.CovarianceFunction = covFunction;
35      gprAlg.MeanFunction = meanFunction;
36      gprAlg.GaussianProcessModelCreatorParameter.Value =
37        gprAlg.GaussianProcessModelCreatorParameter.ValidValues.First(
38          v => v is GaussianProcessRegressionModelCreator);
39      gprAlg.MinimizationIterations = 50;
40
41      var signal = new AutoResetEvent(false);
42      double result = double.MaxValue;
43      IGaussianProcessSolution regSolution = null;
44      gprAlg.Stopped += (sender, args) => {
45        result = ((DoubleValue)gprAlg.Results["NegativeLogLikelihood"].Value).Value;
46        if (gprAlg.Results.ContainsKey("Solution"))
47          regSolution = (IGaussianProcessSolution)gprAlg.Results["Solution"].Value;
48        signal.Set();
49      };
50      Exception ex = null;
51      gprAlg.ExceptionOccurred += (sender, args) => {
52        result = double.MaxValue;
53        regSolution = null;
54        ex = args.Value;
55        signal.Set();
56      };
57
58      gprAlg.Prepare();
59      gprAlg.Start();
60
61      signal.WaitOne();
62      if (ex != null) throw ex;
63
64      gprAlg.Prepare();
65      gprAlg.Problem = null;
66      solution = regSolution;
67      negLogLikelihood = result;
68    }
69
70    private IMeanFunction GetMeanFunction(ISymbolicExpressionTree tree) {
71      return GetMeanFunction(tree.Root.GetSubtree(0).GetSubtree(0).GetSubtree(0));
72    }
73
74    private ICovarianceFunction GetCovFunction(ISymbolicExpressionTree tree) {
75      return GetCovFunction(tree.Root.GetSubtree(0).GetSubtree(0).GetSubtree(1));
76    }
77
78    private ICovarianceFunction GetCovFunction(ISymbolicExpressionTreeNode node) {
79      if (node.Symbol is CovConst) {
80        return new CovarianceConst();
81      } else if (node.Symbol is CovScale) {
82        var cov = new CovarianceScale();
83        cov.CovarianceFunctionParameter.Value = GetCovFunction(node.GetSubtree(0));
84        return cov;
85      } else if (node.Symbol is CovMask) {
86        var maskNode = node as CovMaskTreeNode;
87        var covSymbol = node.Symbol as CovMask;
88        var cov = new CovarianceMask();
89        cov.SelectedDimensionsParameter.Value = new IntArray((from i in Enumerable.Range(0, covSymbol.Dimension)
90                                                              where maskNode.Mask[i]
91                                                              select i).ToArray());
92        cov.CovarianceFunctionParameter.Value = GetCovFunction(node.GetSubtree(0));
93        return cov;
94      } else if (node.Symbol is CovLin) {
95        return new CovarianceLinear();
96      } else if (node.Symbol is CovLinArd) {
97        return new CovarianceLinearArd();
98      } else if (node.Symbol is CovMatern) {
99        var covSymbol = node.Symbol as CovMatern;
100        var cov = new CovarianceMaternIso();
101        cov.DParameter.Value = cov.DParameter.ValidValues.Single(x => x.Value == covSymbol.D);
102        return cov;
103      } else if (node.Symbol is CovSeArd) {
104        return new CovarianceSquaredExponentialArd();
105      } else if (node.Symbol is CovSeIso) {
106        return new CovarianceSquaredExponentialIso();
107      } else if (node.Symbol is CovRQIso) {
108        return new CovarianceRationalQuadraticIso();
109      } else if (node.Symbol is CovRQArd) {
110        return new CovarianceRationalQuadraticArd();
111      } else if (node.Symbol is CovNn) {
112        return new CovarianceNeuralNetwork();
113      } else if (node.Symbol is CovPeriodic) {
114        return new CovariancePeriodic();
115      } else if (node.Symbol is CovNoise) {
116        return new CovarianceNoise();
117      } else if (node.Symbol is CovSum) {
118        var covSum = new Algorithms.DataAnalysis.CovarianceSum();
119        covSum.Terms.Add(GetCovFunction(node.GetSubtree(0)));
120        foreach (var subTree in node.Subtrees.Skip(1)) {
121          covSum.Terms.Add(GetCovFunction(subTree));
122        }
123        return covSum;
124      } else if (node.Symbol is CovProd) {
125        var covProd = new Algorithms.DataAnalysis.CovarianceProduct();
126        covProd.Factors.Add(GetCovFunction(node.GetSubtree(0)));
127        foreach (var subTree in node.Subtrees.Skip(1)) {
128          covProd.Factors.Add(GetCovFunction(subTree));
129        }
130        return covProd;
131      } else {
132        throw new ArgumentException("unknown symbol " + node.Symbol);
133      }
134    }
135
136
137    private IMeanFunction GetMeanFunction(ISymbolicExpressionTreeNode node) {
138      if (node.Symbol is MeanConst) {
139        return new Algorithms.DataAnalysis.MeanConst();
140      } else if (node.Symbol is MeanLinear) {
141        return new Algorithms.DataAnalysis.MeanLinear();
142      } else if (node.Symbol is MeanProd) {
143        var meanProd = new Algorithms.DataAnalysis.MeanProduct();
144        meanProd.Factors.Add(GetMeanFunction(node.GetSubtree(0)));
145        foreach (var subTree in node.Subtrees.Skip(1)) {
146          meanProd.Factors.Add(GetMeanFunction(subTree));
147        }
148        return meanProd;
149      } else if (node.Symbol is MeanSum) {
150        var meanSum = new Algorithms.DataAnalysis.MeanSum();
151        meanSum.Terms.Add(GetMeanFunction(node.GetSubtree(0)));
152        foreach (var subTree in node.Subtrees.Skip(1)) {
153          meanSum.Terms.Add(GetMeanFunction(subTree));
154        }
155        return meanSum;
156      } else if (node.Symbol is MeanZero) {
157        return new Algorithms.DataAnalysis.MeanZero();
158      } else {
159        throw new ArgumentException("Unknown mean function" + node.Symbol);
160      }
161    }
162
163    //private bool[] CalculateMask(ISymbolicExpressionTreeNode node) {
164    //  var maskNode = node as MeanMaskTreeNode;
165    //  if (maskNode != null) {
166    //    bool[] newMask = CombineMasksProd(maskNode.Mask, CalculateMask(node.GetSubtree(0)));
167    //    return newMask;
168    //  } else if (node.Symbol is MeanProd) {
169    //    bool[] newMask = CalculateMask(node.GetSubtree(0));
170    //    foreach (var subTree in node.Subtrees.Skip(1)) {
171    //      newMask = CombineMasksProd(newMask, CalculateMask(subTree));
172    //    }
173    //    return newMask;
174    //  } else if (node.Symbol is MeanSum) {
175    //    bool[] newMask = CalculateMask(node.GetSubtree(0));
176    //    foreach (var subTree in node.Subtrees.Skip(1)) {
177    //      newMask = CombineMasksSum(newMask, CalculateMask(subTree));
178    //    }
179    //    return newMask;
180    //  } else if (node.SubtreeCount == 1) {
181    //    return CalculateMask(node.GetSubtree(0));
182    //  } else if (node is SymbolicExpressionTreeTerminalNode) {
183    //    return null;
184    //  } else {
185    //    throw new NotImplementedException();
186    //  }
187    //}
188
189    //private bool[] CombineMasksProd(bool[] m, bool[] n) {
190    //  if (m == null) return n;
191    //  if (n == null) return m;
192    //  if (m.Length != n.Length) throw new ArgumentException();
193    //  bool[] res = new bool[m.Length];
194    //  for (int i = 0; i < res.Length; i++)
195    //    res[i] = m[i] | n[i];
196    //  return res;
197    //}
198
199
200    //private bool[] CombineMasksSum(bool[] m, bool[] n) {
201    //  if (m == null) return n;
202    //  if (n == null) return m;
203    //  if (m.Length != n.Length) throw new ArgumentException();
204    //  bool[] res = new bool[m.Length];
205    //  for (int i = 0; i < res.Length; i++)
206    //    res[i] = m[i] & n[i];
207    //  return res;
208    //}
209
210  }
211}
Note: See TracBrowser for help on using the repository browser.