Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GPDL/CodeGenerator/MonteCarloTreeSearchCodeGen.cs @ 10411

Last change on this file since 10411 was 10411, checked in by gkronber, 10 years ago

#2026 refactoring

File size: 13.6 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using HeuristicLab.Grammars;
7
8namespace CodeGenerator {
9  public class MonteCarloTreeSearchCodeGen {
10
11    private string solverTemplate = @"
12namespace ?PROBLEMNAME? {
13  public class SearchTreeNode {
14    public int tries;
15    public double sumQuality = 0.0;
16    public double bestQuality = double.NegativeInfinity;
17    public bool ready;
18    public SearchTreeNode[] children;
19
20    public SearchTreeNode() {
21    }
22  }
23
24  public sealed class ?IDENT?Solver {
25
26    private readonly ?IDENT?Problem problem;
27    private readonly Random random;
28    private SearchTreeNode searchTree = new SearchTreeNode();
29   
30    private Tree SampleTree() {
31      var extensionsStack = new Stack<Tuple<Tree, int, int>>(); // the unfinished tree, the state, and the index of the extension point
32      var t = new Tree(-1, new Tree[1]);
33      extensionsStack.Push(Tuple.Create(t, 0, 0));
34      SampleTree(searchTree, extensionsStack);
35      return t.subtrees[0];
36    }
37
38    private void SampleTree(SearchTreeNode searchTree, Stack<Tuple<Tree, int, int>> extensionPoints) {
39      const int RANDOM_TRIES = 100;
40      if(extensionPoints.Count == 0) return; // nothing to do
41      var extensionPoint = extensionPoints.Pop();
42      Tree parent = extensionPoint.Item1;
43      int state = extensionPoint.Item2;
44      int subtreeIdx = extensionPoint.Item3;
45      Tree t = null;
46     
47      if(searchTree.tries < RANDOM_TRIES || Grammar.subtreeCount[state] == 0) {
48        t = SampleTreeRandom(state);
49        if(Grammar.subtreeCount[state] == 0) {
50          // when we produced a terminal continue filling up all other empty points
51          Debug.Assert(searchTree.children == null || searchTree.children.Length == 1);
52          if(searchTree.children == null)
53            searchTree.children = new SearchTreeNode[] { new SearchTreeNode() } ;
54          SampleTree(searchTree.children[0], extensionPoints);
55        } else {
56          // fill up all remaining slots randomly
57          foreach(var p in extensionPoints) {
58            var pParent = p.Item1;
59            var pState = p.Item2;
60            var pIdx = p.Item3;
61            pParent.subtrees[pIdx] = SampleTreeRandom(pState);
62          }
63        }
64      } else {
65        if(Grammar.subtreeCount[state] == 1) {
66          if(searchTree.children == null) {
67            int nChildren = Grammar.transition[state].Length;
68            searchTree.children = new SearchTreeNode[nChildren];
69          }
70          Debug.Assert(searchTree.children.Length == Grammar.transition[state].Length);
71          Debug.Assert(searchTree.tries - RANDOM_TRIES == searchTree.children.Where(c=>c!=null).Sum(c=>c.tries));
72          var altIdx = SelectAlternative(searchTree);
73          t = new Tree(altIdx, new Tree[1]);
74          extensionPoints.Push(Tuple.Create(t, Grammar.transition[state][altIdx], 0));
75          SampleTree(searchTree.children[altIdx], extensionPoints);
76        } else {
77          // multiple subtrees
78          var subtrees = new Tree[Grammar.subtreeCount[state]];
79          t = new Tree(-1, subtrees);
80          for(int i = subtrees.Length - 1; i >= 0; i--) {
81            extensionPoints.Push(Tuple.Create(t, Grammar.transition[state][i], i));
82          }
83          SampleTree(searchTree, extensionPoints);
84        }
85      }
86      Debug.Assert(parent.subtrees[subtreeIdx] == null);
87      parent.subtrees[subtreeIdx] = t;
88    }
89
90    private int SelectAlternative(SearchTreeNode searchTree) {
91      // any alternative not yet explored?
92      var altIdx = Array.FindIndex(searchTree.children, (e) => e == null);
93      if(altIdx >= 0) {
94        searchTree.children[altIdx] = new SearchTreeNode();
95        return altIdx;
96      } else {
97        // select the least sampled alternative
98        altIdx = 0;
99        int minSamples = searchTree.children[altIdx].tries;
100        for(int idx = 1; idx < searchTree.children.Length; idx++) {
101          if(searchTree.children[idx].tries < minSamples) {
102            minSamples = searchTree.children[idx].tries;
103            altIdx = idx;
104          }
105        }
106        // select the alternative with the largest average
107        // altIdx = 0;
108        // double bestAverage = UCB(searchTree, searchTree.children[altIdx]);
109        // for(int idx = 1; idx < searchTree.children.Length; idx++) {
110        //   if (UCB(searchTree, searchTree.children[idx]) > UCB(searchTree, searchTree.children[altIdx])) {
111        //     altIdx = idx;
112        //   }
113        // }
114        return altIdx;
115      }
116    }
117
118    private double UCB(SearchTreeNode parent, SearchTreeNode n) {
119      Debug.Assert(parent.tries >= n.tries);
120      Debug.Assert(n.tries > 0);
121      return n.sumQuality / n.tries + Math.Sqrt((2 * Math.Log(parent.tries)) / n.tries ); // constant is dependent fitness function values
122    }
123
124    private void UpdateSearchTree(Tree t, double quality) {
125      var trees = new Stack<Tree>();
126      trees.Push(t);
127      UpdateSearchTree(searchTree, trees, quality);
128    }
129
130    private void UpdateSearchTree(SearchTreeNode searchTree, Stack<Tree> trees, double quality) {
131      if(trees.Count == 0 || searchTree == null) return;
132      var t = trees.Pop();
133      if(t.altIdx == -1) {
134        // for trees with multiple sub-trees
135        for(int idx = t.subtrees.Length - 1 ; idx >= 0; idx--) {
136          trees.Push(t.subtrees[idx]);
137        }
138        UpdateSearchTree(searchTree, trees, quality);
139      } else {
140        searchTree.sumQuality += quality;
141        searchTree.tries++;
142        if(quality > searchTree.bestQuality)
143          searchTree.bestQuality = quality;
144        if(t.subtrees != null) {
145          Debug.Assert(t.subtrees.Length == 1);
146          if(searchTree.children!=null) {
147            trees.Push(t.subtrees[0]);
148            UpdateSearchTree(searchTree.children[t.altIdx], trees, quality);
149          }
150        } else {
151          if(searchTree.children!=null) {
152            Debug.Assert(searchTree.children.Length == 1);
153            UpdateSearchTree(searchTree.children[0], trees, quality);
154          }
155        }
156      }
157    }
158
159    // same as in random search solver (could reuse random search)
160    private Tree SampleTreeRandom(int state) {
161      return SampleTreeRandom(state, 5);
162    }
163    private Tree SampleTreeRandom(int state, int maxDepth) {
164      Tree t = null;
165
166      // terminals
167      if(Grammar.subtreeCount[state] == 0) {
168        t = CreateTerminalNode(state, random, problem);
169      } else {
170        // if the symbol has alternatives then we must choose one randomly (only one sub-tree in this case)
171        if(Grammar.subtreeCount[state] == 1) {
172          var targetStates = Grammar.transition[state];
173          var altIdx = SampleAlternative(random, state, maxDepth);
174          var alternative = SampleTreeRandom(targetStates[altIdx], maxDepth - 1);
175          t = new Tree(altIdx, new Tree[] { alternative });
176        } else {
177          // if the symbol contains only one sequence we must use create sub-trees for each symbol in the sequence
178          Tree[] subtrees = new Tree[Grammar.subtreeCount[state]];
179          for(int i = 0; i < Grammar.subtreeCount[state]; i++) {
180            subtrees[i] = SampleTreeRandom(Grammar.transition[state][i], maxDepth - 1);
181          }
182          t = new Tree(-1, subtrees); // alternative index is ignored
183        }
184      }
185      return t;
186    }
187
188    private static Tree CreateTerminalNode(int state, Random random, ?IDENT?Problem problem) {
189      switch(state) {
190        ?CREATETERMINALNODECODE?
191        default: { throw new ArgumentException(""Unknown state index"" + state); }
192      }
193    }
194
195    private int SampleAlternative(Random random, int state, int maxDepth) {
196      switch(state) {
197
198?SAMPLEALTERNATIVECODE?
199
200        default: throw new InvalidOperationException();
201      }
202    }
203
204
205    public static void Main(string[] args) {
206      // if(args.Length >= 1) ParseArguments(args);
207
208      var problem = new ?IDENT?Problem();
209      var solver = new ?IDENT?Solver(problem);
210      solver.Start();
211    }
212
213    public ?IDENT?Solver(?IDENT?Problem problem) {
214      this.problem = problem;
215      this.random = new Random();
216    }
217
218    private void Start() {
219      Console.ReadLine();
220      var bestF = ?MAXIMIZATION? ? double.NegativeInfinity : double.PositiveInfinity;
221      int n = 0;
222      long sumDepth = 0;
223      long sumSize = 0;
224      var sumF = 0.0;
225      var sw = new System.Diagnostics.Stopwatch();
226      sw.Start();
227      while (true) {
228
229        int steps, depth;
230        var _t = SampleTree();
231        //  _t.PrintTree(0); Console.WriteLine();
232
233        // inefficient but don't care for now
234        steps = _t.GetSize();
235        depth = _t.GetDepth();
236        var f = problem.Evaluate(_t);
237        if(?MAXIMIZATION?)
238          UpdateSearchTree(_t, f);
239        else
240          UpdateSearchTree(_t, -f);
241        n++;   
242        sumSize += steps;
243        sumDepth += depth;
244        sumF += f;
245        if (problem.IsBetter(f, bestF)) {
246          bestF = f;
247          _t.PrintTree(0); Console.WriteLine();
248          Console.WriteLine(""{0}\t{1}\t(size={2}, depth={3})"", n, bestF, steps, depth);
249        }
250        if (n % 1000 == 0) {
251          sw.Stop();
252          Console.WriteLine(""{0}\tbest: {1:0.000}\t(avg: {2:0.000})\t(avg size: {3:0.0})\t(avg. depth: {4:0.0})\t({5:0.00} sols/ms)"", n, bestF, sumF/1000.0, sumSize/1000.0, sumDepth/1000.0, 1000.0 / sw.ElapsedMilliseconds);
253          sumSize = 0;
254          sumDepth = 0;
255          sumF = 0.0;
256          sw.Restart();
257        }
258      }
259    }
260  }
261}";
262
263    public void Generate(IGrammar grammar, IEnumerable<TerminalNode> terminals, bool maximization, SourceBuilder problemSourceCode) {
264      var solverSourceCode = new SourceBuilder();
265      solverSourceCode.Append(solverTemplate)
266        .Replace("?MAXIMIZATION?", maximization.ToString().ToLowerInvariant())
267        .Replace("?SAMPLEALTERNATIVECODE?", GenerateSampleAlternativeSource(grammar))
268        .Replace("?CREATETERMINALNODECODE?", GenerateCreateTerminalCode(grammar, terminals))
269      ;
270
271      problemSourceCode.Append(solverSourceCode.ToString());
272    }
273
274
275
276    private string GenerateSampleAlternativeSource(IGrammar grammar) {
277      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
278      var sb = new SourceBuilder();
279      int stateCount = 0;
280      foreach (var s in grammar.Symbols) {
281        sb.AppendFormat("case {0}: ", stateCount++);
282        if (grammar.IsTerminal(s)) {
283          // ignore
284        } else {
285          var terminalAltIndexes = grammar.GetAlternatives(s)
286            .Select((alt, idx) => new { alt, idx })
287            .Where((p) => p.alt.All(symb => grammar.IsTerminal(symb)))
288            .Select(p => p.idx);
289          var nonTerminalAltIndexes = grammar.GetAlternatives(s)
290            .Select((alt, idx) => new { alt, idx })
291            .Where((p) => p.alt.Any(symb => grammar.IsNonTerminal(symb)))
292            .Select(p => p.idx);
293          var hasTerminalAlts = terminalAltIndexes.Any();
294          var hasNonTerminalAlts = nonTerminalAltIndexes.Any();
295          if (hasTerminalAlts && hasNonTerminalAlts) {
296            sb.Append("if(maxDepth <= 1) {").BeginBlock();
297            GenerateReturnStatement(terminalAltIndexes, sb);
298            sb.Append("} else {");
299            GenerateReturnStatement(nonTerminalAltIndexes, sb);
300            sb.Append("}").EndBlock();
301          } else {
302            GenerateReturnStatement(grammar.NumberOfAlternatives(s), sb);
303          }
304        }
305      }
306      return sb.ToString();
307    }
308    private string GenerateCreateTerminalCode(IGrammar grammar, IEnumerable<TerminalNode> terminals) {
309      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
310      var sb = new SourceBuilder();
311      var allSymbols = grammar.Symbols.ToList();
312      foreach (var s in grammar.Symbols) {
313        if (grammar.IsTerminal(s)) {
314          sb.AppendFormat("case {0}: {{", allSymbols.IndexOf(s)).BeginBlock();
315          sb.AppendFormat("var t = new {0}Tree();", s.Name).AppendLine();
316          var terminal = terminals.Single(t => t.Ident == s.Name);
317          foreach (var constr in terminal.Constraints) {
318            if (constr.Type == ConstraintNodeType.Set) {
319              sb.Append("{").BeginBlock();
320              sb.AppendFormat("var elements = problem.GetAllowed{0}_{1}().ToArray();", terminal.Ident, constr.Ident).AppendLine();
321              sb.AppendFormat("t.{0} = elements[random.Next(elements.Length)]; ", constr.Ident).EndBlock();
322              sb.AppendLine("}");
323            } else {
324              throw new NotSupportedException("The MTCS solver does not support RANGE constraints.");
325            }
326          }
327          sb.AppendLine("return t;").EndBlock();
328          sb.Append("}");
329        }
330      }
331      return sb.ToString();
332    }
333    private void GenerateReturnStatement(IEnumerable<int> idxs, SourceBuilder sb) {
334      if (idxs.Count() == 1) {
335        sb.AppendFormat("return {0};", idxs.Single()).AppendLine();
336      } else {
337        var idxStr = idxs.Aggregate(string.Empty, (str, idx) => str + idx + ", ");
338        sb.AppendFormat("return new int[] {{ {0} }}[random.Next({1})]; ", idxStr, idxs.Count()).AppendLine();
339      }
340    }
341
342    private void GenerateReturnStatement(int nAlts, SourceBuilder sb) {
343      if (nAlts > 1) {
344        sb.AppendFormat("return random.Next({0});", nAlts).AppendLine();
345      } else if (nAlts == 1) {
346        sb.AppendLine("return 0; ");
347      } else {
348        sb.AppendLine("throw new InvalidProgramException();");
349      }
350    }
351  }
352}
Note: See TracBrowser for help on using the repository browser.