Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GPDL/CodeGenerator/RandomSearchCodeGen.cs @ 15071

Last change on this file since 15071 was 10427, checked in by gkronber, 11 years ago

#2026 integrated max depth into MCTS solver

File size: 12.4 KB
RevLine 
[10062]1using System;
2using System.Collections.Generic;
[10067]3using System.Diagnostics;
[10062]4using System.Linq;
[10335]5using System.Text;
[10067]6using HeuristicLab.Grammars;
[10062]7
8namespace CodeGenerator {
[10080]9  public class RandomSearchCodeGen {
[10062]10
11    private string solverTemplate = @"
12namespace ?PROBLEMNAME? {
[10426]13  public sealed class ?IDENT?RandomSearchSolver {
14    private double baseTerminalProbability = 0.05; // 5% of all samples are only a terminal node
15    private double terminalProbabilityInc = 0.05; // for each level the probability to sample a terminal grows by 5%
[10424]16    private static int maxDepth = 20;
[10335]17
[10386]18    private readonly ?IDENT?Problem problem;
19    private readonly Random random;
[10335]20
[10427]21    public Tree SampleTree(int maxDepth, out int steps, out int depth) {
[10386]22      steps = 0;
23      depth = 0;
24      int curDepth = 0;
[10424]25      return SampleTree(0, maxDepth, ref steps, ref curDepth, ref depth);
[10386]26    }
[10100]27
[10427]28    public Tree SampleTree(int state, int maxDepth, ref int steps, ref int curDepth, ref int depth) {
[10386]29      curDepth += 1;
[10424]30      Debug.Assert(maxDepth > 0);
[10386]31      steps += 1;
32      depth = Math.Max(depth, curDepth);
33      Tree t = null;
[10062]34
[10386]35      // terminals
36      if(Grammar.subtreeCount[state] == 0) {
[10387]37        t = CreateTerminalNode(state, random, problem);
[10386]38      } else {
39        // if the symbol has alternatives then we must choose one randomly (only one sub-tree in this case)
40        if(Grammar.subtreeCount[state] == 1) {
41          var targetStates = Grammar.transition[state];
[10424]42          var altIdx = SampleAlternative(random, maxDepth - 1, state, curDepth);
43          var alternative = SampleTree(targetStates[altIdx], maxDepth - 1, ref steps, ref curDepth, ref depth);
[10386]44          t = new Tree(altIdx, new Tree[] { alternative });
[10384]45        } else {
[10386]46          // if the symbol contains only one sequence we must use create sub-trees for each symbol in the sequence
47          Tree[] subtrees = new Tree[Grammar.subtreeCount[state]];
48          for(int i = 0; i < Grammar.subtreeCount[state]; i++) {
[10424]49            subtrees[i] = SampleTree(Grammar.transition[state][i], maxDepth - 1, ref steps, ref curDepth, ref depth);
[10100]50          }
[10386]51          t = new Tree(-1, subtrees); // alternative index is ignored
[10100]52        }
[10086]53      }
[10386]54      curDepth -=1;
55      return t;
56    }
[10100]57
[10388]58    private static Tree CreateTerminalNode(int state, Random random, ?IDENT?Problem problem) {
[10387]59      switch(state) {
60        ?CREATETERMINALNODECODE?
[10424]61        default: { throw new ArgumentException(""Unknown state index "" + state); }
[10387]62      }
63    }
64
[10424]65    private int SampleAlternative(Random random, int maxDepth, int state, int depth) {
[10386]66      switch(state) {
[10100]67
68?SAMPLEALTERNATIVECODE?
69
[10386]70        default: throw new InvalidOperationException();
[10100]71      }
[10386]72    }
[10100]73
[10386]74    private double TerminalProbForDepth(int depth) {
[10424]75      if(depth>=maxDepth) return 1.0;
[10386]76      return baseTerminalProbability + depth * terminalProbabilityInc;
[10086]77    }
78
[10426]79    private void ParseArguments(string[] args) {
[10100]80      var baseTerminalProbabilityRegex = new Regex(@""--terminalProbBase=(?<prob>.+)"");
81      var terminalProbabilityIncRegex = new Regex(@""--terminalProbInc=(?<prob>.+)"");
[10424]82      var maxDepthRegex = new Regex(@""--maxDepth=(?<d>.+)"");
83
[10100]84      var helpRegex = new Regex(@""--help|/\?"");
85     
86      foreach(var arg in args) {
87        var baseTerminalProbabilityMatch = baseTerminalProbabilityRegex.Match(arg);
88        var terminalProbabilityIncMatch = terminalProbabilityIncRegex.Match(arg);
[10424]89        var maxDepthMatch = maxDepthRegex.Match(arg);
[10100]90        var helpMatch = helpRegex.Match(arg);
91        if(helpMatch.Success) { PrintUsage(); Environment.Exit(0); }
92        else if(baseTerminalProbabilityMatch.Success) {
93          baseTerminalProbability = double.Parse(baseTerminalProbabilityMatch.Groups[""prob""].Captures[0].Value, System.Globalization.CultureInfo.InvariantCulture);
94          if(baseTerminalProbability < 0.0 || baseTerminalProbability > 1.0) throw new ArgumentException(""base terminal probability must lie in range [0.0 ... 1.0]"");
95        } else if(terminalProbabilityIncMatch.Success) {
96           terminalProbabilityInc = double.Parse(terminalProbabilityIncMatch.Groups[""prob""].Captures[0].Value, System.Globalization.CultureInfo.InvariantCulture);
97           if(terminalProbabilityInc < 0.0 || terminalProbabilityInc > 1.0) throw new ArgumentException(""terminal probability increment must lie in range [0.0 ... 1.0]"");
[10424]98        } else if(maxDepthMatch.Success) {
99           maxDepth = int.Parse(maxDepthMatch.Groups[""d""].Captures[0].Value, System.Globalization.CultureInfo.InvariantCulture);
100           if(maxDepth < 1 || maxDepth > 100) throw new ArgumentException(""max depth must lie in range [1 ... 100]"");
[10100]101        } else {
102           Console.WriteLine(""Unknown switch {0}"", arg); PrintUsage(); Environment.Exit(0);
103        }
104      }
[10062]105    }
[10426]106    private void PrintUsage() {
[10100]107      Console.WriteLine(""Find a solution using random tree search."");
108      Console.WriteLine();
109      Console.WriteLine(""Parameters:"");
[10424]110      Console.WriteLine(""\t--maxDepth=<depth>\tSets the maximal depth of sampled trees [Default: 20]"");
[10100]111      Console.WriteLine(""\t--terminalProbBase=<prob>\tSets the probability of sampling a terminal alternative in a rule [Default: 0.05]"");
112      Console.WriteLine(""\t--terminalProbInc=<prob>\tSets the increment for the probability of sampling a terminal alternative for each level in the syntax tree [Default: 0.05]"");
[10062]113    }
114
[10100]115
[10426]116    public ?IDENT?RandomSearchSolver(?IDENT?Problem problem, string[] args) {
117      if(args.Length >= 1) ParseArguments(args);
118
[10100]119      this.problem = problem;
[10386]120      this.random = new Random();
[10086]121    }
[10067]122
[10426]123    public void Start() {
[10062]124      var bestF = ?MAXIMIZATION? ? double.NegativeInfinity : double.PositiveInfinity;
[10067]125      int n = 0;
[10100]126      long sumDepth = 0;
127      long sumSize = 0;
128      var sumF = 0.0;
[10074]129      var sw = new System.Diagnostics.Stopwatch();
130      sw.Start();
[10400]131      while (true) {
[10074]132
[10386]133        int steps, depth;
[10424]134        var _t = SampleTree(maxDepth, out steps, out depth);
135        Debug.Assert(depth <= maxDepth);
136        // _t.PrintTree(0); Console.WriteLine();
[10384]137        var f = problem.Evaluate(_t);
[10100]138 
[10384]139        n++;   
[10386]140        sumSize += steps;
141        sumDepth += depth;
[10100]142        sumF += f;
[10386]143        if (problem.IsBetter(f, bestF)) {
[10086]144          bestF = f;
[10400]145          _t.PrintTree(0); Console.WriteLine();
[10386]146          Console.WriteLine(""{0}\t{1}\t(size={2}, depth={3})"", n, bestF, steps, depth);
[10086]147        }
148        if (n % 1000 == 0) {
[10074]149          sw.Stop();
[10100]150          Console.WriteLine(""{0}\tbest: {1:0.000}\t(avg: {2:0.000})\t(avg size: {3:0.0})\t(avg. depth: {4:0.0})\t({5:0.00} sols/ms)"", n, bestF, sumF/1000.0, sumSize/1000.0, sumDepth/1000.0, 1000.0 / sw.ElapsedMilliseconds);
151          sumSize = 0;
152          sumDepth = 0;
153          sumF = 0.0;
[10385]154          sw.Restart();
[10074]155        }
[10062]156      }
157    }
158  }
159}";
160
[10387]161    public void Generate(IGrammar grammar, IEnumerable<TerminalNode> terminals, bool maximization, SourceBuilder problemSourceCode) {
[10100]162      var solverSourceCode = new SourceBuilder();
163      solverSourceCode.Append(solverTemplate)
164        .Replace("?MAXIMIZATION?", maximization.ToString().ToLowerInvariant())
165        .Replace("?SAMPLEALTERNATIVECODE?", GenerateSampleAlternativeSource(grammar))
[10387]166        .Replace("?CREATETERMINALNODECODE?", GenerateCreateTerminalCode(grammar, terminals))
[10100]167      ;
[10062]168
[10100]169      problemSourceCode.Append(solverSourceCode.ToString());
[10062]170    }
171
172
[10335]173
[10100]174    private string GenerateSampleAlternativeSource(IGrammar grammar) {
175      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
176      var sb = new SourceBuilder();
177      int stateCount = 0;
178      foreach (var s in grammar.Symbols) {
179        sb.AppendFormat("case {0}: ", stateCount++);
180        if (grammar.IsTerminal(s)) {
[10335]181          // ignore
[10100]182        } else {
183          var terminalAltIndexes = grammar.GetAlternatives(s)
184            .Select((alt, idx) => new { alt, idx })
185            .Where((p) => p.alt.All(symb => grammar.IsTerminal(symb)))
186            .Select(p => p.idx);
187          var nonTerminalAltIndexes = grammar.GetAlternatives(s)
188            .Select((alt, idx) => new { alt, idx })
189            .Where((p) => p.alt.Any(symb => grammar.IsNonTerminal(symb)))
190            .Select(p => p.idx);
191          var hasTerminalAlts = terminalAltIndexes.Any();
192          var hasNonTerminalAlts = nonTerminalAltIndexes.Any();
193          if (hasTerminalAlts && hasNonTerminalAlts) {
[10424]194            sb.Append("if(maxDepth <= 1 || random.NextDouble() < TerminalProbForDepth(depth)) {").BeginBlock();
195            GenerateSampleTerminalAlternativesStatement(terminalAltIndexes, sb);
[10100]196            sb.Append("} else {");
[10424]197            GenerateSampleNonterminalAlternativesStatement(nonTerminalAltIndexes, sb);
[10100]198            sb.Append("}").EndBlock();
199          } else {
200            GenerateReturnStatement(grammar.NumberOfAlternatives(s), sb);
201          }
[10062]202        }
203      }
204      return sb.ToString();
205    }
[10387]206    private string GenerateCreateTerminalCode(IGrammar grammar, IEnumerable<TerminalNode> terminals) {
207      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
208      var sb = new SourceBuilder();
209      var allSymbols = grammar.Symbols.ToList();
210      foreach (var s in grammar.Symbols) {
211        if (grammar.IsTerminal(s)) {
212          sb.AppendFormat("case {0}: {{", allSymbols.IndexOf(s)).BeginBlock();
213          sb.AppendFormat("var t = new {0}Tree();", s.Name).AppendLine();
214          var terminal = terminals.Single(t => t.Ident == s.Name);
215          foreach (var constr in terminal.Constraints) {
216            if (constr.Type == ConstraintNodeType.Set) {
217              sb.Append("{").BeginBlock();
218              sb.AppendFormat("var elements = problem.GetAllowed{0}_{1}().ToArray();", terminal.Ident, constr.Ident).AppendLine();
219              sb.AppendFormat("t.{0} = elements[random.Next(elements.Length)]; ", constr.Ident).EndBlock();
220              sb.AppendLine("}");
221            } else {
222              sb.Append("{").BeginBlock();
223              sb.AppendFormat(" var max = problem.GetMax{0}_{1}();", terminal.Ident, constr.Ident).AppendLine();
224              sb.AppendFormat(" var min = problem.GetMin{0}_{1}();", terminal.Ident, constr.Ident).AppendLine();
225              sb.AppendFormat("t.{0} = random.NextDouble() * (max - min) + min;", constr.Ident).EndBlock();
226              sb.AppendLine("}");
227            }
228          }
229          sb.AppendLine("return t;").EndBlock();
230          sb.Append("}");
231        }
232      }
233      return sb.ToString();
234    }
[10424]235    private void GenerateSampleTerminalAlternativesStatement(IEnumerable<int> idxs, SourceBuilder sb) {
[10100]236      if (idxs.Count() == 1) {
237        sb.AppendFormat("return {0};", idxs.Single()).AppendLine();
[10062]238      } else {
[10100]239        var idxStr = idxs.Aggregate(string.Empty, (str, idx) => str + idx + ", ");
240        sb.AppendFormat("return new int[] {{ {0} }}[random.Next({1})]; ", idxStr, idxs.Count()).AppendLine();
[10062]241      }
242    }
[10424]243    private void GenerateSampleNonterminalAlternativesStatement(IEnumerable<int> idxs, SourceBuilder sb) {
244      if (idxs.Count() == 1) {
245        sb.AppendFormat("return {0};", idxs.Single()).AppendLine();
246      } else {
247        var idxStr = idxs.Aggregate(string.Empty, (str, idx) => str + idx + ", ");
248        sb.AppendLine("{");
249        sb.AppendFormat("var allIdx = new int[] {{ {0} }}; ", idxStr).AppendLine();
250        sb.AppendFormat(
251          "var allowedIdx = (from idx in allIdx let targetState = Grammar.transition[state][idx] where Grammar.minDepth[targetState] <= maxDepth select idx).ToArray();")
252          .AppendLine();
253        sb.AppendLine(
254          "if(allowedIdx.Length==0) { allowedIdx = Enumerable.Range(0, Grammar.transition[state].Length).Except(allIdx).ToArray(); } ")
255          .AppendLine();
256        sb.AppendLine("return allowedIdx[random.Next(allowedIdx.Length)];");
257        sb.AppendLine("}");
258      }
259    }
[10062]260
[10100]261    private void GenerateReturnStatement(int nAlts, SourceBuilder sb) {
[10424]262      if (nAlts > 1)
263      {
264        sb.AppendLine("{");
265        sb.AppendFormat(
266          "var allowedIdx = (from idx in Enumerable.Range(0, {0}) let targetState = Grammar.transition[state][idx] where Grammar.minDepth[targetState] <= maxDepth select idx).ToArray();", nAlts)
267          .AppendLine();
268        sb.AppendLine("return allowedIdx[random.Next(allowedIdx.Length)];");
269        sb.AppendLine("}");
[10100]270      } else if (nAlts == 1) {
271        sb.AppendLine("return 0; ");
[10067]272      } else {
[10100]273        sb.AppendLine("throw new InvalidProgramException();");
[10062]274      }
275    }
276  }
277}
Note: See TracBrowser for help on using the repository browser.