using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Text; using HeuristicLab.Algorithms.Bandits; using HeuristicLab.Problems.GrammaticalOptimization; namespace HeuristicLab.Algorithms.GrammaticalOptimization { public class MctsContextualSampler { private class TreeNode { public int randomTries; public int policyTries; public TreeNode[] children; public readonly ReadonlySequence phrase; public readonly ReadonlySequence alt; // phrase represents the phrase of the state and alt represents how the phrase has been reached from the parent state public TreeNode(ReadonlySequence phrase, ReadonlySequence alt) { this.phrase = phrase; this.alt = alt; } public override string ToString() { return string.Format("Node({0} tries: {1})", phrase, randomTries + policyTries); } } public event Action FoundNewBestSolution; public event Action SolutionEvaluated; private readonly int maxLen; private readonly IProblem problem; private readonly Random random; private readonly int randomTries; private readonly IGrammarPolicy policy; private List> updateChain; private TreeNode rootNode; public int treeDepth; public int treeSize; // public MctsSampler(IProblem problem, int maxLen, Random random) : // this(problem, maxLen, random, 10, (rand, numActions) => new EpsGreedyPolicy(rand, numActions, 0.1)) { // // } public MctsContextualSampler(IProblem problem, int maxLen, Random random, int randomTries, IGrammarPolicy policy) { this.maxLen = maxLen; this.problem = problem; this.random = random; this.randomTries = randomTries; this.policy = policy; } public void Run(int maxIterations) { double bestQuality = double.MinValue; InitPolicies(problem.Grammar); for (int i = 0; !policy.Done(rootNode.phrase) && i < maxIterations; i++) { var sentence = SampleSentence(problem.Grammar).ToString(); var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen); Debug.Assert(quality >= 0 && quality <= 1.0); DistributeReward(quality); RaiseSolutionEvaluated(sentence, quality); if (quality > bestQuality) { bestQuality = quality; RaiseFoundNewBestSolution(sentence, quality); } } // clean up InitPolicies(problem.Grammar); GC.Collect(); } public void PrintStats() { var n = rootNode; Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}", treeDepth, treeSize, rootNode.policyTries + rootNode.randomTries); while (n.children != null) { Console.WriteLine(); Console.WriteLine("{0,5}->{1,-50}", n.alt, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.alt)))); Console.WriteLine("{0,5} {1,-50}", string.Empty, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.randomTries + ch.policyTries)))); //n.policy.PrintStats(); n = n.children.OrderByDescending(c => c.policyTries).First(); } Console.ReadLine(); } private void InitPolicies(IGrammar grammar) { this.updateChain = new List>(); rootNode = new TreeNode(new ReadonlySequence(grammar.SentenceSymbol), new ReadonlySequence("$")); treeDepth = 0; treeSize = 0; } private Sequence SampleSentence(IGrammar grammar) { updateChain.Clear(); var startPhrase = new Sequence(rootNode.phrase); return CompleteSentence(grammar, startPhrase); } private Sequence CompleteSentence(IGrammar g, Sequence phrase) { if (phrase.Length > maxLen) throw new ArgumentException(); if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException(); TreeNode parent = null; TreeNode n = rootNode; bool done = false; var curDepth = 0; while (!done) { if (parent != null) updateChain.Add(Tuple.Create(parent.phrase, n.alt, n.phrase)); if (n.randomTries < randomTries) { n.randomTries++; treeDepth = Math.Max(treeDepth, curDepth); return g.CompleteSentenceRandomly(random, phrase, maxLen); } else { char nt = phrase.FirstNonTerminal; int maxLenOfReplacement = maxLen - (phrase.Length - 1); // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2 Debug.Assert(maxLenOfReplacement > 0); var alts = g.GetAlternatives(nt).Where(alt => g.MinPhraseLength(alt) <= maxLenOfReplacement); if (n.randomTries == randomTries && n.children == null) { n.children = new TreeNode[alts.Count()]; int cIdx = 0; foreach (var alt in alts) { var newPhrase = new Sequence(phrase); newPhrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, alt); n.children[cIdx++] = new TreeNode(new ReadonlySequence(newPhrase), new ReadonlySequence(alt)); } treeSize += n.children.Length; } n.policyTries++; // => select using bandit policy ReadonlySequence selectedAlt = policy.SelectAction(random, n.phrase, n.children.Select(c => c.alt)); // replace nt with alt phrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, selectedAlt); curDepth++; done = phrase.IsTerminal; // prepare for next iteration parent = n; n = n.children.Single(ch => ch.alt == selectedAlt); // TODO: perf } } // while n.policyTries++; updateChain.Add(Tuple.Create(parent.phrase, n.alt, n.phrase)); treeDepth = Math.Max(treeDepth, curDepth); return phrase; } private void DistributeReward(double reward) { // iterate in reverse order (bottom up) updateChain.Reverse(); foreach (var e in updateChain) { var state = e.Item1; var action = e.Item2; var newState = e.Item3; policy.UpdateReward(state, action, reward, newState); //policy.UpdateReward(action, reward / updateChain.Count); } } private void RaiseSolutionEvaluated(string sentence, double quality) { var handler = SolutionEvaluated; if (handler != null) handler(sentence, quality); } private void RaiseFoundNewBestSolution(string sentence, double quality) { var handler = FoundNewBestSolution; if (handler != null) handler(sentence, quality); } } }