Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GaussianProcessTuning/HeuristicLab.Problems.GaussianProcessTuning/Grammar.cs @ 9387

Last change on this file since 9387 was 9387, checked in by gkronber, 11 years ago

#1967: added CovNN symbol and tree node

File size: 6.5 KB
Line 
1using System.Collections.Generic;
2using System.Linq;
3using HeuristicLab.Common;
4using HeuristicLab.Core;
5using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
6using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
7
8namespace HeuristicLab.Problems.GaussianProcessTuning {
9  [StorableClass]
10  [Item("Gaussian Process Configuration Grammar", "The grammar for the Gaussian process configration problem.")]
11  public sealed class Grammar : SymbolicExpressionGrammar {
12    [Storable]
13    private int dimension;
14    public int Dimension {
15      get { return dimension; }
16      set {
17        if (dimension != value) {
18          dimension = value;
19          UpdateSymbolDimension();
20        }
21      }
22    }
23
24    [StorableConstructor]
25    private Grammar(bool deserializing) : base(deserializing) { }
26    private Grammar(Grammar original, Cloner cloner)
27      : base(original, cloner) {
28    }
29
30    public Grammar()
31      : base("Gaussian Process Configuration Grammar", "The grammar for the Gaussian process configuration problem.") {
32      Dimension = 10;
33      Initialize();
34    }
35
36    [StorableHook(HookType.AfterDeserialization)]
37    private void AfterDeserialization() {
38    }
39    public override IDeepCloneable Clone(Cloner cloner) {
40      return new Grammar(this, cloner);
41    }
42    private void Initialize() {
43      // create all symbols
44      //var meanOne = new MeanOne();
45      var meanConst = new MeanConst();
46      var meanLinear = new MeanLinear(Dimension);
47      //var meanScale = new MeanScale();
48      var meanSum = new MeanSum();
49      var meanProd = new MeanProd();
50      //var meanPow2 = new MeanPow(2) { Name = "MeanPow2" };
51      //var meanPow3 = new MeanPow(3) { Name = "MeanPow3" };
52      //var meanMask = new MeanMask(Dimension);
53      var covConst = new CovConst();
54      var covLin = new CovLin();
55      var covLinArd = new CovLinArd(Dimension);
56      var covSeArd = new CovSeArd(Dimension);
57      var covSeIso = new CovSeIso();
58      var covRQard = new CovRQArd(Dimension);
59      var covRQiso = new CovRQIso();
60      var covNN = new CovNn();
61      var covMatern1 = new CovMatern(1);
62      covMatern1.Name = "CovMatern1";
63      var covMatern3 = new CovMatern(3);
64      covMatern3.Name = "CovMatern3";
65      var covMatern5 = new CovMatern(5);
66      covMatern5.Name = "CovMatern5";
67      var covPeriodic = new CovPeriodic();
68      var covPeriodic1 = new CovPeriodic(1.0);
69      covPeriodic1.Name = "CovPeriodic(1.0)";
70      var covPeriodicCO = new CovPeriodic(0.021817864467425927);
71      covPeriodicCO.Name = "CovPeriodicCO";
72      var covNoise = new CovNoise();
73      var covScale = new CovScale();
74      var covSum = new CovSum();
75      var covProd = new CovProd();
76      var covMask = new CovMask(Dimension);
77      var likGauss = new LikGauss();
78
79      var configurationStartSymbol = new ConfigurationStartSymbol();
80
81      var meanSymbols = new List<ISymbol>()
82                         {
83                            meanConst,
84                            meanLinear,
85                            meanSum, 
86                            meanProd,
87                            //meanPow2,
88                            //meanPow3,
89                            //meanMask,
90                         };
91      var covSymbols = new List<ISymbol>()
92                         {
93                            covConst,
94                            covLin,
95                            covLinArd,
96                            covSeArd,
97                            covSeIso,
98                            covRQiso,
99                            covRQard,
100                            covMatern1,
101                            covMatern3,
102                            covMatern5,
103                            covNN,
104                            covPeriodic,
105                            covPeriodic1,
106                            covPeriodicCO,
107                            covNoise,
108                            covScale,
109                            covSum,
110                            covProd, 
111                            covMask,
112                         };
113      var likSymbols = new List<ISymbol>()
114                         {
115                           likGauss
116                         };
117      var allSymbols = covSymbols.Concat(meanSymbols).Concat(likSymbols);
118
119      AddSymbol(configurationStartSymbol);
120      // add all symbols to the grammar
121      foreach (var s in allSymbols)
122        AddSymbol(s);
123
124      AddAllowedChildSymbol(StartSymbol, configurationStartSymbol, 0);
125
126      foreach (var meanSymbol in meanSymbols) {
127        AddAllowedChildSymbol(configurationStartSymbol, meanSymbol, 0);
128      }
129      foreach (var covSymbol in covSymbols) {
130        AddAllowedChildSymbol(configurationStartSymbol, covSymbol, 1);
131      }
132      foreach (var likSymbol in likSymbols) {
133        AddAllowedChildSymbol(configurationStartSymbol, likSymbol, 2);
134      }
135      foreach (var meanFunctionSymbol in new ISymbol[] { meanSum, meanProd }) {
136        foreach (var meanSymbol in meanSymbols) {
137          for (int i = 0; i < meanFunctionSymbol.MaximumArity; i++) {
138            AddAllowedChildSymbol(meanFunctionSymbol, meanSymbol, i);
139          }
140        }
141      }
142
143      foreach (var covFunctionSymbol in new ISymbol[] { covSum, covProd, covScale }) {
144        foreach (var covSymbol in covSymbols) {
145          for (int i = 0; i < covFunctionSymbol.MaximumArity; i++) {
146            AddAllowedChildSymbol(covFunctionSymbol, covSymbol, i);
147          }
148        }
149      }
150
151      // mask
152      foreach (var covSymbol in new List<ISymbol>()
153                         {
154                            covConst,
155                            covLin,
156                            covLinArd,
157                            covSeArd,
158                            covSeIso,
159                            covRQiso,
160                            covRQard,
161                            covMatern1,
162                            covMatern3,
163                            covMatern5,
164                            covNN,
165                            covPeriodic,
166                            covPeriodic1,
167                            covPeriodicCO,
168                         }) {
169        for (int i = 0; i < covMask.MaximumArity; i++) {
170          AddAllowedChildSymbol(covMask, covSymbol, i);
171        }
172      }
173    }
174
175    private void UpdateSymbolDimension() {
176      foreach (var s in Symbols.OfType<IDimensionSymbol>()) {
177        s.Dimension = dimension;
178      }
179    }
180  }
181}
Note: See TracBrowser for help on using the repository browser.