Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GPDL/CodeGenerator/RandomSearchCodeGen.cs @ 10424

Last change on this file since 10424 was 10424, checked in by gkronber, 10 years ago

#2026 maximal depth limit for random search

File size: 12.6 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using HeuristicLab.Grammars;
7
8namespace CodeGenerator {
9  public class RandomSearchCodeGen {
10
11    private string solverTemplate = @"
12namespace ?PROBLEMNAME? {
13  public sealed class ?IDENT?Solver {
14    private static double baseTerminalProbability = 0.05; // 5% of all samples are only a terminal node
15    private static double terminalProbabilityInc = 0.05; // for each level the probability to sample a terminal grows by 5%
16    private static int maxDepth = 20;
17
18    private readonly ?IDENT?Problem problem;
19    private readonly Random random;
20
21    private Tree SampleTree(int maxDepth, out int steps, out int depth) {
22      steps = 0;
23      depth = 0;
24      int curDepth = 0;
25      return SampleTree(0, maxDepth, ref steps, ref curDepth, ref depth);
26    }
27
28    private Tree SampleTree(int state, int maxDepth, ref int steps, ref int curDepth, ref int depth) {
29      curDepth += 1;
30      Debug.Assert(maxDepth > 0);
31      steps += 1;
32      depth = Math.Max(depth, curDepth);
33      Tree t = null;
34
35      // terminals
36      if(Grammar.subtreeCount[state] == 0) {
37        t = CreateTerminalNode(state, random, problem);
38      } else {
39        // if the symbol has alternatives then we must choose one randomly (only one sub-tree in this case)
40        if(Grammar.subtreeCount[state] == 1) {
41          var targetStates = Grammar.transition[state];
42          var altIdx = SampleAlternative(random, maxDepth - 1, state, curDepth);
43          var alternative = SampleTree(targetStates[altIdx], maxDepth - 1, ref steps, ref curDepth, ref depth);
44          t = new Tree(altIdx, new Tree[] { alternative });
45        } else {
46          // if the symbol contains only one sequence we must use create sub-trees for each symbol in the sequence
47          Tree[] subtrees = new Tree[Grammar.subtreeCount[state]];
48          for(int i = 0; i < Grammar.subtreeCount[state]; i++) {
49            subtrees[i] = SampleTree(Grammar.transition[state][i], maxDepth - 1, ref steps, ref curDepth, ref depth);
50          }
51          t = new Tree(-1, subtrees); // alternative index is ignored
52        }
53      }
54      curDepth -=1;
55      return t;
56    }
57
58    private static Tree CreateTerminalNode(int state, Random random, ?IDENT?Problem problem) {
59      switch(state) {
60        ?CREATETERMINALNODECODE?
61        default: { throw new ArgumentException(""Unknown state index "" + state); }
62      }
63    }
64
65    private int SampleAlternative(Random random, int maxDepth, int state, int depth) {
66      switch(state) {
67
68?SAMPLEALTERNATIVECODE?
69
70        default: throw new InvalidOperationException();
71      }
72    }
73
74    private double TerminalProbForDepth(int depth) {
75      if(depth>=maxDepth) return 1.0;
76      return baseTerminalProbability + depth * terminalProbabilityInc;
77    }
78
79    public static void Main(string[] args) {
80      if(args.Length >= 1) ParseArguments(args);
81
82      var problem = new ?IDENT?Problem();
83      var solver = new ?IDENT?Solver(problem);
84      solver.Start();
85    }
86    private static void ParseArguments(string[] args) {
87      var baseTerminalProbabilityRegex = new Regex(@""--terminalProbBase=(?<prob>.+)"");
88      var terminalProbabilityIncRegex = new Regex(@""--terminalProbInc=(?<prob>.+)"");
89      var maxDepthRegex = new Regex(@""--maxDepth=(?<d>.+)"");
90
91      var helpRegex = new Regex(@""--help|/\?"");
92     
93      foreach(var arg in args) {
94        var baseTerminalProbabilityMatch = baseTerminalProbabilityRegex.Match(arg);
95        var terminalProbabilityIncMatch = terminalProbabilityIncRegex.Match(arg);
96        var maxDepthMatch = maxDepthRegex.Match(arg);
97        var helpMatch = helpRegex.Match(arg);
98        if(helpMatch.Success) { PrintUsage(); Environment.Exit(0); }
99        else if(baseTerminalProbabilityMatch.Success) {
100          baseTerminalProbability = double.Parse(baseTerminalProbabilityMatch.Groups[""prob""].Captures[0].Value, System.Globalization.CultureInfo.InvariantCulture);
101          if(baseTerminalProbability < 0.0 || baseTerminalProbability > 1.0) throw new ArgumentException(""base terminal probability must lie in range [0.0 ... 1.0]"");
102        } else if(terminalProbabilityIncMatch.Success) {
103           terminalProbabilityInc = double.Parse(terminalProbabilityIncMatch.Groups[""prob""].Captures[0].Value, System.Globalization.CultureInfo.InvariantCulture);
104           if(terminalProbabilityInc < 0.0 || terminalProbabilityInc > 1.0) throw new ArgumentException(""terminal probability increment must lie in range [0.0 ... 1.0]"");
105        } else if(maxDepthMatch.Success) {
106           maxDepth = int.Parse(maxDepthMatch.Groups[""d""].Captures[0].Value, System.Globalization.CultureInfo.InvariantCulture);
107           if(maxDepth < 1 || maxDepth > 100) throw new ArgumentException(""max depth must lie in range [1 ... 100]"");
108        } else {
109           Console.WriteLine(""Unknown switch {0}"", arg); PrintUsage(); Environment.Exit(0);
110        }
111      }
112    }
113    private static void PrintUsage() {
114      Console.WriteLine(""Find a solution using random tree search."");
115      Console.WriteLine();
116      Console.WriteLine(""Parameters:"");
117      Console.WriteLine(""\t--maxDepth=<depth>\tSets the maximal depth of sampled trees [Default: 20]"");
118      Console.WriteLine(""\t--terminalProbBase=<prob>\tSets the probability of sampling a terminal alternative in a rule [Default: 0.05]"");
119      Console.WriteLine(""\t--terminalProbInc=<prob>\tSets the increment for the probability of sampling a terminal alternative for each level in the syntax tree [Default: 0.05]"");
120    }
121
122
123    public ?IDENT?Solver(?IDENT?Problem problem) {
124      this.problem = problem;
125      this.random = new Random();
126    }
127
128    private void Start() {
129      var bestF = ?MAXIMIZATION? ? double.NegativeInfinity : double.PositiveInfinity;
130      int n = 0;
131      long sumDepth = 0;
132      long sumSize = 0;
133      var sumF = 0.0;
134      var sw = new System.Diagnostics.Stopwatch();
135      sw.Start();
136      while (true) {
137
138        int steps, depth;
139        var _t = SampleTree(maxDepth, out steps, out depth);
140        Debug.Assert(depth <= maxDepth);
141        // _t.PrintTree(0); Console.WriteLine();
142        var f = problem.Evaluate(_t);
143 
144        n++;   
145        sumSize += steps;
146        sumDepth += depth;
147        sumF += f;
148        if (problem.IsBetter(f, bestF)) {
149          bestF = f;
150          _t.PrintTree(0); Console.WriteLine();
151          Console.WriteLine(""{0}\t{1}\t(size={2}, depth={3})"", n, bestF, steps, depth);
152        }
153        if (n % 1000 == 0) {
154          sw.Stop();
155          Console.WriteLine(""{0}\tbest: {1:0.000}\t(avg: {2:0.000})\t(avg size: {3:0.0})\t(avg. depth: {4:0.0})\t({5:0.00} sols/ms)"", n, bestF, sumF/1000.0, sumSize/1000.0, sumDepth/1000.0, 1000.0 / sw.ElapsedMilliseconds);
156          sumSize = 0;
157          sumDepth = 0;
158          sumF = 0.0;
159          sw.Restart();
160        }
161      }
162    }
163  }
164}";
165
166    public void Generate(IGrammar grammar, IEnumerable<TerminalNode> terminals, bool maximization, SourceBuilder problemSourceCode) {
167      var solverSourceCode = new SourceBuilder();
168      solverSourceCode.Append(solverTemplate)
169        .Replace("?MAXIMIZATION?", maximization.ToString().ToLowerInvariant())
170        .Replace("?SAMPLEALTERNATIVECODE?", GenerateSampleAlternativeSource(grammar))
171        .Replace("?CREATETERMINALNODECODE?", GenerateCreateTerminalCode(grammar, terminals))
172      ;
173
174      problemSourceCode.Append(solverSourceCode.ToString());
175    }
176
177
178
179    private string GenerateSampleAlternativeSource(IGrammar grammar) {
180      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
181      var sb = new SourceBuilder();
182      int stateCount = 0;
183      foreach (var s in grammar.Symbols) {
184        sb.AppendFormat("case {0}: ", stateCount++);
185        if (grammar.IsTerminal(s)) {
186          // ignore
187        } else {
188          var terminalAltIndexes = grammar.GetAlternatives(s)
189            .Select((alt, idx) => new { alt, idx })
190            .Where((p) => p.alt.All(symb => grammar.IsTerminal(symb)))
191            .Select(p => p.idx);
192          var nonTerminalAltIndexes = grammar.GetAlternatives(s)
193            .Select((alt, idx) => new { alt, idx })
194            .Where((p) => p.alt.Any(symb => grammar.IsNonTerminal(symb)))
195            .Select(p => p.idx);
196          var hasTerminalAlts = terminalAltIndexes.Any();
197          var hasNonTerminalAlts = nonTerminalAltIndexes.Any();
198          if (hasTerminalAlts && hasNonTerminalAlts) {
199            sb.Append("if(maxDepth <= 1 || random.NextDouble() < TerminalProbForDepth(depth)) {").BeginBlock();
200            GenerateSampleTerminalAlternativesStatement(terminalAltIndexes, sb);
201            sb.Append("} else {");
202            GenerateSampleNonterminalAlternativesStatement(nonTerminalAltIndexes, sb);
203            sb.Append("}").EndBlock();
204          } else {
205            GenerateReturnStatement(grammar.NumberOfAlternatives(s), sb);
206          }
207        }
208      }
209      return sb.ToString();
210    }
211    private string GenerateCreateTerminalCode(IGrammar grammar, IEnumerable<TerminalNode> terminals) {
212      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
213      var sb = new SourceBuilder();
214      var allSymbols = grammar.Symbols.ToList();
215      foreach (var s in grammar.Symbols) {
216        if (grammar.IsTerminal(s)) {
217          sb.AppendFormat("case {0}: {{", allSymbols.IndexOf(s)).BeginBlock();
218          sb.AppendFormat("var t = new {0}Tree();", s.Name).AppendLine();
219          var terminal = terminals.Single(t => t.Ident == s.Name);
220          foreach (var constr in terminal.Constraints) {
221            if (constr.Type == ConstraintNodeType.Set) {
222              sb.Append("{").BeginBlock();
223              sb.AppendFormat("var elements = problem.GetAllowed{0}_{1}().ToArray();", terminal.Ident, constr.Ident).AppendLine();
224              sb.AppendFormat("t.{0} = elements[random.Next(elements.Length)]; ", constr.Ident).EndBlock();
225              sb.AppendLine("}");
226            } else {
227              sb.Append("{").BeginBlock();
228              sb.AppendFormat(" var max = problem.GetMax{0}_{1}();", terminal.Ident, constr.Ident).AppendLine();
229              sb.AppendFormat(" var min = problem.GetMin{0}_{1}();", terminal.Ident, constr.Ident).AppendLine();
230              sb.AppendFormat("t.{0} = random.NextDouble() * (max - min) + min;", constr.Ident).EndBlock();
231              sb.AppendLine("}");
232            }
233          }
234          sb.AppendLine("return t;").EndBlock();
235          sb.Append("}");
236        }
237      }
238      return sb.ToString();
239    }
240    private void GenerateSampleTerminalAlternativesStatement(IEnumerable<int> idxs, SourceBuilder sb) {
241      if (idxs.Count() == 1) {
242        sb.AppendFormat("return {0};", idxs.Single()).AppendLine();
243      } else {
244        var idxStr = idxs.Aggregate(string.Empty, (str, idx) => str + idx + ", ");
245        sb.AppendFormat("return new int[] {{ {0} }}[random.Next({1})]; ", idxStr, idxs.Count()).AppendLine();
246      }
247    }
248    private void GenerateSampleNonterminalAlternativesStatement(IEnumerable<int> idxs, SourceBuilder sb) {
249      if (idxs.Count() == 1) {
250        sb.AppendFormat("return {0};", idxs.Single()).AppendLine();
251      } else {
252        var idxStr = idxs.Aggregate(string.Empty, (str, idx) => str + idx + ", ");
253        sb.AppendLine("{");
254        sb.AppendFormat("var allIdx = new int[] {{ {0} }}; ", idxStr).AppendLine();
255        sb.AppendFormat(
256          "var allowedIdx = (from idx in allIdx let targetState = Grammar.transition[state][idx] where Grammar.minDepth[targetState] <= maxDepth select idx).ToArray();")
257          .AppendLine();
258        sb.AppendLine(
259          "if(allowedIdx.Length==0) { allowedIdx = Enumerable.Range(0, Grammar.transition[state].Length).Except(allIdx).ToArray(); } ")
260          .AppendLine();
261        sb.AppendLine("return allowedIdx[random.Next(allowedIdx.Length)];");
262        sb.AppendLine("}");
263      }
264    }
265
266    private void GenerateReturnStatement(int nAlts, SourceBuilder sb) {
267      if (nAlts > 1)
268      {
269        sb.AppendLine("{");
270        sb.AppendFormat(
271          "var allowedIdx = (from idx in Enumerable.Range(0, {0}) let targetState = Grammar.transition[state][idx] where Grammar.minDepth[targetState] <= maxDepth select idx).ToArray();", nAlts)
272          .AppendLine();
273        sb.AppendLine("return allowedIdx[random.Next(allowedIdx.Length)];");
274        sb.AppendLine("}");
275      } else if (nAlts == 1) {
276        sb.AppendLine("return 0; ");
277      } else {
278        sb.AppendLine("throw new InvalidProgramException();");
279      }
280    }
281  }
282}
Note: See TracBrowser for help on using the repository browser.