source: branches/HeuristicLab.Problems.GPDL/CodeGenerator/ProblemCodeGen.cs @ 10388

Last change on this file since 10388 was 10388, checked in by gkronber, 7 years ago

#2026 worked on code generator for brute force solver

File size: 15.7 KB
Line 
1using System.Collections.Generic;
2using System.Diagnostics;
3using System.IO;
4using System.Linq;
5using HeuristicLab.Grammars;
6using Attribute = HeuristicLab.Grammars.Attribute;
7
8namespace CodeGenerator {
9  // code generator for problem class
10  public class ProblemCodeGen {
11    private const string usings = @"
12using System.Collections.Generic;
13using System.Linq;
14using System;
15using System.Text.RegularExpressions;
16";
17
18    private const string problemTemplate = @"
19namespace ?PROBLEMNAME? {
20  public sealed class ?IDENT?Problem {
21   
22   public ?IDENT?Problem() {
23      Initialize();
24    }   
25
26    private void Initialize() {
27      // the following is the source code from the INIT section of the problem definition
28#region INIT section
29?INITSOURCE?
30#endregion
31    }
32
33    private Tree _t;
34    public double Evaluate(Tree _t) {
35      this._t = _t;
36#region objective function (MINIMIZE / MAXIMIZE section)
37?FITNESSFUNCTION?
38#endregion
39    }
40    public bool IsBetter(double a, double b) {
41      return ?MAXIMIZATION? ? a > b : a < b;
42    }
43
44// additional code from the problem definition (CODE section)
45#region additional code
46?ADDITIONALCODE?
47#endregion
48
49#region generated source for interpretation
50?INTERPRETERSOURCE?
51#endregion
52
53#region generated code for the constraints for terminals
54?CONSTRAINTSSOURCE?
55#endregion
56  }
57
58#region class definitions for tree
59  public class Tree {
60    public int altIdx;
61    public Tree[] subtrees;
62    protected Tree() {
63      // leave subtrees uninitialized
64    }
65    public Tree(int altIdx, Tree[] subtrees) {
66      this.altIdx = altIdx;
67      this.subtrees = subtrees;
68    }
69  }
70
71  ?TERMINALNODECLASSDEFINITIONS?
72#endregion
73
74#region helper class for the grammar representation
75  public class Grammar {
76    public static readonly Dictionary<int, int[]> transition = new Dictionary<int, int[]>() {
77?TRANSITIONTABLE?
78    };
79    public static readonly Dictionary<int, int> subtreeCount = new Dictionary<int, int>() {
80       { -1, 0 }, // terminals
81?SUBTREECOUNTTABLE?
82    };
83    public static readonly string[] symb = new string[] { ?SYMBOLNAMES? };
84   
85  } 
86#endregion
87}";
88
89
90    /// <summary>
91    /// Generates the source code for a brute force searcher that can be compiled with a C# compiler
92    /// </summary>
93    /// <param name="ast">An abstract syntax tree for a GPDL file</param>
94    public void Generate(GPDefNode ast) {
95      var problemSourceCode = new SourceBuilder();
96      problemSourceCode.AppendLine(usings);
97
98      GenerateProblemSource(ast, problemSourceCode);
99      GenerateSolvers(ast, problemSourceCode);
100
101      problemSourceCode
102        .Replace("?PROBLEMNAME?", ast.Name)
103        .Replace("?IDENT?", ast.Name);
104
105      // write the source file to disk
106      using (var stream = new StreamWriter(ast.Name + ".cs")) {
107        stream.WriteLine(problemSourceCode.ToString());
108      }
109    }
110
111    private void GenerateProblemSource(GPDefNode ast, SourceBuilder problemSourceCode) {
112      var grammar = CreateGrammarFromAst(ast);
113      problemSourceCode
114        .AppendLine(problemTemplate)
115        .Replace("?FITNESSFUNCTION?", ast.FitnessFunctionNode.SrcCode)
116        .Replace("?MAXIMIZATION?", ast.FitnessFunctionNode.Maximization.ToString().ToLowerInvariant())
117        .Replace("?INITSOURCE?", ast.InitCodeNode.SrcCode)
118        .Replace("?ADDITIONALCODE?", ast.ClassCodeNode.SrcCode)
119        .Replace("?INTERPRETERSOURCE?", GenerateInterpreterSource(grammar))
120        .Replace("?CONSTRAINTSSOURCE?", GenerateConstraintMethods(ast.Terminals))
121        .Replace("?TERMINALNODECLASSDEFINITIONS?", GenerateTerminalNodeClassDefinitions(ast.Terminals.OfType<TerminalNode>()))
122        .Replace("?SYMBOLNAMES?", grammar.Symbols.Select(s => s.Name).Aggregate(string.Empty, (str, symb) => str + "\"" + symb + "\", "))
123        .Replace("?TRANSITIONTABLE?", GenerateTransitionTable(grammar))
124        .Replace("?SUBTREECOUNTTABLE?", GenerateSubtreeCountTable(grammar))
125       ;
126    }
127
128    private void GenerateSolvers(GPDefNode ast, SourceBuilder solverSourceCode) {
129      var grammar = CreateGrammarFromAst(ast);
130      // var randomSearchCodeGen = new RandomSearchCodeGen();
131      // randomSearchCodeGen.Generate(grammar, ast.Terminals.OfType<TerminalNode>(), ast.FitnessFunctionNode.Maximization, solverSourceCode);
132      var bruteForceSearchCodeGen = new BruteForceCodeGen();
133      bruteForceSearchCodeGen.Generate(grammar, ast.Terminals.OfType<TerminalNode>(), ast.FitnessFunctionNode.Maximization, solverSourceCode);
134    }
135
136    #region create grammar instance from AST
137    // should be refactored so that we can directly query the AST
138    private AttributedGrammar CreateGrammarFromAst(GPDefNode ast) {
139
140      var nonTerminals = ast.NonTerminals
141        .Select(t => new Symbol(t.Ident, GetSymbolAttributes(t.FormalParameters)))
142        .ToArray();
143      var terminals = ast.Terminals
144        .Select(t => new Symbol(t.Ident, GetSymbolAttributes(t.FormalParameters)))
145        .ToArray();
146      string startSymbolName = ast.Rules.First().NtSymbol;
147
148      // create startSymbol
149      var startSymbol = nonTerminals.Single(s => s.Name == startSymbolName);
150      var g = new AttributedGrammar(startSymbol, nonTerminals, terminals);
151
152      // add all production rules
153      foreach (var rule in ast.Rules) {
154        var ntSymbol = nonTerminals.Single(s => s.Name == rule.NtSymbol);
155        foreach (var alt in GetAlternatives(rule.Alternatives, nonTerminals.Concat(terminals))) {
156          g.AddProductionRule(ntSymbol, alt);
157        }
158        // local initialization code
159        if (!string.IsNullOrEmpty(rule.LocalCode)) g.AddLocalDefinitions(ntSymbol, rule.LocalCode);
160      }
161      return g;
162    }
163
164    private IEnumerable<IAttribute> GetSymbolAttributes(string formalParameters) {
165      return (from fieldDef in Util.ExtractParameters(formalParameters)
166              select new Attribute(fieldDef.Identifier, fieldDef.Type, AttributeType.Parse(fieldDef.RefOrOut)))
167              .ToList();
168    }
169
170    private IEnumerable<Sequence> GetAlternatives(AlternativesNode altNode, IEnumerable<ISymbol> allSymbols) {
171      foreach (var alt in altNode.Alternatives) {
172        yield return GetSequence(alt.Sequence, allSymbols);
173      }
174    }
175
176    private Sequence GetSequence(IEnumerable<RuleExprNode> sequence, IEnumerable<ISymbol> allSymbols) {
177      Debug.Assert(sequence.All(s => s is CallSymbolNode || s is RuleActionNode));
178      var l = new List<ISymbol>();
179      foreach (var node in sequence) {
180        var callSymbolNode = node as CallSymbolNode;
181        var actionNode = node as RuleActionNode;
182        if (callSymbolNode != null) {
183          Debug.Assert(allSymbols.Any(s => s.Name == callSymbolNode.Ident));
184          // create a new symbol with actual parameters
185          l.Add(new Symbol(callSymbolNode.Ident, GetSymbolAttributes(callSymbolNode.ActualParameter)));
186        } else if (actionNode != null) {
187          l.Add(new SemanticSymbol("SEM", actionNode.SrcCode));
188        }
189      }
190      return new Sequence(l);
191    }
192    #endregion
193
194    #region helper methods for terminal symbols
195    // produces helper methods for the attributes of all terminal nodes
196    private string GenerateConstraintMethods(IEnumerable<SymbolNode> symbols) {
197      var sb = new SourceBuilder();
198      var terminals = symbols.OfType<TerminalNode>();
199      foreach (var t in terminals) {
200        GenerateConstraintMethods(t, sb);
201      }
202      return sb.ToString();
203    }
204
205
206    // generates helper methods for the attributes of a given terminal node
207    private void GenerateConstraintMethods(TerminalNode t, SourceBuilder sb) {
208      foreach (var c in t.Constraints) {
209        var fieldType = t.FieldDefinitions.First(d => d.Identifier == c.Ident).Type;
210        if (c.Type == ConstraintNodeType.Range) {
211          sb.AppendFormat("public {0} GetMax{1}_{2}() {{ return {3}; }}", fieldType, t.Ident, c.Ident, c.RangeMaxExpression).AppendLine();
212          sb.AppendFormat("public {0} GetMin{1}_{2}() {{ return {3}; }}", fieldType, t.Ident, c.Ident, c.RangeMinExpression).AppendLine();
213        } else if (c.Type == ConstraintNodeType.Set) {
214          sb.AppendFormat("public IEnumerable<{0}> GetAllowed{1}_{2}() {{ return {3}; }}", fieldType, t.Ident, c.Ident, c.SetExpression).AppendLine();
215        }
216      }
217    }
218    #endregion
219
220    private string GenerateTerminalNodeClassDefinitions(IEnumerable<TerminalNode> terminals) {
221      var sb = new SourceBuilder();
222
223      foreach (var terminal in terminals) {
224        GenerateTerminalNodeClassDefinitions(terminal, sb);
225      }
226      return sb.ToString();
227    }
228
229    private void GenerateTerminalNodeClassDefinitions(TerminalNode terminal, SourceBuilder sb) {
230      sb.AppendFormat("public class {0}Tree : Tree {{", terminal.Ident).BeginBlock();
231      foreach (var att in terminal.FieldDefinitions) {
232        sb.AppendFormat("public {0} {1};", att.Type, att.Identifier).AppendLine();
233      }
234      sb.AppendFormat(" public {0}Tree() : base() {{ }}", terminal.Ident).AppendLine();
235      sb.AppendLine("}");
236    }
237
238    private string GenerateInterpreterSource(AttributedGrammar grammar) {
239      var sb = new SourceBuilder();
240      GenerateInterpreterStart(grammar, sb);
241
242      // generate methods for all nonterminals and terminals using the grammar instance
243      foreach (var s in grammar.NonTerminalSymbols) {
244        GenerateInterpreterMethod(grammar, s, sb);
245      }
246      foreach (var s in grammar.TerminalSymbols) {
247        GenerateTerminalInterpreterMethod(s, sb);
248      }
249      return sb.ToString();
250    }
251
252    private void GenerateInterpreterStart(AttributedGrammar grammar, SourceBuilder sb) {
253      var s = grammar.StartSymbol;
254      // create the method which can be called from the fitness function
255      if (!s.Attributes.Any())
256        sb.AppendFormat("private void {0}() {{", s.Name).BeginBlock();
257      else
258        sb.AppendFormat("private void {0}({1}) {{", s.Name, s.GetAttributeString()).BeginBlock();
259
260      // get formal parameters of start symbol
261      var attr = s.Attributes;
262
263      // actual parameter are the same as formalparameter only without type identifier
264      string actualParameter;
265      if (attr.Any())
266        actualParameter = attr.Skip(1).Aggregate(attr.First().AttributeType + " " + attr.First().Name, (str, a) => str + ", " + a.AttributeType + " " + a.Name);
267      else
268        actualParameter = string.Empty;
269      sb.AppendFormat("{0}(_t, {1});", s.Name, actualParameter).AppendLine();
270      sb.AppendLine("}").EndBlock();
271    }
272
273    private void GenerateInterpreterMethod(AttributedGrammar g, ISymbol s, SourceBuilder sb) {
274      if (!s.Attributes.Any())
275        sb.AppendFormat("private void {0}(Tree _t) {{", s.Name).BeginBlock();
276      else
277        sb.AppendFormat("private void {0}(Tree _t, {1}) {{", s.Name, s.GetAttributeString()).BeginBlock();
278
279      // generate local definitions
280      sb.AppendLine(g.GetLocalDefinitions(s));
281
282      var altsWithSemActions = g.GetAlternativesWithSemanticActions(s).ToArray();
283
284      if (altsWithSemActions.Length > 1) {
285        GenerateSwitchStatement(altsWithSemActions, sb);
286      } else {
287        int i = 0;
288        foreach (var altSymb in altsWithSemActions.Single()) {
289          GenerateSourceForAction(i, altSymb, sb);
290          if (!(altSymb is SemanticSymbol)) i++;
291        }
292      }
293      sb.Append("}").EndBlock();
294    }
295
296    private void GenerateSwitchStatement(IEnumerable<Sequence> alts, SourceBuilder sb)
297    {
298      sb.Append("switch(_t.altIdx) {").BeginBlock();
299      // generate a case for each alternative
300      int altIdx = 0;
301      foreach (var alt in alts) {
302        sb.AppendFormat("case {0}: {{ ", altIdx).BeginBlock();
303
304        // this only works for alternatives with a single non-terminal symbol (ignoring semantic symbols)!
305        // a way to handle this is through grammar transformation (the examplary grammars all have the correct from)
306        Debug.Assert(alt.Count(symb => !(symb is SemanticSymbol)) == 1);
307        foreach (var altSymb in alt) {
308          GenerateSourceForAction(0, altSymb, sb); // index is always 0 because of the assertion above
309        }
310        altIdx++;
311        sb.AppendLine("break;").Append("}").EndBlock();
312      }
313      sb.AppendLine("default: throw new System.InvalidOperationException();").Append("}").EndBlock();
314    }
315
316    // helper for generating calls to other symbol methods
317    private void GenerateSourceForAction(int idx, ISymbol s, SourceBuilder sb) {
318      var action = s as SemanticSymbol;
319      if (action != null)
320        sb.Append(action.Code + ";");
321      else if (!s.Attributes.Any())
322        sb.AppendFormat("{1}(_t.subtrees[{0}]);", idx, s.Name);
323      else sb.AppendFormat("{1}(_t.subtrees[{0}], {2}); ", idx, s.Name, s.GetAttributeString());
324      sb.AppendLine();
325    }
326
327    private void GenerateTerminalInterpreterMethod(ISymbol s, SourceBuilder sb) {
328      // if the terminal symbol has attributes then we must samples values for these attributes
329      if (!s.Attributes.Any())
330        sb.AppendFormat("private void {0}(Tree _t) {{", s.Name).BeginBlock();
331      else
332        sb.AppendFormat("private void {0}(Tree _t, {1}) {{", s.Name, s.GetAttributeString()).BeginBlock();
333
334      // each field must match a formal parameter, assign a value for each parameter
335      int i = 0;
336      foreach (var element in s.Attributes) {
337        sb.AppendFormat("{0} = (_t as {1}Tree).{0};", element.Name, s.Name).AppendLine();
338      }
339      sb.Append("}").EndBlock();
340    }
341
342   
343
344    private string GenerateTransitionTable(IGrammar grammar) {
345      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
346      var sb = new SourceBuilder();
347
348      // state idx = idx of the corresponding symbol in the grammar
349      var allSymbols = grammar.Symbols.ToList();
350      foreach (var s in grammar.Symbols) {
351        var targetStates = new List<int>();
352        if (grammar.IsTerminal(s)) {
353        } else {
354          if (grammar.NumberOfAlternatives(s) > 1) {
355            foreach (var alt in grammar.GetAlternatives(s)) {
356              // only single-symbol alternatives are supported
357              Debug.Assert(alt.Count() == 1);
358              targetStates.Add(allSymbols.IndexOf(alt.Single()));
359            }
360          } else {
361            // rule is a sequence of symbols
362            var seq = grammar.GetAlternatives(s).Single();
363            targetStates.AddRange(seq.Select(symb => allSymbols.IndexOf(symb)));
364          }
365        }
366
367        var targetStateString = targetStates.Aggregate(string.Empty, (str, state) => str + state + ", ");
368
369        var idxOfSourceState = allSymbols.IndexOf(s);
370        sb.AppendFormat("// {0}", s).AppendLine();
371        sb.AppendFormat("{{ {0} , new int[] {{ {1} }} }},", idxOfSourceState, targetStateString).AppendLine();
372      }
373      return sb.ToString();
374    }
375    private string GenerateSubtreeCountTable(IGrammar grammar) {
376      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
377      var sb = new SourceBuilder();
378
379      // state idx = idx of the corresponding symbol in the grammar
380      var allSymbols = grammar.Symbols.ToList();
381      foreach (var s in grammar.Symbols) {
382        int subtreeCount = 0;
383        if (grammar.IsTerminal(s)) {
384        } else {
385          if (grammar.NumberOfAlternatives(s) > 1) {
386            Debug.Assert(grammar.GetAlternatives(s).All(alt => alt.Count() == 1));
387            subtreeCount = 1;
388          } else {
389            subtreeCount = grammar.GetAlternative(s, 0).Count();
390          }
391        }
392
393        sb.AppendFormat("// {0}", s).AppendLine();
394        sb.AppendFormat("{{ {0} , {1} }},", allSymbols.IndexOf(s), subtreeCount).AppendLine();
395      }
396
397      return sb.ToString();
398    }
399
400  }
401}
Note: See TracBrowser for help on using the repository browser.