# source:branches/HeuristicLab.Problems.GPDL/CodeGenerator/RandomSearchCodeGen.cs@10387

Last change on this file since 10387 was 10387, checked in by gkronber, 7 years ago

#2025 reintegrated initialization of terminal values into the random search solver.

File size: 10.4 KB
Line
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using HeuristicLab.Grammars;
7
8namespace CodeGenerator {
9  public class RandomSearchCodeGen {
10
11    private string solverTemplate = @"
12namespace ?PROBLEMNAME? {
13  public sealed class ?IDENT?Solver {
14    private static double baseTerminalProbability = 0.05; // 5% of all samples are only a terminal node
15    private static double terminalProbabilityInc = 0.05; // for each level the probability to sample a terminal grows by 5%
16
19
20    public Tree SampleTree(out int steps, out int depth) {
21      steps = 0;
22      depth = 0;
23      int curDepth = 0;
24      return SampleTree(0, ref steps, ref curDepth, ref depth);
25    }
26
27    private Tree SampleTree(int state, ref int steps, ref int curDepth, ref int depth) {
28      curDepth += 1;
29      steps += 1;
30      depth = Math.Max(depth, curDepth);
31      Tree t = null;
32
33      // terminals
34      if(Grammar.subtreeCount[state] == 0) {
35        t = CreateTerminalNode(state, random, problem);
36      } else {
37        // if the symbol has alternatives then we must choose one randomly (only one sub-tree in this case)
38        if(Grammar.subtreeCount[state] == 1) {
39          var targetStates = Grammar.transition[state];
40          var altIdx = SampleAlternative(random, state, curDepth);
41          var alternative = SampleTree(targetStates[altIdx], ref steps, ref curDepth, ref depth);
42          t = new Tree(altIdx, new Tree[] { alternative });
43        } else {
44          // if the symbol contains only one sequence we must use create sub-trees for each symbol in the sequence
45          Tree[] subtrees = new Tree[Grammar.subtreeCount[state]];
46          for(int i = 0; i < Grammar.subtreeCount[state]; i++) {
47            subtrees[i] = SampleTree(Grammar.transition[state][i], ref steps, ref curDepth, ref depth);
48          }
49          t = new Tree(-1, subtrees); // alternative index is ignored
50        }
51      }
52      curDepth -=1;
53      return t;
54    }
55
56    public static Tree CreateTerminalNode(int state, Random random, ?IDENT?Problem problem) {
57      switch(state) {
58        ?CREATETERMINALNODECODE?
59        default: { throw new ArgumentException(""Unknown state index"" + state); }
60      }
61    }
62
63    private int SampleAlternative(Random random, int state, int depth) {
64      switch(state) {
65
66?SAMPLEALTERNATIVECODE?
67
68        default: throw new InvalidOperationException();
69      }
70    }
71
72    private double TerminalProbForDepth(int depth) {
73      return baseTerminalProbability + depth * terminalProbabilityInc;
74    }
75
76    public static void Main(string[] args) {
77      if(args.Length >= 1) ParseArguments(args);
78
79      var problem = new ?IDENT?Problem();
80      var solver = new ?IDENT?Solver(problem);
81      solver.Start();
82    }
83    private static void ParseArguments(string[] args) {
84      var baseTerminalProbabilityRegex = new Regex(@""--terminalProbBase=(?<prob>.+)"");
85      var terminalProbabilityIncRegex = new Regex(@""--terminalProbInc=(?<prob>.+)"");
86      var helpRegex = new Regex(@""--help|/\?"");
87
88      foreach(var arg in args) {
89        var baseTerminalProbabilityMatch = baseTerminalProbabilityRegex.Match(arg);
90        var terminalProbabilityIncMatch = terminalProbabilityIncRegex.Match(arg);
91        var helpMatch = helpRegex.Match(arg);
92        if(helpMatch.Success) { PrintUsage(); Environment.Exit(0); }
93        else if(baseTerminalProbabilityMatch.Success) {
94          baseTerminalProbability = double.Parse(baseTerminalProbabilityMatch.Groups[""prob""].Captures[0].Value, System.Globalization.CultureInfo.InvariantCulture);
95          if(baseTerminalProbability < 0.0 || baseTerminalProbability > 1.0) throw new ArgumentException(""base terminal probability must lie in range [0.0 ... 1.0]"");
96        } else if(terminalProbabilityIncMatch.Success) {
97           terminalProbabilityInc = double.Parse(terminalProbabilityIncMatch.Groups[""prob""].Captures[0].Value, System.Globalization.CultureInfo.InvariantCulture);
98           if(terminalProbabilityInc < 0.0 || terminalProbabilityInc > 1.0) throw new ArgumentException(""terminal probability increment must lie in range [0.0 ... 1.0]"");
99        } else {
100           Console.WriteLine(""Unknown switch {0}"", arg); PrintUsage(); Environment.Exit(0);
101        }
102      }
103    }
104    private static void PrintUsage() {
105      Console.WriteLine(""Find a solution using random tree search."");
106      Console.WriteLine();
107      Console.WriteLine(""Parameters:"");
108      Console.WriteLine(""\t--terminalProbBase=<prob>\tSets the probability of sampling a terminal alternative in a rule [Default: 0.05]"");
109      Console.WriteLine(""\t--terminalProbInc=<prob>\tSets the increment for the probability of sampling a terminal alternative for each level in the syntax tree [Default: 0.05]"");
110    }
111
112
113    public ?IDENT?Solver(?IDENT?Problem problem) {
114      this.problem = problem;
115      this.random = new Random();
116    }
117
118    private void Start() {
119      var bestF = ?MAXIMIZATION? ? double.NegativeInfinity : double.PositiveInfinity;
120      int n = 0;
121      long sumDepth = 0;
122      long sumSize = 0;
123      var sumF = 0.0;
124      var sw = new System.Diagnostics.Stopwatch();
125      sw.Start();
126      while (n <= 10000) {
127
128        int steps, depth;
129        var _t = SampleTree(out steps, out depth);
130        var f = problem.Evaluate(_t);
131
132        n++;
133        sumSize += steps;
134        sumDepth += depth;
135        sumF += f;
136        if (problem.IsBetter(f, bestF)) {
137          bestF = f;
138          Console.WriteLine(""{0}\t{1}\t(size={2}, depth={3})"", n, bestF, steps, depth);
139        }
140        if (n % 1000 == 0) {
141          sw.Stop();
142          Console.WriteLine(""{0}\tbest: {1:0.000}\t(avg: {2:0.000})\t(avg size: {3:0.0})\t(avg. depth: {4:0.0})\t({5:0.00} sols/ms)"", n, bestF, sumF/1000.0, sumSize/1000.0, sumDepth/1000.0, 1000.0 / sw.ElapsedMilliseconds);
143          sumSize = 0;
144          sumDepth = 0;
145          sumF = 0.0;
146          sw.Restart();
147        }
148      }
149    }
150  }
151}";
152
153    public void Generate(IGrammar grammar, IEnumerable<TerminalNode> terminals, bool maximization, SourceBuilder problemSourceCode) {
154      var solverSourceCode = new SourceBuilder();
155      solverSourceCode.Append(solverTemplate)
156        .Replace("?MAXIMIZATION?", maximization.ToString().ToLowerInvariant())
157        .Replace("?SAMPLEALTERNATIVECODE?", GenerateSampleAlternativeSource(grammar))
158        .Replace("?CREATETERMINALNODECODE?", GenerateCreateTerminalCode(grammar, terminals))
159      ;
160
161      problemSourceCode.Append(solverSourceCode.ToString());
162    }
163
164
165
166    private string GenerateSampleAlternativeSource(IGrammar grammar) {
167      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
168      var sb = new SourceBuilder();
169      int stateCount = 0;
170      foreach (var s in grammar.Symbols) {
171        sb.AppendFormat("case {0}: ", stateCount++);
172        if (grammar.IsTerminal(s)) {
173          // ignore
174        } else {
175          var terminalAltIndexes = grammar.GetAlternatives(s)
176            .Select((alt, idx) => new { alt, idx })
177            .Where((p) => p.alt.All(symb => grammar.IsTerminal(symb)))
178            .Select(p => p.idx);
179          var nonTerminalAltIndexes = grammar.GetAlternatives(s)
180            .Select((alt, idx) => new { alt, idx })
181            .Where((p) => p.alt.Any(symb => grammar.IsNonTerminal(symb)))
182            .Select(p => p.idx);
183          var hasTerminalAlts = terminalAltIndexes.Any();
184          var hasNonTerminalAlts = nonTerminalAltIndexes.Any();
185          if (hasTerminalAlts && hasNonTerminalAlts) {
186            sb.Append("if(random.NextDouble() < TerminalProbForDepth(depth)) {").BeginBlock();
187            GenerateReturnStatement(terminalAltIndexes, sb);
188            sb.Append("} else {");
189            GenerateReturnStatement(nonTerminalAltIndexes, sb);
190            sb.Append("}").EndBlock();
191          } else {
192            GenerateReturnStatement(grammar.NumberOfAlternatives(s), sb);
193          }
194        }
195      }
196      return sb.ToString();
197    }
198    private string GenerateCreateTerminalCode(IGrammar grammar, IEnumerable<TerminalNode> terminals) {
199      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
200      var sb = new SourceBuilder();
201      var allSymbols = grammar.Symbols.ToList();
202      foreach (var s in grammar.Symbols) {
203        if (grammar.IsTerminal(s)) {
204          sb.AppendFormat("case {0}: {{", allSymbols.IndexOf(s)).BeginBlock();
205          sb.AppendFormat("var t = new {0}Tree();", s.Name).AppendLine();
206          var terminal = terminals.Single(t => t.Ident == s.Name);
207          foreach (var constr in terminal.Constraints) {
208            if (constr.Type == ConstraintNodeType.Set) {
209              sb.Append("{").BeginBlock();
210              sb.AppendFormat("var elements = problem.GetAllowed{0}_{1}().ToArray();", terminal.Ident, constr.Ident).AppendLine();
211              sb.AppendFormat("t.{0} = elements[random.Next(elements.Length)]; ", constr.Ident).EndBlock();
212              sb.AppendLine("}");
213            } else {
214              sb.Append("{").BeginBlock();
215              sb.AppendFormat(" var max = problem.GetMax{0}_{1}();", terminal.Ident, constr.Ident).AppendLine();
216              sb.AppendFormat(" var min = problem.GetMin{0}_{1}();", terminal.Ident, constr.Ident).AppendLine();
217              sb.AppendFormat("t.{0} = random.NextDouble() * (max - min) + min;", constr.Ident).EndBlock();
218              sb.AppendLine("}");
219            }
220          }
221          sb.AppendLine("return t;").EndBlock();
222          sb.Append("}");
223        }
224      }
225      return sb.ToString();
226    }
227    private void GenerateReturnStatement(IEnumerable<int> idxs, SourceBuilder sb) {
228      if (idxs.Count() == 1) {
229        sb.AppendFormat("return {0};", idxs.Single()).AppendLine();
230      } else {
231        var idxStr = idxs.Aggregate(string.Empty, (str, idx) => str + idx + ", ");
232        sb.AppendFormat("return new int[] {{ {0} }}[random.Next({1})]; ", idxStr, idxs.Count()).AppendLine();
233      }
234    }
235
236    private void GenerateReturnStatement(int nAlts, SourceBuilder sb) {
237      if (nAlts > 1) {
238        sb.AppendFormat("return random.Next({0});", nAlts).AppendLine();
239      } else if (nAlts == 1) {
240        sb.AppendLine("return 0; ");
241      } else {
242        sb.AppendLine("throw new InvalidProgramException();");
243      }
244    }
245  }
246}
Note: See TracBrowser for help on using the repository browser.