Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GPDL/CodeGenerator/MonteCarloTreeSearchCodeGen.cs @ 10426

Last change on this file since 10426 was 10426, checked in by gkronber, 11 years ago

#2026 generate code for all solvers

File size: 14.6 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using HeuristicLab.Grammars;
7
8namespace CodeGenerator {
9  public class MonteCarloTreeSearchCodeGen {
10
11    private string solverTemplate = @"
12namespace ?PROBLEMNAME? {
13  public class SearchTreeNode {
14    public int tries;
15    public double sumQuality = 0.0;
16    public double bestQuality = double.NegativeInfinity;
17    public bool done;
18    public SearchTreeNode[] children;
19    // only for debugging
20    public double[] Ucb {
21      get {
22        return (from c in children
23                select ?IDENT?MonteCarloTreeSearchSolver.UCB(this, c)
24               ).ToArray();
25      }
26    }
27    public SearchTreeNode() {
28    }
29  }
30
31  public sealed class ?IDENT?MonteCarloTreeSearchSolver {
32    private int maxDepth = 20;
33
34    private readonly ?IDENT?Problem problem;
35    private readonly Random random;
36    private SearchTreeNode searchTree = new SearchTreeNode();
37   
38    private Tree SampleTree(int maxDepth) {
39      var extensionsStack = new Stack<Tuple<Tree, int, int, int>>(); // the unfinished tree, the state, the index of the extension point and the maximal depth of a tree inserted at that point
40      var t = new Tree(-1, new Tree[1]);
41      extensionsStack.Push(Tuple.Create(t, 0, 0, maxDepth));
42      SampleTree(searchTree, extensionsStack);
43      return t.subtrees[0];
44    }
45
46    private void SampleTree(SearchTreeNode searchTree, Stack<Tuple<Tree, int, int, int>> extensionPoints) {
47      const int RANDOM_TRIES = 1000;
48      if(extensionPoints.Count == 0) {
49        searchTree.done = true;
50        return; // nothing to do
51      }
52      var extensionPoint = extensionPoints.Pop();
53      Tree parent = extensionPoint.Item1;
54      int state = extensionPoint.Item2;
55      int subtreeIdx = extensionPoint.Item3;
56      int maxDepth = extensionPoint.Item4;
57      Debug.Assert(maxDepth >= 1);
58      Tree t = null;
59      if(searchTree.tries < RANDOM_TRIES || Grammar.subtreeCount[state] == 0) {
60        t = SampleTreeRandom(state);
61        if(Grammar.subtreeCount[state] == 0) {
62          // when we produced a terminal continue filling up all other empty points
63          Debug.Assert(searchTree.children == null || searchTree.children.Length == 1);
64          if(searchTree.children == null)
65            searchTree.children = new SearchTreeNode[] { new SearchTreeNode() } ;
66          SampleTree(searchTree.children[0], extensionPoints);
67          if(searchTree.children[0].done) searchTree.done = true;
68        } else {
69          // fill up all remaining slots randomly
70          foreach(var p in extensionPoints) {
71            var pParent = p.Item1;
72            var pState = p.Item2;
73            var pIdx = p.Item3;
74            pParent.subtrees[pIdx] = SampleTreeRandom(pState);
75          }
76        }
77      } else {
78        if(Grammar.subtreeCount[state] == 1) {
79          if(searchTree.children == null) {
80            int nChildren = Grammar.transition[state].Length;
81            searchTree.children = new SearchTreeNode[nChildren];
82          }
83          Debug.Assert(searchTree.children.Length == Grammar.transition[state].Length);
84          Debug.Assert(searchTree.tries - RANDOM_TRIES == searchTree.children.Where(c=>c!=null).Sum(c=>c.tries));
85          var altIdx = SelectAlternative(searchTree);
86          t = new Tree(altIdx, new Tree[1]);
87          extensionPoints.Push(Tuple.Create(t, Grammar.transition[state][altIdx], 0, maxDepth - 1));
88          SampleTree(searchTree.children[altIdx], extensionPoints);
89        } else {
90          // multiple subtrees
91          var subtrees = new Tree[Grammar.subtreeCount[state]];
92          t = new Tree(-1, subtrees);
93          for(int i = subtrees.Length - 1; i >= 0; i--) {
94            extensionPoints.Push(Tuple.Create(t, Grammar.transition[state][i], i, maxDepth - 1));
95          }
96          SampleTree(searchTree, extensionPoints);         
97        }
98      }     
99      Debug.Assert(parent.subtrees[subtreeIdx] == null);
100      parent.subtrees[subtreeIdx] = t;
101    }
102
103    private int SelectAlternative(SearchTreeNode searchTree) {
104      // any alternative not yet explored?
105      var altIdx = Array.FindIndex(searchTree.children, (e) => e == null);
106      if(altIdx >= 0) {
107        searchTree.children[altIdx] = new SearchTreeNode();
108        return altIdx;
109      } else {
110        altIdx = Array.FindIndex(searchTree.children, (e) => !e.done && e.tries < 1000);
111        if(altIdx >= 0) return altIdx;
112        // select the least sampled alternative
113        //altIdx = 0;
114        //int minSamples = searchTree.children[altIdx].tries;
115        //for(int idx = 1; idx < searchTree.children.Length; idx++) {
116        //  if(!searchTree.children[idx].done && searchTree.children[idx].tries < minSamples) {
117        //    minSamples = searchTree.children[idx].tries;
118        //    altIdx = idx;
119        //  }
120        //}
121        // select the alternative with the largest average
122        altIdx = 0;
123        double bestAverage = UCB(searchTree, searchTree.children[altIdx]);
124        for(int idx = 1; idx < searchTree.children.Length; idx++) {
125          if (!searchTree.children[idx].done && UCB(searchTree, searchTree.children[idx]) > UCB(searchTree, searchTree.children[altIdx])) {
126            altIdx = idx;
127          }
128        }
129
130        searchTree.done = searchTree.children.All(c=>c.done);
131        return altIdx;
132      }
133    }
134
135    public static double UCB(SearchTreeNode parent, SearchTreeNode n) {
136      Debug.Assert(parent.tries >= n.tries);
137      Debug.Assert(n.tries > 0);
138      return n.sumQuality / n.tries + Math.Sqrt((10 * Math.Log(parent.tries)) / n.tries ); // constant is dependent fitness function values
139    }
140
141    private void UpdateSearchTree(Tree t, double quality) {
142      var trees = new Stack<Tree>();
143      trees.Push(t);
144      UpdateSearchTree(searchTree, trees, quality);
145    }
146
147    private void UpdateSearchTree(SearchTreeNode searchTree, Stack<Tree> trees, double quality) {
148      if(trees.Count == 0 || searchTree == null) return;
149      var t = trees.Pop();
150      if(t.altIdx == -1) {
151        // for trees with multiple sub-trees
152        for(int idx = t.subtrees.Length - 1 ; idx >= 0; idx--) {
153          trees.Push(t.subtrees[idx]);
154        }
155        UpdateSearchTree(searchTree, trees, quality);
156      } else {
157        searchTree.sumQuality += quality;
158        searchTree.tries++;
159        if(quality > searchTree.bestQuality)
160          searchTree.bestQuality = quality;
161        if(t.subtrees != null) {
162          Debug.Assert(t.subtrees.Length == 1);
163          if(searchTree.children != null) {
164            trees.Push(t.subtrees[0]);
165            UpdateSearchTree(searchTree.children[t.altIdx], trees, quality);
166          }
167        } else {
168          if(searchTree.children != null) {
169            Debug.Assert(searchTree.children.Length == 1);
170            UpdateSearchTree(searchTree.children[0], trees, quality);
171          }
172        }
173      }
174    }
175
176    // same as in random search solver (could reuse random search)
177    private Tree SampleTreeRandom(int state) {
178      return SampleTreeRandom(state, 5);
179    }
180    private Tree SampleTreeRandom(int state, int maxDepth) {
181      Tree t = null;
182
183      // terminals
184      if(Grammar.subtreeCount[state] == 0) {
185        t = CreateTerminalNode(state, random, problem);
186      } else {
187        // if the symbol has alternatives then we must choose one randomly (only one sub-tree in this case)
188        if(Grammar.subtreeCount[state] == 1) {
189          var targetStates = Grammar.transition[state];
190          var altIdx = SampleAlternative(random, state, maxDepth);
191          var alternative = SampleTreeRandom(targetStates[altIdx], maxDepth - 1);
192          t = new Tree(altIdx, new Tree[] { alternative });
193        } else {
194          // if the symbol contains only one sequence we must use create sub-trees for each symbol in the sequence
195          Tree[] subtrees = new Tree[Grammar.subtreeCount[state]];
196          for(int i = 0; i < Grammar.subtreeCount[state]; i++) {
197            subtrees[i] = SampleTreeRandom(Grammar.transition[state][i], maxDepth - 1);
198          }
199          t = new Tree(-1, subtrees); // alternative index is ignored
200        }
201      }
202      return t;
203    }
204
205    private static Tree CreateTerminalNode(int state, Random random, ?IDENT?Problem problem) {
206      switch(state) {
207        ?CREATETERMINALNODECODE?
208        default: { throw new ArgumentException(""Unknown state index"" + state); }
209      }
210    }
211
212    private int SampleAlternative(Random random, int state, int maxDepth) {
213      switch(state) {
214
215?SAMPLEALTERNATIVECODE?
216
217        default: throw new InvalidOperationException();
218      }
219    }
220
221    public ?IDENT?MonteCarloTreeSearchSolver(?IDENT?Problem problem, string[] args) {
222      if(args.Length > 0 ) throw new ArgumentException(""Arguments ""+args.Aggregate("""", (str, s) => str+ "" "" + s)+"" are not supported "");
223      this.problem = problem;
224      this.random = new Random();
225    }
226
227    public void Start() {
228      Console.ReadLine();
229      var bestF = ?MAXIMIZATION? ? double.NegativeInfinity : double.PositiveInfinity;
230      int n = 0;
231      long sumDepth = 0;
232      long sumSize = 0;
233      var sumF = 0.0;
234      var sw = new System.Diagnostics.Stopwatch();
235      sw.Start();
236      while (!searchTree.done) {
237
238        int steps, depth;
239        var _t = SampleTree(maxDepth);
240        //  _t.PrintTree(0); Console.WriteLine();
241
242        // inefficient but don't care for now
243        steps = _t.GetSize();
244        depth = _t.GetDepth();
245        Debug.Assert(depth <= maxDepth);
246        var f = problem.Evaluate(_t);
247        if(?MAXIMIZATION?)
248          UpdateSearchTree(_t, f);
249        else
250          UpdateSearchTree(_t, -f);
251        n++;   
252        sumSize += steps;
253        sumDepth += depth;
254        sumF += f;
255        if (problem.IsBetter(f, bestF)) {
256          bestF = f;
257          _t.PrintTree(0); Console.WriteLine();
258          Console.WriteLine(""{0}\t{1}\t(size={2}, depth={3})"", n, bestF, steps, depth);
259        }
260        if (n % 1000 == 0) {
261          sw.Stop();
262          Console.WriteLine(""{0}\tbest: {1:0.000}\t(avg: {2:0.000})\t(avg size: {3:0.0})\t(avg. depth: {4:0.0})\t({5:0.00} sols/ms)"", n, bestF, sumF/1000.0, sumSize/1000.0, sumDepth/1000.0, 1000.0 / sw.ElapsedMilliseconds);
263          sumSize = 0;
264          sumDepth = 0;
265          sumF = 0.0;
266          sw.Restart();
267        }
268      }
269    }
270  }
271}";
272
273    public void Generate(IGrammar grammar, IEnumerable<TerminalNode> terminals, bool maximization, SourceBuilder problemSourceCode) {
274      var solverSourceCode = new SourceBuilder();
275      solverSourceCode.Append(solverTemplate)
276        .Replace("?MAXIMIZATION?", maximization.ToString().ToLowerInvariant())
277        .Replace("?SAMPLEALTERNATIVECODE?", GenerateSampleAlternativeSource(grammar))
278        .Replace("?CREATETERMINALNODECODE?", GenerateCreateTerminalCode(grammar, terminals))
279      ;
280
281      problemSourceCode.Append(solverSourceCode.ToString());
282    }
283
284
285
286    private string GenerateSampleAlternativeSource(IGrammar grammar) {
287      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
288      var sb = new SourceBuilder();
289      int stateCount = 0;
290      foreach (var s in grammar.Symbols) {
291        sb.AppendFormat("case {0}: ", stateCount++);
292        if (grammar.IsTerminal(s)) {
293          // ignore
294        } else {
295          var terminalAltIndexes = grammar.GetAlternatives(s)
296            .Select((alt, idx) => new { alt, idx })
297            .Where((p) => p.alt.All(symb => grammar.IsTerminal(symb)))
298            .Select(p => p.idx);
299          var nonTerminalAltIndexes = grammar.GetAlternatives(s)
300            .Select((alt, idx) => new { alt, idx })
301            .Where((p) => p.alt.Any(symb => grammar.IsNonTerminal(symb)))
302            .Select(p => p.idx);
303          var hasTerminalAlts = terminalAltIndexes.Any();
304          var hasNonTerminalAlts = nonTerminalAltIndexes.Any();
305          if (hasTerminalAlts && hasNonTerminalAlts) {
306            sb.Append("if(maxDepth <= 1) {").BeginBlock();
307            GenerateReturnStatement(terminalAltIndexes, sb);
308            sb.Append("} else {");
309            GenerateReturnStatement(nonTerminalAltIndexes.Concat(terminalAltIndexes), sb);
310            sb.Append("}").EndBlock();
311          } else {
312            GenerateReturnStatement(grammar.NumberOfAlternatives(s), sb);
313          }
314        }
315      }
316      return sb.ToString();
317    }
318    private string GenerateCreateTerminalCode(IGrammar grammar, IEnumerable<TerminalNode> terminals) {
319      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
320      var sb = new SourceBuilder();
321      var allSymbols = grammar.Symbols.ToList();
322      foreach (var s in grammar.Symbols) {
323        if (grammar.IsTerminal(s)) {
324          sb.AppendFormat("case {0}: {{", allSymbols.IndexOf(s)).BeginBlock();
325          sb.AppendFormat("var t = new {0}Tree();", s.Name).AppendLine();
326          var terminal = terminals.Single(t => t.Ident == s.Name);
327          foreach (var constr in terminal.Constraints) {
328            if (constr.Type == ConstraintNodeType.Set) {
329              throw new NotImplementedException("Support for terminal symbols with attributes is not yet implemented.");
330              // sb.Append("{").BeginBlock();
331              // sb.AppendFormat("var elements = problem.GetAllowed{0}_{1}().ToArray();", terminal.Ident, constr.Ident).AppendLine();
332              // sb.AppendFormat("t.{0} = elements[random.Next(elements.Length)]; ", constr.Ident).EndBlock();
333              // sb.AppendLine("}");
334            } else {
335              throw new NotSupportedException("The MTCS solver does not support RANGE constraints.");
336            }
337          }
338          sb.AppendLine("return t;").EndBlock();
339          sb.Append("}");
340        }
341      }
342      return sb.ToString();
343    }
344    private void GenerateReturnStatement(IEnumerable<int> idxs, SourceBuilder sb) {
345      if (idxs.Count() == 1) {
346        sb.AppendFormat("return {0};", idxs.Single()).AppendLine();
347      } else {
348        var idxStr = idxs.Aggregate(string.Empty, (str, idx) => str + idx + ", ");
349        sb.AppendFormat("return new int[] {{ {0} }}[random.Next({1})]; ", idxStr, idxs.Count()).AppendLine();
350      }
351    }
352
353    private void GenerateReturnStatement(int nAlts, SourceBuilder sb) {
354      if (nAlts > 1) {
355        sb.AppendFormat("return random.Next({0});", nAlts).AppendLine();
356      } else if (nAlts == 1) {
357        sb.AppendLine("return 0; ");
358      } else {
359        sb.AppendLine("throw new InvalidProgramException();");
360      }
361    }
362  }
363}
Note: See TracBrowser for help on using the repository browser.