Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GPDL/CodeGenerator/ProblemCodeGen.cs @ 10424

Last change on this file since 10424 was 10424, checked in by gkronber, 10 years ago

#2026 maximal depth limit for random search

File size: 18.3 KB
Line 
1using System.Collections.Generic;
2using System.Diagnostics;
3using System.IO;
4using System.Linq;
5using HeuristicLab.Grammars;
6using Attribute = HeuristicLab.Grammars.Attribute;
7
8namespace CodeGenerator {
9  // code generator for problem class
10  public class ProblemCodeGen {
11    private const string usings = @"
12using System.Collections.Generic;
13using System.Linq;
14using System;
15using System.Text.RegularExpressions;
16using System.Diagnostics;
17";
18
19    private const string problemTemplate = @"
20namespace ?PROBLEMNAME? {
21  public sealed class ?IDENT?Problem {
22   
23   public ?IDENT?Problem() {
24      Initialize();
25    }   
26
27    private void Initialize() {
28      // the following is the source code from the INIT section of the problem definition
29#region INIT section
30?INITSOURCE?
31#endregion
32    }
33
34    private Tree _t;
35    public double Evaluate(Tree _t) {
36      this._t = _t;
37#region objective function (MINIMIZE / MAXIMIZE section)
38?FITNESSFUNCTION?
39#endregion
40    }
41    public bool IsBetter(double a, double b) {
42      return ?MAXIMIZATION? ? a > b : a < b;
43    }
44
45// additional code from the problem definition (CODE section)
46#region additional code
47?ADDITIONALCODE?
48#endregion
49
50#region generated source for interpretation
51?INTERPRETERSOURCE?
52#endregion
53
54#region generated code for the constraints for terminals
55?CONSTRAINTSSOURCE?
56#endregion
57  }
58
59#region class definitions for tree
60  public class Tree {
61    public int altIdx;
62    public Tree[] subtrees;
63    protected Tree() {
64      // leave subtrees uninitialized
65    }
66    public Tree(int altIdx, Tree[] subtrees = null) {
67      this.altIdx = altIdx;
68      this.subtrees = subtrees;
69    }
70    public int GetSize() {
71      if(subtrees==null) return 1;
72      else return 1 + subtrees.Sum(t=>t.GetSize());
73    }
74    public int GetDepth() {
75      if(subtrees==null) return 1;
76      else return 1 + subtrees.Max(t=>t.GetDepth());
77    }
78    public virtual void PrintTree(int curState) {
79      Console.Write(""{0} "", Grammar.symb[curState]);
80      if(subtrees != null) {
81        if(subtrees.Length==1) {
82          subtrees[0].PrintTree(Grammar.transition[curState][altIdx]);
83        } else {
84          for(int i=0;i<subtrees.Length;i++) {
85            subtrees[i].PrintTree(Grammar.transition[curState][i]);
86          }
87        }
88      }
89    }
90  }
91
92  ?TERMINALNODECLASSDEFINITIONS?
93#endregion
94
95#region helper class for the grammar representation
96  public class Grammar {
97    public static readonly Dictionary<int, int[]> transition = new Dictionary<int, int[]>() {
98?TRANSITIONTABLE?
99    };
100    public static readonly Dictionary<int, int> subtreeCount = new Dictionary<int, int>() {
101       { -1, 0 }, // terminals
102?SUBTREECOUNTTABLE?
103    };
104    public static readonly Dictionary<int, int> minDepth = new Dictionary<int, int>() {
105?MINDEPTHTABLE?
106    };
107    public static readonly string[] symb = new string[] { ?SYMBOLNAMES? };
108   
109  } 
110#endregion
111}";
112
113
114    /// <summary>
115    /// Generates the source code for a brute force searcher that can be compiled with a C# compiler
116    /// </summary>
117    /// <param name="ast">An abstract syntax tree for a GPDL file</param>
118    public void Generate(GPDefNode ast) {
119      var problemSourceCode = new SourceBuilder();
120      problemSourceCode.AppendLine(usings);
121
122      GenerateProblemSource(ast, problemSourceCode);
123      GenerateSolvers(ast, problemSourceCode);
124
125      problemSourceCode
126        .Replace("?PROBLEMNAME?", ast.Name)
127        .Replace("?IDENT?", ast.Name);
128
129      // write the source file to disk
130      using (var stream = new StreamWriter(ast.Name + ".cs")) {
131        stream.WriteLine(problemSourceCode.ToString());
132      }
133    }
134
135    private void GenerateProblemSource(GPDefNode ast, SourceBuilder problemSourceCode) {
136      var grammar = CreateGrammarFromAst(ast);
137      problemSourceCode
138        .AppendLine(problemTemplate)
139        .Replace("?FITNESSFUNCTION?", ast.FitnessFunctionNode.SrcCode)
140        .Replace("?MAXIMIZATION?", ast.FitnessFunctionNode.Maximization.ToString().ToLowerInvariant())
141        .Replace("?INITSOURCE?", ast.InitCodeNode.SrcCode)
142        .Replace("?ADDITIONALCODE?", ast.ClassCodeNode.SrcCode)
143        .Replace("?INTERPRETERSOURCE?", GenerateInterpreterSource(grammar))
144        .Replace("?CONSTRAINTSSOURCE?", GenerateConstraintMethods(ast.Terminals))
145        .Replace("?TERMINALNODECLASSDEFINITIONS?", GenerateTerminalNodeClassDefinitions(ast.Terminals.OfType<TerminalNode>()))
146        .Replace("?SYMBOLNAMES?", grammar.Symbols.Select(s => s.Name).Aggregate(string.Empty, (str, symb) => str + "\"" + symb + "\", "))
147        .Replace("?TRANSITIONTABLE?", GenerateTransitionTable(grammar))
148        .Replace("?SUBTREECOUNTTABLE?", GenerateSubtreeCountTable(grammar))
149        .Replace("?MINDEPTHTABLE?", GenerateMinDepthTable(grammar))
150       ;
151    }
152
153    private void GenerateSolvers(GPDefNode ast, SourceBuilder solverSourceCode) {
154      var grammar = CreateGrammarFromAst(ast);
155      var randomSearchCodeGen = new RandomSearchCodeGen();
156      randomSearchCodeGen.Generate(grammar, ast.Terminals.OfType<TerminalNode>(), ast.FitnessFunctionNode.Maximization, solverSourceCode);
157      // var bruteForceSearchCodeGen = new BruteForceCodeGen();
158      // bruteForceSearchCodeGen.Generate(grammar, ast.Terminals.OfType<TerminalNode>(), ast.FitnessFunctionNode.Maximization, solverSourceCode);
159      // var mctsCodeGen = new MonteCarloTreeSearchCodeGen();
160      // mctsCodeGen.Generate(grammar, ast.Terminals.OfType<TerminalNode>(), ast.FitnessFunctionNode.Maximization, solverSourceCode);
161    }
162
163    #region create grammar instance from AST
164    // should be refactored so that we can directly query the AST
165    private AttributedGrammar CreateGrammarFromAst(GPDefNode ast) {
166
167      var nonTerminals = ast.NonTerminals
168        .Select(t => new Symbol(t.Ident, GetSymbolAttributes(t.FormalParameters)))
169        .ToArray();
170      var terminals = ast.Terminals
171        .Select(t => new Symbol(t.Ident, GetSymbolAttributes(t.FormalParameters)))
172        .ToArray();
173      string startSymbolName = ast.Rules.First().NtSymbol;
174
175      // create startSymbol
176      var startSymbol = nonTerminals.Single(s => s.Name == startSymbolName);
177      var g = new AttributedGrammar(startSymbol, nonTerminals, terminals);
178
179      // add all production rules
180      foreach (var rule in ast.Rules) {
181        var ntSymbol = nonTerminals.Single(s => s.Name == rule.NtSymbol);
182        foreach (var alt in GetAlternatives(rule.Alternatives, nonTerminals.Concat(terminals))) {
183          g.AddProductionRule(ntSymbol, alt);
184        }
185        // local initialization code
186        if (!string.IsNullOrEmpty(rule.LocalCode)) g.AddLocalDefinitions(ntSymbol, rule.LocalCode);
187      }
188      return g;
189    }
190
191    private IEnumerable<IAttribute> GetSymbolAttributes(string formalParameters) {
192      return (from fieldDef in Util.ExtractParameters(formalParameters)
193              select new Attribute(fieldDef.Identifier, fieldDef.Type, AttributeType.Parse(fieldDef.RefOrOut)))
194              .ToList();
195    }
196
197    private IEnumerable<Sequence> GetAlternatives(AlternativesNode altNode, IEnumerable<ISymbol> allSymbols) {
198      foreach (var alt in altNode.Alternatives) {
199        yield return GetSequence(alt.Sequence, allSymbols);
200      }
201    }
202
203    private Sequence GetSequence(IEnumerable<RuleExprNode> sequence, IEnumerable<ISymbol> allSymbols) {
204      Debug.Assert(sequence.All(s => s is CallSymbolNode || s is RuleActionNode));
205      var l = new List<ISymbol>();
206      foreach (var node in sequence) {
207        var callSymbolNode = node as CallSymbolNode;
208        var actionNode = node as RuleActionNode;
209        if (callSymbolNode != null) {
210          Debug.Assert(allSymbols.Any(s => s.Name == callSymbolNode.Ident));
211          // create a new symbol with actual parameters
212          l.Add(new Symbol(callSymbolNode.Ident, GetSymbolAttributes(callSymbolNode.ActualParameter)));
213        } else if (actionNode != null) {
214          l.Add(new SemanticSymbol("SEM", actionNode.SrcCode));
215        }
216      }
217      return new Sequence(l);
218    }
219    #endregion
220
221    #region helper methods for terminal symbols
222    // produces helper methods for the attributes of all terminal nodes
223    private string GenerateConstraintMethods(IEnumerable<SymbolNode> symbols) {
224      var sb = new SourceBuilder();
225      var terminals = symbols.OfType<TerminalNode>();
226      foreach (var t in terminals) {
227        GenerateConstraintMethods(t, sb);
228      }
229      return sb.ToString();
230    }
231
232
233    // generates helper methods for the attributes of a given terminal node
234    private void GenerateConstraintMethods(TerminalNode t, SourceBuilder sb) {
235      foreach (var c in t.Constraints) {
236        var fieldType = t.FieldDefinitions.First(d => d.Identifier == c.Ident).Type;
237        if (c.Type == ConstraintNodeType.Range) {
238          sb.AppendFormat("public {0} GetMax{1}_{2}() {{ return {3}; }}", fieldType, t.Ident, c.Ident, c.RangeMaxExpression).AppendLine();
239          sb.AppendFormat("public {0} GetMin{1}_{2}() {{ return {3}; }}", fieldType, t.Ident, c.Ident, c.RangeMinExpression).AppendLine();
240        } else if (c.Type == ConstraintNodeType.Set) {
241          sb.AppendFormat("public IEnumerable<{0}> GetAllowed{1}_{2}() {{ return {3}; }}", fieldType, t.Ident, c.Ident, c.SetExpression).AppendLine();
242        }
243      }
244    }
245    #endregion
246
247    private string GenerateTerminalNodeClassDefinitions(IEnumerable<TerminalNode> terminals) {
248      var sb = new SourceBuilder();
249
250      foreach (var terminal in terminals) {
251        GenerateTerminalNodeClassDefinitions(terminal, sb);
252      }
253      return sb.ToString();
254    }
255
256    private void GenerateTerminalNodeClassDefinitions(TerminalNode terminal, SourceBuilder sb) {
257      sb.AppendFormat("public class {0}Tree : Tree {{", terminal.Ident).BeginBlock();
258      foreach (var att in terminal.FieldDefinitions) {
259        sb.AppendFormat("public {0} {1};", att.Type, att.Identifier).AppendLine();
260      }
261      sb.AppendFormat(" public {0}Tree() : base() {{ }}", terminal.Ident).AppendLine();
262      sb.AppendLine(@"
263          public override void PrintTree(int curState) {
264            Console.Write(""{0} "", Grammar.symb[curState]);");
265      foreach (var att in terminal.FieldDefinitions) {
266        sb.AppendFormat("Console.Write(\"{{0}} \", {0});", att.Identifier).AppendLine();
267      }
268
269      sb.AppendLine("}");
270
271      sb.AppendLine("}");
272    }
273
274    private string GenerateInterpreterSource(AttributedGrammar grammar) {
275      var sb = new SourceBuilder();
276      GenerateInterpreterStart(grammar, sb);
277
278      // generate methods for all nonterminals and terminals using the grammar instance
279      foreach (var s in grammar.NonTerminalSymbols) {
280        GenerateInterpreterMethod(grammar, s, sb);
281      }
282      foreach (var s in grammar.TerminalSymbols) {
283        GenerateTerminalInterpreterMethod(s, sb);
284      }
285      return sb.ToString();
286    }
287
288    private void GenerateInterpreterStart(AttributedGrammar grammar, SourceBuilder sb) {
289      var s = grammar.StartSymbol;
290      // create the method which can be called from the fitness function
291      if (!s.Attributes.Any())
292        sb.AppendFormat("private void {0}() {{", s.Name).BeginBlock();
293      else
294        sb.AppendFormat("private void {0}({1}) {{", s.Name, s.GetAttributeString()).BeginBlock();
295
296      // get formal parameters of start symbol
297      var attr = s.Attributes;
298
299      // actual parameter are the same as formalparameter only without type identifier
300      string actualParameter;
301      if (attr.Any())
302        actualParameter = attr.Skip(1).Aggregate(attr.First().AttributeType + " " + attr.First().Name, (str, a) => str + ", " + a.AttributeType + " " + a.Name);
303      else
304        actualParameter = string.Empty;
305      sb.AppendFormat("{0}(_t, {1});", s.Name, actualParameter).AppendLine();
306      sb.AppendLine("}").EndBlock();
307    }
308
309    private void GenerateInterpreterMethod(AttributedGrammar g, ISymbol s, SourceBuilder sb) {
310      if (!s.Attributes.Any())
311        sb.AppendFormat("private void {0}(Tree _t) {{", s.Name).BeginBlock();
312      else
313        sb.AppendFormat("private void {0}(Tree _t, {1}) {{", s.Name, s.GetAttributeString()).BeginBlock();
314
315      // generate local definitions
316      sb.AppendLine(g.GetLocalDefinitions(s));
317
318      var altsWithSemActions = g.GetAlternativesWithSemanticActions(s).ToArray();
319
320      if (altsWithSemActions.Length > 1) {
321        GenerateSwitchStatement(altsWithSemActions, sb);
322      } else {
323        int i = 0;
324        foreach (var altSymb in altsWithSemActions.Single()) {
325          GenerateSourceForAction(i, altSymb, sb);
326          if (!(altSymb is SemanticSymbol)) i++;
327        }
328      }
329      sb.Append("}").EndBlock();
330    }
331
332    private void GenerateSwitchStatement(IEnumerable<Sequence> alts, SourceBuilder sb) {
333      sb.Append("switch(_t.altIdx) {").BeginBlock();
334      // generate a case for each alternative
335      int altIdx = 0;
336      foreach (var alt in alts) {
337        sb.AppendFormat("case {0}: {{ ", altIdx).BeginBlock();
338
339        // this only works for alternatives with a single non-terminal symbol (ignoring semantic symbols)!
340        // a way to handle this is through grammar transformation (the examplary grammars all have the correct from)
341        Debug.Assert(alt.Count(symb => !(symb is SemanticSymbol)) == 1);
342        foreach (var altSymb in alt) {
343          GenerateSourceForAction(0, altSymb, sb); // index is always 0 because of the assertion above
344        }
345        altIdx++;
346        sb.AppendLine("break;").Append("}").EndBlock();
347      }
348      sb.AppendLine("default: throw new System.InvalidOperationException();").Append("}").EndBlock();
349    }
350
351    // helper for generating calls to other symbol methods
352    private void GenerateSourceForAction(int idx, ISymbol s, SourceBuilder sb) {
353      var action = s as SemanticSymbol;
354      if (action != null)
355        sb.Append(action.Code + ";");
356      else if (!s.Attributes.Any())
357        sb.AppendFormat("{1}(_t.subtrees[{0}]);", idx, s.Name);
358      else sb.AppendFormat("{1}(_t.subtrees[{0}], {2}); ", idx, s.Name, s.GetAttributeString());
359      sb.AppendLine();
360    }
361
362    private void GenerateTerminalInterpreterMethod(ISymbol s, SourceBuilder sb) {
363      // if the terminal symbol has attributes then we must samples values for these attributes
364      if (!s.Attributes.Any())
365        sb.AppendFormat("private void {0}(Tree _t) {{", s.Name).BeginBlock();
366      else
367        sb.AppendFormat("private void {0}(Tree _t, {1}) {{", s.Name, s.GetAttributeString()).BeginBlock();
368
369      // each field must match a formal parameter, assign a value for each parameter
370      int i = 0;
371      foreach (var element in s.Attributes) {
372        sb.AppendFormat("{0} = (_t as {1}Tree).{0};", element.Name, s.Name).AppendLine();
373      }
374      sb.Append("}").EndBlock();
375    }
376
377
378
379    private string GenerateTransitionTable(IGrammar grammar) {
380      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
381      var sb = new SourceBuilder();
382
383      // state idx = idx of the corresponding symbol in the grammar
384      var allSymbols = grammar.Symbols.ToList();
385      foreach (var s in grammar.Symbols) {
386        var targetStates = new List<int>();
387        if (grammar.IsTerminal(s)) {
388        } else {
389          if (grammar.NumberOfAlternatives(s) > 1) {
390            foreach (var alt in grammar.GetAlternatives(s)) {
391              // only single-symbol alternatives are supported
392              Debug.Assert(alt.Count() == 1);
393              targetStates.Add(allSymbols.IndexOf(alt.Single()));
394            }
395          } else {
396            // rule is a sequence of symbols
397            var seq = grammar.GetAlternatives(s).Single();
398            targetStates.AddRange(seq.Select(symb => allSymbols.IndexOf(symb)));
399          }
400        }
401
402        var targetStateString = targetStates.Aggregate(string.Empty, (str, state) => str + state + ", ");
403
404        var idxOfSourceState = allSymbols.IndexOf(s);
405        sb.AppendFormat("// {0}", s).AppendLine();
406        sb.AppendFormat("{{ {0} , new int[] {{ {1} }} }},", idxOfSourceState, targetStateString).AppendLine();
407      }
408      return sb.ToString();
409    }
410    private string GenerateSubtreeCountTable(IGrammar grammar) {
411      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
412      var sb = new SourceBuilder();
413
414      // state idx = idx of the corresponding symbol in the grammar
415      var allSymbols = grammar.Symbols.ToList();
416      foreach (var s in grammar.Symbols) {
417        int subtreeCount = 0;
418        if (grammar.IsTerminal(s)) {
419        } else {
420          if (grammar.NumberOfAlternatives(s) > 1) {
421            Debug.Assert(grammar.GetAlternatives(s).All(alt => alt.Count() == 1));
422            subtreeCount = 1;
423          } else {
424            subtreeCount = grammar.GetAlternative(s, 0).Count();
425          }
426        }
427
428        sb.AppendFormat("// {0}", s).AppendLine();
429        sb.AppendFormat("{{ {0} , {1} }},", allSymbols.IndexOf(s), subtreeCount).AppendLine();
430      }
431
432      return sb.ToString();
433    }
434    private string GenerateMinDepthTable(IGrammar grammar) {
435      var sb = new SourceBuilder();
436      var minDepth = new Dictionary<ISymbol, int>();
437      foreach (var s in grammar.TerminalSymbols) {
438        minDepth[s] = 1;
439      }
440      while (minDepth.Count < grammar.Symbols.Count()) {
441        foreach (var s in grammar.NonTerminalSymbols) {
442          if (grammar.NumberOfAlternatives(s) > 1) {
443            // alternatives
444            if (grammar.GetAlternatives(s).Any(alt => minDepth.ContainsKey(alt.Single()))) {
445              minDepth[s] = (grammar.GetAlternatives(s)
446                .Where(alt => minDepth.ContainsKey(alt.Single()))
447                .Select(alt => minDepth[alt.Single()]))
448                .Min() + 1;
449            }
450          } else {
451            // sequences
452            if (grammar.GetAlternatives(s).Single().All(c => minDepth.ContainsKey(c))) {
453              minDepth[s] = (grammar.GetAlternatives(s).Single().Select(c => minDepth[c])).Max() + 1;
454            }
455          }
456        }
457      }
458      var allSymbols = grammar.Symbols.ToList();
459      foreach (var s in grammar.Symbols) {
460        sb.AppendFormat("// {0}", s).AppendLine();
461        sb.AppendFormat("{{ {0}, {1} }}, ", allSymbols.IndexOf(s), minDepth[s]).AppendLine();
462      }
463      return sb.ToString();
464    }
465  }
466}
Note: See TracBrowser for help on using the repository browser.