Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GPDL/CodeGenerator/MonteCarloTreeSearchCodeGen.cs @ 16371

Last change on this file since 16371 was 10437, checked in by gkronber, 11 years ago

#2026 implemented epsilon-greedy search policy

File size: 17.3 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using HeuristicLab.Grammars;
7
8namespace CodeGenerator {
9  public class MonteCarloTreeSearchCodeGen {
10
11    private string solverTemplate = @"
12namespace ?PROBLEMNAME? {
13  public class SearchTreeNode {
14    public int tries;
15    public double sumQuality = 0.0;
16    public double bestQuality = double.NegativeInfinity;
17    public bool done;
18    public SearchTreeNode[] children;
19    // only for debugging
20    public double[] Ucb {
21      get {
22        return (from c in children
23                select ?IDENT?MonteCarloTreeSearchSolver.UCB(this, c)
24               ).ToArray();
25      }
26    }
27    public SearchTreeNode() {
28    }
29  }
30
31  public sealed class ?IDENT?MonteCarloTreeSearchSolver {
32    private int maxDepth = 20;
33
34    private readonly ?IDENT?Problem problem;
35    private readonly Random random;
36    private readonly ?IDENT?RandomSearchSolver randomSearch;
37
38    private SearchTreeNode searchTree = new SearchTreeNode();
39   
40    private Tree SampleTree(int maxDepth) {
41      var extensionsStack = new Stack<Tuple<Tree, int, int, int>>(); // the unfinished tree, the state, the index of the extension point and the maximal depth of a tree inserted at that point
42      var t = new Tree(-1, new Tree[1]);
43      extensionsStack.Push(Tuple.Create(t, 0, 0, maxDepth));
44      SampleTree(searchTree, extensionsStack);
45      return t.subtrees[0];
46    }
47
48    private  const int RANDOM_TRIES = 100;
49
50    private void SampleTree(SearchTreeNode searchTree, Stack<Tuple<Tree, int, int, int>> extensionPoints) {
51      var extensionPoint = extensionPoints.Pop();
52      Tree parent = extensionPoint.Item1;
53      int state = extensionPoint.Item2;
54      int subtreeIdx = extensionPoint.Item3;
55      int maxDepth = extensionPoint.Item4;
56      Debug.Assert(maxDepth >= 1);
57      Debug.Assert(Grammar.minDepth[state] <= maxDepth);
58      Tree t = null;
59      if(searchTree.tries < RANDOM_TRIES || Grammar.subtreeCount[state] == 0) {
60        int steps = 0; int curDepth = this.maxDepth - maxDepth; int depth = this.maxDepth - maxDepth;
61        t = randomSearch.SampleTree(state, maxDepth, ref steps, ref curDepth, ref depth);
62        if(Grammar.subtreeCount[state] == 0) {
63          // when we produced a terminal continue filling up all other empty points
64          Debug.Assert(searchTree.children == null || searchTree.children.Length == 1);
65          if(extensionPoints.Count == 0) {
66            searchTree.done = true;
67          } else {
68            if(searchTree.children == null)
69              searchTree.children = new SearchTreeNode[] { new SearchTreeNode() } ;
70            SampleTree(searchTree.children[0], extensionPoints);
71            if(searchTree.children[0].done) searchTree.done = true;
72          }
73        } else {
74          // fill up all remaining slots randomly
75          foreach(var p in extensionPoints) {
76            var pParent = p.Item1;
77            var pState = p.Item2;
78            var pIdx = p.Item3;
79            var pMaxDepth = p.Item4;
80            curDepth = this.maxDepth - pMaxDepth;
81            depth = curDepth;
82            pParent.subtrees[pIdx] = randomSearch.SampleTree(pState, pMaxDepth, ref steps, ref curDepth, ref depth);
83          }
84        }
85      } else if(Grammar.subtreeCount[state] == 1) {
86         if(searchTree.children == null) {
87           int nChildren = Grammar.transition[state].Length;
88           searchTree.children = new SearchTreeNode[nChildren];
89         }
90         Debug.Assert(searchTree.children.Length == Grammar.transition[state].Length);
91         Debug.Assert(searchTree.tries - RANDOM_TRIES == searchTree.children.Where(c=>c!=null).Sum(c=>c.tries));
92         var altIdx = SelectAlternative(searchTree, state, maxDepth);
93         t = new Tree(altIdx, new Tree[1]);
94         extensionPoints.Push(Tuple.Create(t, Grammar.transition[state][altIdx], 0, maxDepth - 1));
95         SampleTree(searchTree.children[altIdx], extensionPoints);
96         searchTree.done = (from idx in Enumerable.Range(0, searchTree.children.Length)
97                            where Grammar.minDepth[Grammar.transition[state][idx]] <= maxDepth - 1
98                            select searchTree.children[idx]).All(c=>c != null && c.done);
99      } else {
100        // multiple subtrees
101        var subtrees = new Tree[Grammar.subtreeCount[state]];
102        t = new Tree(-1, subtrees);
103        for(int i = subtrees.Length - 1; i >= 0; i--) {
104          extensionPoints.Push(Tuple.Create(t, Grammar.transition[state][i], i, maxDepth - 1));
105        }
106        SampleTree(searchTree, extensionPoints);   
107      }     
108      Debug.Assert(parent.subtrees[subtreeIdx] == null);
109      parent.subtrees[subtreeIdx] = t;
110    }
111
112    private int SelectAlternative(SearchTreeNode searchTree, int state, int maxDepth) {
113      // any alternative not yet explored?
114      var altIndexes = searchTree.children
115                     .Select((e,i) => new {Elem = e, Idx = i})
116                     .Where(p => p.Elem == null && Grammar.minDepth[Grammar.transition[state][p.Idx]] <= maxDepth)
117                     .Select(p => p.Idx);
118      int altIdx = altIndexes.Any()?altIndexes.First() : -1;
119      if(altIdx >= 0) {
120        searchTree.children[altIdx] = new SearchTreeNode();
121        return altIdx;
122      } else {
123        altIndexes = searchTree.children
124                       .Select((e,i) => new {Elem = e, Idx = i})
125                       .Where(p => p.Elem != null && !p.Elem.done && p.Elem.tries < RANDOM_TRIES && Grammar.minDepth[Grammar.transition[state][p.Idx]] <= maxDepth)
126                       .Select(p => p.Idx);
127        altIdx = altIndexes.Any()?altIndexes.First() : -1;
128        if(altIdx >= 0) return altIdx;
129        // select the least sampled alternative
130        //altIdx = -1;
131        //int minSamples = int.MaxValue;
132        //for(int idx = 0; idx < searchTree.children.Length; idx++) {
133        //  if(searchTree.children[idx] == null) continue;
134        //  if(!searchTree.children[idx].done && Grammar.minDepth[Grammar.transition[state][idx]] <= maxDepth && searchTree.children[idx].tries < minSamples) {
135        //    minSamples = searchTree.children[idx].tries;
136        //    altIdx = idx;
137        //  }
138        //}
139        // select the alternative with the largest average
140        // altIdx = -1;
141        // double best = double.NegativeInfinity;
142        // for(int idx = 0; idx < searchTree.children.Length; idx++) {
143        //   if(searchTree.children[idx] == null) continue;
144        //   if (!searchTree.children[idx].done && Grammar.minDepth[Grammar.transition[state][idx]] <= maxDepth  && UCB(searchTree, searchTree.children[idx]) > best) {
145        //     altIdx = idx;
146        //     best = UCB(searchTree, searchTree.children[idx]);
147        //   }
148        // }
149        // Softmax selection
150        // double temperature = 1;
151        // var ms = searchTree.children.Select((c,i) => c == null || c.done || Grammar.minDepth[Grammar.transition[state][i]] > maxDepth ? 0.0 : Math.Exp((c.sumQuality / c.tries) / temperature)).ToArray();
152        // var msSum = ms.Sum();
153        // if(msSum == 0.0) {
154        //   // uniform distribution
155        //   ms = searchTree.children.Select((c,i) => c == null || c.done || Grammar.minDepth[Grammar.transition[state][i]] > maxDepth ? 0.0 : 1.0).ToArray(); 
156        //   msSum = ms.Sum();
157        // }
158        // Debug.Assert(msSum > 0.0);
159        // var r = random.NextDouble() * msSum;
160        // 
161        // altIdx = 0;
162        // var aggSum = 0.0;
163        // while(altIdx < searchTree.children.Length && aggSum <= r) {
164        //   var c = searchTree.children[altIdx];
165        //   aggSum += ms[altIdx++];
166        // }
167        // altIdx--;
168
169        // epsilon-greedy selection
170        double eps = 0.1;
171        if(random.NextDouble() >= eps) {
172          // select best
173          altIdx = 0; while(searchTree.children[altIdx]==null) altIdx++;
174          for(int idx=0;idx<searchTree.children.Length;idx++) {
175            var c = searchTree.children[idx];
176            if(c==null || c.done || Grammar.minDepth[Grammar.transition[state][idx]] > maxDepth) continue;
177            if(searchTree.children[idx].bestQuality > searchTree.children[altIdx].bestQuality) {
178              altIdx = idx;
179            }
180          }
181        } else {
182          // select random
183          var allowedIndexes = (searchTree.children
184                         .Select((e,i) => new {Elem = e, Idx = i})
185                         .Where(p => p.Elem != null && !p.Elem.done && Grammar.minDepth[Grammar.transition[state][p.Idx]] <= maxDepth)
186                         .Select(p => p.Idx)).ToArray();
187          altIdx = allowedIndexes[random.Next(allowedIndexes.Length)];
188        }
189        Debug.Assert(altIdx > -1);
190        return altIdx;
191      }
192    }
193 
194    public static double UCB(SearchTreeNode parent, SearchTreeNode n) {
195      Debug.Assert(parent.tries >= n.tries);
196      Debug.Assert(n.tries > 0);
197      return n.sumQuality / n.tries + Math.Sqrt((400 * Math.Log(parent.tries)) / n.tries ); // constant is dependent fitness function values
198    }
199
200    private void UpdateSearchTree(Tree t, double quality) {
201      var trees = new Stack<Tree>();
202      trees.Push(t);
203      UpdateSearchTree(searchTree, trees, quality);
204    }
205
206    private void UpdateSearchTree(SearchTreeNode searchTree, Stack<Tree> trees, double quality) {
207      if(trees.Count == 0 || searchTree == null) return;
208      var t = trees.Pop();
209      if(t.altIdx == -1) {
210        // for trees with multiple sub-trees
211        for(int idx = t.subtrees.Length - 1 ; idx >= 0; idx--) {
212          trees.Push(t.subtrees[idx]);
213        }
214        UpdateSearchTree(searchTree, trees, quality);
215      } else {
216        searchTree.sumQuality += quality;
217        searchTree.tries++;
218        if(quality > searchTree.bestQuality)
219          searchTree.bestQuality = quality;
220        if(t.subtrees != null) {
221          Debug.Assert(t.subtrees.Length == 1);
222          if(searchTree.children != null) {
223            trees.Push(t.subtrees[0]);
224            UpdateSearchTree(searchTree.children[t.altIdx], trees, quality);
225          }
226        } else {
227          if(searchTree.children != null) {
228            Debug.Assert(searchTree.children.Length == 1);
229            UpdateSearchTree(searchTree.children[0], trees, quality);
230          }
231        }
232      }
233    }
234
235    public ?IDENT?MonteCarloTreeSearchSolver(?IDENT?Problem problem, string[] args) {
236      this.randomSearch = new ?IDENT?RandomSearchSolver(problem, args);
237      if(args.Length > 0 ) ParseArguments(args);
238      this.problem = problem;
239      this.random = new Random();
240    }
241    private void ParseArguments(string[] args) {
242      var maxDepthRegex = new Regex(@""--maxDepth=(?<d>.+)"");
243
244      var helpRegex = new Regex(@""--help|/\?"");
245     
246      foreach(var arg in args) {
247        var maxDepthMatch = maxDepthRegex.Match(arg);
248        var helpMatch = helpRegex.Match(arg);
249        if(helpMatch.Success) {
250          PrintUsage(); Environment.Exit(0);
251        } else if(maxDepthMatch.Success) {
252           maxDepth = int.Parse(maxDepthMatch.Groups[""d""].Captures[0].Value, System.Globalization.CultureInfo.InvariantCulture);
253           if(maxDepth < 1 || maxDepth > 100) throw new ArgumentException(""max depth must lie in range [1 ... 100]"");
254        } else {
255           Console.WriteLine(""Unknown switch {0}"", arg); PrintUsage(); Environment.Exit(0);
256        }
257      }
258    }
259    private void PrintUsage() {
260      Console.WriteLine(""Find a solution using Monte-Carlo tree search."");
261      Console.WriteLine();
262      Console.WriteLine(""Parameters:"");
263      Console.WriteLine(""\t--maxDepth=<depth>\tSets the maximal depth of sampled trees [Default: 20]"");
264    }
265
266
267
268    public void Start() {
269      Console.ReadLine();
270      var bestF = ?MAXIMIZATION? ? double.NegativeInfinity : double.PositiveInfinity;
271      int n = 0;
272      long sumDepth = 0;
273      long sumSize = 0;
274      var sumF = 0.0;
275      var sw = new System.Diagnostics.Stopwatch();
276      sw.Start();
277      while (!searchTree.done) {
278
279        int steps, depth;
280        var _t = SampleTree(maxDepth);
281        //  _t.PrintTree(0); Console.WriteLine();
282
283        // inefficient but don't care for now
284        steps = _t.GetSize();
285        depth = _t.GetDepth();
286        Debug.Assert(depth <= maxDepth);
287        var f = problem.Evaluate(_t);
288        if(?MAXIMIZATION?)
289          UpdateSearchTree(_t, f);
290        else
291          UpdateSearchTree(_t, -f);
292        n++;   
293        sumSize += steps;
294        sumDepth += depth;
295        sumF += f;
296        if (problem.IsBetter(f, bestF)) {
297          bestF = f;
298          _t.PrintTree(0); Console.WriteLine();
299          Console.WriteLine(""{0}\t{1}\t(size={2}, depth={3})"", n, bestF, steps, depth);
300        }
301        if (n % 1000 == 0) {
302          sw.Stop();
303          Console.WriteLine(""{0}\tbest: {1:0.000}\t(avg: {2:0.000})\t(avg size: {3:0.0})\t(avg. depth: {4:0.0})\t({5:0.00} sols/ms)"", n, bestF, sumF/1000.0, sumSize/1000.0, sumDepth/1000.0, 1000.0 / sw.ElapsedMilliseconds);
304          sumSize = 0;
305          sumDepth = 0;
306          sumF = 0.0;
307          sw.Restart();
308        }
309      }
310    }
311  }
312}";
313
314    public void Generate(IGrammar grammar, IEnumerable<TerminalNode> terminals, bool maximization, SourceBuilder problemSourceCode) {
315      var solverSourceCode = new SourceBuilder();
316      solverSourceCode.Append(solverTemplate)
317        .Replace("?MAXIMIZATION?", maximization.ToString().ToLowerInvariant())
318        .Replace("?SAMPLEALTERNATIVECODE?", GenerateSampleAlternativeSource(grammar))
319        .Replace("?CREATETERMINALNODECODE?", GenerateCreateTerminalCode(grammar, terminals))
320      ;
321
322      problemSourceCode.Append(solverSourceCode.ToString());
323    }
324
325
326
327    private string GenerateSampleAlternativeSource(IGrammar grammar) {
328      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
329      var sb = new SourceBuilder();
330      int stateCount = 0;
331      foreach (var s in grammar.Symbols) {
332        sb.AppendFormat("case {0}: ", stateCount++);
333        if (grammar.IsTerminal(s)) {
334          // ignore
335        } else {
336          var terminalAltIndexes = grammar.GetAlternatives(s)
337            .Select((alt, idx) => new { alt, idx })
338            .Where((p) => p.alt.All(symb => grammar.IsTerminal(symb)))
339            .Select(p => p.idx);
340          var nonTerminalAltIndexes = grammar.GetAlternatives(s)
341            .Select((alt, idx) => new { alt, idx })
342            .Where((p) => p.alt.Any(symb => grammar.IsNonTerminal(symb)))
343            .Select(p => p.idx);
344          var hasTerminalAlts = terminalAltIndexes.Any();
345          var hasNonTerminalAlts = nonTerminalAltIndexes.Any();
346          if (hasTerminalAlts && hasNonTerminalAlts) {
347            sb.Append("if(maxDepth <= 1) {").BeginBlock();
348            GenerateReturnStatement(terminalAltIndexes, sb);
349            sb.Append("} else {");
350            GenerateReturnStatement(nonTerminalAltIndexes.Concat(terminalAltIndexes), sb);
351            sb.Append("}").EndBlock();
352          } else {
353            GenerateReturnStatement(grammar.NumberOfAlternatives(s), sb);
354          }
355        }
356      }
357      return sb.ToString();
358    }
359    private string GenerateCreateTerminalCode(IGrammar grammar, IEnumerable<TerminalNode> terminals) {
360      Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
361      var sb = new SourceBuilder();
362      var allSymbols = grammar.Symbols.ToList();
363      foreach (var s in grammar.Symbols) {
364        if (grammar.IsTerminal(s)) {
365          sb.AppendFormat("case {0}: {{", allSymbols.IndexOf(s)).BeginBlock();
366          sb.AppendFormat("var t = new {0}Tree();", s.Name).AppendLine();
367          var terminal = terminals.Single(t => t.Ident == s.Name);
368          foreach (var constr in terminal.Constraints) {
369            if (constr.Type == ConstraintNodeType.Set) {
370              throw new NotImplementedException("Support for terminal symbols with attributes is not yet implemented.");
371              // sb.Append("{").BeginBlock();
372              // sb.AppendFormat("var elements = problem.GetAllowed{0}_{1}().ToArray();", terminal.Ident, constr.Ident).AppendLine();
373              // sb.AppendFormat("t.{0} = elements[random.Next(elements.Length)]; ", constr.Ident).EndBlock();
374              // sb.AppendLine("}");
375            } else {
376              throw new NotSupportedException("The MTCS solver does not support RANGE constraints.");
377            }
378          }
379          sb.AppendLine("return t;").EndBlock();
380          sb.Append("}");
381        }
382      }
383      return sb.ToString();
384    }
385    private void GenerateReturnStatement(IEnumerable<int> idxs, SourceBuilder sb) {
386      if (idxs.Count() == 1) {
387        sb.AppendFormat("return {0};", idxs.Single()).AppendLine();
388      } else {
389        var idxStr = idxs.Aggregate(string.Empty, (str, idx) => str + idx + ", ");
390        sb.AppendFormat("return new int[] {{ {0} }}[random.Next({1})]; ", idxStr, idxs.Count()).AppendLine();
391      }
392    }
393
394    private void GenerateReturnStatement(int nAlts, SourceBuilder sb) {
395      if (nAlts > 1) {
396        sb.AppendFormat("return random.Next({0});", nAlts).AppendLine();
397      } else if (nAlts == 1) {
398        sb.AppendLine("return 0; ");
399      } else {
400        sb.AppendLine("throw new InvalidProgramException();");
401      }
402    }
403  }
404}
Note: See TracBrowser for help on using the repository browser.