Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GPDL/CodeGenerator/ProblemCodeGen.cs

Last change on this file was 10427, checked in by gkronber, 10 years ago

#2026 integrated max depth into MCTS solver

File size: 18.9 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.IO;
5using System.Linq;
6using HeuristicLab.Grammars;
7using Attribute = HeuristicLab.Grammars.Attribute;
8
9namespace CodeGenerator {
10  // code generator for problem class
11  public class ProblemCodeGen {
12    private const string usings = @"
13using System.Collections.Generic;
14using System.Linq;
15using System;
16using System.Text.RegularExpressions;
17using System.Diagnostics;
18";
19
20    private const string problemTemplate = @"
21namespace ?PROBLEMNAME? {
22
23
24
25
26  public sealed class ?IDENT?Problem {
27    public static void Main(string[] args) {
28      var problem = new ?IDENT?Problem();
29      var solver = new ?IDENT?MonteCarloTreeSearchSolver(problem, args);
30      solver.Start();
31    }
32   
33   public ?IDENT?Problem() {
34      Initialize();
35    }   
36
37    private void Initialize() {
38      // the following is the source code from the INIT section of the problem definition
39#region INIT section
40?INITSOURCE?
41#endregion
42    }
43
44    private Tree _t;
45    public double Evaluate(Tree _t) {
46      this._t = _t;
47#region objective function (MINIMIZE / MAXIMIZE section)
48?FITNESSFUNCTION?
49#endregion
50    }
51    public bool IsBetter(double a, double b) {
52      return ?MAXIMIZATION? ? a > b : a < b;
53    }
54
55// additional code from the problem definition (CODE section)
56#region additional code
57?ADDITIONALCODE?
58#endregion
59
60#region generated source for interpretation
61?INTERPRETERSOURCE?
62#endregion
63
64#region generated code for the constraints for terminals
65?CONSTRAINTSSOURCE?
66#endregion
67  }
68
69#region class definitions for tree
70  public class Tree {
71    public int altIdx;
72    public Tree[] subtrees;
73    protected Tree() {
74      // leave subtrees uninitialized
75    }
76    public Tree(int altIdx, Tree[] subtrees = null) {
77      this.altIdx = altIdx;
78      this.subtrees = subtrees;
79    }
80    public int GetSize() {
81      if(subtrees==null) return 1;
82      else return 1 + subtrees.Sum(t=>t.GetSize());
83    }
84    public int GetDepth() {
85      if(subtrees==null) return 1;
86      else return 1 + subtrees.Max(t=>t.GetDepth());
87    }
88    public virtual void PrintTree(int curState) {
89      Console.Write(""{0} "", Grammar.symb[curState]);
90      if(subtrees != null) {
91        if(subtrees.Length==1) {
92          subtrees[0].PrintTree(Grammar.transition[curState][altIdx]);
93        } else {
94          for(int i=0;i<subtrees.Length;i++) {
95            subtrees[i].PrintTree(Grammar.transition[curState][i]);
96          }
97        }
98      }
99    }
100  }
101
102  ?TERMINALNODECLASSDEFINITIONS?
103#endregion
104
105#region helper class for the grammar representation
106  public class Grammar {
107    public static readonly Dictionary<int, int[]> transition = new Dictionary<int, int[]>() {
108?TRANSITIONTABLE?
109    };
110    public static readonly Dictionary<int, int> subtreeCount = new Dictionary<int, int>() {
111       { -1, 0 }, // terminals
112?SUBTREECOUNTTABLE?
113    };
114    public static readonly Dictionary<int, int> minDepth = new Dictionary<int, int>() {
115?MINDEPTHTABLE?
116    };
117    public static readonly string[] symb = new string[] { ?SYMBOLNAMES? };
118   
119  } 
120#endregion
121}";
122
123
124    /// <summary>
125    /// Generates the source code for a brute force searcher that can be compiled with a C# compiler
126    /// </summary>
127    /// <param name="ast">An abstract syntax tree for a GPDL file</param>
128    public void Generate(GPDefNode ast) {
129      var problemSourceCode = new SourceBuilder();
130      problemSourceCode.AppendLine(usings);
131
132      GenerateProblemSource(ast, problemSourceCode);
133      GenerateSolvers(ast, problemSourceCode);
134
135      problemSourceCode
136        .Replace("?PROBLEMNAME?", ast.Name)
137        .Replace("?IDENT?", ast.Name);
138
139      // write the source file to disk
140      using (var stream = new StreamWriter(ast.Name + ".cs")) {
141        stream.WriteLine(problemSourceCode.ToString());
142      }
143    }
144
145    private void GenerateProblemSource(GPDefNode ast, SourceBuilder problemSourceCode) {
146      var grammar = CreateGrammarFromAst(ast);
147      problemSourceCode
148        .AppendLine(problemTemplate)
149        .Replace("?FITNESSFUNCTION?", ast.FitnessFunctionNode.SrcCode)
150        .Replace("?MAXIMIZATION?", ast.FitnessFunctionNode.Maximization.ToString().ToLowerInvariant())
151        .Replace("?INITSOURCE?", ast.InitCodeNode.SrcCode)
152        .Replace("?ADDITIONALCODE?", ast.ClassCodeNode.SrcCode)
153        .Replace("?INTERPRETERSOURCE?", GenerateInterpreterSource(grammar))
154        .Replace("?CONSTRAINTSSOURCE?", GenerateConstraintMethods(ast.Terminals))
155        .Replace("?TERMINALNODECLASSDEFINITIONS?", GenerateTerminalNodeClassDefinitions(ast.Terminals.OfType<TerminalNode>()))
156        .Replace("?SYMBOLNAMES?", grammar.Symbols.Select(s => s.Name).Aggregate(string.Empty, (str, symb) => str + "\"" + symb + "\", "))
157        .Replace("?TRANSITIONTABLE?", GenerateTransitionTable(grammar))
158        .Replace("?SUBTREECOUNTTABLE?", GenerateSubtreeCountTable(grammar))
159        .Replace("?MINDEPTHTABLE?", GenerateMinDepthTable(grammar))
160       ;
161    }
162
163    private void GenerateSolvers(GPDefNode ast, SourceBuilder solverSourceCode) {
164      var grammar = CreateGrammarFromAst(ast);
165      try {
166        var randomSearchCodeGen = new RandomSearchCodeGen();
167        randomSearchCodeGen.Generate(grammar, ast.Terminals.OfType<TerminalNode>(), ast.FitnessFunctionNode.Maximization,
168                                     solverSourceCode);
169      }
170      catch (Exception e) {
171        Console.WriteLine(e.Message);
172      }
173      try {
174        var bruteForceSearchCodeGen = new BruteForceCodeGen();
175        bruteForceSearchCodeGen.Generate(grammar, ast.Terminals.OfType<TerminalNode>(), ast.FitnessFunctionNode.Maximization, solverSourceCode);
176      }
177      catch (Exception e) {
178        Console.WriteLine(e.Message);
179      }
180      try {
181        var mctsCodeGen = new MonteCarloTreeSearchCodeGen();
182        mctsCodeGen.Generate(grammar, ast.Terminals.OfType<TerminalNode>(), ast.FitnessFunctionNode.Maximization,
183                             solverSourceCode);
184      }
185      catch (Exception e) {
186        Console.WriteLine(e.Message);
187      }
188    }
189
190    #region create grammar instance from AST
191    // should be refactored so that we can directly query the AST
192    private AttributedGrammar CreateGrammarFromAst(GPDefNode ast) {
193
194      var nonTerminals = ast.NonTerminals
195        .Select(t => new Symbol(t.Ident, GetSymbolAttributes(t.FormalParameters)))
196        .ToArray();
197      var terminals = ast.Terminals
198        .Select(t => new Symbol(t.Ident, GetSymbolAttributes(t.FormalParameters)))
199        .ToArray();
200      string startSymbolName = ast.Rules.First().NtSymbol;
201
202      // create startSymbol
203      var startSymbol = nonTerminals.Single(s => s.Name == startSymbolName);
204      var g = new AttributedGrammar(startSymbol, nonTerminals, terminals);
205
206      // add all production rules
207      foreach (var rule in ast.Rules) {
208        var ntSymbol = nonTerminals.Single(s => s.Name == rule.NtSymbol);
209        foreach (var alt in GetAlternatives(rule.Alternatives, nonTerminals.Concat(terminals))) {
210          g.AddProductionRule(ntSymbol, alt);
211        }
212        // local initialization code
213        if (!string.IsNullOrEmpty(rule.LocalCode)) g.AddLocalDefinitions(ntSymbol, rule.LocalCode);
214      }
215      return g;
216    }
217
218    private IEnumerable<IAttribute> GetSymbolAttributes(string formalParameters) {
219      return (from fieldDef in Util.ExtractParameters(formalParameters)
220              select new Attribute(fieldDef.Identifier, fieldDef.Type, AttributeType.Parse(fieldDef.RefOrOut)))
221              .ToList();
222    }
223
224    private IEnumerable<Sequence> GetAlternatives(AlternativesNode altNode, IEnumerable<ISymbol> allSymbols) {
225      foreach (var alt in altNode.Alternatives) {
226        yield return GetSequence(alt.Sequence, allSymbols);
227      }
228    }
229
230    private Sequence GetSequence(IEnumerable<RuleExprNode> sequence, IEnumerable<ISymbol> allSymbols) {
231      Debug.Assert(sequence.All(s => s is CallSymbolNode || s is RuleActionNode));
232      var l = new List<ISymbol>();
233      foreach (var node in sequence) {
234        var callSymbolNode = node as CallSymbolNode;
235        var actionNode = node as RuleActionNode;
236        if (callSymbolNode != null) {
237          Debug.Assert(allSymbols.Any(s => s.Name == callSymbolNode.Ident));
238          // create a new symbol with actual parameters
239          l.Add(new Symbol(callSymbolNode.Ident, GetSymbolAttributes(callSymbolNode.ActualParameter)));
240        } else if (actionNode != null) {
241          l.Add(new SemanticSymbol("SEM", actionNode.SrcCode));
242        }
243      }
244      return new Sequence(l);
245    }
246    #endregion
247
248    #region helper methods for terminal symbols
249    // produces helper methods for the attributes of all terminal nodes
250    private string GenerateConstraintMethods(IEnumerable<SymbolNode> symbols) {
251      var sb = new SourceBuilder();
252      var terminals = symbols.OfType<TerminalNode>();
253      foreach (var t in terminals) {
254        GenerateConstraintMethods(t, sb);
255      }
256      return sb.ToString();
257    }
258
259
260    // generates helper methods for the attributes of a given terminal node
261    private void GenerateConstraintMethods(TerminalNode t, SourceBuilder sb) {
262      foreach (var c in t.Constraints) {
263        var fieldType = t.FieldDefinitions.First(d => d.Identifier == c.Ident).Type;
264        if (c.Type == ConstraintNodeType.Range) {
265          sb.AppendFormat("public {0} GetMax{1}_{2}() {{ return {3}; }}", fieldType, t.Ident, c.Ident, c.RangeMaxExpression).AppendLine();
266          sb.AppendFormat("public {0} GetMin{1}_{2}() {{ return {3}; }}", fieldType, t.Ident, c.Ident, c.RangeMinExpression).AppendLine();
267        } else if (c.Type == ConstraintNodeType.Set) {
268          sb.AppendFormat("public IEnumerable<{0}> GetAllowed{1}_{2}() {{ return {3}; }}", fieldType, t.Ident, c.Ident, c.SetExpression).AppendLine();
269        }
270      }
271    }
272    #endregion
273
274    private string GenerateTerminalNodeClassDefinitions(IEnumerable<TerminalNode> terminals) {
275      var sb = new SourceBuilder();
276
277      foreach (var terminal in terminals) {
278        GenerateTerminalNodeClassDefinitions(terminal, sb);
279      }
280      return sb.ToString();
281    }
282
283    private void GenerateTerminalNodeClassDefinitions(TerminalNode terminal, SourceBuilder sb) {
284      sb.AppendFormat("public class {0}Tree : Tree {{", terminal.Ident).BeginBlock();
285      foreach (var att in terminal.FieldDefinitions) {
286        sb.AppendFormat("public {0} {1};", att.Type, att.Identifier).AppendLine();
287      }
288      sb.AppendFormat(" public {0}Tree() : base() {{ }}", terminal.Ident).AppendLine();
289      sb.AppendLine(@"
290          public override void PrintTree(int curState) {
291            Console.Write(""{0} "", Grammar.symb[curState]);");
292      foreach (var att in terminal.FieldDefinitions) {
293        sb.AppendFormat("Console.Write(\"{{0}} \", {0});", att.Identifier).AppendLine();
294      }
295
296      sb.AppendLine("}");
297
298      sb.AppendLine("}");
299    }
300
301    private string GenerateInterpreterSource(AttributedGrammar grammar) {
302      var sb = new SourceBuilder();
303      GenerateInterpreterStart(grammar, sb);
304
305      // generate methods for all nonterminals and terminals using the grammar instance
306      foreach (var s in grammar.NonTerminalSymbols) {
307        GenerateInterpreterMethod(grammar, s, sb);
308      }
309      foreach (var s in grammar.TerminalSymbols) {
310        GenerateTerminalInterpreterMethod(s, sb);
311      }
312      return sb.ToString();
313    }
314
315    private void GenerateInterpreterStart(AttributedGrammar grammar, SourceBuilder sb) {
316      var s = grammar.StartSymbol;
317      // create the method which can be called from the fitness function
318      if (!s.Attributes.Any())
319        sb.AppendFormat("private void {0}() {{", s.Name).BeginBlock();
320      else
321        sb.AppendFormat("private void {0}({1}) {{", s.Name, s.GetAttributeString()).BeginBlock();
322
323      // get formal parameters of start symbol
324      var attr = s.Attributes;
325
326      // actual parameter are the same as formalparameter only without type identifier
327      string actualParameter;
328      if (attr.Any())
329        actualParameter = attr.Skip(1).Aggregate(attr.First().AttributeType + " " + attr.First().Name, (str, a) => str + ", " + a.AttributeType + " " + a.Name);
330      else
331        actualParameter = string.Empty;
332      sb.AppendFormat("{0}(_t, {1});", s.Name, actualParameter).AppendLine();
333      sb.AppendLine("}").EndBlock();
334    }
335
336    private void GenerateInterpreterMethod(AttributedGrammar g, ISymbol s, SourceBuilder sb) {
337      if (!s.Attributes.Any())
338        sb.AppendFormat("private void {0}(Tree _t) {{", s.Name).BeginBlock();
339      else
340        sb.AppendFormat("private void {0}(Tree _t, {1}) {{", s.Name, s.GetAttributeString()).BeginBlock();
341
342      // generate local definitions
343      sb.AppendLine(g.GetLocalDefinitions(s));
344
345      var altsWithSemActions = g.GetAlternativesWithSemanticActions(s).ToArray();
346
347      if (altsWithSemActions.Length > 1) {
348        GenerateSwitchStatement(altsWithSemActions, sb);
349      } else {
350        int i = 0;
351        foreach (var altSymb in altsWithSemActions.Single()) {
352          GenerateSourceForAction(i, altSymb, sb);
353          if (!(altSymb is SemanticSymbol)) i++;
354        }
355      }
356      sb.Append("}").EndBlock();
357    }
358
359    private void GenerateSwitchStatement(IEnumerable<Sequence> alts, SourceBuilder sb) {
360      sb.Append("switch(_t.altIdx) {").BeginBlock();
361      // generate a case for each alternative
362      int altIdx = 0;
363      foreach (var alt in alts) {
364        sb.AppendFormat("case {0}: {{ ", altIdx).BeginBlock();
365
366        // this only works for alternatives with a single non-terminal symbol (ignoring semantic symbols)!
367        // a way to handle this is through grammar transformation (the examplary grammars all have the correct from)
368        Debug.Assert(alt.Count(symb => !(symb is SemanticSymbol)) == 1);
369        foreach (var altSymb in alt) {
370          GenerateSourceForAction(0, altSymb, sb); // index is always 0 because of the assertion above
371        }
372        altIdx++;
373        sb.AppendLine("break;").Append("}").EndBlock();
374      }
375      sb.AppendLine("default: throw new System.InvalidOperationException();").Append("}").EndBlock();
376    }
377
378    // helper for generating calls to other symbol methods
379    private void GenerateSourceForAction(int idx, ISymbol s, SourceBuilder sb) {
380      var action = s as SemanticSymbol;
381      if (action != null)
382        sb.Append(action.Code + ";");
383      else if (!s.Attributes.Any())
384        sb.AppendFormat("{1}(_t.subtrees[{0}]);", idx, s.Name);
385      else sb.AppendFormat("{1}(_t.subtrees[{0}], {2}); ", idx, s.Name, s.GetAttributeString());
386      sb.AppendLine();
387    }
388
389    private void GenerateTerminalInterpreterMethod(ISymbol s, SourceBuilder sb) {
390      // if the terminal symbol has attributes then we must samples values for these attributes
391      if (!s.Attributes.Any())
392        sb.AppendFormat("private void {0}(Tree _t) {{", s.Name).BeginBlock();
393      else
394        sb.AppendFormat("private void {0}(Tree _t, {1}) {{", s.Name, s.GetAttributeString()).BeginBlock();
395
396      // each field must match a formal parameter, assign a value for each parameter
397      int i = 0;
398      foreach (var element in s.Attributes) {
399        sb.AppendFormat("{0} = (_t as {1}Tree).{0};", element.Name, s.Name).AppendLine();
400      }
401      sb.Append("}").EndBlock();
402    }
403
404
405
406    private string GenerateTransitionTable(IGrammar grammar) {
407      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
408      var sb = new SourceBuilder();
409
410      // state idx = idx of the corresponding symbol in the grammar
411      var allSymbols = grammar.Symbols.ToList();
412      foreach (var s in grammar.Symbols) {
413        var targetStates = new List<int>();
414        if (grammar.IsTerminal(s)) {
415        } else {
416          if (grammar.NumberOfAlternatives(s) > 1) {
417            foreach (var alt in grammar.GetAlternatives(s)) {
418              // only single-symbol alternatives are supported
419              Debug.Assert(alt.Count() == 1);
420              targetStates.Add(allSymbols.IndexOf(alt.Single()));
421            }
422          } else {
423            // rule is a sequence of symbols
424            var seq = grammar.GetAlternatives(s).Single();
425            targetStates.AddRange(seq.Select(symb => allSymbols.IndexOf(symb)));
426          }
427        }
428
429        var targetStateString = targetStates.Aggregate(string.Empty, (str, state) => str + state + ", ");
430
431        var idxOfSourceState = allSymbols.IndexOf(s);
432        sb.AppendFormat("// {0}", s).AppendLine();
433        sb.AppendFormat("{{ {0} , new int[] {{ {1} }} }},", idxOfSourceState, targetStateString).AppendLine();
434      }
435      return sb.ToString();
436    }
437    private string GenerateSubtreeCountTable(IGrammar grammar) {
438      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
439      var sb = new SourceBuilder();
440
441      // state idx = idx of the corresponding symbol in the grammar
442      var allSymbols = grammar.Symbols.ToList();
443      foreach (var s in grammar.Symbols) {
444        int subtreeCount = 0;
445        if (grammar.IsTerminal(s)) {
446        } else {
447          if (grammar.NumberOfAlternatives(s) > 1) {
448            Debug.Assert(grammar.GetAlternatives(s).All(alt => alt.Count() == 1));
449            subtreeCount = 1;
450          } else {
451            subtreeCount = grammar.GetAlternative(s, 0).Count();
452          }
453        }
454
455        sb.AppendFormat("// {0}", s).AppendLine();
456        sb.AppendFormat("{{ {0} , {1} }},", allSymbols.IndexOf(s), subtreeCount).AppendLine();
457      }
458
459      return sb.ToString();
460    }
461    private string GenerateMinDepthTable(IGrammar grammar) {
462      var sb = new SourceBuilder();
463      var minDepth = new Dictionary<ISymbol, int>();
464      foreach (var s in grammar.TerminalSymbols) {
465        minDepth[s] = 1;
466      }
467      while (minDepth.Count < grammar.Symbols.Count()) {
468        foreach (var s in grammar.NonTerminalSymbols) {
469          if (grammar.NumberOfAlternatives(s) > 1) {
470            // alternatives
471            if (grammar.GetAlternatives(s).Any(alt => minDepth.ContainsKey(alt.Single()))) {
472              minDepth[s] = (grammar.GetAlternatives(s)
473                .Where(alt => minDepth.ContainsKey(alt.Single()))
474                .Select(alt => minDepth[alt.Single()]))
475                .Min() + 1;
476            }
477          } else {
478            // sequences
479            if (grammar.GetAlternatives(s).Single().All(c => minDepth.ContainsKey(c))) {
480              minDepth[s] = (grammar.GetAlternatives(s).Single().Select(c => minDepth[c])).Max() + 1;
481            }
482          }
483        }
484      }
485      var allSymbols = grammar.Symbols.ToList();
486      foreach (var s in grammar.Symbols) {
487        sb.AppendFormat("// {0}", s).AppendLine();
488        sb.AppendFormat("{{ {0}, {1} }}, ", allSymbols.IndexOf(s), minDepth[s]).AppendLine();
489      }
490      return sb.ToString();
491    }
492  }
493}
Note: See TracBrowser for help on using the repository browser.