Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GPDL/CodeGenerator/RandomSearchCodeGen.cs @ 10086

Last change on this file since 10086 was 10086, checked in by gkronber, 10 years ago

#2026 worked on random search solver (now all examples are working)

File size: 14.6 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.IO;
5using System.Linq;
6using System.Text;
7using HeuristicLab.Grammars;
8using Attribute = HeuristicLab.Grammars.Attribute;
9
10namespace CodeGenerator {
11  public class RandomSearchCodeGen {
12
13    private string usings = @"
14using System.Collections.Generic;
15using System.Linq;
16using System;
17";
18
19    private string solverTemplate = @"
20namespace ?PROBLEMNAME? {
21  public sealed class ?IDENT?Solver {
22    public sealed class SolverState {
23      private int currentSeed;
24      public Random random;
25      public int depth;
26      public int steps;
27      public int maxDepth;
28      public SolverState(int seed) {
29        this.currentSeed = seed;
30      }
31
32      public SolverState IncDepth() {
33        depth++;
34        if(depth>maxDepth) maxDepth = depth;
35        return this;
36      }
37      public SolverState Prepare() {
38        this.random = new Random(currentSeed);
39        depth = 0;
40        maxDepth = 0;
41        steps = 0;
42        return this;
43      }
44    }
45
46    public static void Main(string[] args) {
47      var solver = new ?IDENT?Solver();
48      solver.Start();
49    }
50
51    public ?IDENT?Solver() {
52      Initialize();
53    }   
54
55    private void Initialize() {
56      ?INITCODE?
57    }
58
59    private bool IsBetter(double a, double b) {
60      return ?MAXIMIZATION? ? a > b : a < b;
61    }
62
63    private double TerminalProbForDepth(int depth) {
64      const double baseProb = 0.05;  // 5% of all samples are only a terminal node
65      const double probIncPerLevel = 0.05; // for each level the probability to sample a terminal grows by 5%
66      return baseProb + depth * probIncPerLevel;
67    }
68
69    private SolverState _state;
70    private void Start() {
71      var seedRandom = new Random();
72      var bestF = ?MAXIMIZATION? ? double.NegativeInfinity : double.PositiveInfinity;
73      int n = 0;
74      var sw = new System.Diagnostics.Stopwatch();
75      sw.Start();
76      while (true) {
77
78        // must make sure that calling the start-symbol multiple times in the fitness function always leads to the same path through the grammar
79        // so we use a PRNG for generating seeds for a separate PRNG that is reset each time the start symbol is called
80       
81        _state = new SolverState(seedRandom.Next());
82
83        var f = Calculate();
84
85        n++;
86        if (IsBetter(f, bestF)) {
87          bestF = f;
88          Console.WriteLine(""{0}\t{1}\t(depth={2})"", n, bestF, _state.maxDepth);
89        }
90        if (n % 1000 == 0) {
91          sw.Stop();
92          Console.WriteLine(""{0}\t{1}\t{2}\t({3:0.00} sols/ms)"", n, bestF, f, 1000.0 / sw.ElapsedMilliseconds);
93          sw.Reset();
94          sw.Start();
95        }
96      }
97    }
98
99    public double Calculate() {
100      ?FITNESSFUNCTION?
101    }
102
103    ?ADDITIONALCODE?
104
105    ?INTERPRETERSOURCE?
106
107    ?CONSTRAINTSSOURCE?
108  }
109}";
110
111
112    /// <summary>
113    /// Generates the source code for a brute force searcher that can be compiled with a C# compiler
114    /// </summary>
115    /// <param name="ast">An abstract syntax tree for a GPDL file</param>
116    public void Generate(GPDefNode ast) {
117      var problemSourceCode = new StringBuilder();
118      problemSourceCode.AppendLine(usings);
119
120      GenerateProblem(ast, problemSourceCode);
121
122      problemSourceCode.Replace("?PROBLEMNAME?", ast.Name);
123
124      // write the source file to disk
125      using (var stream = new StreamWriter(ast.Name + ".cs")) {
126        stream.WriteLine(problemSourceCode.ToString());
127      }
128    }
129
130    private void GenerateProblem(GPDefNode ast, StringBuilder problemSourceCode) {
131      var grammar = CreateGrammarFromAst(ast);
132      var problemClassCode =
133        solverTemplate
134          .Replace("?MAXIMIZATION?", ast.FitnessFunctionNode.Maximization.ToString().ToLowerInvariant())
135          .Replace("?IDENT?", ast.Name)
136          .Replace("?FITNESSFUNCTION?", ast.FitnessFunctionNode.SrcCode)
137          .Replace("?INTERPRETERSOURCE?", GenerateInterpreterSource(grammar))
138          .Replace("?INITCODE?", ast.InitCodeNode.SrcCode)
139          .Replace("?ADDITIONALCODE?", ast.ClassCodeNode.SrcCode)
140          .Replace("?CONSTRAINTSSOURCE?", GenerateConstraintMethods(ast.Terminals))
141          ;
142
143      problemSourceCode.AppendLine(problemClassCode).AppendLine();
144    }
145
146    #region create grammar instance from AST
147    private AttributedGrammar CreateGrammarFromAst(GPDefNode ast) {
148
149      var nonTerminals = ast.NonTerminals.Select(t => new Symbol(t.Ident, GetSymbolAttributes(t.FormalParameters))).ToArray();
150      var terminals = ast.Terminals.Select(t => new Symbol(t.Ident, GetSymbolAttributes(t.FormalParameters))).ToArray();
151      string startSymbolName = ast.Rules.First().NtSymbol;
152
153      // create startSymbol
154      var startSymbol = nonTerminals.Single(s => s.Name == startSymbolName);
155      var g = new AttributedGrammar(startSymbol, nonTerminals, terminals);
156
157      // add all production rules
158      foreach (var rule in ast.Rules) {
159        var ntSymbol = nonTerminals.Single(s => s.Name == rule.NtSymbol);
160        foreach (var alt in GetAlternatives(rule.Alternatives, nonTerminals.Concat(terminals))) {
161          g.AddProductionRule(ntSymbol, alt);
162        }
163        // local initialization code
164        if (!string.IsNullOrEmpty(rule.LocalCode)) g.AddLocalDefinitions(ntSymbol, rule.LocalCode);
165      }
166      return g;
167    }
168
169    private IEnumerable<IAttribute> GetSymbolAttributes(string formalParameters) {
170      return (from fieldDef in Util.ExtractParameters(formalParameters)
171              select new Attribute(fieldDef.Identifier, fieldDef.Type, AttributeType.Parse(fieldDef.RefOrOut))).
172        ToList();
173    }
174
175    private IEnumerable<Sequence> GetAlternatives(AlternativesNode altNode, IEnumerable<ISymbol> allSymbols) {
176      foreach (var alt in altNode.Alternatives) {
177        yield return GetSequence(alt.Sequence, allSymbols);
178      }
179    }
180
181    private Sequence GetSequence(IEnumerable<RuleExprNode> sequence, IEnumerable<ISymbol> allSymbols) {
182      Debug.Assert(sequence.All(s => s is CallSymbolNode || s is RuleActionNode));
183      var l = new List<ISymbol>();
184      foreach (var node in sequence) {
185        var callSymbolNode = node as CallSymbolNode;
186        var actionNode = node as RuleActionNode;
187        if (callSymbolNode != null) {
188          Debug.Assert(allSymbols.Any(s => s.Name == callSymbolNode.Ident));
189          // create a new symbol with actual parameters
190          l.Add(new Symbol(callSymbolNode.Ident, GetSymbolAttributes(callSymbolNode.ActualParameter)));
191        } else if (actionNode != null) {
192          l.Add(new SemanticSymbol("SEM", actionNode.SrcCode));
193        }
194      }
195      return new Sequence(l);
196    }
197    #endregion
198
199    #region helper methods for terminal symbols
200    // produces helper methods for the attributes of all terminal nodes
201    private string GenerateConstraintMethods(List<SymbolNode> symbols) {
202      var sb = new StringBuilder();
203      var terminals = symbols.OfType<TerminalNode>();
204      foreach (var t in terminals) {
205        sb.AppendLine(GenerateConstraintMethods(t));
206      }
207      return sb.ToString();
208    }
209
210    // generates helper methods for the attributes of a given terminal node
211    private string GenerateConstraintMethods(TerminalNode t) {
212      var sb = new StringBuilder();
213      foreach (var c in t.Constraints) {
214        var fieldType = t.FieldDefinitions.First(d => d.Identifier == c.Ident).Type;
215        if (c.Type == ConstraintNodeType.Range) {
216          sb.AppendFormat("public {0} GetMax{1}_{2}() {{ return {3}; }}", fieldType, t.Ident, c.Ident, c.RangeMaxExpression).AppendLine();
217          sb.AppendFormat("public {0} GetMin{1}_{2}() {{ return {3}; }}", fieldType, t.Ident, c.Ident, c.RangeMinExpression).AppendLine();
218          sb.AppendFormat("public {0} GetRandom{1}_{2}(SolverState _state) {{ return _state.random.NextDouble() * (GetMax{1}_{2}() - GetMin{1}_{2}()) + GetMin{1}_{2}(); }}", fieldType, t.Ident, c.Ident).AppendLine();
219        } else if (c.Type == ConstraintNodeType.Set) {
220          sb.AppendFormat("public IEnumerable<{0}> GetAllowed{1}_{2}() {{ return {3}; }}", fieldType, t.Ident, c.Ident, c.SetExpression).AppendLine();
221          sb.AppendFormat("public {0} GetRandom{1}_{2}(SolverState _state) {{ var tmp = GetAllowed{1}_{2}().ToArray(); return tmp[_state.random.Next(tmp.Length)]; }}", fieldType, t.Ident, c.Ident).AppendLine();
222        }
223      }
224      return sb.ToString();
225    }
226    #endregion
227
228    private string GenerateInterpreterSource(AttributedGrammar grammar) {
229      var sb = new StringBuilder();
230
231      // generate methods for all nonterminals and terminals using the grammar instance
232      foreach (var s in grammar.NonTerminalSymbols) {
233        sb.AppendLine(GenerateInterpreterMethod(grammar, s));
234      }
235      foreach (var s in grammar.TerminalSymbols) {
236        sb.AppendLine(GenerateTerminalInterpreterMethod(s));
237      }
238      return sb.ToString();
239    }
240
241    private string GenerateInterpreterMethod(AttributedGrammar g, ISymbol s) {
242      var sb = new StringBuilder();
243
244      // if this is the start symbol we additionally have to create the method which can be called from the fitness function
245      if (g.StartSymbol.Equals(s)) {
246        if (!s.Attributes.Any())
247          sb.AppendFormat("private void {0}() {{", s.Name);
248        else
249          sb.AppendFormat("private void {0}({1}) {{", s.Name, s.GetAttributeString());
250
251        // get formal parameters of start symbol
252        var attr = g.StartSymbol.Attributes;
253
254        // actual parameter are the same as formalparameter only without type identifier
255        string actualParameter;
256        if (attr.Any())
257          actualParameter = attr.Skip(1).Aggregate(attr.First().AttributeType + " " + attr.First().Name, (str, a) => str + ", " + a.AttributeType + " " + a.Name);
258        else
259          actualParameter = string.Empty;
260
261        sb.AppendFormat("{0}(_state.Prepare(), {1});", g.StartSymbol.Name, actualParameter).AppendLine();
262        sb.AppendLine("}");
263      }
264
265      if (!s.Attributes.Any())
266        sb.AppendFormat("private void {0}(SolverState _state) {{", s.Name);
267      else
268        sb.AppendFormat("private void {0}(SolverState _state, {1}) {{", s.Name, s.GetAttributeString());
269
270      // generate local definitions
271      sb.AppendLine(g.GetLocalDefinitions(s));
272
273
274      var altsWithSemActions = g.GetAlternativesWithSemanticActions(s).ToArray();
275      var terminalAlts = altsWithSemActions.Where(alt => alt.Count(g.IsNonTerminal) == 0);
276      var nonTerminalAlts = altsWithSemActions.Where(alt => alt.Count(g.IsNonTerminal) > 0);
277      bool hasTerminalAlts = terminalAlts.Any();
278      bool hasNonTerminalAlts = nonTerminalAlts.Any();
279
280      if (altsWithSemActions.Length > 1) {
281        // here we need to bias the selection of alternatives (non-terminal vs terminal alternatives) to make sure that
282        // terminals are selected with a certain probability to make sure that:
283        // 1) we don't create the same small trees all the time (terminals have high probability to be selected)
284        // 2) we don't create very big trees by recursing to deep (leads to stack-overflow) (terminals have a low probability to be selected)
285        // so we first decide if we want to generate a terminal or non-terminal (50%, 50%) and then choose a symbol in the class randomly
286        // the probability of choosing terminals should depend on the depth of the tree (small likelihood to choose terminals for small depths, large likelihood for large depths)
287        if (hasTerminalAlts && hasNonTerminalAlts) {
288          sb.AppendLine("if(_state.random.NextDouble() < TerminalProbForDepth(_state.depth)) {");
289          // terminals
290          sb.AppendLine("// terminals ");
291          GenerateSwitchStatement(sb, terminalAlts);
292          sb.AppendLine("} else {");
293          // non-terminals
294          sb.AppendLine("// non-terminals ");
295          GenerateSwitchStatement(sb, nonTerminalAlts);
296          sb.AppendLine("}");
297        } else if (hasTerminalAlts) {
298          sb.AppendLine("// terminals ");
299          GenerateSwitchStatement(sb, terminalAlts);
300        } else if (hasNonTerminalAlts) {
301          sb.AppendLine("// non-terminals ");
302          GenerateSwitchStatement(sb, nonTerminalAlts);
303        }
304      } else {
305        foreach (var altSymb in altsWithSemActions.Single()) {
306          sb.AppendLine(GenerateSourceForAction(altSymb));
307        }
308      }
309      sb.AppendLine("}");
310      return sb.ToString();
311    }
312
313    private void GenerateSwitchStatement(StringBuilder sb, IEnumerable<Sequence> alts) {
314      sb.AppendFormat("switch(_state.random.Next({0})) {{", alts.Count());
315      // generate a case for each alternative
316      int i = 0;
317      foreach (var alt in alts) {
318        sb.AppendFormat("case {0}: {{ ", i).AppendLine();
319
320        // this only works for alternatives with a single non-terminal symbol (ignoring semantic symbols) so far!
321        // a way to handle this is through grammar transformation (the examplary grammars all have the correct from)
322        Debug.Assert(alt.Count(symb => !(symb is SemanticSymbol)) == 1);
323        foreach (var altSymb in alt) {
324          sb.AppendLine(GenerateSourceForAction(altSymb));
325        }
326        i++;
327        sb.AppendLine("break;").AppendLine("}");
328      }
329      sb.AppendLine("default: throw new System.InvalidOperationException();").AppendLine("}");
330
331    }
332
333    // helper for generating calls to other symbol methods
334    private string GenerateSourceForAction(ISymbol s) {
335      var action = s as SemanticSymbol;
336      if (action != null) {
337        return action.Code + ";";
338      } else {
339        if (!s.Attributes.Any())
340          return string.Format("{0}(_state.IncDepth()); _state.depth--;", s.Name);
341        else
342          return string.Format("{0}(_state.IncDepth(), {1}); _state.depth--;", s.Name, s.GetAttributeString());
343      }
344    }
345
346    private string GenerateTerminalInterpreterMethod(ISymbol s) {
347      var sb = new StringBuilder();
348      // if the terminal symbol has attributes then we must samples values for these attributes
349      if (!s.Attributes.Any())
350        sb.AppendFormat("private void {0}(SolverState _state) {{", s.Name);
351      else
352        sb.AppendFormat("private void {0}(SolverState _state, {1}) {{", s.Name, s.GetAttributeString());
353
354      // each field must match a formal parameter, assign a value for each parameter
355      foreach (var element in s.Attributes) {
356        sb.AppendFormat("{{ {0} = GetRandom{1}_{0}(_state); }} ", element.Name, s.Name);
357      }
358      sb.AppendLine("}");
359      return sb.ToString();
360    }
361  }
362}
Note: See TracBrowser for help on using the repository browser.