using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Text; using HeuristicLab.Grammars; namespace CodeGenerator { public class MonteCarloTreeSearchCodeGen { private string solverTemplate = @" namespace ?PROBLEMNAME? { public class SearchTreeNode { public int tries; public double sumQuality = 0.0; public double bestQuality = double.NegativeInfinity; public bool done; public SearchTreeNode[] children; // only for debugging public double[] Ucb { get { return (from c in children select ?IDENT?MonteCarloTreeSearchSolver.UCB(this, c) ).ToArray(); } } public SearchTreeNode() { } } public sealed class ?IDENT?MonteCarloTreeSearchSolver { private int maxDepth = 20; private readonly ?IDENT?Problem problem; private readonly Random random; private readonly ?IDENT?RandomSearchSolver randomSearch; private SearchTreeNode searchTree = new SearchTreeNode(); private Tree SampleTree(int maxDepth) { var extensionsStack = new Stack>(); // the unfinished tree, the state, the index of the extension point and the maximal depth of a tree inserted at that point var t = new Tree(-1, new Tree[1]); extensionsStack.Push(Tuple.Create(t, 0, 0, maxDepth)); SampleTree(searchTree, extensionsStack); return t.subtrees[0]; } private const int RANDOM_TRIES = 100; private void SampleTree(SearchTreeNode searchTree, Stack> extensionPoints) { var extensionPoint = extensionPoints.Pop(); Tree parent = extensionPoint.Item1; int state = extensionPoint.Item2; int subtreeIdx = extensionPoint.Item3; int maxDepth = extensionPoint.Item4; Debug.Assert(maxDepth >= 1); Debug.Assert(Grammar.minDepth[state] <= maxDepth); Tree t = null; if(searchTree.tries < RANDOM_TRIES || Grammar.subtreeCount[state] == 0) { int steps = 0; int curDepth = this.maxDepth - maxDepth; int depth = this.maxDepth - maxDepth; t = randomSearch.SampleTree(state, maxDepth, ref steps, ref curDepth, ref depth); if(Grammar.subtreeCount[state] == 0) { // when we produced a terminal continue filling up all other empty points Debug.Assert(searchTree.children == null || searchTree.children.Length == 1); if(extensionPoints.Count == 0) { searchTree.done = true; } else { if(searchTree.children == null) searchTree.children = new SearchTreeNode[] { new SearchTreeNode() } ; SampleTree(searchTree.children[0], extensionPoints); if(searchTree.children[0].done) searchTree.done = true; } } else { // fill up all remaining slots randomly foreach(var p in extensionPoints) { var pParent = p.Item1; var pState = p.Item2; var pIdx = p.Item3; var pMaxDepth = p.Item4; curDepth = this.maxDepth - pMaxDepth; depth = curDepth; pParent.subtrees[pIdx] = randomSearch.SampleTree(pState, pMaxDepth, ref steps, ref curDepth, ref depth); } } } else if(Grammar.subtreeCount[state] == 1) { if(searchTree.children == null) { int nChildren = Grammar.transition[state].Length; searchTree.children = new SearchTreeNode[nChildren]; } Debug.Assert(searchTree.children.Length == Grammar.transition[state].Length); Debug.Assert(searchTree.tries - RANDOM_TRIES == searchTree.children.Where(c=>c!=null).Sum(c=>c.tries)); var altIdx = SelectAlternative(searchTree, state, maxDepth); t = new Tree(altIdx, new Tree[1]); extensionPoints.Push(Tuple.Create(t, Grammar.transition[state][altIdx], 0, maxDepth - 1)); SampleTree(searchTree.children[altIdx], extensionPoints); searchTree.done = (from idx in Enumerable.Range(0, searchTree.children.Length) where Grammar.minDepth[Grammar.transition[state][idx]] <= maxDepth - 1 select searchTree.children[idx]).All(c=>c != null && c.done); } else { // multiple subtrees var subtrees = new Tree[Grammar.subtreeCount[state]]; t = new Tree(-1, subtrees); for(int i = subtrees.Length - 1; i >= 0; i--) { extensionPoints.Push(Tuple.Create(t, Grammar.transition[state][i], i, maxDepth - 1)); } SampleTree(searchTree, extensionPoints); } Debug.Assert(parent.subtrees[subtreeIdx] == null); parent.subtrees[subtreeIdx] = t; } private int SelectAlternative(SearchTreeNode searchTree, int state, int maxDepth) { // any alternative not yet explored? var altIndexes = searchTree.children .Select((e,i) => new {Elem = e, Idx = i}) .Where(p => p.Elem == null && Grammar.minDepth[Grammar.transition[state][p.Idx]] <= maxDepth) .Select(p => p.Idx); int altIdx = altIndexes.Any()?altIndexes.First() : -1; if(altIdx >= 0) { searchTree.children[altIdx] = new SearchTreeNode(); return altIdx; } else { altIndexes = searchTree.children .Select((e,i) => new {Elem = e, Idx = i}) .Where(p => p.Elem != null && !p.Elem.done && p.Elem.tries < RANDOM_TRIES && Grammar.minDepth[Grammar.transition[state][p.Idx]] <= maxDepth) .Select(p => p.Idx); altIdx = altIndexes.Any()?altIndexes.First() : -1; if(altIdx >= 0) return altIdx; // select the least sampled alternative //altIdx = -1; //int minSamples = int.MaxValue; //for(int idx = 0; idx < searchTree.children.Length; idx++) { // if(searchTree.children[idx] == null) continue; // if(!searchTree.children[idx].done && Grammar.minDepth[Grammar.transition[state][idx]] <= maxDepth && searchTree.children[idx].tries < minSamples) { // minSamples = searchTree.children[idx].tries; // altIdx = idx; // } //} // select the alternative with the largest average // altIdx = -1; // double best = double.NegativeInfinity; // for(int idx = 0; idx < searchTree.children.Length; idx++) { // if(searchTree.children[idx] == null) continue; // if (!searchTree.children[idx].done && Grammar.minDepth[Grammar.transition[state][idx]] <= maxDepth && UCB(searchTree, searchTree.children[idx]) > best) { // altIdx = idx; // best = UCB(searchTree, searchTree.children[idx]); // } // } // Softmax selection // double temperature = 1; // var ms = searchTree.children.Select((c,i) => c == null || c.done || Grammar.minDepth[Grammar.transition[state][i]] > maxDepth ? 0.0 : Math.Exp((c.sumQuality / c.tries) / temperature)).ToArray(); // var msSum = ms.Sum(); // if(msSum == 0.0) { // // uniform distribution // ms = searchTree.children.Select((c,i) => c == null || c.done || Grammar.minDepth[Grammar.transition[state][i]] > maxDepth ? 0.0 : 1.0).ToArray(); // msSum = ms.Sum(); // } // Debug.Assert(msSum > 0.0); // var r = random.NextDouble() * msSum; // // altIdx = 0; // var aggSum = 0.0; // while(altIdx < searchTree.children.Length && aggSum <= r) { // var c = searchTree.children[altIdx]; // aggSum += ms[altIdx++]; // } // altIdx--; // epsilon-greedy selection double eps = 0.1; if(random.NextDouble() >= eps) { // select best altIdx = 0; while(searchTree.children[altIdx]==null) altIdx++; for(int idx=0;idx maxDepth) continue; if(searchTree.children[idx].bestQuality > searchTree.children[altIdx].bestQuality) { altIdx = idx; } } } else { // select random var allowedIndexes = (searchTree.children .Select((e,i) => new {Elem = e, Idx = i}) .Where(p => p.Elem != null && !p.Elem.done && Grammar.minDepth[Grammar.transition[state][p.Idx]] <= maxDepth) .Select(p => p.Idx)).ToArray(); altIdx = allowedIndexes[random.Next(allowedIndexes.Length)]; } Debug.Assert(altIdx > -1); return altIdx; } } public static double UCB(SearchTreeNode parent, SearchTreeNode n) { Debug.Assert(parent.tries >= n.tries); Debug.Assert(n.tries > 0); return n.sumQuality / n.tries + Math.Sqrt((400 * Math.Log(parent.tries)) / n.tries ); // constant is dependent fitness function values } private void UpdateSearchTree(Tree t, double quality) { var trees = new Stack(); trees.Push(t); UpdateSearchTree(searchTree, trees, quality); } private void UpdateSearchTree(SearchTreeNode searchTree, Stack trees, double quality) { if(trees.Count == 0 || searchTree == null) return; var t = trees.Pop(); if(t.altIdx == -1) { // for trees with multiple sub-trees for(int idx = t.subtrees.Length - 1 ; idx >= 0; idx--) { trees.Push(t.subtrees[idx]); } UpdateSearchTree(searchTree, trees, quality); } else { searchTree.sumQuality += quality; searchTree.tries++; if(quality > searchTree.bestQuality) searchTree.bestQuality = quality; if(t.subtrees != null) { Debug.Assert(t.subtrees.Length == 1); if(searchTree.children != null) { trees.Push(t.subtrees[0]); UpdateSearchTree(searchTree.children[t.altIdx], trees, quality); } } else { if(searchTree.children != null) { Debug.Assert(searchTree.children.Length == 1); UpdateSearchTree(searchTree.children[0], trees, quality); } } } } public ?IDENT?MonteCarloTreeSearchSolver(?IDENT?Problem problem, string[] args) { this.randomSearch = new ?IDENT?RandomSearchSolver(problem, args); if(args.Length > 0 ) ParseArguments(args); this.problem = problem; this.random = new Random(); } private void ParseArguments(string[] args) { var maxDepthRegex = new Regex(@""--maxDepth=(?.+)""); var helpRegex = new Regex(@""--help|/\?""); foreach(var arg in args) { var maxDepthMatch = maxDepthRegex.Match(arg); var helpMatch = helpRegex.Match(arg); if(helpMatch.Success) { PrintUsage(); Environment.Exit(0); } else if(maxDepthMatch.Success) { maxDepth = int.Parse(maxDepthMatch.Groups[""d""].Captures[0].Value, System.Globalization.CultureInfo.InvariantCulture); if(maxDepth < 1 || maxDepth > 100) throw new ArgumentException(""max depth must lie in range [1 ... 100]""); } else { Console.WriteLine(""Unknown switch {0}"", arg); PrintUsage(); Environment.Exit(0); } } } private void PrintUsage() { Console.WriteLine(""Find a solution using Monte-Carlo tree search.""); Console.WriteLine(); Console.WriteLine(""Parameters:""); Console.WriteLine(""\t--maxDepth=\tSets the maximal depth of sampled trees [Default: 20]""); } public void Start() { Console.ReadLine(); var bestF = ?MAXIMIZATION? ? double.NegativeInfinity : double.PositiveInfinity; int n = 0; long sumDepth = 0; long sumSize = 0; var sumF = 0.0; var sw = new System.Diagnostics.Stopwatch(); sw.Start(); while (!searchTree.done) { int steps, depth; var _t = SampleTree(maxDepth); // _t.PrintTree(0); Console.WriteLine(); // inefficient but don't care for now steps = _t.GetSize(); depth = _t.GetDepth(); Debug.Assert(depth <= maxDepth); var f = problem.Evaluate(_t); if(?MAXIMIZATION?) UpdateSearchTree(_t, f); else UpdateSearchTree(_t, -f); n++; sumSize += steps; sumDepth += depth; sumF += f; if (problem.IsBetter(f, bestF)) { bestF = f; _t.PrintTree(0); Console.WriteLine(); Console.WriteLine(""{0}\t{1}\t(size={2}, depth={3})"", n, bestF, steps, depth); } if (n % 1000 == 0) { sw.Stop(); Console.WriteLine(""{0}\tbest: {1:0.000}\t(avg: {2:0.000})\t(avg size: {3:0.0})\t(avg. depth: {4:0.0})\t({5:0.00} sols/ms)"", n, bestF, sumF/1000.0, sumSize/1000.0, sumDepth/1000.0, 1000.0 / sw.ElapsedMilliseconds); sumSize = 0; sumDepth = 0; sumF = 0.0; sw.Restart(); } } } } }"; public void Generate(IGrammar grammar, IEnumerable terminals, bool maximization, SourceBuilder problemSourceCode) { var solverSourceCode = new SourceBuilder(); solverSourceCode.Append(solverTemplate) .Replace("?MAXIMIZATION?", maximization.ToString().ToLowerInvariant()) .Replace("?SAMPLEALTERNATIVECODE?", GenerateSampleAlternativeSource(grammar)) .Replace("?CREATETERMINALNODECODE?", GenerateCreateTerminalCode(grammar, terminals)) ; problemSourceCode.Append(solverSourceCode.ToString()); } private string GenerateSampleAlternativeSource(IGrammar grammar) { Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol)); var sb = new SourceBuilder(); int stateCount = 0; foreach (var s in grammar.Symbols) { sb.AppendFormat("case {0}: ", stateCount++); if (grammar.IsTerminal(s)) { // ignore } else { var terminalAltIndexes = grammar.GetAlternatives(s) .Select((alt, idx) => new { alt, idx }) .Where((p) => p.alt.All(symb => grammar.IsTerminal(symb))) .Select(p => p.idx); var nonTerminalAltIndexes = grammar.GetAlternatives(s) .Select((alt, idx) => new { alt, idx }) .Where((p) => p.alt.Any(symb => grammar.IsNonTerminal(symb))) .Select(p => p.idx); var hasTerminalAlts = terminalAltIndexes.Any(); var hasNonTerminalAlts = nonTerminalAltIndexes.Any(); if (hasTerminalAlts && hasNonTerminalAlts) { sb.Append("if(maxDepth <= 1) {").BeginBlock(); GenerateReturnStatement(terminalAltIndexes, sb); sb.Append("} else {"); GenerateReturnStatement(nonTerminalAltIndexes.Concat(terminalAltIndexes), sb); sb.Append("}").EndBlock(); } else { GenerateReturnStatement(grammar.NumberOfAlternatives(s), sb); } } } return sb.ToString(); } private string GenerateCreateTerminalCode(IGrammar grammar, IEnumerable terminals) { Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol)); var sb = new SourceBuilder(); var allSymbols = grammar.Symbols.ToList(); foreach (var s in grammar.Symbols) { if (grammar.IsTerminal(s)) { sb.AppendFormat("case {0}: {{", allSymbols.IndexOf(s)).BeginBlock(); sb.AppendFormat("var t = new {0}Tree();", s.Name).AppendLine(); var terminal = terminals.Single(t => t.Ident == s.Name); foreach (var constr in terminal.Constraints) { if (constr.Type == ConstraintNodeType.Set) { throw new NotImplementedException("Support for terminal symbols with attributes is not yet implemented."); // sb.Append("{").BeginBlock(); // sb.AppendFormat("var elements = problem.GetAllowed{0}_{1}().ToArray();", terminal.Ident, constr.Ident).AppendLine(); // sb.AppendFormat("t.{0} = elements[random.Next(elements.Length)]; ", constr.Ident).EndBlock(); // sb.AppendLine("}"); } else { throw new NotSupportedException("The MTCS solver does not support RANGE constraints."); } } sb.AppendLine("return t;").EndBlock(); sb.Append("}"); } } return sb.ToString(); } private void GenerateReturnStatement(IEnumerable idxs, SourceBuilder sb) { if (idxs.Count() == 1) { sb.AppendFormat("return {0};", idxs.Single()).AppendLine(); } else { var idxStr = idxs.Aggregate(string.Empty, (str, idx) => str + idx + ", "); sb.AppendFormat("return new int[] {{ {0} }}[random.Next({1})]; ", idxStr, idxs.Count()).AppendLine(); } } private void GenerateReturnStatement(int nAlts, SourceBuilder sb) { if (nAlts > 1) { sb.AppendFormat("return random.Next({0});", nAlts).AppendLine(); } else if (nAlts == 1) { sb.AppendLine("return 0; "); } else { sb.AppendLine("throw new InvalidProgramException();"); } } } }