source: branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration/Grammar.cs @ 15974

Last change on this file since 15974 was 15974, checked in by bburlacu, 3 years ago

#2886: implement LRU cache for storing search nodes, introduce SortedSet for handling priorities, fix serialization and cloning

File size: 12.9 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
6using HeuristicLab.Common;
7using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
8using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
9using HeuristicLab.Problems.DataAnalysis;
10using HeuristicLab.Problems.DataAnalysis.Symbolic;
11
12namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
13  public enum GrammarRule {
14    MultipleTerms,
15    MultipleFactors,
16    InverseTerm,
17    Logarithm,
18    Exponentiation,
19    Sine
20  }
21
22  [StorableClass(StorableClassType.AllFieldsAndAllProperties)]
23  public class Grammar : DeepCloneable {
24    public Symbol StartSymbol { get; private set; }
25
26    public Hasher<int> Hasher { get; }
27
28    #region Symbols
29
30    public IReadOnlyDictionary<Symbol, IReadOnlyList<Production>> Productions { get; private set; }
31
32    public NonterminalSymbol Var;
33    public IReadOnlyList<VariableTerminalSymbol> VarTerminals;
34
35    public NonterminalSymbol Expr;
36    public NonterminalSymbol Term;
37    public NonterminalSymbol Factor;
38    public NonterminalSymbol LogFactor;
39    public NonterminalSymbol ExpFactor;
40    public NonterminalSymbol SinFactor;
41
42    public NonterminalSymbol SimpleExpr;
43    public NonterminalSymbol SimpleTerm;
44
45    public NonterminalSymbol InvExpr;
46    public NonterminalSymbol InvTerm;
47
48    public TerminalSymbol Addition;
49    public TerminalSymbol Multiplication;
50    public TerminalSymbol Log;
51    public TerminalSymbol Exp;
52    public TerminalSymbol Sin;
53    public TerminalSymbol Inv;
54
55    public TerminalSymbol Const;
56
57    #endregion
58
59    #region HL Symbols for Parsing ExpressionTrees
60
61    private ISymbol constSy;
62    private ISymbol varSy;
63
64    private ISymbol addSy;
65    private ISymbol mulSy;
66    private ISymbol logSy;
67    private ISymbol expSy;
68    private ISymbol divSy;
69    private ISymbol sinSy;
70
71    private ISymbol rootSy;
72    private ISymbol startSy;
73
74    private InfixExpressionFormatter infixExpressionFormatter;
75    #endregion
76
77    public Grammar(string[] variables) : this(variables, Enum.GetValues(typeof(GrammarRule)).Cast<GrammarRule>()) { }
78
79    protected Grammar(Grammar original, Cloner cloner) : base(original, cloner) {
80      infixExpressionFormatter = cloner.Clone(original.infixExpressionFormatter);
81
82      Productions = original.Productions.ToDictionary(x => cloner.Clone(x.Key), x => (IReadOnlyList<Production>)x.Value.Select(cloner.Clone).ToList());
83      VarTerminals = original.VarTerminals.Select(cloner.Clone).ToList();
84
85      Var = cloner.Clone(original.Var);
86      Expr = cloner.Clone(original.Expr);
87      Term = cloner.Clone(original.Term);
88      Factor = cloner.Clone(original.Factor);
89      LogFactor = cloner.Clone(original.LogFactor);
90      ExpFactor = cloner.Clone(original.ExpFactor);
91      SinFactor = cloner.Clone(original.SinFactor);
92      SimpleExpr = cloner.Clone(original.SimpleExpr);
93      SimpleTerm = cloner.Clone(original.SimpleTerm);
94      InvExpr = cloner.Clone(original.InvExpr);
95      InvTerm = cloner.Clone(original.InvTerm);
96
97      Addition = cloner.Clone(original.Addition);
98      Multiplication = cloner.Clone(original.Multiplication);
99      Log = cloner.Clone(original.Log);
100      Exp = cloner.Clone(original.Exp);
101      Sin = cloner.Clone(original.Sin);
102      Inv = cloner.Clone(original.Inv);
103      Const = cloner.Clone(original.Const);
104
105      StartSymbol = Expr;
106      Hasher = cloner.Clone(original.Hasher);
107
108      InitTreeParser(); // easier this way (and less typing)
109    }
110
111    private void InitProductions(string[] variables, IEnumerable<GrammarRule> includedRules) {
112      #region Define Symbols
113      Var = new NonterminalSymbol("Var");
114
115      Expr = new NonterminalSymbol("Expr");
116      Term = new NonterminalSymbol("Term");
117      Factor = new NonterminalSymbol("Factor");
118      LogFactor = new NonterminalSymbol("LogFactor");
119      ExpFactor = new NonterminalSymbol("ExpFactor");
120      SinFactor = new NonterminalSymbol("SinFactor");
121
122      SimpleExpr = new NonterminalSymbol("SimpleExpr");
123      SimpleTerm = new NonterminalSymbol("SimpleTerm");
124
125      InvExpr = new NonterminalSymbol("InvExpr");
126      InvTerm = new NonterminalSymbol("InvTerm");
127
128      Addition = new TerminalSymbol("+");
129      Multiplication = new TerminalSymbol("*");
130      Log = new TerminalSymbol("log");
131      Exp = new TerminalSymbol("exp");
132      Sin = new TerminalSymbol("sin");
133      Inv = new TerminalSymbol("inv");
134
135      Const = new TerminalSymbol("c");
136      #endregion
137
138      #region Production rules
139      StartSymbol = Expr;
140
141      Dictionary<Symbol, IReadOnlyList<Production>> productions = new Dictionary<Symbol, IReadOnlyList<Production>>();
142
143      // Map each variable to a separate production rule of the "Var" nonterminal symbol.
144      VarTerminals = variables.Select(v => new VariableTerminalSymbol(v)).ToArray();
145      productions[Var] = VarTerminals.Select(v => new Production(v)).ToArray();
146
147      // Expression Grammar Rules
148      var exprProductions = new List<Production>();
149      if (includedRules.Contains(GrammarRule.MultipleTerms))
150        exprProductions.Add(new Production(Const, Term, Multiplication, Expr, Addition));
151
152      exprProductions.Add(new Production(Const, Term, Multiplication, Const, Addition));
153      productions[Expr] = exprProductions.ToArray();
154
155      // Term Grammar Rules
156      var termProductions = new List<Production>();
157      if (includedRules.Contains(GrammarRule.MultipleFactors))
158        termProductions.Add(new Production(Factor, Term, Multiplication));
159      if (includedRules.Contains(GrammarRule.InverseTerm))
160        termProductions.Add(new Production(InvExpr, Inv));
161      termProductions.Add(new Production(Factor));
162      productions[Term] = termProductions.ToArray();
163
164      // Factor Grammar Rules
165      var factorProductions = new List<Production>();
166      factorProductions.Add(new Production(Var));
167      if (includedRules.Contains(GrammarRule.Logarithm))
168        factorProductions.Add(new Production(LogFactor));
169      if (includedRules.Contains(GrammarRule.Exponentiation))
170        factorProductions.Add(new Production(ExpFactor));
171      if (includedRules.Contains(GrammarRule.Sine))
172        factorProductions.Add(new Production(SinFactor));
173      productions[Factor] = factorProductions.ToArray();
174
175      productions[LogFactor] = new[] { new Production(SimpleExpr, Log) };
176      productions[ExpFactor] = new[] { new Production(Const, SimpleTerm, Multiplication, Exp) };
177      productions[SinFactor] = new[] { new Production(SimpleExpr, Sin) };
178
179      productions[SimpleExpr] = new[] {
180        new Production(Const, SimpleTerm, Multiplication, SimpleExpr, Addition),
181        new Production(Const, SimpleTerm, Multiplication, Const, Addition)
182      };
183
184      productions[SimpleTerm] = new[] {
185        new Production(Var, SimpleTerm, Multiplication),
186        new Production(Var)
187      };
188
189      productions[InvExpr] = new[] {
190        new Production(Const, InvTerm, Multiplication, InvExpr, Addition),
191        new Production(Const, InvTerm, Multiplication, Const, Addition)
192      };
193
194      productions[InvTerm] = new[] {
195        new Production(Factor, InvTerm, Multiplication),
196        new Production(Factor)
197      };
198
199      Productions = productions;
200      #endregion
201    }
202
203    private void InitTreeParser() {
204      #region Parsing to SymbolicExpressionTree
205      var symbolicExpressionGrammar = new TypeCoherentExpressionGrammar();
206      symbolicExpressionGrammar.ConfigureAsDefaultRegressionGrammar();
207
208      constSy = symbolicExpressionGrammar.Symbols.OfType<Constant>().First();
209      varSy = symbolicExpressionGrammar.Symbols.OfType<Variable>().First();
210      addSy = symbolicExpressionGrammar.Symbols.OfType<Addition>().First();
211      mulSy = symbolicExpressionGrammar.Symbols.OfType<Multiplication>().First();
212      logSy = symbolicExpressionGrammar.Symbols.OfType<Logarithm>().First();
213      expSy = symbolicExpressionGrammar.Symbols.OfType<Exponential>().First();
214      divSy = symbolicExpressionGrammar.Symbols.OfType<Division>().First();
215      sinSy = symbolicExpressionGrammar.Symbols.OfType<Sine>().First();
216
217      rootSy = symbolicExpressionGrammar.Symbols.OfType<ProgramRootSymbol>().First();
218      startSy = symbolicExpressionGrammar.Symbols.OfType<StartSymbol>().First();
219
220      infixExpressionFormatter = new InfixExpressionFormatter();
221      #endregion
222    }
223
224    public Grammar(string[] variables, IEnumerable<GrammarRule> includedRules) {
225      InitProductions(variables, includedRules);
226      InitTreeParser();
227
228      Hasher = new IntHasher(this);
229    }
230
231    public int GetComplexity(SymbolString s) {
232      int c = 0;
233      int length = s.Count();
234      for (int i = 0; i < length; i++) {
235        if (s[i] is NonterminalSymbol || s[i] is VariableTerminalSymbol) c++;
236      }
237      return c;
238    }
239
240    public double EvaluatePhrase(SymbolString s, IRegressionProblemData problemData, bool optimizeConstants) {
241      SymbolicExpressionTree tree = ParseSymbolicExpressionTree(s);
242      return RSquaredEvaluator.Evaluate(problemData, tree, optimizeConstants);
243    }
244
245    #region Parse to SymbolicExpressionTree
246
247    public string ToInfixString(SymbolString sentence) {
248      Debug.Assert(sentence.Any(), "Trying to evaluate empty sentence!");
249      Debug.Assert(sentence.All(s => s is TerminalSymbol), "Trying to evaluate symbol sequence with nonterminalsymbols!");
250
251      return infixExpressionFormatter.Format(ParseSymbolicExpressionTree(sentence));
252    }
253
254    public SymbolicExpressionTree ParseSymbolicExpressionTree(SymbolString sentence) {
255      Debug.Assert(sentence.Any(), "Trying to evaluate empty sentence!");
256
257      var rootNode = rootSy.CreateTreeNode();
258      var startNode = startSy.CreateTreeNode();
259      rootNode.AddSubtree(startNode);
260
261      Stack<Symbol> parseStack = new Stack<Symbol>(sentence);
262      startNode.AddSubtree(ParseSymbolicExpressionTree(parseStack));
263
264      return new SymbolicExpressionTree(rootNode);
265    }
266
267    public ISymbolicExpressionTreeNode ParseSymbolicExpressionTree(Stack<Symbol> parseStack) {
268      Symbol currentSymbol = parseStack.Pop();
269
270      ISymbolicExpressionTreeNode parsedSubTree = null;
271
272      if (currentSymbol == Addition) {
273        parsedSubTree = addSy.CreateTreeNode();
274        ISymbolicExpressionTreeNode rightSubtree = ParseSymbolicExpressionTree(parseStack);
275        if (rightSubtree is ConstantTreeNode) {
276          ((ConstantTreeNode)rightSubtree).Value = 0.0;
277        }
278        parsedSubTree.AddSubtree(rightSubtree); // left part
279
280        ISymbolicExpressionTreeNode leftSubtree = ParseSymbolicExpressionTree(parseStack);
281        if (leftSubtree is ConstantTreeNode) {
282          ((ConstantTreeNode)leftSubtree).Value = 0.0;
283        }
284        parsedSubTree.AddSubtree(leftSubtree); // right part
285
286      } else if (currentSymbol == Multiplication) {
287        parsedSubTree = mulSy.CreateTreeNode();
288        parsedSubTree.AddSubtree(ParseSymbolicExpressionTree(parseStack)); // left part
289        parsedSubTree.AddSubtree(ParseSymbolicExpressionTree(parseStack)); // right part
290
291      } else if (currentSymbol == Log) {
292        parsedSubTree = logSy.CreateTreeNode();
293        parsedSubTree.AddSubtree(ParseSymbolicExpressionTree(parseStack));
294
295      } else if (currentSymbol == Exp) {
296        parsedSubTree = expSy.CreateTreeNode();
297        parsedSubTree.AddSubtree(ParseSymbolicExpressionTree(parseStack));
298
299      } else if (currentSymbol == Sin) {
300        parsedSubTree = sinSy.CreateTreeNode();
301        parsedSubTree.AddSubtree(ParseSymbolicExpressionTree(parseStack));
302
303      } else if (currentSymbol == Inv) {
304        parsedSubTree = divSy.CreateTreeNode();
305        ConstantTreeNode dividend = (ConstantTreeNode)constSy.CreateTreeNode();
306        dividend.Value = 1.0;
307        parsedSubTree.AddSubtree(dividend);
308        parsedSubTree.AddSubtree(ParseSymbolicExpressionTree(parseStack));
309
310      } else if (currentSymbol == Const) {
311        ConstantTreeNode constNode = (ConstantTreeNode)constSy.CreateTreeNode();
312        constNode.Value = 1.0;
313        parsedSubTree = constNode;
314
315      } else if (currentSymbol is VariableTerminalSymbol) {
316        VariableTreeNode varNode = (VariableTreeNode)varSy.CreateTreeNode();
317        varNode.Weight = 1.0;
318        varNode.VariableName = currentSymbol.StringRepresentation;
319        parsedSubTree = varNode;
320
321      } else if (currentSymbol is NonterminalSymbol) {
322        ConstantTreeNode constNode = (ConstantTreeNode)constSy.CreateTreeNode();
323        constNode.Value = 0.0;
324        parsedSubTree = constNode;
325      }
326
327      Debug.Assert(parsedSubTree != null);
328      return parsedSubTree;
329    }
330    #endregion
331
332    #region abstract DeepCloneable methods
333    public override IDeepCloneable Clone(Cloner cloner) {
334      return new Grammar(this, cloner);
335    }
336    #endregion
337  }
338}
Note: See TracBrowser for help on using the repository browser.