Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GPDL/CodeGenerator/ProblemCodeGen.cs @ 10373

Last change on this file since 10373 was 10338, checked in by gkronber, 11 years ago

#2026 changing the way terminal symbols are handled (work in progress)

File size: 14.1 KB
Line 
1using System.Collections.Generic;
2using System.Diagnostics;
3using System.IO;
4using System.Linq;
5using HeuristicLab.Grammars;
6using Attribute = HeuristicLab.Grammars.Attribute;
7
8namespace CodeGenerator {
9  // code generator for problem class
10  public class ProblemCodeGen {
11    private const string usings = @"
12using System.Collections.Generic;
13using System.Linq;
14using System;
15using System.Text.RegularExpressions;
16";
17
18    private const string problemTemplate = @"
19namespace ?PROBLEMNAME? {
20  public class Tree {
21    public int altIdx;
22    public List<Tree> subtrees;
23    protected Tree() {
24      // leave subtrees uninitialized
25    }
26    public Tree(int state, int altIdx, int numSubTrees) {
27      subtrees = new List<Tree>();
28      this.altIdx = altIdx;
29    }
30  }
31
32  ?TERMINALNODECLASSDEFINITIONS?
33
34  // generic interface for communication from problem interpretation to solver
35  public interface ISolverState {
36    void Reset();
37    int PeekNextAlternative(); // in alternative nodes returns the index of the alternative that should be followed
38    void Follow(int idx); // next: derive the NT symbol with index=idx;
39    void Unwind(); // finished with deriving the NT symbol
40  }
41
42  public sealed class ?IDENT?Problem {
43   
44   public ?IDENT?Problem() {
45      Initialize();
46    }   
47
48    private void Initialize() {
49      // the following is the source code from the INIT section of the problem definition
50#region INIT section
51?INITSOURCE?
52#endregion
53    }
54
55    private ISolverState _state;
56    public double Evaluate(ISolverState _state) {
57      this._state = _state;
58#region objective function (MINIMIZE / MAXIMIZE section)
59?FITNESSFUNCTION?
60#endregion
61    }
62
63// additional code from the problem definition (CODE section)
64#region additional code
65?ADDITIONALCODE?
66#endregion
67
68#region generated source for interpretation
69?INTERPRETERSOURCE?
70#endregion
71
72#region generated code for the constraints for terminals
73?CONSTRAINTSSOURCE?
74#endregion
75  }
76}";
77
78
79    /// <summary>
80    /// Generates the source code for a brute force searcher that can be compiled with a C# compiler
81    /// </summary>
82    /// <param name="ast">An abstract syntax tree for a GPDL file</param>
83    public void Generate(GPDefNode ast) {
84      var problemSourceCode = new SourceBuilder();
85      problemSourceCode.AppendLine(usings);
86
87      GenerateProblemSource(ast, problemSourceCode);
88      GenerateSolvers(ast, problemSourceCode);
89
90      problemSourceCode
91        .Replace("?PROBLEMNAME?", ast.Name)
92        .Replace("?IDENT?", ast.Name);
93
94      // write the source file to disk
95      using (var stream = new StreamWriter(ast.Name + ".cs")) {
96        stream.WriteLine(problemSourceCode.ToString());
97      }
98    }
99
100    private void GenerateProblemSource(GPDefNode ast, SourceBuilder problemSourceCode) {
101      var grammar = CreateGrammarFromAst(ast);
102      problemSourceCode
103        .AppendLine(problemTemplate)
104        .Replace("?FITNESSFUNCTION?", ast.FitnessFunctionNode.SrcCode)
105        .Replace("?INITSOURCE?", ast.InitCodeNode.SrcCode)
106        .Replace("?ADDITIONALCODE?", ast.ClassCodeNode.SrcCode)
107        .Replace("?INTERPRETERSOURCE?", GenerateInterpreterSource(grammar))
108        .Replace("?CONSTRAINTSSOURCE?", GenerateConstraintMethods(ast.Terminals))
109        .Replace("?TERMINALNODECLASSDEFINITIONS?", GenerateTerminalNodeClassDefinitions(ast.Terminals.OfType<TerminalNode>()))
110       ;
111    }
112
113    private void GenerateSolvers(GPDefNode ast, SourceBuilder solverSourceCode) {
114      var grammar = CreateGrammarFromAst(ast);
115      var randomSearchCodeGen = new RandomSearchCodeGen();
116      randomSearchCodeGen.Generate(grammar, ast.FitnessFunctionNode.Maximization, solverSourceCode);
117      //var bruteForceSearchCodeGen = new BruteForceCodeGen();
118      //bruteForceSearchCodeGen.Generate(grammar, ast.FitnessFunctionNode.Maximization, solverSourceCode);
119    }
120
121    #region create grammar instance from AST
122    // should be refactored so that we can directly query the AST
123    private AttributedGrammar CreateGrammarFromAst(GPDefNode ast) {
124
125      var nonTerminals = ast.NonTerminals
126        .Select(t => new Symbol(t.Ident, GetSymbolAttributes(t.FormalParameters)))
127        .ToArray();
128      var terminals = ast.Terminals
129        .Select(t => new Symbol(t.Ident, GetSymbolAttributes(t.FormalParameters)))
130        .ToArray();
131      string startSymbolName = ast.Rules.First().NtSymbol;
132
133      // create startSymbol
134      var startSymbol = nonTerminals.Single(s => s.Name == startSymbolName);
135      var g = new AttributedGrammar(startSymbol, nonTerminals, terminals);
136
137      // add all production rules
138      foreach (var rule in ast.Rules) {
139        var ntSymbol = nonTerminals.Single(s => s.Name == rule.NtSymbol);
140        foreach (var alt in GetAlternatives(rule.Alternatives, nonTerminals.Concat(terminals))) {
141          g.AddProductionRule(ntSymbol, alt);
142        }
143        // local initialization code
144        if (!string.IsNullOrEmpty(rule.LocalCode)) g.AddLocalDefinitions(ntSymbol, rule.LocalCode);
145      }
146      return g;
147    }
148
149    private IEnumerable<IAttribute> GetSymbolAttributes(string formalParameters) {
150      return (from fieldDef in Util.ExtractParameters(formalParameters)
151              select new Attribute(fieldDef.Identifier, fieldDef.Type, AttributeType.Parse(fieldDef.RefOrOut)))
152              .ToList();
153    }
154
155    private IEnumerable<Sequence> GetAlternatives(AlternativesNode altNode, IEnumerable<ISymbol> allSymbols) {
156      foreach (var alt in altNode.Alternatives) {
157        yield return GetSequence(alt.Sequence, allSymbols);
158      }
159    }
160
161    private Sequence GetSequence(IEnumerable<RuleExprNode> sequence, IEnumerable<ISymbol> allSymbols) {
162      Debug.Assert(sequence.All(s => s is CallSymbolNode || s is RuleActionNode));
163      var l = new List<ISymbol>();
164      foreach (var node in sequence) {
165        var callSymbolNode = node as CallSymbolNode;
166        var actionNode = node as RuleActionNode;
167        if (callSymbolNode != null) {
168          Debug.Assert(allSymbols.Any(s => s.Name == callSymbolNode.Ident));
169          // create a new symbol with actual parameters
170          l.Add(new Symbol(callSymbolNode.Ident, GetSymbolAttributes(callSymbolNode.ActualParameter)));
171        } else if (actionNode != null) {
172          l.Add(new SemanticSymbol("SEM", actionNode.SrcCode));
173        }
174      }
175      return new Sequence(l);
176    }
177    #endregion
178
179    #region helper methods for terminal symbols
180    // produces helper methods for the attributes of all terminal nodes
181    private string GenerateConstraintMethods(IEnumerable<SymbolNode> symbols) {
182      var sb = new SourceBuilder();
183      var terminals = symbols.OfType<TerminalNode>();
184      foreach (var t in terminals) {
185        GenerateConstraintMethods(t, sb);
186      }
187      return sb.ToString();
188    }
189
190
191    // generates helper methods for the attributes of a given terminal node
192    private void GenerateConstraintMethods(TerminalNode t, SourceBuilder sb) {
193      foreach (var c in t.Constraints) {
194        var fieldType = t.FieldDefinitions.First(d => d.Identifier == c.Ident).Type;
195        if (c.Type == ConstraintNodeType.Range) {
196          sb.AppendFormat("public {0} GetMax{1}_{2}() {{ return {3}; }}", fieldType, t.Ident, c.Ident, c.RangeMaxExpression).AppendLine();
197          sb.AppendFormat("public {0} GetMin{1}_{2}() {{ return {3}; }}", fieldType, t.Ident, c.Ident, c.RangeMinExpression).AppendLine();
198          //sb.AppendFormat("public {0} Get{1}_{2}(ISolverState _state) {{ _state. }}", fieldType, t.Ident, c.Ident, )
199        } else if (c.Type == ConstraintNodeType.Set) {
200          sb.AppendFormat("public IEnumerable<{0}> GetAllowed{1}_{2}() {{ return {3}; }}", fieldType, t.Ident, c.Ident, c.SetExpression).AppendLine();
201        }
202      }
203    }
204    #endregion
205
206    private string GenerateTerminalNodeClassDefinitions(IEnumerable<TerminalNode> terminals) {
207      var sb = new SourceBuilder();
208
209      foreach (var terminal in terminals) {
210        GenerateTerminalNodeClassDefinitions(terminal, sb);
211      }
212      return sb.ToString();
213    }
214
215    private void GenerateTerminalNodeClassDefinitions(TerminalNode terminal, SourceBuilder sb) {
216      sb.AppendFormat("public class {0}Tree : Tree {{", terminal.Ident).BeginBlock();
217      foreach (var att in terminal.FieldDefinitions) {
218        sb.AppendFormat("{0} {1};", att.Type, att.Identifier).AppendLine();
219      }
220      sb.AppendFormat(" public {0}Tree(Random random, ?IDENT?Problem problem) : base() {{", terminal.Ident).BeginBlock();
221      foreach (var constr in terminal.Constraints) {
222        if (constr.Type == ConstraintNodeType.Set) {
223          sb.AppendLine("{").BeginBlock();
224          sb.AppendFormat(" var elements = problem.GetAllowed{0}_{1}().ToArray();", terminal.Ident, constr.Ident).AppendLine();
225          sb.AppendFormat("{0} = elements[random.Next(elements.Length)]; ", constr.Ident).AppendLine();
226          sb.AppendLine("}").EndBlock();
227        } else {
228          sb.AppendLine("{").BeginBlock();
229          sb.AppendFormat(" var max = problem.GetMax{0}_{1}();", terminal.Ident, constr.Ident).AppendLine();
230          sb.AppendFormat(" var min = problem.GetMin{0}_{1}();", terminal.Ident, constr.Ident).AppendLine();
231          sb.AppendFormat("{0} = random.NextDouble() * (max - min) + min ", constr.Ident).AppendLine();
232          sb.AppendLine("}").EndBlock();
233        }
234      }
235      sb.AppendLine("}").EndBlock();
236      sb.AppendLine("}").EndBlock();
237    }
238
239    private string GenerateInterpreterSource(AttributedGrammar grammar) {
240      var sb = new SourceBuilder();
241      GenerateInterpreterStart(grammar, sb);
242
243      // generate methods for all nonterminals and terminals using the grammar instance
244      foreach (var s in grammar.NonTerminalSymbols) {
245        GenerateInterpreterMethod(grammar, s, sb);
246      }
247      foreach (var s in grammar.TerminalSymbols) {
248        GenerateTerminalInterpreterMethod(s, sb);
249      }
250      return sb.ToString();
251    }
252
253    private void GenerateInterpreterStart(AttributedGrammar grammar, SourceBuilder sb) {
254      var s = grammar.StartSymbol;
255      // create the method which can be called from the fitness function
256      if (!s.Attributes.Any())
257        sb.AppendFormat("private void {0}() {{", s.Name).BeginBlock();
258      else
259        sb.AppendFormat("private void {0}({1}) {{", s.Name, s.GetAttributeString()).BeginBlock();
260
261      // get formal parameters of start symbol
262      var attr = s.Attributes;
263
264      // actual parameter are the same as formalparameter only without type identifier
265      string actualParameter;
266      if (attr.Any())
267        actualParameter = attr.Skip(1).Aggregate(attr.First().AttributeType + " " + attr.First().Name, (str, a) => str + ", " + a.AttributeType + " " + a.Name);
268      else
269        actualParameter = string.Empty;
270      sb.AppendLine("_state.Reset();");
271      sb.AppendFormat("{0}(_state, {1});", s.Name, actualParameter).AppendLine();
272      sb.AppendLine("}").EndBlock();
273    }
274
275    private void GenerateInterpreterMethod(AttributedGrammar g, ISymbol s, SourceBuilder sb) {
276      if (!s.Attributes.Any())
277        sb.AppendFormat("private void {0}(ISolverState _state) {{", s.Name).BeginBlock();
278      else
279        sb.AppendFormat("private void {0}(ISolverState _state, {1}) {{", s.Name, s.GetAttributeString()).BeginBlock();
280
281      // generate local definitions
282      sb.AppendLine(g.GetLocalDefinitions(s));
283
284      var altsWithSemActions = g.GetAlternativesWithSemanticActions(s).ToArray();
285
286      if (altsWithSemActions.Length > 1) {
287        GenerateSwitchStatement(altsWithSemActions, sb);
288      } else {
289        int i = 0;
290        foreach (var altSymb in altsWithSemActions.Single()) {
291          GenerateSourceForAction(i, altSymb, sb);
292          if (!(altSymb is SemanticSymbol)) i++;
293        }
294      }
295      sb.Append("}").EndBlock();
296    }
297
298    private void GenerateSwitchStatement(IEnumerable<Sequence> alts, SourceBuilder sb) {
299      sb.Append("switch(_state.PeekNextAlternative()) {").BeginBlock();
300      // generate a case for each alternative
301      int i = 0;
302      foreach (var alt in alts) {
303        sb.AppendFormat("case {0}: {{ ", i).BeginBlock();
304
305        // this only works for alternatives with a single non-terminal symbol (ignoring semantic symbols)!
306        // a way to handle this is through grammar transformation (the examplary grammars all have the correct from)
307        Debug.Assert(alt.Count(symb => !(symb is SemanticSymbol)) == 1);
308        foreach (var altSymb in alt) {
309          GenerateSourceForAction(0, altSymb, sb); // index is always 0 because of the assertion above
310        }
311        i++;
312        sb.AppendLine("break;").Append("}").EndBlock();
313      }
314      sb.AppendLine("default: throw new System.InvalidOperationException();").Append("}").EndBlock();
315    }
316
317    // helper for generating calls to other symbol methods
318    private void GenerateSourceForAction(int idx, ISymbol s, SourceBuilder sb) {
319      var action = s as SemanticSymbol;
320      if (action != null)
321        sb.Append(action.Code + ";");
322      else if (!s.Attributes.Any())
323        sb.AppendFormat("_state.Follow({0}); {1}(_state); _state.Unwind();", idx, s.Name);
324      else sb.AppendFormat("_state.Follow({0}); {1}(_state, {2}); _state.Unwind();", idx, s.Name, s.GetAttributeString());
325      sb.AppendLine();
326    }
327
328    private void GenerateTerminalInterpreterMethod(ISymbol s, SourceBuilder sb) {
329      // if the terminal symbol has attributes then we must samples values for these attributes
330      if (!s.Attributes.Any())
331        sb.AppendFormat("private void {0}(ISolverState _state) {{", s.Name).BeginBlock();
332      else
333        sb.AppendFormat("private void {0}(ISolverState _state, {1}) {{", s.Name, s.GetAttributeString()).BeginBlock();
334
335      // each field must match a formal parameter, assign a value for each parameter
336      int i = 0;
337      foreach (var element in s.Attributes) {
338        sb.AppendFormat("_state.Follow({0});", i++).AppendLine();
339        sb.AppendFormat("{0} = Get{1}_{0}(_state);", element.Name, s.Name).AppendLine();
340        sb.AppendFormat("_state.Unwind();").AppendLine();
341      }
342      sb.Append("}").EndBlock();
343    }
344  }
345}
Note: See TracBrowser for help on using the repository browser.