source: branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration/Grammar.cs @ 15806

Last change on this file since 15806 was 15806, checked in by gkronber, 20 months ago

#2886 made a few comments

File size: 13.9 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
6using HeuristicLab.Common;
7using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
8using HeuristicLab.Problems.DataAnalysis.Symbolic;
9
10namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
11  public class Grammar {
12
13    public Symbol StartSymbol { get; }
14
15    #region Symbols
16    public VariableSymbol Var;
17
18    public NonterminalSymbol Expr;
19    public NonterminalSymbol Term;
20    public NonterminalSymbol Factor;
21    public NonterminalSymbol LogFactor;
22    public NonterminalSymbol ExpFactor;
23    public NonterminalSymbol SinFactor;
24
25    public NonterminalSymbol SimpleExpr;
26    public NonterminalSymbol SimpleTerm;
27
28    public NonterminalSymbol InvExpr;
29    public NonterminalSymbol InvTerm;
30
31    public TerminalSymbol Addition;
32    public TerminalSymbol Multiplication;
33    public TerminalSymbol Log;
34    public TerminalSymbol Exp;
35    public TerminalSymbol Sin;
36    public TerminalSymbol Inv;
37
38    // For infix notation
39    public TerminalSymbol OpeningBracket;
40    public TerminalSymbol ClosingBracket;
41
42    #endregion
43
44    #region HL Symbols for Parsing ExpressionTrees
45    private TypeCoherentExpressionGrammar symbolicExpressionGrammar;
46
47    private ISymbol constSy;
48    private ISymbol varSy;
49
50    private ISymbol addSy;
51    private ISymbol mulSy;
52    private ISymbol logSy;
53    private ISymbol expSy;
54    private ISymbol divSy;
55    private ISymbol sinSy;
56
57    private ISymbol rootSy;
58    private ISymbol startSy;
59    #endregion
60
61    public Grammar(string[] variables) {
62      #region Define Symbols
63      Var = new VariableSymbol("var", variables);
64
65      Expr = new NonterminalSymbol("Expr");
66      Term = new NonterminalSymbol("Term");
67      Factor = new NonterminalSymbol("Factor");
68      LogFactor = new NonterminalSymbol("LogFactor");
69      ExpFactor = new NonterminalSymbol("ExpFactor");
70      SinFactor = new NonterminalSymbol("SinFactor");
71
72      SimpleExpr = new NonterminalSymbol("SimpleExpr");
73      SimpleTerm = new NonterminalSymbol("SimpleTerm");
74
75      InvExpr = new NonterminalSymbol("InvExpr");
76      InvTerm = new NonterminalSymbol("InvTerm");
77
78      Addition = new TerminalSymbol("+");
79      Multiplication = new TerminalSymbol("*");
80      Log = new TerminalSymbol("log");
81      Exp = new TerminalSymbol("exp");
82      Sin = new TerminalSymbol("sin");
83      Inv = new TerminalSymbol("inv");
84
85      OpeningBracket = new TerminalSymbol("(");
86      ClosingBracket = new TerminalSymbol(")");
87      #endregion
88
89      #region Production rules
90      StartSymbol = Expr;
91
92      Expr.AddProduction(Term, Expr, Addition);
93      Expr.AddProduction(Term);
94
95      Term.AddProduction(Factor, Term, Multiplication);
96      Term.AddProduction(Factor);
97      Term.AddProduction(InvExpr, Inv);
98
99      Factor.AddProduction(Var);
100      Factor.AddProduction(LogFactor);
101      Factor.AddProduction(ExpFactor);
102      Factor.AddProduction(SinFactor);
103
104      LogFactor.AddProduction(SimpleExpr, Log);
105      ExpFactor.AddProduction(SimpleTerm, Exp);
106      SinFactor.AddProduction(SimpleExpr, Sin);
107
108      SimpleExpr.AddProduction(SimpleTerm, SimpleExpr, Addition);
109      SimpleExpr.AddProduction(SimpleTerm);
110
111      SimpleTerm.AddProduction(Var, SimpleTerm, Multiplication);
112      SimpleTerm.AddProduction(Var);
113
114      InvExpr.AddProduction(InvTerm, InvExpr, Addition);
115      InvExpr.AddProduction(InvTerm);
116
117      InvTerm.AddProduction(Factor, InvTerm, Multiplication);
118      InvTerm.AddProduction(Factor);
119      #endregion
120
121      #region Parsing to SymbolicExpressionTree
122      symbolicExpressionGrammar = new TypeCoherentExpressionGrammar();
123      symbolicExpressionGrammar.ConfigureAsDefaultRegressionGrammar();
124
125      constSy = symbolicExpressionGrammar.Symbols.OfType<Constant>().First();
126      varSy = symbolicExpressionGrammar.Symbols.OfType<Variable>().First();
127      addSy = symbolicExpressionGrammar.Symbols.OfType<Addition>().First();
128      mulSy = symbolicExpressionGrammar.Symbols.OfType<Multiplication>().First();
129      logSy = symbolicExpressionGrammar.Symbols.OfType<Logarithm>().First();
130      expSy = symbolicExpressionGrammar.Symbols.OfType<Exponential>().First();
131      divSy = symbolicExpressionGrammar.Symbols.OfType<Division>().First();
132      sinSy = symbolicExpressionGrammar.Symbols.OfType<Sine>().First();
133
134      rootSy = symbolicExpressionGrammar.Symbols.OfType<ProgramRootSymbol>().First();
135      startSy = symbolicExpressionGrammar.Symbols.OfType<StartSymbol>().First();
136
137      #endregion
138    }
139
140    #region Hashing
141    public int CalcHashCode(SymbolString sentence) {
142      return CalcHashCode<int>(sentence, AggregateIntHashes);
143    }
144
145    private int CalcHashCode<THashType>(SymbolString sentence, Func<Symbol, IEnumerable<THashType>, THashType> aggregateFunction) {
146      Debug.Assert(sentence.Any(), "Trying to evaluate empty sentence!");
147
148      Stack<Symbol> parseStack = new Stack<Symbol>(sentence);
149
150      Symbol peek = parseStack.Peek();
151      return aggregateFunction(peek, GetSubtreeHashes(parseStack, aggregateFunction)).GetHashCode();
152    }
153
154    private THashType[] GetSubtreeHashes<THashType>(Stack<Symbol> parseStack, Func<Symbol, IEnumerable<THashType>, THashType> aggregateHashes) {
155      Symbol currentSymbol = parseStack.Pop();
156
157      // ADDITION
158      if (ReferenceEquals(currentSymbol, Addition)) {
159        var uniqueChildHashes = new HashSet<THashType>();
160
161        // First subtree
162        if (ReferenceEquals(parseStack.Peek(), Addition)) {
163          uniqueChildHashes.UnionWith(GetSubtreeHashes(parseStack, aggregateHashes));
164        } else {
165          var peek = parseStack.Peek();
166          uniqueChildHashes.Add(aggregateHashes(peek, GetSubtreeHashes(parseStack, aggregateHashes)));
167        }
168        // Second subtree
169        if (ReferenceEquals(parseStack.Peek(), Addition)) {
170          uniqueChildHashes.UnionWith(GetSubtreeHashes(parseStack, aggregateHashes));
171        } else {
172          var peek = parseStack.Peek();
173          uniqueChildHashes.Add(aggregateHashes(peek, GetSubtreeHashes(parseStack, aggregateHashes)));
174        }
175
176        var result = uniqueChildHashes.ToArray();
177        Array.Sort(result);
178        return result;
179      }
180
181      // MULTIPLICATION
182      if (ReferenceEquals(currentSymbol, Multiplication)) {
183        var childHashes = new List<THashType>();
184
185        // First subtree
186        if (ReferenceEquals(parseStack.Peek(), Multiplication)) {
187          childHashes.AddRange(GetSubtreeHashes(parseStack, aggregateHashes));
188        } else {
189          childHashes.Add(aggregateHashes(parseStack.Peek(), GetSubtreeHashes(parseStack, aggregateHashes)));
190        }
191        // Second subtree
192        if (ReferenceEquals(parseStack.Peek(), Multiplication)) {
193          childHashes.AddRange(GetSubtreeHashes(parseStack, aggregateHashes));
194        } else {
195          childHashes.Add(aggregateHashes(parseStack.Peek(), GetSubtreeHashes(parseStack, aggregateHashes)));
196        }
197
198        // Sort due to commutativity
199        childHashes.Sort();
200
201        // Cancel out inverse factors.
202        bool[] isFactorRemaining = Enumerable.Repeat(true, childHashes.Count).ToArray();
203
204        for (int i = 0; i < isFactorRemaining.Length; i++) {
205          if (!isFactorRemaining[i]) continue;
206          if (isFactorRemaining.Count() <= 2) break; // Until we have constants, we can't cancel out all terms.
207
208          var currFactor = childHashes[i];
209          var invFactor = aggregateHashes(Inv, currFactor.ToEnumerable());
210
211          int indexOfInv = childHashes.IndexOf(invFactor);
212          if (indexOfInv >= 0 && isFactorRemaining[indexOfInv]) {
213            isFactorRemaining[i] = isFactorRemaining[indexOfInv] = false;
214          }
215        }
216        return Enumerable
217          .Range(0, isFactorRemaining.Length)
218          .Where(i => isFactorRemaining[i])
219          .Select(i => childHashes[i])
220          .ToArray();
221      }
222
223      // LOG, EXP, SIN, INV
224      if (ReferenceEquals(currentSymbol, Log) || ReferenceEquals(currentSymbol, Exp) ||
225          ReferenceEquals(currentSymbol, Sin) || ReferenceEquals(currentSymbol, Inv)) {
226        return new[] { aggregateHashes(parseStack.Peek(), GetSubtreeHashes(parseStack, aggregateHashes)) };
227      }
228
229      // var or nonterminal symbol
230      return new[] { aggregateHashes(currentSymbol, Enumerable.Empty<THashType>()) };
231    }
232
233    private string AggregateStringHashes(Symbol operatorSym, IEnumerable<string> hashes) {
234      var hashesArray = hashes.ToArray();
235
236      if ((ReferenceEquals(operatorSym, Addition) || ReferenceEquals(operatorSym, Multiplication)) && hashesArray.Count() <= 1) {
237        return hashesArray[0];
238      }
239      if (operatorSym is NonterminalSymbol || Var.VariableTerminalSymbols.Contains(operatorSym)) {
240        return operatorSym.StringRepresentation;
241      }
242
243      return $"[{hashesArray.Aggregate(operatorSym.StringRepresentation, (result, ti) => string.Concat(result, " ° ", ti))}]";      // TODO: use string join instead of string.Concat
244    }
245
246    private int AggregateIntHashes(Symbol operatorSym, IEnumerable<int> hashes) {
247      var hashesArray = hashes.ToArray();
248
249      int start;
250      if ((ReferenceEquals(operatorSym, Addition) || ReferenceEquals(operatorSym, Multiplication)) &&
251          hashesArray.Count() <= 1) {
252        start = 0;
253
254      } else if (operatorSym is NonterminalSymbol || Var.VariableTerminalSymbols.Contains(operatorSym)) {
255        return operatorSym.StringRepresentation.GetHashCode();
256
257      } else {
258        start = operatorSym.StringRepresentation.GetHashCode();
259      }
260
261      return hashesArray.Aggregate(start, (result, ti) => ((result << 5) + result) ^ ti.GetHashCode());
262    }
263    #endregion
264
265    #region Parse to SymbolicExpressionTree
266    public SymbolicExpressionTree ParseSymbolicExpressionTree(SymbolString sentence) {
267      Debug.Assert(sentence.Any(), "Trying to evaluate empty sentence!");
268      Debug.Assert(sentence.All(s => s is TerminalSymbol), "Trying to evaluate symbol sequence with nonterminalsymbols!");
269
270      symbolicExpressionGrammar.ConfigureAsDefaultRegressionGrammar();     // TODO: not necessary to call this for each sentence
271
272      var rootNode = rootSy.CreateTreeNode();
273      var startNode = startSy.CreateTreeNode();
274      rootNode.AddSubtree(startNode);
275
276      Stack<TerminalSymbol> parseStack = new Stack<TerminalSymbol>(sentence.OfType<TerminalSymbol>());
277      startNode.AddSubtree(ParseSymbolicExpressionTree(parseStack));
278
279      return new SymbolicExpressionTree(rootNode);
280    }
281
282    public ISymbolicExpressionTreeNode ParseSymbolicExpressionTree(Stack<TerminalSymbol> parseStack) {
283      TerminalSymbol currentSymbol = parseStack.Pop();
284
285      ISymbolicExpressionTreeNode parsedSubTree = null;
286
287      if (ReferenceEquals(currentSymbol, Addition)) {
288        parsedSubTree = addSy.CreateTreeNode();
289        parsedSubTree.AddSubtree(ParseSymbolicExpressionTree(parseStack)); // left part
290        parsedSubTree.AddSubtree(ParseSymbolicExpressionTree(parseStack)); // right part
291
292      } else if (ReferenceEquals(currentSymbol, Multiplication)) {
293        parsedSubTree = mulSy.CreateTreeNode();
294        parsedSubTree.AddSubtree(ParseSymbolicExpressionTree(parseStack)); // left part
295        parsedSubTree.AddSubtree(ParseSymbolicExpressionTree(parseStack)); // right part
296
297      } else if (ReferenceEquals(currentSymbol, Log)) {
298        parsedSubTree = logSy.CreateTreeNode();
299        parsedSubTree.AddSubtree(ParseSymbolicExpressionTree(parseStack));
300
301      } else if (ReferenceEquals(currentSymbol, Exp)) {
302        parsedSubTree = expSy.CreateTreeNode();
303        parsedSubTree.AddSubtree(ParseSymbolicExpressionTree(parseStack));
304
305      } else if (ReferenceEquals(currentSymbol, Sin)) {
306        parsedSubTree = sinSy.CreateTreeNode();
307        parsedSubTree.AddSubtree(ParseSymbolicExpressionTree(parseStack));
308
309      } else if (ReferenceEquals(currentSymbol, Inv)) {
310        parsedSubTree = divSy.CreateTreeNode();
311        ConstantTreeNode dividend = (ConstantTreeNode)constSy.CreateTreeNode();
312        dividend.Value = 1.0;
313        parsedSubTree.AddSubtree(dividend);
314        parsedSubTree.AddSubtree(ParseSymbolicExpressionTree(parseStack));
315
316      } else if (Var.VariableTerminalSymbols.Contains(currentSymbol)) {
317        VariableTreeNode varNode = (VariableTreeNode)varSy.CreateTreeNode();
318        varNode.Weight = 1.0;
319        varNode.VariableName = currentSymbol.StringRepresentation;
320        parsedSubTree = varNode;
321      }
322
323      Debug.Assert(parsedSubTree != null);
324      return parsedSubTree;
325    }
326    #endregion
327
328    #region Parse to Infix string
329
330    public SymbolString PostfixToInfixParser(SymbolString phrase) {
331      Stack<Symbol> parseStack = new Stack<Symbol>(phrase);
332
333      return PostfixToInfixSubtreeParser(parseStack);
334    }
335
336    private SymbolString PostfixToInfixSubtreeParser(Stack<Symbol> parseStack) {
337      Symbol head = parseStack.Pop();
338
339      SymbolString result = new SymbolString();
340
341      if (ReferenceEquals(head, Addition) || ReferenceEquals(head, Multiplication)) {
342        // right part
343        SymbolString rightPart = PostfixToInfixSubtreeParser(parseStack);
344        SymbolString leftPart = PostfixToInfixSubtreeParser(parseStack);
345
346        result.AddRange(leftPart);
347        result.Add(head);
348        result.AddRange(rightPart);
349
350      } else if (ReferenceEquals(head, Log) || ReferenceEquals(head, Exp)
351              || ReferenceEquals(head, Sin) || ReferenceEquals(head, Inv)) {
352        result.Add(head);
353        result.Add(OpeningBracket);
354        result.AddRange(PostfixToInfixSubtreeParser(parseStack));
355        result.Add(ClosingBracket);
356
357      } else {
358        result.Add(head);
359      }
360      return result;
361    }
362
363    #endregion
364  }
365}
Note: See TracBrowser for help on using the repository browser.