Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration/Grammar.cs @ 15725

Last change on this file since 15725 was 15725, checked in by lkammere, 6 years ago

#2886: Refactor tree hash function.

File size: 10.7 KB
RevLine 
[15725]1using System;
[15714]2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
6using HeuristicLab.Common;
[15722]7using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
8using HeuristicLab.Problems.DataAnalysis.Symbolic;
[15714]9
10namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
[15712]11  public class Grammar {
12
13    public Symbol StartSymbol;
14
[15714]15    #region Symbols
[15722]16
[15714]17    public VariableSymbol Var;
18
19    public NonterminalSymbol Expr;
20    public NonterminalSymbol Term;
21    public NonterminalSymbol Factor;
22
23    public TerminalSymbol Addition;
24    public TerminalSymbol Multiplication;
[15722]25
[15714]26    #endregion
27
28
[15722]29    #region HL Symbols for Parsing ExpressionTrees
30
31    private TypeCoherentExpressionGrammar symbolicExpressionGrammar;
32
33    private ISymbol constSy;
34    private ISymbol varSy;
35
36    private ISymbol addSy;
37    private ISymbol mulSy;
38    private ISymbol logSy;
39    private ISymbol expSy;
40    private ISymbol divSy;
41
42    private ISymbol rootSy;
43    private ISymbol startSy;
44
45    #endregion
46
[15712]47    public Grammar(string[] variables) {
48      #region Define Symbols
[15714]49      Var = new VariableSymbol("var", variables);
[15712]50
[15714]51      Expr = new NonterminalSymbol("Expr");
52      Term = new NonterminalSymbol("Term");
53      Factor = new NonterminalSymbol("Factor");
[15712]54
[15714]55      Addition = new TerminalSymbol("+");
56      Multiplication = new TerminalSymbol("*");
[15712]57      #endregion
58
59
[15723]60      #region Production rules
61      // order of production is important, since they are accessed via index
62      // in memoization.
[15714]63      StartSymbol = Expr;
[15712]64
[15714]65      Expr.AddProduction(Term, Expr, Addition);
66      Expr.AddProduction(Term);
[15712]67
[15714]68      Term.AddProduction(Factor, Term, Multiplication);
69      Term.AddProduction(Factor);
[15712]70
[15714]71      Factor.AddProduction(Var);
[15712]72      #endregion
[15722]73
74      #region Parsing to SymbolicExpressionTree
75      symbolicExpressionGrammar = new TypeCoherentExpressionGrammar();
76      symbolicExpressionGrammar.ConfigureAsDefaultRegressionGrammar();
77
78      constSy = symbolicExpressionGrammar.Symbols.OfType<Constant>().First();
79      varSy = symbolicExpressionGrammar.Symbols.OfType<Variable>().First();
80      addSy = symbolicExpressionGrammar.AllowedSymbols.OfType<Addition>().First();
81      mulSy = symbolicExpressionGrammar.AllowedSymbols.OfType<Multiplication>().First();
82      logSy = symbolicExpressionGrammar.AllowedSymbols.OfType<Logarithm>().First();
83      expSy = symbolicExpressionGrammar.AllowedSymbols.OfType<Exponential>().First();
84      divSy = symbolicExpressionGrammar.AllowedSymbols.OfType<Division>().First();
85
86      rootSy = symbolicExpressionGrammar.AllowedSymbols.OfType<ProgramRootSymbol>().First();
87      startSy = symbolicExpressionGrammar.AllowedSymbols.OfType<StartSymbol>().First();
88
89      #endregion
[15712]90    }
[15714]91
[15723]92    /*
93    #region Memoize subtrees
94
95    public void MemoizeSubtrees(SymbolString sentence) {
96      Stack<TerminalSymbol> parseStack = new Stack<TerminalSymbol>(sentence.OfType<TerminalSymbol>());
97
98      // Parse root symbol "+"
99      MemoizeSubtreeExpression(parseStack);
100    }
101
102    private SymbolString MemoizeSubtreeExpression(Stack<TerminalSymbol> parseStack) {
103      SymbolString subtree = new SymbolString();
104
105      if (ReferenceEquals(parseStack.Peek(), Addition)) {
106        subtree.Add(parseStack.Pop());
107        subtree.InsertRange(0, MemoizeSubtreeExpression(parseStack));
108        subtree.InsertRange(0, MemoizeSubtreeTerm(parseStack));
109
110        Expr.Alternatives[0].GeneratedSentences.Add(subtree);
111      } else {
112        subtree.InsertRange(0, MemoizeSubtreeTerm(parseStack));
113
114        Expr.Alternatives[1].GeneratedSentences.Add(subtree);
115      }
116
117      return subtree;
118    }
119
120    private SymbolString MemoizeSubtreeTerm(Stack<TerminalSymbol> parseStack) {
121      SymbolString subtree = new SymbolString();
122
123      if (ReferenceEquals(parseStack.Peek(), Multiplication)) {
124        subtree.Add(parseStack.Pop());
125        subtree.InsertRange(0, MemoizeSubtreeTerm(parseStack));
126        subtree.InsertRange(0, MemoizeSubtreeFactor(parseStack));
127
128        Term.Alternatives[0].GeneratedSentences.Add(subtree);
129      } else {
130        subtree.InsertRange(0, MemoizeSubtreeFactor(parseStack));
131
132        Term.Alternatives[1].GeneratedSentences.Add(subtree);
133      }
134
135      return subtree;
136    }
137
138    private SymbolString MemoizeSubtreeFactor(Stack<TerminalSymbol> parseStack) {
139      SymbolString subtree = new SymbolString(MemoizeSubtreeVar(parseStack));
140
141      Factor.Alternatives[0].GeneratedSentences.Add(subtree);
142      return subtree;
143    }
144
145    private SymbolString MemoizeSubtreeVar(Stack<TerminalSymbol> parseStack) {
146      SymbolString subtree = new SymbolString(parseStack.Pop().ToEnumerable());
147
148      // ... not really
149      //Var.Alternatives[0].GeneratedSentences.Add(subtree);
150      return subtree;
151    }
152
153
154    #endregion
155    */
[15724]156
[15723]157    #region Hashing
[15714]158    public int CalcHashCode(SymbolString sentence) {
159      Debug.Assert(sentence.Any(), "Trying to evaluate empty sentence!");
160      Debug.Assert(sentence.All(s => s is TerminalSymbol), "Trying to evaluate symbol sequence with nonterminalsymbols!");
161
162      Stack<TerminalSymbol> parseStack = new Stack<TerminalSymbol>(sentence.OfType<TerminalSymbol>());
163
[15725]164      TerminalSymbol peek = parseStack.Peek();
165      int[] subtreeHashes = GetSubtreeHashes(parseStack);
166      return AggregateHashes(peek, subtreeHashes);
[15714]167    }
168
[15725]169    private int[] GetSubtreeHashes(Stack<TerminalSymbol> parseStack) {
170      TerminalSymbol currentSymbol = parseStack.Pop();
[15714]171
[15725]172      // VARIABLE
[15714]173      if (Var.VariableTerminalSymbols.Contains(currentSymbol)) {
[15725]174        return currentSymbol.StringRepresentation.GetHashCode().ToEnumerable().ToArray();
175      }
[15714]176
[15725]177      // MULTIPLICATION
178      if (ReferenceEquals(currentSymbol, Multiplication)) {
179        List<int> childHashes = new List<int>();
[15714]180
181        // First subtree
[15725]182        if (ReferenceEquals(parseStack.Peek(), Multiplication)) {
183          childHashes.AddRange(GetSubtreeHashes(parseStack));
[15714]184        } else {
[15725]185          childHashes.Add(AggregateHashes(parseStack.Peek(), GetSubtreeHashes(parseStack)));
[15714]186        }
187        // Second subtree
[15725]188        if (ReferenceEquals(parseStack.Peek(), Multiplication)) {
189          childHashes.AddRange(GetSubtreeHashes(parseStack));
[15714]190        } else {
[15725]191          childHashes.Add(AggregateHashes(parseStack.Peek(), GetSubtreeHashes(parseStack)));
[15714]192        }
193
194        // Sort due to commutativity
195        childHashes.Sort();
[15725]196        return childHashes.ToArray();
197      }
[15714]198
[15725]199      // ADDITION
200      if (ReferenceEquals(currentSymbol, Addition)) {
[15714]201        HashSet<int> uniqueChildHashes = new HashSet<int>();
202
203        // First subtree
[15725]204        if (ReferenceEquals(parseStack.Peek(), Addition)) {
205          uniqueChildHashes.UnionWith(GetSubtreeHashes(parseStack));
[15714]206        } else {
[15725]207          var peek = parseStack.Peek();
208          uniqueChildHashes.Add(AggregateHashes(peek, GetSubtreeHashes(parseStack)));
[15714]209        }
210        // Second subtree
[15725]211        if (ReferenceEquals(parseStack.Peek(), Addition)) {
212          uniqueChildHashes.UnionWith(GetSubtreeHashes(parseStack));
[15714]213        } else {
[15725]214          var peek = parseStack.Peek();
215          uniqueChildHashes.Add(AggregateHashes(peek, GetSubtreeHashes(parseStack)));
[15714]216        }
217
[15725]218        var result = uniqueChildHashes.ToList();
219        result.Sort();
220        return result.ToArray();
[15714]221      }
[15725]222      throw new ArgumentException("Trying to hash malformed sentence!");
[15714]223    }
224
[15725]225    private int AggregateHashes(TerminalSymbol rule, IEnumerable<int> hashes) {
226      // If multiple subtrees are "merged" (e.g. added, multiplied, etc.), consider the executed operation
227      var hashesArray = hashes.ToArray();
228      int start = hashesArray.Length > 1 ? rule.StringRepresentation.GetHashCode() : 0;
229      return hashesArray.Aggregate(start, (result, ti) => ((result << 5) + result) ^ ti.GetHashCode());
[15714]230    }
[15722]231    #endregion
232
233    #region Parse to SymbolicExpressionTree
234    public SymbolicExpressionTree ParseSymbolicExpressionTree(SymbolString sentence) {
235      Debug.Assert(sentence.Any(), "Trying to evaluate empty sentence!");
236      Debug.Assert(sentence.All(s => s is TerminalSymbol), "Trying to evaluate symbol sequence with nonterminalsymbols!");
237
238      symbolicExpressionGrammar.ConfigureAsDefaultRegressionGrammar();
239
240      var rootNode = rootSy.CreateTreeNode();
241      var startNode = startSy.CreateTreeNode();
242      rootNode.AddSubtree(startNode);
243
244      Stack<TerminalSymbol> parseStack = new Stack<TerminalSymbol>(sentence.OfType<TerminalSymbol>());
245      startNode.AddSubtree(ParseSymbolicExpressionTree(parseStack));
246
247      return new SymbolicExpressionTree(rootNode);
248    }
249
250    public ISymbolicExpressionTreeNode ParseSymbolicExpressionTree(Stack<TerminalSymbol> parseStack) {
251      TerminalSymbol currentSymbol = parseStack.Pop();
252
253      ISymbolicExpressionTreeNode parsedSubTree = null;
254
255      if (ReferenceEquals(currentSymbol, Addition)) {
256        parsedSubTree = addSy.CreateTreeNode();
257        parsedSubTree.AddSubtree(ParseSymbolicExpressionTree(parseStack)); // left part
258        parsedSubTree.AddSubtree(ParseSymbolicExpressionTree(parseStack)); // right part
259
260      } else if (ReferenceEquals(currentSymbol, Multiplication)) {
261        parsedSubTree = mulSy.CreateTreeNode();
262        parsedSubTree.AddSubtree(ParseSymbolicExpressionTree(parseStack)); // left part
263        parsedSubTree.AddSubtree(ParseSymbolicExpressionTree(parseStack)); // right part
264
265      } else if (Var.VariableTerminalSymbols.Contains(currentSymbol)) {
266        VariableTreeNode varNode = (VariableTreeNode)varSy.CreateTreeNode();
267        varNode.Weight = 1.0;
268        varNode.VariableName = currentSymbol.StringRepresentation;
269        parsedSubTree = varNode;
270      }
271
272      Debug.Assert(parsedSubTree != null);
273      return parsedSubTree;
274    }
275    #endregion
[15724]276
277    #region Parse to Infix string
278
279    public SymbolString PostfixToInfixParser(SymbolString phrase) {
280      Stack<Symbol> parseStack = new Stack<Symbol>(phrase);
281
282      return PostfixToInfixSubtreeParser(parseStack);
283    }
284
285    private SymbolString PostfixToInfixSubtreeParser(Stack<Symbol> parseStack) {
286      Symbol head = parseStack.Pop();
287
288      SymbolString result = new SymbolString();
289
290      if (ReferenceEquals(head, Addition) || ReferenceEquals(head, Multiplication)) {
[15725]291        // right part
292        SymbolString rightPart = PostfixToInfixSubtreeParser(parseStack);
293        SymbolString leftPart = PostfixToInfixSubtreeParser(parseStack);
294
295        result.AddRange(leftPart);
[15724]296        result.Add(head);
[15725]297        result.AddRange(rightPart);
[15724]298      } else {
299        result.Add(head);
300      }
301      return result;
302    }
303
304    #endregion
[15712]305  }
306}
Note: See TracBrowser for help on using the repository browser.