Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GPDL/CodeGenerator/RandomSearchCodeGen.cs

Last change on this file was 10427, checked in by gkronber, 11 years ago

#2026 integrated max depth into MCTS solver

File size: 12.4 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using HeuristicLab.Grammars;
7
8namespace CodeGenerator {
9  public class RandomSearchCodeGen {
10
11    private string solverTemplate = @"
12namespace ?PROBLEMNAME? {
13  public sealed class ?IDENT?RandomSearchSolver {
14    private double baseTerminalProbability = 0.05; // 5% of all samples are only a terminal node
15    private double terminalProbabilityInc = 0.05; // for each level the probability to sample a terminal grows by 5%
16    private static int maxDepth = 20;
17
18    private readonly ?IDENT?Problem problem;
19    private readonly Random random;
20
21    public Tree SampleTree(int maxDepth, out int steps, out int depth) {
22      steps = 0;
23      depth = 0;
24      int curDepth = 0;
25      return SampleTree(0, maxDepth, ref steps, ref curDepth, ref depth);
26    }
27
28    public Tree SampleTree(int state, int maxDepth, ref int steps, ref int curDepth, ref int depth) {
29      curDepth += 1;
30      Debug.Assert(maxDepth > 0);
31      steps += 1;
32      depth = Math.Max(depth, curDepth);
33      Tree t = null;
34
35      // terminals
36      if(Grammar.subtreeCount[state] == 0) {
37        t = CreateTerminalNode(state, random, problem);
38      } else {
39        // if the symbol has alternatives then we must choose one randomly (only one sub-tree in this case)
40        if(Grammar.subtreeCount[state] == 1) {
41          var targetStates = Grammar.transition[state];
42          var altIdx = SampleAlternative(random, maxDepth - 1, state, curDepth);
43          var alternative = SampleTree(targetStates[altIdx], maxDepth - 1, ref steps, ref curDepth, ref depth);
44          t = new Tree(altIdx, new Tree[] { alternative });
45        } else {
46          // if the symbol contains only one sequence we must use create sub-trees for each symbol in the sequence
47          Tree[] subtrees = new Tree[Grammar.subtreeCount[state]];
48          for(int i = 0; i < Grammar.subtreeCount[state]; i++) {
49            subtrees[i] = SampleTree(Grammar.transition[state][i], maxDepth - 1, ref steps, ref curDepth, ref depth);
50          }
51          t = new Tree(-1, subtrees); // alternative index is ignored
52        }
53      }
54      curDepth -=1;
55      return t;
56    }
57
58    private static Tree CreateTerminalNode(int state, Random random, ?IDENT?Problem problem) {
59      switch(state) {
60        ?CREATETERMINALNODECODE?
61        default: { throw new ArgumentException(""Unknown state index "" + state); }
62      }
63    }
64
65    private int SampleAlternative(Random random, int maxDepth, int state, int depth) {
66      switch(state) {
67
68?SAMPLEALTERNATIVECODE?
69
70        default: throw new InvalidOperationException();
71      }
72    }
73
74    private double TerminalProbForDepth(int depth) {
75      if(depth>=maxDepth) return 1.0;
76      return baseTerminalProbability + depth * terminalProbabilityInc;
77    }
78
79    private void ParseArguments(string[] args) {
80      var baseTerminalProbabilityRegex = new Regex(@""--terminalProbBase=(?<prob>.+)"");
81      var terminalProbabilityIncRegex = new Regex(@""--terminalProbInc=(?<prob>.+)"");
82      var maxDepthRegex = new Regex(@""--maxDepth=(?<d>.+)"");
83
84      var helpRegex = new Regex(@""--help|/\?"");
85     
86      foreach(var arg in args) {
87        var baseTerminalProbabilityMatch = baseTerminalProbabilityRegex.Match(arg);
88        var terminalProbabilityIncMatch = terminalProbabilityIncRegex.Match(arg);
89        var maxDepthMatch = maxDepthRegex.Match(arg);
90        var helpMatch = helpRegex.Match(arg);
91        if(helpMatch.Success) { PrintUsage(); Environment.Exit(0); }
92        else if(baseTerminalProbabilityMatch.Success) {
93          baseTerminalProbability = double.Parse(baseTerminalProbabilityMatch.Groups[""prob""].Captures[0].Value, System.Globalization.CultureInfo.InvariantCulture);
94          if(baseTerminalProbability < 0.0 || baseTerminalProbability > 1.0) throw new ArgumentException(""base terminal probability must lie in range [0.0 ... 1.0]"");
95        } else if(terminalProbabilityIncMatch.Success) {
96           terminalProbabilityInc = double.Parse(terminalProbabilityIncMatch.Groups[""prob""].Captures[0].Value, System.Globalization.CultureInfo.InvariantCulture);
97           if(terminalProbabilityInc < 0.0 || terminalProbabilityInc > 1.0) throw new ArgumentException(""terminal probability increment must lie in range [0.0 ... 1.0]"");
98        } else if(maxDepthMatch.Success) {
99           maxDepth = int.Parse(maxDepthMatch.Groups[""d""].Captures[0].Value, System.Globalization.CultureInfo.InvariantCulture);
100           if(maxDepth < 1 || maxDepth > 100) throw new ArgumentException(""max depth must lie in range [1 ... 100]"");
101        } else {
102           Console.WriteLine(""Unknown switch {0}"", arg); PrintUsage(); Environment.Exit(0);
103        }
104      }
105    }
106    private void PrintUsage() {
107      Console.WriteLine(""Find a solution using random tree search."");
108      Console.WriteLine();
109      Console.WriteLine(""Parameters:"");
110      Console.WriteLine(""\t--maxDepth=<depth>\tSets the maximal depth of sampled trees [Default: 20]"");
111      Console.WriteLine(""\t--terminalProbBase=<prob>\tSets the probability of sampling a terminal alternative in a rule [Default: 0.05]"");
112      Console.WriteLine(""\t--terminalProbInc=<prob>\tSets the increment for the probability of sampling a terminal alternative for each level in the syntax tree [Default: 0.05]"");
113    }
114
115
116    public ?IDENT?RandomSearchSolver(?IDENT?Problem problem, string[] args) {
117      if(args.Length >= 1) ParseArguments(args);
118
119      this.problem = problem;
120      this.random = new Random();
121    }
122
123    public void Start() {
124      var bestF = ?MAXIMIZATION? ? double.NegativeInfinity : double.PositiveInfinity;
125      int n = 0;
126      long sumDepth = 0;
127      long sumSize = 0;
128      var sumF = 0.0;
129      var sw = new System.Diagnostics.Stopwatch();
130      sw.Start();
131      while (true) {
132
133        int steps, depth;
134        var _t = SampleTree(maxDepth, out steps, out depth);
135        Debug.Assert(depth <= maxDepth);
136        // _t.PrintTree(0); Console.WriteLine();
137        var f = problem.Evaluate(_t);
138 
139        n++;   
140        sumSize += steps;
141        sumDepth += depth;
142        sumF += f;
143        if (problem.IsBetter(f, bestF)) {
144          bestF = f;
145          _t.PrintTree(0); Console.WriteLine();
146          Console.WriteLine(""{0}\t{1}\t(size={2}, depth={3})"", n, bestF, steps, depth);
147        }
148        if (n % 1000 == 0) {
149          sw.Stop();
150          Console.WriteLine(""{0}\tbest: {1:0.000}\t(avg: {2:0.000})\t(avg size: {3:0.0})\t(avg. depth: {4:0.0})\t({5:0.00} sols/ms)"", n, bestF, sumF/1000.0, sumSize/1000.0, sumDepth/1000.0, 1000.0 / sw.ElapsedMilliseconds);
151          sumSize = 0;
152          sumDepth = 0;
153          sumF = 0.0;
154          sw.Restart();
155        }
156      }
157    }
158  }
159}";
160
161    public void Generate(IGrammar grammar, IEnumerable<TerminalNode> terminals, bool maximization, SourceBuilder problemSourceCode) {
162      var solverSourceCode = new SourceBuilder();
163      solverSourceCode.Append(solverTemplate)
164        .Replace("?MAXIMIZATION?", maximization.ToString().ToLowerInvariant())
165        .Replace("?SAMPLEALTERNATIVECODE?", GenerateSampleAlternativeSource(grammar))
166        .Replace("?CREATETERMINALNODECODE?", GenerateCreateTerminalCode(grammar, terminals))
167      ;
168
169      problemSourceCode.Append(solverSourceCode.ToString());
170    }
171
172
173
174    private string GenerateSampleAlternativeSource(IGrammar grammar) {
175      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
176      var sb = new SourceBuilder();
177      int stateCount = 0;
178      foreach (var s in grammar.Symbols) {
179        sb.AppendFormat("case {0}: ", stateCount++);
180        if (grammar.IsTerminal(s)) {
181          // ignore
182        } else {
183          var terminalAltIndexes = grammar.GetAlternatives(s)
184            .Select((alt, idx) => new { alt, idx })
185            .Where((p) => p.alt.All(symb => grammar.IsTerminal(symb)))
186            .Select(p => p.idx);
187          var nonTerminalAltIndexes = grammar.GetAlternatives(s)
188            .Select((alt, idx) => new { alt, idx })
189            .Where((p) => p.alt.Any(symb => grammar.IsNonTerminal(symb)))
190            .Select(p => p.idx);
191          var hasTerminalAlts = terminalAltIndexes.Any();
192          var hasNonTerminalAlts = nonTerminalAltIndexes.Any();
193          if (hasTerminalAlts && hasNonTerminalAlts) {
194            sb.Append("if(maxDepth <= 1 || random.NextDouble() < TerminalProbForDepth(depth)) {").BeginBlock();
195            GenerateSampleTerminalAlternativesStatement(terminalAltIndexes, sb);
196            sb.Append("} else {");
197            GenerateSampleNonterminalAlternativesStatement(nonTerminalAltIndexes, sb);
198            sb.Append("}").EndBlock();
199          } else {
200            GenerateReturnStatement(grammar.NumberOfAlternatives(s), sb);
201          }
202        }
203      }
204      return sb.ToString();
205    }
206    private string GenerateCreateTerminalCode(IGrammar grammar, IEnumerable<TerminalNode> terminals) {
207      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
208      var sb = new SourceBuilder();
209      var allSymbols = grammar.Symbols.ToList();
210      foreach (var s in grammar.Symbols) {
211        if (grammar.IsTerminal(s)) {
212          sb.AppendFormat("case {0}: {{", allSymbols.IndexOf(s)).BeginBlock();
213          sb.AppendFormat("var t = new {0}Tree();", s.Name).AppendLine();
214          var terminal = terminals.Single(t => t.Ident == s.Name);
215          foreach (var constr in terminal.Constraints) {
216            if (constr.Type == ConstraintNodeType.Set) {
217              sb.Append("{").BeginBlock();
218              sb.AppendFormat("var elements = problem.GetAllowed{0}_{1}().ToArray();", terminal.Ident, constr.Ident).AppendLine();
219              sb.AppendFormat("t.{0} = elements[random.Next(elements.Length)]; ", constr.Ident).EndBlock();
220              sb.AppendLine("}");
221            } else {
222              sb.Append("{").BeginBlock();
223              sb.AppendFormat(" var max = problem.GetMax{0}_{1}();", terminal.Ident, constr.Ident).AppendLine();
224              sb.AppendFormat(" var min = problem.GetMin{0}_{1}();", terminal.Ident, constr.Ident).AppendLine();
225              sb.AppendFormat("t.{0} = random.NextDouble() * (max - min) + min;", constr.Ident).EndBlock();
226              sb.AppendLine("}");
227            }
228          }
229          sb.AppendLine("return t;").EndBlock();
230          sb.Append("}");
231        }
232      }
233      return sb.ToString();
234    }
235    private void GenerateSampleTerminalAlternativesStatement(IEnumerable<int> idxs, SourceBuilder sb) {
236      if (idxs.Count() == 1) {
237        sb.AppendFormat("return {0};", idxs.Single()).AppendLine();
238      } else {
239        var idxStr = idxs.Aggregate(string.Empty, (str, idx) => str + idx + ", ");
240        sb.AppendFormat("return new int[] {{ {0} }}[random.Next({1})]; ", idxStr, idxs.Count()).AppendLine();
241      }
242    }
243    private void GenerateSampleNonterminalAlternativesStatement(IEnumerable<int> idxs, SourceBuilder sb) {
244      if (idxs.Count() == 1) {
245        sb.AppendFormat("return {0};", idxs.Single()).AppendLine();
246      } else {
247        var idxStr = idxs.Aggregate(string.Empty, (str, idx) => str + idx + ", ");
248        sb.AppendLine("{");
249        sb.AppendFormat("var allIdx = new int[] {{ {0} }}; ", idxStr).AppendLine();
250        sb.AppendFormat(
251          "var allowedIdx = (from idx in allIdx let targetState = Grammar.transition[state][idx] where Grammar.minDepth[targetState] <= maxDepth select idx).ToArray();")
252          .AppendLine();
253        sb.AppendLine(
254          "if(allowedIdx.Length==0) { allowedIdx = Enumerable.Range(0, Grammar.transition[state].Length).Except(allIdx).ToArray(); } ")
255          .AppendLine();
256        sb.AppendLine("return allowedIdx[random.Next(allowedIdx.Length)];");
257        sb.AppendLine("}");
258      }
259    }
260
261    private void GenerateReturnStatement(int nAlts, SourceBuilder sb) {
262      if (nAlts > 1)
263      {
264        sb.AppendLine("{");
265        sb.AppendFormat(
266          "var allowedIdx = (from idx in Enumerable.Range(0, {0}) let targetState = Grammar.transition[state][idx] where Grammar.minDepth[targetState] <= maxDepth select idx).ToArray();", nAlts)
267          .AppendLine();
268        sb.AppendLine("return allowedIdx[random.Next(allowedIdx.Length)];");
269        sb.AppendLine("}");
270      } else if (nAlts == 1) {
271        sb.AppendLine("return 0; ");
272      } else {
273        sb.AppendLine("throw new InvalidProgramException();");
274      }
275    }
276  }
277}
Note: See TracBrowser for help on using the repository browser.