Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GPDL/CodeGenerator/RandomSearchCodeGen.cs @ 10384

Last change on this file since 10384 was 10384, checked in by gkronber, 10 years ago

#2026 fixed random search code generation

File size: 13.9 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using HeuristicLab.Grammars;
7
8namespace CodeGenerator {
9  public class RandomSearchCodeGen {
10
11    private string solverTemplate = @"
12namespace ?PROBLEMNAME? {
13  public sealed class ?IDENT?Solver {
14    private static double baseTerminalProbability = 0.05; // 5% of all samples are only a terminal node
15    private static double terminalProbabilityInc = 0.05; // for each level the probability to sample a terminal grows by 5%
16
17    private static Dictionary<int, int[]> transition = new Dictionary<int, int[]>() {
18?TRANSITIONTABLE?
19    };
20    private static Dictionary<int, int> subtreeCount = new Dictionary<int, int>() {
21       { -1, 0 }, // terminals
22?SUBTREECOUNTTABLE?
23    };
24    private static string[] symb = new string[] { ?SYMBOLNAMES? };
25
26   
27    public sealed class SolverState  {
28      public int curDepth;
29      public int steps;
30      public int depth;
31      // private readonly Stack<Tree> nodes;
32      private readonly ?IDENT?Problem problem;
33      private Random random;
34
35      public SolverState(?IDENT?Problem problem, int seed) {
36        this.problem = problem;
37        // this.nodes = new Stack<Tree>();
38       
39        this.random = new Random(seed);
40        // nodes.Push(tree);
41      }
42
43      // public void Reset() {
44      //   // stack must contain only the root of the tree
45      //   System.Diagnostics.Debug.Assert(nodes.Count == 1);
46      // }
47
48      public Tree SampleTree() {
49        return SampleTree(0);
50      }
51
52      private Tree SampleTree(int state) {
53        // Console.Write(state + "" "" );
54        curDepth += 1;
55        steps += 1;
56        depth = Math.Max(depth, curDepth);
57        Tree t = null;
58
59        // terminals
60        if(subtreeCount[state] == 0) {
61          t = CreateTerminalNode(state);
62        } else {
63         
64          t = new Tree(state, -1, subtreeCount[state]);
65          if(subtreeCount[state] < 1) throw new ArgumentException();
66          // if the symbol has alternatives then we must choose one randomly (only one sub-tree in this case)
67          if(subtreeCount[state] == 1) {
68            var targetStates = transition[state];
69            var i = SampleAlternative(random, state);
70            t.altIdx = i;
71            t.subtrees.Add(SampleTree(targetStates[i]));
72          } else {
73            // if the symbol contains only one sequence we must use create sub-trees for each symbol in the sequence
74            for(int i = 0; i < subtreeCount[state]; i++) {
75              t.subtrees.Add(SampleTree(transition[state][i]));
76            }
77          }
78        }
79        curDepth -=1;
80        return t;
81      }
82
83      private Tree CreateTerminalNode(int state) {
84        switch(state) {
85          ?CREATETERMINALNODECODE?
86          default: { throw new ArgumentException(""Unknown state index"" + state); }
87        }
88      }
89
90      // public int PeekNextAlternative() {
91      //   // this must only be called nodes that contain alternatives and therefore must only have single-symbols alternatives
92      //   System.Diagnostics.Debug.Assert(nodes.Peek().subtrees.Count == 1);
93      //   return nodes.Peek().subtrees[0].altIdx;
94      // }
95      //
96      // public void Follow(int idx) {
97      //   nodes.Push(nodes.Peek().subtrees[idx]);
98      // }
99      //
100      // public void Unwind() {
101      //   nodes.Pop();
102      // }
103
104      private int SampleAlternative(Random random, int state) {
105        switch(state) {
106
107?SAMPLEALTERNATIVECODE?
108
109          default: throw new InvalidOperationException();
110        }
111      }
112
113      private double TerminalProbForDepth(int depth) {
114        return baseTerminalProbability + depth * terminalProbabilityInc;
115      }
116    }
117
118    public static void Main(string[] args) {
119      if(args.Length >= 1) ParseArguments(args);
120
121      var problem = new ?IDENT?Problem();
122      var solver = new ?IDENT?Solver(problem);
123      solver.Start();
124    }
125    private static void ParseArguments(string[] args) {
126      var baseTerminalProbabilityRegex = new Regex(@""--terminalProbBase=(?<prob>.+)"");
127      var terminalProbabilityIncRegex = new Regex(@""--terminalProbInc=(?<prob>.+)"");
128      var helpRegex = new Regex(@""--help|/\?"");
129     
130      foreach(var arg in args) {
131        var baseTerminalProbabilityMatch = baseTerminalProbabilityRegex.Match(arg);
132        var terminalProbabilityIncMatch = terminalProbabilityIncRegex.Match(arg);
133        var helpMatch = helpRegex.Match(arg);
134        if(helpMatch.Success) { PrintUsage(); Environment.Exit(0); }
135        else if(baseTerminalProbabilityMatch.Success) {
136          baseTerminalProbability = double.Parse(baseTerminalProbabilityMatch.Groups[""prob""].Captures[0].Value, System.Globalization.CultureInfo.InvariantCulture);
137          if(baseTerminalProbability < 0.0 || baseTerminalProbability > 1.0) throw new ArgumentException(""base terminal probability must lie in range [0.0 ... 1.0]"");
138        } else if(terminalProbabilityIncMatch.Success) {
139           terminalProbabilityInc = double.Parse(terminalProbabilityIncMatch.Groups[""prob""].Captures[0].Value, System.Globalization.CultureInfo.InvariantCulture);
140           if(terminalProbabilityInc < 0.0 || terminalProbabilityInc > 1.0) throw new ArgumentException(""terminal probability increment must lie in range [0.0 ... 1.0]"");
141        } else {
142           Console.WriteLine(""Unknown switch {0}"", arg); PrintUsage(); Environment.Exit(0);
143        }
144      }
145    }
146    private static void PrintUsage() {
147      Console.WriteLine(""Find a solution using random tree search."");
148      Console.WriteLine();
149      Console.WriteLine(""Parameters:"");
150      Console.WriteLine(""\t--terminalProbBase=<prob>\tSets the probability of sampling a terminal alternative in a rule [Default: 0.05]"");
151      Console.WriteLine(""\t--terminalProbInc=<prob>\tSets the increment for the probability of sampling a terminal alternative for each level in the syntax tree [Default: 0.05]"");
152    }
153
154
155    private readonly ?IDENT?Problem problem;
156    public ?IDENT?Solver(?IDENT?Problem problem) {
157      this.problem = problem;
158    }
159
160    private void Start() {
161      var seedRandom = new Random();
162      var bestF = ?MAXIMIZATION? ? double.NegativeInfinity : double.PositiveInfinity;
163      int n = 0;
164      long sumDepth = 0;
165      long sumSize = 0;
166      var sumF = 0.0;
167      var sw = new System.Diagnostics.Stopwatch();
168      sw.Start();
169      while (true) {
170
171        var _state = new SolverState(problem, seedRandom.Next());
172        var _t = _state.SampleTree();
173        var f = problem.Evaluate(_t);
174 
175        n++;   
176        // sumSize += _state.steps;
177        // sumDepth += _state.depth;
178        sumF += f;
179        if (IsBetter(f, bestF)) {
180          bestF = f;
181          Console.WriteLine(""{0}\t{1}\t(size={2}, depth={3})"", n, bestF, 0, 0 /* _state.steps, _state.depth */);
182        }
183        if (n % 1000 == 0) {
184          sw.Stop();
185          Console.WriteLine(""{0}\tbest: {1:0.000}\t(avg: {2:0.000})\t(avg size: {3:0.0})\t(avg. depth: {4:0.0})\t({5:0.00} sols/ms)"", n, bestF, sumF/1000.0, sumSize/1000.0, sumDepth/1000.0, 1000.0 / sw.ElapsedMilliseconds);
186          sw.Reset();
187          sumSize = 0;
188          sumDepth = 0;
189          sumF = 0.0;
190          sw.Start();
191        }
192      }
193    }
194
195    private bool IsBetter(double a, double b) {
196      return ?MAXIMIZATION? ? a > b : a < b;
197    }
198  }
199}";
200
201    public void Generate(IGrammar grammar, bool maximization, SourceBuilder problemSourceCode) {
202      var solverSourceCode = new SourceBuilder();
203      solverSourceCode.Append(solverTemplate)
204        // .Replace("?TERMINALFIELDS?", GenerateTerminalFields(grammar))
205        // .Replace("?CONSTRUCTORCODE?", GenerateConstructorCode(grammar))
206        .Replace("?MAXIMIZATION?", maximization.ToString().ToLowerInvariant())
207        .Replace("?SYMBOLNAMES?", grammar.Symbols.Select(s => s.Name).Aggregate(string.Empty, (str, symb) => str + "\"" + symb + "\", "))
208        .Replace("?TRANSITIONTABLE?", GenerateTransitionTable(grammar))
209        .Replace("?CREATETERMINALNODECODE?", GenerateCreateTerminalCode(grammar))
210        .Replace("?SUBTREECOUNTTABLE?", GenerateSubtreeCountTable(grammar))
211        .Replace("?SAMPLEALTERNATIVECODE?", GenerateSampleAlternativeSource(grammar))
212        //        .Replace("?SAMPLETERMINALCODE?", GenerateSampleTerminalSource(grammar))
213      ;
214
215      problemSourceCode.Append(solverSourceCode.ToString());
216    }
217
218
219    //private string GenerateSampleTerminalSource(IGrammar grammar) {
220    //  StringBuilder sb = new StringBuilder();
221    //  foreach (var t in grammar.TerminalSymbols) {
222    //    sb.AppendFormat("public void {0}(ISolverState _state, {1}) {{", t.Name, t.GetAttributeString()).AppendLine();
223    //    foreach (var att in t.Attributes) {
224    //      // need constraints
225    //      sb.AppendFormat("{0}", att.Name);
226    //    }
227    //    sb.AppendLine("}");
228    //  }
229    //  return sb.ToString();
230    //}
231    private string GenerateCreateTerminalCode(IGrammar grammar) {
232      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
233      var sb = new SourceBuilder();
234      var allSymbols = grammar.Symbols.ToList();
235      foreach (var s in grammar.Symbols) {
236        if (grammar.IsTerminal(s)) {
237          sb.AppendFormat("case {0}: {{ return new {1}Tree(random, problem); }}", allSymbols.IndexOf(s), s.Name).AppendLine();
238        }
239      }
240      return sb.ToString();
241    }
242
243    private string GenerateTransitionTable(IGrammar grammar) {
244      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
245      var sb = new SourceBuilder();
246
247      // state idx = idx of the corresponding symbol in the grammar
248      var allSymbols = grammar.Symbols.ToList();
249      foreach (var s in grammar.Symbols) {
250        var targetStates = new List<int>();
251        if (grammar.IsTerminal(s)) {
252        } else {
253          if (grammar.NumberOfAlternatives(s) > 1) {
254            foreach (var alt in grammar.GetAlternatives(s)) {
255              // only single-symbol alternatives are supported
256              Debug.Assert(alt.Count() == 1);
257              targetStates.Add(allSymbols.IndexOf(alt.Single()));
258            }
259          } else {
260            // rule is a sequence of symbols
261            var seq = grammar.GetAlternatives(s).Single();
262            targetStates.AddRange(seq.Select(symb => allSymbols.IndexOf(symb)));
263          }
264        }
265
266        var targetStateString = targetStates.Aggregate(string.Empty, (str, state) => str + state + ", ");
267
268        var idxOfSourceState = allSymbols.IndexOf(s);
269        sb.AppendFormat("// {0}", s).AppendLine();
270        sb.AppendFormat("{{ {0} , new int[] {{ {1} }} }},", idxOfSourceState, targetStateString).AppendLine();
271      }
272      return sb.ToString();
273    }
274    private string GenerateSubtreeCountTable(IGrammar grammar) {
275      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
276      var sb = new SourceBuilder();
277
278      // state idx = idx of the corresponding symbol in the grammar
279      var allSymbols = grammar.Symbols.ToList();
280      foreach (var s in grammar.Symbols) {
281        int subtreeCount = 0;
282        if (grammar.IsTerminal(s)) {
283        } else {
284          if (grammar.NumberOfAlternatives(s) > 1) {
285            Debug.Assert(grammar.GetAlternatives(s).All(alt => alt.Count() == 1));
286            subtreeCount = 1;
287          } else {
288            subtreeCount = grammar.GetAlternative(s, 0).Count();
289          }
290        }
291
292        sb.AppendFormat("// {0}", s).AppendLine();
293        sb.AppendFormat("{{ {0} , {1} }},", allSymbols.IndexOf(s), subtreeCount).AppendLine();
294      }
295
296      return sb.ToString();
297    }
298
299    private string GenerateSampleAlternativeSource(IGrammar grammar) {
300      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
301      var sb = new SourceBuilder();
302      int stateCount = 0;
303      foreach (var s in grammar.Symbols) {
304        sb.AppendFormat("case {0}: ", stateCount++);
305        if (grammar.IsTerminal(s)) {
306          // ignore
307        } else {
308          var terminalAltIndexes = grammar.GetAlternatives(s)
309            .Select((alt, idx) => new { alt, idx })
310            .Where((p) => p.alt.All(symb => grammar.IsTerminal(symb)))
311            .Select(p => p.idx);
312          var nonTerminalAltIndexes = grammar.GetAlternatives(s)
313            .Select((alt, idx) => new { alt, idx })
314            .Where((p) => p.alt.Any(symb => grammar.IsNonTerminal(symb)))
315            .Select(p => p.idx);
316          var hasTerminalAlts = terminalAltIndexes.Any();
317          var hasNonTerminalAlts = nonTerminalAltIndexes.Any();
318          if (hasTerminalAlts && hasNonTerminalAlts) {
319            sb.Append("if(random.NextDouble() < TerminalProbForDepth(depth)) {").BeginBlock();
320            GenerateReturnStatement(terminalAltIndexes, sb);
321            sb.Append("} else {");
322            GenerateReturnStatement(nonTerminalAltIndexes, sb);
323            sb.Append("}").EndBlock();
324          } else {
325            GenerateReturnStatement(grammar.NumberOfAlternatives(s), sb);
326          }
327        }
328      }
329      return sb.ToString();
330    }
331    private void GenerateReturnStatement(IEnumerable<int> idxs, SourceBuilder sb) {
332      if (idxs.Count() == 1) {
333        sb.AppendFormat("return {0};", idxs.Single()).AppendLine();
334      } else {
335        var idxStr = idxs.Aggregate(string.Empty, (str, idx) => str + idx + ", ");
336        sb.AppendFormat("return new int[] {{ {0} }}[random.Next({1})]; ", idxStr, idxs.Count()).AppendLine();
337      }
338    }
339
340    private void GenerateReturnStatement(int nAlts, SourceBuilder sb) {
341      if (nAlts > 1) {
342        sb.AppendFormat("return random.Next({0});", nAlts).AppendLine();
343      } else if (nAlts == 1) {
344        sb.AppendLine("return 0; ");
345      } else {
346        sb.AppendLine("throw new InvalidProgramException();");
347      }
348    }
349  }
350}
Note: See TracBrowser for help on using the repository browser.