source: branches/2886_SymRegGrammarEnumeration/HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration/GrammarEnumeration/Grammar.cs @ 15746

Last change on this file since 15746 was 15746, checked in by lkammere, 20 months ago

#2886: Refactor grammar enumeration alg.

File size: 8.7 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration.GrammarEnumeration;
6using HeuristicLab.Common;
7using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
8using HeuristicLab.Problems.DataAnalysis.Symbolic;
9
10namespace HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration {
11  public class Grammar {
12
13    public Symbol StartSymbol;
14
15    #region Symbols
16
17    public VariableSymbol Var;
18
19    public NonterminalSymbol Expr;
20    public NonterminalSymbol Term;
21    public NonterminalSymbol Factor;
22
23    public TerminalSymbol Addition;
24    public TerminalSymbol Multiplication;
25
26    #endregion
27
28
29    #region HL Symbols for Parsing ExpressionTrees
30
31    private TypeCoherentExpressionGrammar symbolicExpressionGrammar;
32
33    private ISymbol constSy;
34    private ISymbol varSy;
35
36    private ISymbol addSy;
37    private ISymbol mulSy;
38    private ISymbol logSy;
39    private ISymbol expSy;
40    private ISymbol divSy;
41
42    private ISymbol rootSy;
43    private ISymbol startSy;
44
45    #endregion
46
47    public Grammar(string[] variables) {
48      #region Define Symbols
49      Var = new VariableSymbol("var", variables);
50
51      Expr = new NonterminalSymbol("Expr");
52      Term = new NonterminalSymbol("Term");
53      Factor = new NonterminalSymbol("Factor");
54
55      Addition = new TerminalSymbol("+");
56      Multiplication = new TerminalSymbol("*");
57      #endregion
58
59
60      #region Production rules
61      // order of production is important, since they are accessed via index
62      // in memoization.
63      StartSymbol = Expr;
64
65      Expr.AddProduction(Term, Expr, Addition);
66      Expr.AddProduction(Term);
67
68      Term.AddProduction(Factor, Term, Multiplication);
69      Term.AddProduction(Factor);
70
71      Factor.AddProduction(Var);
72      #endregion
73
74      #region Parsing to SymbolicExpressionTree
75      symbolicExpressionGrammar = new TypeCoherentExpressionGrammar();
76      symbolicExpressionGrammar.ConfigureAsDefaultRegressionGrammar();
77
78      constSy = symbolicExpressionGrammar.Symbols.OfType<Constant>().First();
79      varSy = symbolicExpressionGrammar.Symbols.OfType<Variable>().First();
80      addSy = symbolicExpressionGrammar.AllowedSymbols.OfType<Addition>().First();
81      mulSy = symbolicExpressionGrammar.AllowedSymbols.OfType<Multiplication>().First();
82      logSy = symbolicExpressionGrammar.AllowedSymbols.OfType<Logarithm>().First();
83      expSy = symbolicExpressionGrammar.AllowedSymbols.OfType<Exponential>().First();
84      divSy = symbolicExpressionGrammar.AllowedSymbols.OfType<Division>().First();
85
86      rootSy = symbolicExpressionGrammar.AllowedSymbols.OfType<ProgramRootSymbol>().First();
87      startSy = symbolicExpressionGrammar.AllowedSymbols.OfType<StartSymbol>().First();
88
89      #endregion
90    }
91
92    #region Hashing
93    public int CalcHashCode(SymbolString sentence) {
94      Debug.Assert(sentence.Any(), "Trying to evaluate empty sentence!");
95      // Debug.Assert(sentence.All(s => s is TerminalSymbol), "Trying to evaluate symbol sequence with nonterminalsymbols!");
96
97      Stack<Symbol> parseStack = new Stack<Symbol>(sentence);
98
99      Symbol peek = parseStack.Peek();
100      int[] subtreeHashes = GetSubtreeHashes(parseStack);
101      return AggregateHashes(peek, subtreeHashes);
102    }
103
104    private int[] GetSubtreeHashes(Stack<Symbol> parseStack) {
105      Symbol currentSymbol = parseStack.Pop();
106
107      // VARIABLE
108      // if (Var.VariableTerminalSymbols.Contains(currentSymbol)) {
109      //   return currentSymbol.StringRepresentation.GetHashCode().ToEnumerable().ToArray();
110      // }
111
112      // MULTIPLICATION
113      if (ReferenceEquals(currentSymbol, Multiplication)) {
114        List<int> childHashes = new List<int>();
115
116        // First subtree
117        if (ReferenceEquals(parseStack.Peek(), Multiplication)) {
118          childHashes.AddRange(GetSubtreeHashes(parseStack));
119        } else {
120          childHashes.Add(AggregateHashes(parseStack.Peek(), GetSubtreeHashes(parseStack)));
121        }
122        // Second subtree
123        if (ReferenceEquals(parseStack.Peek(), Multiplication)) {
124          childHashes.AddRange(GetSubtreeHashes(parseStack));
125        } else {
126          childHashes.Add(AggregateHashes(parseStack.Peek(), GetSubtreeHashes(parseStack)));
127        }
128
129        // Sort due to commutativity
130        childHashes.Sort();
131        return childHashes.ToArray();
132      }
133
134      // ADDITION
135      if (ReferenceEquals(currentSymbol, Addition)) {
136        HashSet<int> uniqueChildHashes = new HashSet<int>();
137
138        // First subtree
139        if (ReferenceEquals(parseStack.Peek(), Addition)) {
140          uniqueChildHashes.UnionWith(GetSubtreeHashes(parseStack));
141        } else {
142          var peek = parseStack.Peek();
143          uniqueChildHashes.Add(AggregateHashes(peek, GetSubtreeHashes(parseStack)));
144        }
145        // Second subtree
146        if (ReferenceEquals(parseStack.Peek(), Addition)) {
147          uniqueChildHashes.UnionWith(GetSubtreeHashes(parseStack));
148        } else {
149          var peek = parseStack.Peek();
150          uniqueChildHashes.Add(AggregateHashes(peek, GetSubtreeHashes(parseStack)));
151        }
152
153        var result = uniqueChildHashes.ToList();
154        result.Sort();
155        return result.ToArray();
156      }
157
158      // var or nonterminal symbol
159      return currentSymbol.StringRepresentation.GetHashCode().ToEnumerable().ToArray();
160    }
161
162    private int AggregateHashes(Symbol operatorSym, IEnumerable<int> hashes) {
163      // If multiple subtrees are "merged" (e.g. added, multiplied, etc.), consider the executed operation
164      var hashesArray = hashes.ToArray();
165      int start = hashesArray.Length > 1 ? operatorSym.StringRepresentation.GetHashCode() : 0;
166      return hashesArray.Aggregate(start, (result, ti) => ((result << 5) + result) ^ ti.GetHashCode());
167    }
168    #endregion
169
170    #region Parse to SymbolicExpressionTree
171    public SymbolicExpressionTree ParseSymbolicExpressionTree(SymbolString sentence) {
172      Debug.Assert(sentence.Any(), "Trying to evaluate empty sentence!");
173      Debug.Assert(sentence.All(s => s is TerminalSymbol), "Trying to evaluate symbol sequence with nonterminalsymbols!");
174
175      symbolicExpressionGrammar.ConfigureAsDefaultRegressionGrammar();
176
177      var rootNode = rootSy.CreateTreeNode();
178      var startNode = startSy.CreateTreeNode();
179      rootNode.AddSubtree(startNode);
180
181      Stack<TerminalSymbol> parseStack = new Stack<TerminalSymbol>(sentence.OfType<TerminalSymbol>());
182      startNode.AddSubtree(ParseSymbolicExpressionTree(parseStack));
183
184      return new SymbolicExpressionTree(rootNode);
185    }
186
187    public ISymbolicExpressionTreeNode ParseSymbolicExpressionTree(Stack<TerminalSymbol> parseStack) {
188      TerminalSymbol currentSymbol = parseStack.Pop();
189
190      ISymbolicExpressionTreeNode parsedSubTree = null;
191
192      if (ReferenceEquals(currentSymbol, Addition)) {
193        parsedSubTree = addSy.CreateTreeNode();
194        parsedSubTree.AddSubtree(ParseSymbolicExpressionTree(parseStack)); // left part
195        parsedSubTree.AddSubtree(ParseSymbolicExpressionTree(parseStack)); // right part
196
197      } else if (ReferenceEquals(currentSymbol, Multiplication)) {
198        parsedSubTree = mulSy.CreateTreeNode();
199        parsedSubTree.AddSubtree(ParseSymbolicExpressionTree(parseStack)); // left part
200        parsedSubTree.AddSubtree(ParseSymbolicExpressionTree(parseStack)); // right part
201
202      } else if (Var.VariableTerminalSymbols.Contains(currentSymbol)) {
203        VariableTreeNode varNode = (VariableTreeNode)varSy.CreateTreeNode();
204        varNode.Weight = 1.0;
205        varNode.VariableName = currentSymbol.StringRepresentation;
206        parsedSubTree = varNode;
207      }
208
209      Debug.Assert(parsedSubTree != null);
210      return parsedSubTree;
211    }
212    #endregion
213
214    #region Parse to Infix string
215
216    public SymbolString PostfixToInfixParser(SymbolString phrase) {
217      Stack<Symbol> parseStack = new Stack<Symbol>(phrase);
218
219      return PostfixToInfixSubtreeParser(parseStack);
220    }
221
222    private SymbolString PostfixToInfixSubtreeParser(Stack<Symbol> parseStack) {
223      Symbol head = parseStack.Pop();
224
225      SymbolString result = new SymbolString();
226
227      if (ReferenceEquals(head, Addition) || ReferenceEquals(head, Multiplication)) {
228        // right part
229        SymbolString rightPart = PostfixToInfixSubtreeParser(parseStack);
230        SymbolString leftPart = PostfixToInfixSubtreeParser(parseStack);
231
232        result.AddRange(leftPart);
233        result.Add(head);
234        result.AddRange(rightPart);
235      } else {
236        result.Add(head);
237      }
238      return result;
239    }
240
241    #endregion
242  }
243}
Note: See TracBrowser for help on using the repository browser.