Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GPDL/CodeGenerator/MonteCarloTreeSearchCodeGen.cs @ 10409

Last change on this file since 10409 was 10409, checked in by gkronber, 10 years ago

#2026 prepare for inclusion of terminals into the search tree

File size: 15.9 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using HeuristicLab.Grammars;
7
8namespace CodeGenerator {
9  public class MonteCarloTreeSearchCodeGen {
10
11    private string solverTemplate = @"
12namespace ?PROBLEMNAME? {
13  public class SearchTreeNode {
14    public int tries;
15    public double sumQuality = 0.0;
16    public double bestQuality = double.NegativeInfinity;
17    public bool ready;
18    public SearchTreeNode[] children;
19
20    public SearchTreeNode() {
21    }
22  }
23
24  public sealed class ?IDENT?Solver {
25    // private static double baseTerminalProbability = 0.05; // 5% of all samples are only a terminal node
26    // private static double terminalProbabilityInc = 0.05; // for each level the probability to sample a terminal grows by 5%
27
28    private readonly ?IDENT?Problem problem;
29    private readonly Random random;
30    private SearchTreeNode searchTree = new SearchTreeNode();
31   
32    private Tree SampleTree() {
33      var extensionsStack = new Stack<Tuple<Tree, int, int>>(); // the unfinished tree, the state, and the index of the extension point
34      var t = new Tree(-1, new Tree[1]);
35      extensionsStack.Push(Tuple.Create(t, 0, 0));
36      SampleTree(searchTree, extensionsStack);
37      return t.subtrees[0];
38    }
39
40    private void SampleTree(SearchTreeNode searchTree, Stack<Tuple<Tree, int, int>> extensionPoints) {
41      const int RANDOM_TRIES = 100;
42      if(extensionPoints.Count == 0) return; // nothing to do
43      var extensionPoint = extensionPoints.Pop();
44      Tree parent = extensionPoint.Item1;
45      int state = extensionPoint.Item2;
46      int subtreeIdx = extensionPoint.Item3;
47      Tree t = null;
48     
49      if(searchTree.tries < RANDOM_TRIES || Grammar.subtreeCount[state] == 0) {
50        searchTree.tries++; // could be moved to UpdateTree
51        t = SampleTreeRandom(state);
52        if(Grammar.subtreeCount[state] == 0) {
53          // when we produced a terminal continue filling up all other empty points
54          Debug.Assert(searchTree.children == null || searchTree.children.Length == 1);
55          if(searchTree.children == null)
56            searchTree.children = new SearchTreeNode[] { new SearchTreeNode() } ;
57          SampleTree(searchTree.children[0], extensionPoints);
58        } else {
59          // fill up all remaining slots randomly
60          foreach(var p in extensionPoints) {
61            var pParent = p.Item1;
62            var pState = p.Item2;
63            var pIdx = p.Item3;
64            pParent.subtrees[pIdx] = SampleTreeRandom(pState);
65          }
66        }
67      } else {
68        if(Grammar.subtreeCount[state] == 1) {
69          searchTree.tries++; // could be moved to updateTree
70          if(searchTree.children == null) {
71            int nChildren = Grammar.transition[state].Length;
72            searchTree.children = new SearchTreeNode[nChildren];
73          }
74          Debug.Assert(searchTree.children.Length == Grammar.transition[state].Length);
75          Debug.Assert(searchTree.tries - RANDOM_TRIES - 1 == searchTree.children.Where(c=>c!=null).Sum(c=>c.tries));
76          var altIdx = SelectAlternative(searchTree);
77          t = new Tree(altIdx, new Tree[1]);
78          extensionPoints.Push(Tuple.Create(t, Grammar.transition[state][altIdx], 0));
79          SampleTree(searchTree.children[altIdx], extensionPoints);
80        } else {
81          // multiple subtrees
82          var subtrees = new Tree[Grammar.subtreeCount[state]];
83          t = new Tree(-1, subtrees);
84          for(int i = subtrees.Length - 1; i >= 0; i--) {
85            extensionPoints.Push(Tuple.Create(t, Grammar.transition[state][i], i));
86          }
87          SampleTree(searchTree, extensionPoints);
88        }
89      }
90      Debug.Assert(parent.subtrees[subtreeIdx] == null);
91      parent.subtrees[subtreeIdx] = t;
92    }
93
94    private int SelectAlternative(SearchTreeNode searchTree) {
95      // any alternative not yet explored?
96      var altIdx = Array.FindIndex(searchTree.children, (e) => e == null);
97      if(altIdx >= 0) {
98        searchTree.children[altIdx] = new SearchTreeNode();
99        return altIdx;
100      } else {
101        // select the least sampled alternative
102        altIdx = 0;
103        int minSamples = searchTree.children[altIdx].tries;
104        for(int idx = 1; idx < searchTree.children.Length; idx++) {
105          if(searchTree.children[idx].tries < minSamples) {
106            minSamples = searchTree.children[idx].tries;
107            altIdx = idx;
108          }
109        }
110        // select the alternative with the largest average
111        // altIdx = 0;
112        // double bestAverage = UCB(searchTree, searchTree.children[altIdx]);
113        // for(int idx = 1; idx < searchTree.children.Length; idx++) {
114        //   if (UCB(searchTree, searchTree.children[idx]) > UCB(searchTree, searchTree.children[altIdx])) {
115        //     altIdx = idx;
116        //   }
117        // }
118        return altIdx;
119      }
120    }
121
122    private double UCB(SearchTreeNode parent, SearchTreeNode n) {
123      Debug.Assert(parent.tries >= n.tries);
124      Debug.Assert(n.tries > 0);
125      return n.sumQuality / n.tries + Math.Sqrt((2 * Math.Log(parent.tries)) / n.tries ); // constant is dependent fitness function values
126    }
127
128    private void UpdateSearchTree(Tree t, double quality) {
129      var trees = new Stack<Tree>();
130      trees.Push(t);
131      UpdateSearchTree(searchTree, trees, quality);
132    }
133
134    private void UpdateSearchTree(SearchTreeNode searchTree, Stack<Tree> trees, double quality) {
135      if(trees.Count == 0 || searchTree == null) return;
136      var t = trees.Pop();
137      searchTree.sumQuality += quality;
138      if(quality > searchTree.bestQuality)
139        searchTree.bestQuality = quality;
140      if(searchTree.children == null) return;
141      if(t.subtrees == null) {
142        Debug.Assert(searchTree.children.Length == 1);
143        UpdateSearchTree(searchTree.children[0], trees, quality);
144      } else if(t.altIdx==-1) {
145        for(int idx = t.subtrees.Length - 1 ; idx >= 0; idx--) {
146          trees.Push(t.subtrees[idx]);
147        }
148        UpdateSearchTree(searchTree, trees, quality);
149      } else {
150        Debug.Assert(t.subtrees.Length == 1);
151        trees.Push(t.subtrees[0]);
152        UpdateSearchTree(searchTree.children[t.altIdx], trees, quality);
153      }
154    }
155
156    // same as in random search solver (could reuse random search)
157    private Tree SampleTreeRandom(int state) {
158      return SampleTreeRandom(state, 5);
159    }
160    private Tree SampleTreeRandom(int state, int maxDepth) {
161      Tree t = null;
162
163      // terminals
164      if(Grammar.subtreeCount[state] == 0) {
165        t = CreateTerminalNode(state, random, problem);
166      } else {
167        // if the symbol has alternatives then we must choose one randomly (only one sub-tree in this case)
168        if(Grammar.subtreeCount[state] == 1) {
169          var targetStates = Grammar.transition[state];
170          var altIdx = SampleAlternative(random, state, maxDepth);
171          var alternative = SampleTreeRandom(targetStates[altIdx], maxDepth - 1);
172          t = new Tree(altIdx, new Tree[] { alternative });
173        } else {
174          // if the symbol contains only one sequence we must use create sub-trees for each symbol in the sequence
175          Tree[] subtrees = new Tree[Grammar.subtreeCount[state]];
176          for(int i = 0; i < Grammar.subtreeCount[state]; i++) {
177            subtrees[i] = SampleTreeRandom(Grammar.transition[state][i], maxDepth - 1);
178          }
179          t = new Tree(-1, subtrees); // alternative index is ignored
180        }
181      }
182      return t;
183    }
184
185    private static Tree CreateTerminalNode(int state, Random random, ?IDENT?Problem problem) {
186      switch(state) {
187        ?CREATETERMINALNODECODE?
188        default: { throw new ArgumentException(""Unknown state index"" + state); }
189      }
190    }
191
192    private int SampleAlternative(Random random, int state, int maxDepth) {
193      switch(state) {
194
195?SAMPLEALTERNATIVECODE?
196
197        default: throw new InvalidOperationException();
198      }
199    }
200
201    //private double TerminalProbForDepth(int depth) {
202    //  return baseTerminalProbability + depth * terminalProbabilityInc;
203    //}
204
205    public static void Main(string[] args) {
206      // if(args.Length >= 1) ParseArguments(args);
207
208      var problem = new ?IDENT?Problem();
209      var solver = new ?IDENT?Solver(problem);
210      solver.Start();
211    }
212    //private static void ParseArguments(string[] args) {
213    //  var baseTerminalProbabilityRegex = new Regex(@""--terminalProbBase=(?<prob>.+)"");
214    //  var terminalProbabilityIncRegex = new Regex(@""--terminalProbInc=(?<prob>.+)"");
215    //  var helpRegex = new Regex(@""--help|/\?"");
216    // 
217    //  foreach(var arg in args) {
218    //    var baseTerminalProbabilityMatch = baseTerminalProbabilityRegex.Match(arg);
219    //    var terminalProbabilityIncMatch = terminalProbabilityIncRegex.Match(arg);
220    //    var helpMatch = helpRegex.Match(arg);
221    //    if(helpMatch.Success) { PrintUsage(); Environment.Exit(0); }
222    //    else if(baseTerminalProbabilityMatch.Success) {
223    //      baseTerminalProbability = double.Parse(baseTerminalProbabilityMatch.Groups[""prob""].Captures[0].Value, System.Globalization.CultureInfo.InvariantCulture);
224    //      if(baseTerminalProbability < 0.0 || baseTerminalProbability > 1.0) throw new ArgumentException(""base terminal probability must lie in range [0.0 ... 1.0]"");
225    //    } else if(terminalProbabilityIncMatch.Success) {
226    //       terminalProbabilityInc = double.Parse(terminalProbabilityIncMatch.Groups[""prob""].Captures[0].Value, System.Globalization.CultureInfo.InvariantCulture);
227    //       if(terminalProbabilityInc < 0.0 || terminalProbabilityInc > 1.0) throw new ArgumentException(""terminal probability increment must lie in range [0.0 ... 1.0]"");
228    //    } else {
229    //       Console.WriteLine(""Unknown switch {0}"", arg); PrintUsage(); Environment.Exit(0);
230    //    }
231    //  }
232    //}
233    //private static void PrintUsage() {
234    //  Console.WriteLine(""Find a solution using monte carlo tree search."");
235    //  Console.WriteLine();
236    //  Console.WriteLine(""Parameters:"");
237    //  Console.WriteLine(""\t--terminalProbBase=<prob>\tSets the probability of sampling a terminal alternative in a rule [Default: 0.05]"");
238    //  Console.WriteLine(""\t--terminalProbInc=<prob>\tSets the increment for the probability of sampling a terminal alternative for each level in the syntax tree [Default: 0.05]"");
239    //}
240
241
242    public ?IDENT?Solver(?IDENT?Problem problem) {
243      this.problem = problem;
244      this.random = new Random();
245    }
246
247    private void Start() {
248      Console.ReadLine();
249      var bestF = ?MAXIMIZATION? ? double.NegativeInfinity : double.PositiveInfinity;
250      int n = 0;
251      long sumDepth = 0;
252      long sumSize = 0;
253      var sumF = 0.0;
254      var sw = new System.Diagnostics.Stopwatch();
255      sw.Start();
256      while (true) {
257
258        int steps, depth;
259        var _t = SampleTree();
260        //  _t.PrintTree(0); Console.WriteLine();
261
262        // inefficient but don't care for now
263        steps = _t.GetSize();
264        depth = _t.GetDepth();
265        var f = problem.Evaluate(_t);
266        if(?MAXIMIZATION?)
267          UpdateSearchTree(_t, f);
268        else
269          UpdateSearchTree(_t, -f);
270        n++;   
271        sumSize += steps;
272        sumDepth += depth;
273        sumF += f;
274        if (problem.IsBetter(f, bestF)) {
275          bestF = f;
276          _t.PrintTree(0); Console.WriteLine();
277          Console.WriteLine(""{0}\t{1}\t(size={2}, depth={3})"", n, bestF, steps, depth);
278        }
279        if (n % 1000 == 0) {
280          sw.Stop();
281          Console.WriteLine(""{0}\tbest: {1:0.000}\t(avg: {2:0.000})\t(avg size: {3:0.0})\t(avg. depth: {4:0.0})\t({5:0.00} sols/ms)"", n, bestF, sumF/1000.0, sumSize/1000.0, sumDepth/1000.0, 1000.0 / sw.ElapsedMilliseconds);
282          sumSize = 0;
283          sumDepth = 0;
284          sumF = 0.0;
285          sw.Restart();
286        }
287      }
288    }
289  }
290}";
291
292    public void Generate(IGrammar grammar, IEnumerable<TerminalNode> terminals, bool maximization, SourceBuilder problemSourceCode) {
293      var solverSourceCode = new SourceBuilder();
294      solverSourceCode.Append(solverTemplate)
295        .Replace("?MAXIMIZATION?", maximization.ToString().ToLowerInvariant())
296        .Replace("?SAMPLEALTERNATIVECODE?", GenerateSampleAlternativeSource(grammar))
297        .Replace("?CREATETERMINALNODECODE?", GenerateCreateTerminalCode(grammar, terminals))
298      ;
299
300      problemSourceCode.Append(solverSourceCode.ToString());
301    }
302
303
304
305    private string GenerateSampleAlternativeSource(IGrammar grammar) {
306      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
307      var sb = new SourceBuilder();
308      int stateCount = 0;
309      foreach (var s in grammar.Symbols) {
310        sb.AppendFormat("case {0}: ", stateCount++);
311        if (grammar.IsTerminal(s)) {
312          // ignore
313        } else {
314          var terminalAltIndexes = grammar.GetAlternatives(s)
315            .Select((alt, idx) => new { alt, idx })
316            .Where((p) => p.alt.All(symb => grammar.IsTerminal(symb)))
317            .Select(p => p.idx);
318          var nonTerminalAltIndexes = grammar.GetAlternatives(s)
319            .Select((alt, idx) => new { alt, idx })
320            .Where((p) => p.alt.Any(symb => grammar.IsNonTerminal(symb)))
321            .Select(p => p.idx);
322          var hasTerminalAlts = terminalAltIndexes.Any();
323          var hasNonTerminalAlts = nonTerminalAltIndexes.Any();
324          if (hasTerminalAlts && hasNonTerminalAlts) {
325            sb.Append("if(maxDepth <= 1) {").BeginBlock();
326            GenerateReturnStatement(terminalAltIndexes, sb);
327            sb.Append("} else {");
328            GenerateReturnStatement(nonTerminalAltIndexes, sb);
329            sb.Append("}").EndBlock();
330          } else {
331            GenerateReturnStatement(grammar.NumberOfAlternatives(s), sb);
332          }
333        }
334      }
335      return sb.ToString();
336    }
337    private string GenerateCreateTerminalCode(IGrammar grammar, IEnumerable<TerminalNode> terminals) {
338      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
339      var sb = new SourceBuilder();
340      var allSymbols = grammar.Symbols.ToList();
341      foreach (var s in grammar.Symbols) {
342        if (grammar.IsTerminal(s)) {
343          sb.AppendFormat("case {0}: {{", allSymbols.IndexOf(s)).BeginBlock();
344          sb.AppendFormat("var t = new {0}Tree();", s.Name).AppendLine();
345          var terminal = terminals.Single(t => t.Ident == s.Name);
346          foreach (var constr in terminal.Constraints) {
347            if (constr.Type == ConstraintNodeType.Set) {
348              sb.Append("{").BeginBlock();
349              sb.AppendFormat("var elements = problem.GetAllowed{0}_{1}().ToArray();", terminal.Ident, constr.Ident).AppendLine();
350              sb.AppendFormat("t.{0} = elements[random.Next(elements.Length)]; ", constr.Ident).EndBlock();
351              sb.AppendLine("}");
352            } else {
353              throw new NotSupportedException("The MTCS solver does not support RANGE constraints.");
354            }
355          }
356          sb.AppendLine("return t;").EndBlock();
357          sb.Append("}");
358        }
359      }
360      return sb.ToString();
361    }
362    private void GenerateReturnStatement(IEnumerable<int> idxs, SourceBuilder sb) {
363      if (idxs.Count() == 1) {
364        sb.AppendFormat("return {0};", idxs.Single()).AppendLine();
365      } else {
366        var idxStr = idxs.Aggregate(string.Empty, (str, idx) => str + idx + ", ");
367        sb.AppendFormat("return new int[] {{ {0} }}[random.Next({1})]; ", idxStr, idxs.Count()).AppendLine();
368      }
369    }
370
371    private void GenerateReturnStatement(int nAlts, SourceBuilder sb) {
372      if (nAlts > 1) {
373        sb.AppendFormat("return random.Next({0});", nAlts).AppendLine();
374      } else if (nAlts == 1) {
375        sb.AppendLine("return 0; ");
376      } else {
377        sb.AppendLine("throw new InvalidProgramException();");
378      }
379    }
380  }
381}
Note: See TracBrowser for help on using the repository browser.