Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GPDL/CodeGenerator/ProblemCodeGen.cs @ 10415

Last change on this file since 10415 was 10415, checked in by gkronber, 11 years ago

#2026 implemented prevention of resampling of known nodes.

File size: 16.9 KB
Line 
1using System.Collections.Generic;
2using System.Diagnostics;
3using System.IO;
4using System.Linq;
5using HeuristicLab.Grammars;
6using Attribute = HeuristicLab.Grammars.Attribute;
7
8namespace CodeGenerator {
9  // code generator for problem class
10  public class ProblemCodeGen {
11    private const string usings = @"
12using System.Collections.Generic;
13using System.Linq;
14using System;
15using System.Text.RegularExpressions;
16using System.Diagnostics;
17";
18
19    private const string problemTemplate = @"
20namespace ?PROBLEMNAME? {
21  public sealed class ?IDENT?Problem {
22   
23   public ?IDENT?Problem() {
24      Initialize();
25    }   
26
27    private void Initialize() {
28      // the following is the source code from the INIT section of the problem definition
29#region INIT section
30?INITSOURCE?
31#endregion
32    }
33
34    private Tree _t;
35    public double Evaluate(Tree _t) {
36      this._t = _t;
37#region objective function (MINIMIZE / MAXIMIZE section)
38?FITNESSFUNCTION?
39#endregion
40    }
41    public bool IsBetter(double a, double b) {
42      return ?MAXIMIZATION? ? a > b : a < b;
43    }
44
45// additional code from the problem definition (CODE section)
46#region additional code
47?ADDITIONALCODE?
48#endregion
49
50#region generated source for interpretation
51?INTERPRETERSOURCE?
52#endregion
53
54#region generated code for the constraints for terminals
55?CONSTRAINTSSOURCE?
56#endregion
57  }
58
59#region class definitions for tree
60  public class Tree {
61    public int altIdx;
62    public Tree[] subtrees;
63    protected Tree() {
64      // leave subtrees uninitialized
65    }
66    public Tree(int altIdx, Tree[] subtrees = null) {
67      this.altIdx = altIdx;
68      this.subtrees = subtrees;
69    }
70    public int GetSize() {
71      if(subtrees==null) return 1;
72      else return 1 + subtrees.Sum(t=>t.GetSize());
73    }
74    public int GetDepth() {
75      if(subtrees==null) return 1;
76      else return 1 + subtrees.Max(t=>t.GetDepth());
77    }
78    public virtual void PrintTree(int curState) {
79      Console.Write(""{0} "", Grammar.symb[curState]);
80      if(subtrees != null) {
81        if(subtrees.Length==1) {
82          subtrees[0].PrintTree(Grammar.transition[curState][altIdx]);
83        } else {
84          for(int i=0;i<subtrees.Length;i++) {
85            subtrees[i].PrintTree(Grammar.transition[curState][i]);
86          }
87        }
88      }
89    }
90  }
91
92  ?TERMINALNODECLASSDEFINITIONS?
93#endregion
94
95#region helper class for the grammar representation
96  public class Grammar {
97    public static readonly Dictionary<int, int[]> transition = new Dictionary<int, int[]>() {
98?TRANSITIONTABLE?
99    };
100    public static readonly Dictionary<int, int> subtreeCount = new Dictionary<int, int>() {
101       { -1, 0 }, // terminals
102?SUBTREECOUNTTABLE?
103    };
104    public static readonly string[] symb = new string[] { ?SYMBOLNAMES? };
105   
106  } 
107#endregion
108}";
109
110
111    /// <summary>
112    /// Generates the source code for a brute force searcher that can be compiled with a C# compiler
113    /// </summary>
114    /// <param name="ast">An abstract syntax tree for a GPDL file</param>
115    public void Generate(GPDefNode ast) {
116      var problemSourceCode = new SourceBuilder();
117      problemSourceCode.AppendLine(usings);
118
119      GenerateProblemSource(ast, problemSourceCode);
120      GenerateSolvers(ast, problemSourceCode);
121
122      problemSourceCode
123        .Replace("?PROBLEMNAME?", ast.Name)
124        .Replace("?IDENT?", ast.Name);
125
126      // write the source file to disk
127      using (var stream = new StreamWriter(ast.Name + ".cs")) {
128        stream.WriteLine(problemSourceCode.ToString());
129      }
130    }
131
132    private void GenerateProblemSource(GPDefNode ast, SourceBuilder problemSourceCode) {
133      var grammar = CreateGrammarFromAst(ast);
134      problemSourceCode
135        .AppendLine(problemTemplate)
136        .Replace("?FITNESSFUNCTION?", ast.FitnessFunctionNode.SrcCode)
137        .Replace("?MAXIMIZATION?", ast.FitnessFunctionNode.Maximization.ToString().ToLowerInvariant())
138        .Replace("?INITSOURCE?", ast.InitCodeNode.SrcCode)
139        .Replace("?ADDITIONALCODE?", ast.ClassCodeNode.SrcCode)
140        .Replace("?INTERPRETERSOURCE?", GenerateInterpreterSource(grammar))
141        .Replace("?CONSTRAINTSSOURCE?", GenerateConstraintMethods(ast.Terminals))
142        .Replace("?TERMINALNODECLASSDEFINITIONS?", GenerateTerminalNodeClassDefinitions(ast.Terminals.OfType<TerminalNode>()))
143        .Replace("?SYMBOLNAMES?", grammar.Symbols.Select(s => s.Name).Aggregate(string.Empty, (str, symb) => str + "\"" + symb + "\", "))
144        .Replace("?TRANSITIONTABLE?", GenerateTransitionTable(grammar))
145        .Replace("?SUBTREECOUNTTABLE?", GenerateSubtreeCountTable(grammar))
146       ;
147    }
148
149    private void GenerateSolvers(GPDefNode ast, SourceBuilder solverSourceCode) {
150      var grammar = CreateGrammarFromAst(ast);
151      // var randomSearchCodeGen = new RandomSearchCodeGen();
152      // randomSearchCodeGen.Generate(grammar, ast.Terminals.OfType<TerminalNode>(), ast.FitnessFunctionNode.Maximization, solverSourceCode);
153      // var bruteForceSearchCodeGen = new BruteForceCodeGen();
154      // bruteForceSearchCodeGen.Generate(grammar, ast.Terminals.OfType<TerminalNode>(), ast.FitnessFunctionNode.Maximization, solverSourceCode);
155      var mctsCodeGen = new MonteCarloTreeSearchCodeGen();
156      mctsCodeGen.Generate(grammar, ast.Terminals.OfType<TerminalNode>(), ast.FitnessFunctionNode.Maximization, solverSourceCode);
157    }
158
159    #region create grammar instance from AST
160    // should be refactored so that we can directly query the AST
161    private AttributedGrammar CreateGrammarFromAst(GPDefNode ast) {
162
163      var nonTerminals = ast.NonTerminals
164        .Select(t => new Symbol(t.Ident, GetSymbolAttributes(t.FormalParameters)))
165        .ToArray();
166      var terminals = ast.Terminals
167        .Select(t => new Symbol(t.Ident, GetSymbolAttributes(t.FormalParameters)))
168        .ToArray();
169      string startSymbolName = ast.Rules.First().NtSymbol;
170
171      // create startSymbol
172      var startSymbol = nonTerminals.Single(s => s.Name == startSymbolName);
173      var g = new AttributedGrammar(startSymbol, nonTerminals, terminals);
174
175      // add all production rules
176      foreach (var rule in ast.Rules) {
177        var ntSymbol = nonTerminals.Single(s => s.Name == rule.NtSymbol);
178        foreach (var alt in GetAlternatives(rule.Alternatives, nonTerminals.Concat(terminals))) {
179          g.AddProductionRule(ntSymbol, alt);
180        }
181        // local initialization code
182        if (!string.IsNullOrEmpty(rule.LocalCode)) g.AddLocalDefinitions(ntSymbol, rule.LocalCode);
183      }
184      return g;
185    }
186
187    private IEnumerable<IAttribute> GetSymbolAttributes(string formalParameters) {
188      return (from fieldDef in Util.ExtractParameters(formalParameters)
189              select new Attribute(fieldDef.Identifier, fieldDef.Type, AttributeType.Parse(fieldDef.RefOrOut)))
190              .ToList();
191    }
192
193    private IEnumerable<Sequence> GetAlternatives(AlternativesNode altNode, IEnumerable<ISymbol> allSymbols) {
194      foreach (var alt in altNode.Alternatives) {
195        yield return GetSequence(alt.Sequence, allSymbols);
196      }
197    }
198
199    private Sequence GetSequence(IEnumerable<RuleExprNode> sequence, IEnumerable<ISymbol> allSymbols) {
200      Debug.Assert(sequence.All(s => s is CallSymbolNode || s is RuleActionNode));
201      var l = new List<ISymbol>();
202      foreach (var node in sequence) {
203        var callSymbolNode = node as CallSymbolNode;
204        var actionNode = node as RuleActionNode;
205        if (callSymbolNode != null) {
206          Debug.Assert(allSymbols.Any(s => s.Name == callSymbolNode.Ident));
207          // create a new symbol with actual parameters
208          l.Add(new Symbol(callSymbolNode.Ident, GetSymbolAttributes(callSymbolNode.ActualParameter)));
209        } else if (actionNode != null) {
210          l.Add(new SemanticSymbol("SEM", actionNode.SrcCode));
211        }
212      }
213      return new Sequence(l);
214    }
215    #endregion
216
217    #region helper methods for terminal symbols
218    // produces helper methods for the attributes of all terminal nodes
219    private string GenerateConstraintMethods(IEnumerable<SymbolNode> symbols) {
220      var sb = new SourceBuilder();
221      var terminals = symbols.OfType<TerminalNode>();
222      foreach (var t in terminals) {
223        GenerateConstraintMethods(t, sb);
224      }
225      return sb.ToString();
226    }
227
228
229    // generates helper methods for the attributes of a given terminal node
230    private void GenerateConstraintMethods(TerminalNode t, SourceBuilder sb) {
231      foreach (var c in t.Constraints) {
232        var fieldType = t.FieldDefinitions.First(d => d.Identifier == c.Ident).Type;
233        if (c.Type == ConstraintNodeType.Range) {
234          sb.AppendFormat("public {0} GetMax{1}_{2}() {{ return {3}; }}", fieldType, t.Ident, c.Ident, c.RangeMaxExpression).AppendLine();
235          sb.AppendFormat("public {0} GetMin{1}_{2}() {{ return {3}; }}", fieldType, t.Ident, c.Ident, c.RangeMinExpression).AppendLine();
236        } else if (c.Type == ConstraintNodeType.Set) {
237          sb.AppendFormat("public IEnumerable<{0}> GetAllowed{1}_{2}() {{ return {3}; }}", fieldType, t.Ident, c.Ident, c.SetExpression).AppendLine();
238        }
239      }
240    }
241    #endregion
242
243    private string GenerateTerminalNodeClassDefinitions(IEnumerable<TerminalNode> terminals) {
244      var sb = new SourceBuilder();
245
246      foreach (var terminal in terminals) {
247        GenerateTerminalNodeClassDefinitions(terminal, sb);
248      }
249      return sb.ToString();
250    }
251
252    private void GenerateTerminalNodeClassDefinitions(TerminalNode terminal, SourceBuilder sb) {
253      sb.AppendFormat("public class {0}Tree : Tree {{", terminal.Ident).BeginBlock();
254      foreach (var att in terminal.FieldDefinitions) {
255        sb.AppendFormat("public {0} {1};", att.Type, att.Identifier).AppendLine();
256      }
257      sb.AppendFormat(" public {0}Tree() : base() {{ }}", terminal.Ident).AppendLine();
258      sb.AppendLine(@"
259          public override void PrintTree(int curState) {
260            Console.Write(""{0} "", Grammar.symb[curState]);");
261      foreach (var att in terminal.FieldDefinitions) {
262        sb.AppendFormat("Console.Write(\"{{0}} \", {0});", att.Identifier).AppendLine();
263      }
264
265      sb.AppendLine("}");
266
267      sb.AppendLine("}");
268    }
269
270    private string GenerateInterpreterSource(AttributedGrammar grammar) {
271      var sb = new SourceBuilder();
272      GenerateInterpreterStart(grammar, sb);
273
274      // generate methods for all nonterminals and terminals using the grammar instance
275      foreach (var s in grammar.NonTerminalSymbols) {
276        GenerateInterpreterMethod(grammar, s, sb);
277      }
278      foreach (var s in grammar.TerminalSymbols) {
279        GenerateTerminalInterpreterMethod(s, sb);
280      }
281      return sb.ToString();
282    }
283
284    private void GenerateInterpreterStart(AttributedGrammar grammar, SourceBuilder sb) {
285      var s = grammar.StartSymbol;
286      // create the method which can be called from the fitness function
287      if (!s.Attributes.Any())
288        sb.AppendFormat("private void {0}() {{", s.Name).BeginBlock();
289      else
290        sb.AppendFormat("private void {0}({1}) {{", s.Name, s.GetAttributeString()).BeginBlock();
291
292      // get formal parameters of start symbol
293      var attr = s.Attributes;
294
295      // actual parameter are the same as formalparameter only without type identifier
296      string actualParameter;
297      if (attr.Any())
298        actualParameter = attr.Skip(1).Aggregate(attr.First().AttributeType + " " + attr.First().Name, (str, a) => str + ", " + a.AttributeType + " " + a.Name);
299      else
300        actualParameter = string.Empty;
301      sb.AppendFormat("{0}(_t, {1});", s.Name, actualParameter).AppendLine();
302      sb.AppendLine("}").EndBlock();
303    }
304
305    private void GenerateInterpreterMethod(AttributedGrammar g, ISymbol s, SourceBuilder sb) {
306      if (!s.Attributes.Any())
307        sb.AppendFormat("private void {0}(Tree _t) {{", s.Name).BeginBlock();
308      else
309        sb.AppendFormat("private void {0}(Tree _t, {1}) {{", s.Name, s.GetAttributeString()).BeginBlock();
310
311      // generate local definitions
312      sb.AppendLine(g.GetLocalDefinitions(s));
313
314      var altsWithSemActions = g.GetAlternativesWithSemanticActions(s).ToArray();
315
316      if (altsWithSemActions.Length > 1) {
317        GenerateSwitchStatement(altsWithSemActions, sb);
318      } else {
319        int i = 0;
320        foreach (var altSymb in altsWithSemActions.Single()) {
321          GenerateSourceForAction(i, altSymb, sb);
322          if (!(altSymb is SemanticSymbol)) i++;
323        }
324      }
325      sb.Append("}").EndBlock();
326    }
327
328    private void GenerateSwitchStatement(IEnumerable<Sequence> alts, SourceBuilder sb) {
329      sb.Append("switch(_t.altIdx) {").BeginBlock();
330      // generate a case for each alternative
331      int altIdx = 0;
332      foreach (var alt in alts) {
333        sb.AppendFormat("case {0}: {{ ", altIdx).BeginBlock();
334
335        // this only works for alternatives with a single non-terminal symbol (ignoring semantic symbols)!
336        // a way to handle this is through grammar transformation (the examplary grammars all have the correct from)
337        Debug.Assert(alt.Count(symb => !(symb is SemanticSymbol)) == 1);
338        foreach (var altSymb in alt) {
339          GenerateSourceForAction(0, altSymb, sb); // index is always 0 because of the assertion above
340        }
341        altIdx++;
342        sb.AppendLine("break;").Append("}").EndBlock();
343      }
344      sb.AppendLine("default: throw new System.InvalidOperationException();").Append("}").EndBlock();
345    }
346
347    // helper for generating calls to other symbol methods
348    private void GenerateSourceForAction(int idx, ISymbol s, SourceBuilder sb) {
349      var action = s as SemanticSymbol;
350      if (action != null)
351        sb.Append(action.Code + ";");
352      else if (!s.Attributes.Any())
353        sb.AppendFormat("{1}(_t.subtrees[{0}]);", idx, s.Name);
354      else sb.AppendFormat("{1}(_t.subtrees[{0}], {2}); ", idx, s.Name, s.GetAttributeString());
355      sb.AppendLine();
356    }
357
358    private void GenerateTerminalInterpreterMethod(ISymbol s, SourceBuilder sb) {
359      // if the terminal symbol has attributes then we must samples values for these attributes
360      if (!s.Attributes.Any())
361        sb.AppendFormat("private void {0}(Tree _t) {{", s.Name).BeginBlock();
362      else
363        sb.AppendFormat("private void {0}(Tree _t, {1}) {{", s.Name, s.GetAttributeString()).BeginBlock();
364
365      // each field must match a formal parameter, assign a value for each parameter
366      int i = 0;
367      foreach (var element in s.Attributes) {
368        sb.AppendFormat("{0} = (_t as {1}Tree).{0};", element.Name, s.Name).AppendLine();
369      }
370      sb.Append("}").EndBlock();
371    }
372
373
374
375    private string GenerateTransitionTable(IGrammar grammar) {
376      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
377      var sb = new SourceBuilder();
378
379      // state idx = idx of the corresponding symbol in the grammar
380      var allSymbols = grammar.Symbols.ToList();
381      foreach (var s in grammar.Symbols) {
382        var targetStates = new List<int>();
383        if (grammar.IsTerminal(s)) {
384        } else {
385          if (grammar.NumberOfAlternatives(s) > 1) {
386            foreach (var alt in grammar.GetAlternatives(s)) {
387              // only single-symbol alternatives are supported
388              Debug.Assert(alt.Count() == 1);
389              targetStates.Add(allSymbols.IndexOf(alt.Single()));
390            }
391          } else {
392            // rule is a sequence of symbols
393            var seq = grammar.GetAlternatives(s).Single();
394            targetStates.AddRange(seq.Select(symb => allSymbols.IndexOf(symb)));
395          }
396        }
397
398        var targetStateString = targetStates.Aggregate(string.Empty, (str, state) => str + state + ", ");
399
400        var idxOfSourceState = allSymbols.IndexOf(s);
401        sb.AppendFormat("// {0}", s).AppendLine();
402        sb.AppendFormat("{{ {0} , new int[] {{ {1} }} }},", idxOfSourceState, targetStateString).AppendLine();
403      }
404      return sb.ToString();
405    }
406    private string GenerateSubtreeCountTable(IGrammar grammar) {
407      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
408      var sb = new SourceBuilder();
409
410      // state idx = idx of the corresponding symbol in the grammar
411      var allSymbols = grammar.Symbols.ToList();
412      foreach (var s in grammar.Symbols) {
413        int subtreeCount = 0;
414        if (grammar.IsTerminal(s)) {
415        } else {
416          if (grammar.NumberOfAlternatives(s) > 1) {
417            Debug.Assert(grammar.GetAlternatives(s).All(alt => alt.Count() == 1));
418            subtreeCount = 1;
419          } else {
420            subtreeCount = grammar.GetAlternative(s, 0).Count();
421          }
422        }
423
424        sb.AppendFormat("// {0}", s).AppendLine();
425        sb.AppendFormat("{{ {0} , {1} }},", allSymbols.IndexOf(s), subtreeCount).AppendLine();
426      }
427
428      return sb.ToString();
429    }
430
431  }
432}
Note: See TracBrowser for help on using the repository browser.