Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GPDL/CodeGenerator/MonteCarloTreeSearchCodeGen.cs @ 10415

Last change on this file since 10415 was 10415, checked in by gkronber, 10 years ago

#2026 implemented prevention of resampling of known nodes.

File size: 14.3 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using HeuristicLab.Grammars;
7
8namespace CodeGenerator {
9  public class MonteCarloTreeSearchCodeGen {
10
11    private string solverTemplate = @"
12namespace ?PROBLEMNAME? {
13  public class SearchTreeNode {
14    public int tries;
15    public double sumQuality = 0.0;
16    public double bestQuality = double.NegativeInfinity;
17    public bool done;
18    public SearchTreeNode[] children;
19    public double[] Ucb {
20      get {
21        return (from c in children
22                select ?IDENT?Solver.UCB(this, c)
23               ).ToArray();
24      }
25    }
26    public SearchTreeNode() {
27    }
28  }
29
30  public sealed class ?IDENT?Solver {
31
32    private readonly ?IDENT?Problem problem;
33    private readonly Random random;
34    private SearchTreeNode searchTree = new SearchTreeNode();
35   
36    private Tree SampleTree() {
37      var extensionsStack = new Stack<Tuple<Tree, int, int>>(); // the unfinished tree, the state, and the index of the extension point
38      var t = new Tree(-1, new Tree[1]);
39      extensionsStack.Push(Tuple.Create(t, 0, 0));
40      SampleTree(searchTree, extensionsStack);
41      return t.subtrees[0];
42    }
43
44    private void SampleTree(SearchTreeNode searchTree, Stack<Tuple<Tree, int, int>> extensionPoints) {
45      const int RANDOM_TRIES = 1000;
46      if(extensionPoints.Count == 0) {
47        searchTree.done = true;
48        return; // nothing to do
49      }
50      var extensionPoint = extensionPoints.Pop();
51      Tree parent = extensionPoint.Item1;
52      int state = extensionPoint.Item2;
53      int subtreeIdx = extensionPoint.Item3;
54      Tree t = null;
55      if(searchTree.tries < RANDOM_TRIES || Grammar.subtreeCount[state] == 0) {
56        t = SampleTreeRandom(state);
57        if(Grammar.subtreeCount[state] == 0) {
58          // when we produced a terminal continue filling up all other empty points
59          Debug.Assert(searchTree.children == null || searchTree.children.Length == 1);
60          if(searchTree.children == null)
61            searchTree.children = new SearchTreeNode[] { new SearchTreeNode() } ;
62          SampleTree(searchTree.children[0], extensionPoints);
63          if(searchTree.children[0].done) searchTree.done = true;
64        } else {
65          // fill up all remaining slots randomly
66          foreach(var p in extensionPoints) {
67            var pParent = p.Item1;
68            var pState = p.Item2;
69            var pIdx = p.Item3;
70            pParent.subtrees[pIdx] = SampleTreeRandom(pState);
71          }
72        }
73      } else {
74        if(Grammar.subtreeCount[state] == 1) {
75          if(searchTree.children == null) {
76            int nChildren = Grammar.transition[state].Length;
77            searchTree.children = new SearchTreeNode[nChildren];
78          }
79          Debug.Assert(searchTree.children.Length == Grammar.transition[state].Length);
80          Debug.Assert(searchTree.tries - RANDOM_TRIES == searchTree.children.Where(c=>c!=null).Sum(c=>c.tries));
81          var altIdx = SelectAlternative(searchTree);
82          t = new Tree(altIdx, new Tree[1]);
83          extensionPoints.Push(Tuple.Create(t, Grammar.transition[state][altIdx], 0));
84          SampleTree(searchTree.children[altIdx], extensionPoints);
85        } else {
86          // multiple subtrees
87          var subtrees = new Tree[Grammar.subtreeCount[state]];
88          t = new Tree(-1, subtrees);
89          for(int i = subtrees.Length - 1; i >= 0; i--) {
90            extensionPoints.Push(Tuple.Create(t, Grammar.transition[state][i], i));
91          }
92          SampleTree(searchTree, extensionPoints);         
93        }
94      }     
95      Debug.Assert(parent.subtrees[subtreeIdx] == null);
96      parent.subtrees[subtreeIdx] = t;
97    }
98
99    private int SelectAlternative(SearchTreeNode searchTree) {
100      // any alternative not yet explored?
101      var altIdx = Array.FindIndex(searchTree.children, (e) => e == null);
102      if(altIdx >= 0) {
103        searchTree.children[altIdx] = new SearchTreeNode();
104        return altIdx;
105      } else {
106        altIdx = Array.FindIndex(searchTree.children, (e) => !e.done && e.tries < 1000);
107        if(altIdx >= 0) return altIdx;
108        // select the least sampled alternative
109        //altIdx = 0;
110        //int minSamples = searchTree.children[altIdx].tries;
111        //for(int idx = 1; idx < searchTree.children.Length; idx++) {
112        //  if(!searchTree.children[idx].done && searchTree.children[idx].tries < minSamples) {
113        //    minSamples = searchTree.children[idx].tries;
114        //    altIdx = idx;
115        //  }
116        //}
117        // select the alternative with the largest average
118        altIdx = 0;
119        double bestAverage = UCB(searchTree, searchTree.children[altIdx]);
120        for(int idx = 1; idx < searchTree.children.Length; idx++) {
121          if (!searchTree.children[idx].done && UCB(searchTree, searchTree.children[idx]) > UCB(searchTree, searchTree.children[altIdx])) {
122            altIdx = idx;
123          }
124        }
125
126        searchTree.done = searchTree.children.All(c=>c.done);
127        return altIdx;
128      }
129    }
130
131    public static double UCB(SearchTreeNode parent, SearchTreeNode n) {
132      Debug.Assert(parent.tries >= n.tries);
133      Debug.Assert(n.tries > 0);
134      return n.sumQuality / n.tries + Math.Sqrt((10 * Math.Log(parent.tries)) / n.tries ); // constant is dependent fitness function values
135    }
136
137    private void UpdateSearchTree(Tree t, double quality) {
138      var trees = new Stack<Tree>();
139      trees.Push(t);
140      UpdateSearchTree(searchTree, trees, quality);
141    }
142
143    private void UpdateSearchTree(SearchTreeNode searchTree, Stack<Tree> trees, double quality) {
144      if(trees.Count == 0 || searchTree == null) return;
145      var t = trees.Pop();
146      if(t.altIdx == -1) {
147        // for trees with multiple sub-trees
148        for(int idx = t.subtrees.Length - 1 ; idx >= 0; idx--) {
149          trees.Push(t.subtrees[idx]);
150        }
151        UpdateSearchTree(searchTree, trees, quality);
152      } else {
153        searchTree.sumQuality += quality;
154        searchTree.tries++;
155        if(quality > searchTree.bestQuality)
156          searchTree.bestQuality = quality;
157        if(t.subtrees != null) {
158          Debug.Assert(t.subtrees.Length == 1);
159          if(searchTree.children != null) {
160            trees.Push(t.subtrees[0]);
161            UpdateSearchTree(searchTree.children[t.altIdx], trees, quality);
162          }
163        } else {
164          if(searchTree.children != null) {
165            Debug.Assert(searchTree.children.Length == 1);
166            UpdateSearchTree(searchTree.children[0], trees, quality);
167          }
168        }
169      }
170    }
171
172    // same as in random search solver (could reuse random search)
173    private Tree SampleTreeRandom(int state) {
174      return SampleTreeRandom(state, 5);
175    }
176    private Tree SampleTreeRandom(int state, int maxDepth) {
177      Tree t = null;
178
179      // terminals
180      if(Grammar.subtreeCount[state] == 0) {
181        t = CreateTerminalNode(state, random, problem);
182      } else {
183        // if the symbol has alternatives then we must choose one randomly (only one sub-tree in this case)
184        if(Grammar.subtreeCount[state] == 1) {
185          var targetStates = Grammar.transition[state];
186          var altIdx = SampleAlternative(random, state, maxDepth);
187          var alternative = SampleTreeRandom(targetStates[altIdx], maxDepth - 1);
188          t = new Tree(altIdx, new Tree[] { alternative });
189        } else {
190          // if the symbol contains only one sequence we must use create sub-trees for each symbol in the sequence
191          Tree[] subtrees = new Tree[Grammar.subtreeCount[state]];
192          for(int i = 0; i < Grammar.subtreeCount[state]; i++) {
193            subtrees[i] = SampleTreeRandom(Grammar.transition[state][i], maxDepth - 1);
194          }
195          t = new Tree(-1, subtrees); // alternative index is ignored
196        }
197      }
198      return t;
199    }
200
201    private static Tree CreateTerminalNode(int state, Random random, ?IDENT?Problem problem) {
202      switch(state) {
203        ?CREATETERMINALNODECODE?
204        default: { throw new ArgumentException(""Unknown state index"" + state); }
205      }
206    }
207
208    private int SampleAlternative(Random random, int state, int maxDepth) {
209      switch(state) {
210
211?SAMPLEALTERNATIVECODE?
212
213        default: throw new InvalidOperationException();
214      }
215    }
216
217
218    public static void Main(string[] args) {
219      // if(args.Length >= 1) ParseArguments(args);
220
221      var problem = new ?IDENT?Problem();
222      var solver = new ?IDENT?Solver(problem);
223      solver.Start();
224    }
225
226    public ?IDENT?Solver(?IDENT?Problem problem) {
227      this.problem = problem;
228      this.random = new Random();
229    }
230
231    private void Start() {
232      Console.ReadLine();
233      var bestF = ?MAXIMIZATION? ? double.NegativeInfinity : double.PositiveInfinity;
234      int n = 0;
235      long sumDepth = 0;
236      long sumSize = 0;
237      var sumF = 0.0;
238      var sw = new System.Diagnostics.Stopwatch();
239      sw.Start();
240      while (!searchTree.done) {
241
242        int steps, depth;
243        var _t = SampleTree();
244        //  _t.PrintTree(0); Console.WriteLine();
245
246        // inefficient but don't care for now
247        steps = _t.GetSize();
248        depth = _t.GetDepth();
249        var f = problem.Evaluate(_t);
250        if(?MAXIMIZATION?)
251          UpdateSearchTree(_t, f);
252        else
253          UpdateSearchTree(_t, -f);
254        n++;   
255        sumSize += steps;
256        sumDepth += depth;
257        sumF += f;
258        if (problem.IsBetter(f, bestF)) {
259          bestF = f;
260          _t.PrintTree(0); Console.WriteLine();
261          Console.WriteLine(""{0}\t{1}\t(size={2}, depth={3})"", n, bestF, steps, depth);
262        }
263        if (n % 1000 == 0) {
264          sw.Stop();
265          Console.WriteLine(""{0}\tbest: {1:0.000}\t(avg: {2:0.000})\t(avg size: {3:0.0})\t(avg. depth: {4:0.0})\t({5:0.00} sols/ms)"", n, bestF, sumF/1000.0, sumSize/1000.0, sumDepth/1000.0, 1000.0 / sw.ElapsedMilliseconds);
266          sumSize = 0;
267          sumDepth = 0;
268          sumF = 0.0;
269          sw.Restart();
270        }
271      }
272    }
273  }
274}";
275
276    public void Generate(IGrammar grammar, IEnumerable<TerminalNode> terminals, bool maximization, SourceBuilder problemSourceCode) {
277      var solverSourceCode = new SourceBuilder();
278      solverSourceCode.Append(solverTemplate)
279        .Replace("?MAXIMIZATION?", maximization.ToString().ToLowerInvariant())
280        .Replace("?SAMPLEALTERNATIVECODE?", GenerateSampleAlternativeSource(grammar))
281        .Replace("?CREATETERMINALNODECODE?", GenerateCreateTerminalCode(grammar, terminals))
282      ;
283
284      problemSourceCode.Append(solverSourceCode.ToString());
285    }
286
287
288
289    private string GenerateSampleAlternativeSource(IGrammar grammar) {
290      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
291      var sb = new SourceBuilder();
292      int stateCount = 0;
293      foreach (var s in grammar.Symbols) {
294        sb.AppendFormat("case {0}: ", stateCount++);
295        if (grammar.IsTerminal(s)) {
296          // ignore
297        } else {
298          var terminalAltIndexes = grammar.GetAlternatives(s)
299            .Select((alt, idx) => new { alt, idx })
300            .Where((p) => p.alt.All(symb => grammar.IsTerminal(symb)))
301            .Select(p => p.idx);
302          var nonTerminalAltIndexes = grammar.GetAlternatives(s)
303            .Select((alt, idx) => new { alt, idx })
304            .Where((p) => p.alt.Any(symb => grammar.IsNonTerminal(symb)))
305            .Select(p => p.idx);
306          var hasTerminalAlts = terminalAltIndexes.Any();
307          var hasNonTerminalAlts = nonTerminalAltIndexes.Any();
308          if (hasTerminalAlts && hasNonTerminalAlts) {
309            sb.Append("if(maxDepth <= 1) {").BeginBlock();
310            GenerateReturnStatement(terminalAltIndexes, sb);
311            sb.Append("} else {");
312            GenerateReturnStatement(nonTerminalAltIndexes.Concat(terminalAltIndexes), sb);
313            sb.Append("}").EndBlock();
314          } else {
315            GenerateReturnStatement(grammar.NumberOfAlternatives(s), sb);
316          }
317        }
318      }
319      return sb.ToString();
320    }
321    private string GenerateCreateTerminalCode(IGrammar grammar, IEnumerable<TerminalNode> terminals) {
322      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
323      var sb = new SourceBuilder();
324      var allSymbols = grammar.Symbols.ToList();
325      foreach (var s in grammar.Symbols) {
326        if (grammar.IsTerminal(s)) {
327          sb.AppendFormat("case {0}: {{", allSymbols.IndexOf(s)).BeginBlock();
328          sb.AppendFormat("var t = new {0}Tree();", s.Name).AppendLine();
329          var terminal = terminals.Single(t => t.Ident == s.Name);
330          foreach (var constr in terminal.Constraints) {
331            if (constr.Type == ConstraintNodeType.Set) {
332              throw new NotImplementedException("Support for terminal symbols with attributes is not yet implemented.");
333              // sb.Append("{").BeginBlock();
334              // sb.AppendFormat("var elements = problem.GetAllowed{0}_{1}().ToArray();", terminal.Ident, constr.Ident).AppendLine();
335              // sb.AppendFormat("t.{0} = elements[random.Next(elements.Length)]; ", constr.Ident).EndBlock();
336              // sb.AppendLine("}");
337            } else {
338              throw new NotSupportedException("The MTCS solver does not support RANGE constraints.");
339            }
340          }
341          sb.AppendLine("return t;").EndBlock();
342          sb.Append("}");
343        }
344      }
345      return sb.ToString();
346    }
347    private void GenerateReturnStatement(IEnumerable<int> idxs, SourceBuilder sb) {
348      if (idxs.Count() == 1) {
349        sb.AppendFormat("return {0};", idxs.Single()).AppendLine();
350      } else {
351        var idxStr = idxs.Aggregate(string.Empty, (str, idx) => str + idx + ", ");
352        sb.AppendFormat("return new int[] {{ {0} }}[random.Next({1})]; ", idxStr, idxs.Count()).AppendLine();
353      }
354    }
355
356    private void GenerateReturnStatement(int nAlts, SourceBuilder sb) {
357      if (nAlts > 1) {
358        sb.AppendFormat("return random.Next({0});", nAlts).AppendLine();
359      } else if (nAlts == 1) {
360        sb.AppendLine("return 0; ");
361      } else {
362        sb.AppendLine("throw new InvalidProgramException();");
363      }
364    }
365  }
366}
Note: See TracBrowser for help on using the repository browser.