Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GPDL/CodeGenerator/ProblemCodeGen.cs @ 10393

Last change on this file since 10393 was 10393, checked in by gkronber, 10 years ago

#2026 fixed a bug in the brute force solver

File size: 16.3 KB
Line 
1using System.Collections.Generic;
2using System.Diagnostics;
3using System.IO;
4using System.Linq;
5using HeuristicLab.Grammars;
6using Attribute = HeuristicLab.Grammars.Attribute;
7
8namespace CodeGenerator {
9  // code generator for problem class
10  public class ProblemCodeGen {
11    private const string usings = @"
12using System.Collections.Generic;
13using System.Linq;
14using System;
15using System.Text.RegularExpressions;
16using System.Diagnostics;
17";
18
19    private const string problemTemplate = @"
20namespace ?PROBLEMNAME? {
21  public sealed class ?IDENT?Problem {
22   
23   public ?IDENT?Problem() {
24      Initialize();
25    }   
26
27    private void Initialize() {
28      // the following is the source code from the INIT section of the problem definition
29#region INIT section
30?INITSOURCE?
31#endregion
32    }
33
34    private Tree _t;
35    public double Evaluate(Tree _t) {
36      this._t = _t;
37#region objective function (MINIMIZE / MAXIMIZE section)
38?FITNESSFUNCTION?
39#endregion
40    }
41    public bool IsBetter(double a, double b) {
42      return ?MAXIMIZATION? ? a > b : a < b;
43    }
44
45// additional code from the problem definition (CODE section)
46#region additional code
47?ADDITIONALCODE?
48#endregion
49
50#region generated source for interpretation
51?INTERPRETERSOURCE?
52#endregion
53
54#region generated code for the constraints for terminals
55?CONSTRAINTSSOURCE?
56#endregion
57  }
58
59#region class definitions for tree
60  public class Tree {
61    public int altIdx;
62    public Tree[] subtrees;
63    protected Tree() {
64      // leave subtrees uninitialized
65    }
66    public Tree(int altIdx, Tree[] subtrees = null) {
67      this.altIdx = altIdx;
68      this.subtrees = subtrees;
69    }
70    public int GetSize() {
71      if(subtrees==null) return 1;
72      else return 1 + subtrees.Sum(t=>t.GetSize());
73    }
74    public int GetDepth() {
75      if(subtrees==null) return 1;
76      else return 1 + subtrees.Max(t=>t.GetDepth());
77    }
78    public void PrintTree(int curState) {
79      Console.Write(""{0} "", Grammar.symb[curState]);
80      if(subtrees != null) {
81        if(subtrees.Length==1) {
82          subtrees[0].PrintTree(Grammar.transition[curState][altIdx]);
83        } else {
84          for(int i=0;i<subtrees.Length;i++) {
85            subtrees[i].PrintTree(Grammar.transition[curState][i]);
86          }
87        }
88      }
89    }
90  }
91
92  ?TERMINALNODECLASSDEFINITIONS?
93#endregion
94
95#region helper class for the grammar representation
96  public class Grammar {
97    public static readonly Dictionary<int, int[]> transition = new Dictionary<int, int[]>() {
98?TRANSITIONTABLE?
99    };
100    public static readonly Dictionary<int, int> subtreeCount = new Dictionary<int, int>() {
101       { -1, 0 }, // terminals
102?SUBTREECOUNTTABLE?
103    };
104    public static readonly string[] symb = new string[] { ?SYMBOLNAMES? };
105   
106  } 
107#endregion
108}";
109
110
111    /// <summary>
112    /// Generates the source code for a brute force searcher that can be compiled with a C# compiler
113    /// </summary>
114    /// <param name="ast">An abstract syntax tree for a GPDL file</param>
115    public void Generate(GPDefNode ast) {
116      var problemSourceCode = new SourceBuilder();
117      problemSourceCode.AppendLine(usings);
118
119      GenerateProblemSource(ast, problemSourceCode);
120      GenerateSolvers(ast, problemSourceCode);
121
122      problemSourceCode
123        .Replace("?PROBLEMNAME?", ast.Name)
124        .Replace("?IDENT?", ast.Name);
125
126      // write the source file to disk
127      using (var stream = new StreamWriter(ast.Name + ".cs")) {
128        stream.WriteLine(problemSourceCode.ToString());
129      }
130    }
131
132    private void GenerateProblemSource(GPDefNode ast, SourceBuilder problemSourceCode) {
133      var grammar = CreateGrammarFromAst(ast);
134      problemSourceCode
135        .AppendLine(problemTemplate)
136        .Replace("?FITNESSFUNCTION?", ast.FitnessFunctionNode.SrcCode)
137        .Replace("?MAXIMIZATION?", ast.FitnessFunctionNode.Maximization.ToString().ToLowerInvariant())
138        .Replace("?INITSOURCE?", ast.InitCodeNode.SrcCode)
139        .Replace("?ADDITIONALCODE?", ast.ClassCodeNode.SrcCode)
140        .Replace("?INTERPRETERSOURCE?", GenerateInterpreterSource(grammar))
141        .Replace("?CONSTRAINTSSOURCE?", GenerateConstraintMethods(ast.Terminals))
142        .Replace("?TERMINALNODECLASSDEFINITIONS?", GenerateTerminalNodeClassDefinitions(ast.Terminals.OfType<TerminalNode>()))
143        .Replace("?SYMBOLNAMES?", grammar.Symbols.Select(s => s.Name).Aggregate(string.Empty, (str, symb) => str + "\"" + symb + "\", "))
144        .Replace("?TRANSITIONTABLE?", GenerateTransitionTable(grammar))
145        .Replace("?SUBTREECOUNTTABLE?", GenerateSubtreeCountTable(grammar))
146       ;
147    }
148
149    private void GenerateSolvers(GPDefNode ast, SourceBuilder solverSourceCode) {
150      var grammar = CreateGrammarFromAst(ast);
151      //var randomSearchCodeGen = new RandomSearchCodeGen();
152      //randomSearchCodeGen.Generate(grammar, ast.Terminals.OfType<TerminalNode>(), ast.FitnessFunctionNode.Maximization, solverSourceCode);
153      var bruteForceSearchCodeGen = new BruteForceCodeGen();
154      bruteForceSearchCodeGen.Generate(grammar, ast.Terminals.OfType<TerminalNode>(), ast.FitnessFunctionNode.Maximization, solverSourceCode);
155    }
156
157    #region create grammar instance from AST
158    // should be refactored so that we can directly query the AST
159    private AttributedGrammar CreateGrammarFromAst(GPDefNode ast) {
160
161      var nonTerminals = ast.NonTerminals
162        .Select(t => new Symbol(t.Ident, GetSymbolAttributes(t.FormalParameters)))
163        .ToArray();
164      var terminals = ast.Terminals
165        .Select(t => new Symbol(t.Ident, GetSymbolAttributes(t.FormalParameters)))
166        .ToArray();
167      string startSymbolName = ast.Rules.First().NtSymbol;
168
169      // create startSymbol
170      var startSymbol = nonTerminals.Single(s => s.Name == startSymbolName);
171      var g = new AttributedGrammar(startSymbol, nonTerminals, terminals);
172
173      // add all production rules
174      foreach (var rule in ast.Rules) {
175        var ntSymbol = nonTerminals.Single(s => s.Name == rule.NtSymbol);
176        foreach (var alt in GetAlternatives(rule.Alternatives, nonTerminals.Concat(terminals))) {
177          g.AddProductionRule(ntSymbol, alt);
178        }
179        // local initialization code
180        if (!string.IsNullOrEmpty(rule.LocalCode)) g.AddLocalDefinitions(ntSymbol, rule.LocalCode);
181      }
182      return g;
183    }
184
185    private IEnumerable<IAttribute> GetSymbolAttributes(string formalParameters) {
186      return (from fieldDef in Util.ExtractParameters(formalParameters)
187              select new Attribute(fieldDef.Identifier, fieldDef.Type, AttributeType.Parse(fieldDef.RefOrOut)))
188              .ToList();
189    }
190
191    private IEnumerable<Sequence> GetAlternatives(AlternativesNode altNode, IEnumerable<ISymbol> allSymbols) {
192      foreach (var alt in altNode.Alternatives) {
193        yield return GetSequence(alt.Sequence, allSymbols);
194      }
195    }
196
197    private Sequence GetSequence(IEnumerable<RuleExprNode> sequence, IEnumerable<ISymbol> allSymbols) {
198      Debug.Assert(sequence.All(s => s is CallSymbolNode || s is RuleActionNode));
199      var l = new List<ISymbol>();
200      foreach (var node in sequence) {
201        var callSymbolNode = node as CallSymbolNode;
202        var actionNode = node as RuleActionNode;
203        if (callSymbolNode != null) {
204          Debug.Assert(allSymbols.Any(s => s.Name == callSymbolNode.Ident));
205          // create a new symbol with actual parameters
206          l.Add(new Symbol(callSymbolNode.Ident, GetSymbolAttributes(callSymbolNode.ActualParameter)));
207        } else if (actionNode != null) {
208          l.Add(new SemanticSymbol("SEM", actionNode.SrcCode));
209        }
210      }
211      return new Sequence(l);
212    }
213    #endregion
214
215    #region helper methods for terminal symbols
216    // produces helper methods for the attributes of all terminal nodes
217    private string GenerateConstraintMethods(IEnumerable<SymbolNode> symbols) {
218      var sb = new SourceBuilder();
219      var terminals = symbols.OfType<TerminalNode>();
220      foreach (var t in terminals) {
221        GenerateConstraintMethods(t, sb);
222      }
223      return sb.ToString();
224    }
225
226
227    // generates helper methods for the attributes of a given terminal node
228    private void GenerateConstraintMethods(TerminalNode t, SourceBuilder sb) {
229      foreach (var c in t.Constraints) {
230        var fieldType = t.FieldDefinitions.First(d => d.Identifier == c.Ident).Type;
231        if (c.Type == ConstraintNodeType.Range) {
232          sb.AppendFormat("public {0} GetMax{1}_{2}() {{ return {3}; }}", fieldType, t.Ident, c.Ident, c.RangeMaxExpression).AppendLine();
233          sb.AppendFormat("public {0} GetMin{1}_{2}() {{ return {3}; }}", fieldType, t.Ident, c.Ident, c.RangeMinExpression).AppendLine();
234        } else if (c.Type == ConstraintNodeType.Set) {
235          sb.AppendFormat("public IEnumerable<{0}> GetAllowed{1}_{2}() {{ return {3}; }}", fieldType, t.Ident, c.Ident, c.SetExpression).AppendLine();
236        }
237      }
238    }
239    #endregion
240
241    private string GenerateTerminalNodeClassDefinitions(IEnumerable<TerminalNode> terminals) {
242      var sb = new SourceBuilder();
243
244      foreach (var terminal in terminals) {
245        GenerateTerminalNodeClassDefinitions(terminal, sb);
246      }
247      return sb.ToString();
248    }
249
250    private void GenerateTerminalNodeClassDefinitions(TerminalNode terminal, SourceBuilder sb) {
251      sb.AppendFormat("public class {0}Tree : Tree {{", terminal.Ident).BeginBlock();
252      foreach (var att in terminal.FieldDefinitions) {
253        sb.AppendFormat("public {0} {1};", att.Type, att.Identifier).AppendLine();
254      }
255      sb.AppendFormat(" public {0}Tree() : base() {{ }}", terminal.Ident).AppendLine();
256      sb.AppendLine("}");
257    }
258
259    private string GenerateInterpreterSource(AttributedGrammar grammar) {
260      var sb = new SourceBuilder();
261      GenerateInterpreterStart(grammar, sb);
262
263      // generate methods for all nonterminals and terminals using the grammar instance
264      foreach (var s in grammar.NonTerminalSymbols) {
265        GenerateInterpreterMethod(grammar, s, sb);
266      }
267      foreach (var s in grammar.TerminalSymbols) {
268        GenerateTerminalInterpreterMethod(s, sb);
269      }
270      return sb.ToString();
271    }
272
273    private void GenerateInterpreterStart(AttributedGrammar grammar, SourceBuilder sb) {
274      var s = grammar.StartSymbol;
275      // create the method which can be called from the fitness function
276      if (!s.Attributes.Any())
277        sb.AppendFormat("private void {0}() {{", s.Name).BeginBlock();
278      else
279        sb.AppendFormat("private void {0}({1}) {{", s.Name, s.GetAttributeString()).BeginBlock();
280
281      // get formal parameters of start symbol
282      var attr = s.Attributes;
283
284      // actual parameter are the same as formalparameter only without type identifier
285      string actualParameter;
286      if (attr.Any())
287        actualParameter = attr.Skip(1).Aggregate(attr.First().AttributeType + " " + attr.First().Name, (str, a) => str + ", " + a.AttributeType + " " + a.Name);
288      else
289        actualParameter = string.Empty;
290      sb.AppendFormat("{0}(_t, {1});", s.Name, actualParameter).AppendLine();
291      sb.AppendLine("}").EndBlock();
292    }
293
294    private void GenerateInterpreterMethod(AttributedGrammar g, ISymbol s, SourceBuilder sb) {
295      if (!s.Attributes.Any())
296        sb.AppendFormat("private void {0}(Tree _t) {{", s.Name).BeginBlock();
297      else
298        sb.AppendFormat("private void {0}(Tree _t, {1}) {{", s.Name, s.GetAttributeString()).BeginBlock();
299
300      // generate local definitions
301      sb.AppendLine(g.GetLocalDefinitions(s));
302
303      var altsWithSemActions = g.GetAlternativesWithSemanticActions(s).ToArray();
304
305      if (altsWithSemActions.Length > 1) {
306        GenerateSwitchStatement(altsWithSemActions, sb);
307      } else {
308        int i = 0;
309        foreach (var altSymb in altsWithSemActions.Single()) {
310          GenerateSourceForAction(i, altSymb, sb);
311          if (!(altSymb is SemanticSymbol)) i++;
312        }
313      }
314      sb.Append("}").EndBlock();
315    }
316
317    private void GenerateSwitchStatement(IEnumerable<Sequence> alts, SourceBuilder sb) {
318      sb.Append("switch(_t.altIdx) {").BeginBlock();
319      // generate a case for each alternative
320      int altIdx = 0;
321      foreach (var alt in alts) {
322        sb.AppendFormat("case {0}: {{ ", altIdx).BeginBlock();
323
324        // this only works for alternatives with a single non-terminal symbol (ignoring semantic symbols)!
325        // a way to handle this is through grammar transformation (the examplary grammars all have the correct from)
326        Debug.Assert(alt.Count(symb => !(symb is SemanticSymbol)) == 1);
327        foreach (var altSymb in alt) {
328          GenerateSourceForAction(0, altSymb, sb); // index is always 0 because of the assertion above
329        }
330        altIdx++;
331        sb.AppendLine("break;").Append("}").EndBlock();
332      }
333      sb.AppendLine("default: throw new System.InvalidOperationException();").Append("}").EndBlock();
334    }
335
336    // helper for generating calls to other symbol methods
337    private void GenerateSourceForAction(int idx, ISymbol s, SourceBuilder sb) {
338      var action = s as SemanticSymbol;
339      if (action != null)
340        sb.Append(action.Code + ";");
341      else if (!s.Attributes.Any())
342        sb.AppendFormat("{1}(_t.subtrees[{0}]);", idx, s.Name);
343      else sb.AppendFormat("{1}(_t.subtrees[{0}], {2}); ", idx, s.Name, s.GetAttributeString());
344      sb.AppendLine();
345    }
346
347    private void GenerateTerminalInterpreterMethod(ISymbol s, SourceBuilder sb) {
348      // if the terminal symbol has attributes then we must samples values for these attributes
349      if (!s.Attributes.Any())
350        sb.AppendFormat("private void {0}(Tree _t) {{", s.Name).BeginBlock();
351      else
352        sb.AppendFormat("private void {0}(Tree _t, {1}) {{", s.Name, s.GetAttributeString()).BeginBlock();
353
354      // each field must match a formal parameter, assign a value for each parameter
355      int i = 0;
356      foreach (var element in s.Attributes) {
357        sb.AppendFormat("{0} = (_t as {1}Tree).{0};", element.Name, s.Name).AppendLine();
358      }
359      sb.Append("}").EndBlock();
360    }
361
362
363
364    private string GenerateTransitionTable(IGrammar grammar) {
365      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
366      var sb = new SourceBuilder();
367
368      // state idx = idx of the corresponding symbol in the grammar
369      var allSymbols = grammar.Symbols.ToList();
370      foreach (var s in grammar.Symbols) {
371        var targetStates = new List<int>();
372        if (grammar.IsTerminal(s)) {
373        } else {
374          if (grammar.NumberOfAlternatives(s) > 1) {
375            foreach (var alt in grammar.GetAlternatives(s)) {
376              // only single-symbol alternatives are supported
377              Debug.Assert(alt.Count() == 1);
378              targetStates.Add(allSymbols.IndexOf(alt.Single()));
379            }
380          } else {
381            // rule is a sequence of symbols
382            var seq = grammar.GetAlternatives(s).Single();
383            targetStates.AddRange(seq.Select(symb => allSymbols.IndexOf(symb)));
384          }
385        }
386
387        var targetStateString = targetStates.Aggregate(string.Empty, (str, state) => str + state + ", ");
388
389        var idxOfSourceState = allSymbols.IndexOf(s);
390        sb.AppendFormat("// {0}", s).AppendLine();
391        sb.AppendFormat("{{ {0} , new int[] {{ {1} }} }},", idxOfSourceState, targetStateString).AppendLine();
392      }
393      return sb.ToString();
394    }
395    private string GenerateSubtreeCountTable(IGrammar grammar) {
396      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
397      var sb = new SourceBuilder();
398
399      // state idx = idx of the corresponding symbol in the grammar
400      var allSymbols = grammar.Symbols.ToList();
401      foreach (var s in grammar.Symbols) {
402        int subtreeCount = 0;
403        if (grammar.IsTerminal(s)) {
404        } else {
405          if (grammar.NumberOfAlternatives(s) > 1) {
406            Debug.Assert(grammar.GetAlternatives(s).All(alt => alt.Count() == 1));
407            subtreeCount = 1;
408          } else {
409            subtreeCount = grammar.GetAlternative(s, 0).Count();
410          }
411        }
412
413        sb.AppendFormat("// {0}", s).AppendLine();
414        sb.AppendFormat("{{ {0} , {1} }},", allSymbols.IndexOf(s), subtreeCount).AppendLine();
415      }
416
417      return sb.ToString();
418    }
419
420  }
421}
Note: See TracBrowser for help on using the repository browser.