using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Text; using HeuristicLab.Algorithms.Bandits; using HeuristicLab.Common; using HeuristicLab.Problems.GrammaticalOptimization; namespace HeuristicLab.Algorithms.GrammaticalOptimization { public class MctsContextualSampler { private class TreeNode { public string ident; public ReadonlySequence alt; public int randomTries; public int tries; public List parents; public TreeNode[] children; public bool done = false; public TreeNode(string id, ReadonlySequence alt) { this.ident = id; this.alt = alt; this.parents = new List(); } public override string ToString() { return string.Format("Node({0} tries: {1}, done: {2})", ident, tries, done); } } private Dictionary treeNodes; private TreeNode GetTreeNode(string id, ReadonlySequence alt) { TreeNode n; var canonicalId = problem.CanonicalRepresentation(id); if (!treeNodes.TryGetValue(canonicalId, out n)) { n = new TreeNode(canonicalId, alt); tries.TryGetValue(canonicalId, out n.tries); treeNodes[canonicalId] = n; } return n; } public event Action FoundNewBestSolution; public event Action SolutionEvaluated; private readonly int maxLen; private readonly IProblem problem; private readonly Random random; private readonly int randomTries; private List> updateChain; private TreeNode rootNode; public int treeDepth; public int treeSize; private double bestQuality; public MctsContextualSampler(IProblem problem, int maxLen, Random random, int randomTries) { this.maxLen = maxLen; this.problem = problem; this.random = random; this.randomTries = randomTries; this.v = new Dictionary(1000000); this.tries = new Dictionary(1000000); treeNodes = new Dictionary(); } public void Run(int maxIterations) { bestQuality = double.MinValue; InitPolicies(problem.Grammar); for (int i = 0; !rootNode.done && i < maxIterations; i++) { bool success; var sentence = SampleSentence(problem.Grammar, out success).ToString(); if (success) { var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen); Debug.Assert(quality >= 0 && quality <= 1.0); DistributeReward(quality); RaiseSolutionEvaluated(sentence, quality); if (quality > bestQuality) { bestQuality = quality; RaiseFoundNewBestSolution(sentence, quality); } } } // clean up InitPolicies(problem.Grammar); GC.Collect(); } public void PrintStats() { var n = rootNode; Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}, rootQ {3:F3}, bestQ {4:F3}", treeDepth, treeSize, n.tries, V(n), bestQuality); while (n.children != null) { Console.WriteLine("{0,-30}", n.ident); double maxVForRow = n.children.Select(ch => Math.Min(1.0, Math.Max(0.0, V(ch)))).Max(); if (maxVForRow == 0) maxVForRow = 1.0; for (int i = 0; i < n.children.Length; i++) { var ch = n.children[i]; Console.ForegroundColor = ConsoleEx.ColorForValue(Math.Min(1.0, V(ch)) / maxVForRow); Console.Write("{0,5}", ch.alt); } Console.WriteLine(); for (int i = 0; i < n.children.Length; i++) { var ch = n.children[i]; Console.ForegroundColor = ConsoleEx.ColorForValue(Math.Min(1.0, V(ch)) / maxVForRow); Console.Write("{0,5:F2}", Math.Min(1.0, V(ch)) * 10); } Console.WriteLine(); for (int i = 0; i < n.children.Length; i++) { var ch = n.children[i]; Console.ForegroundColor = ConsoleEx.ColorForValue(Math.Min(1.0, V(ch)) / maxVForRow); Console.Write("{0,5}", ch.done ? "X" : ch.tries.ToString()); } Console.ForegroundColor = ConsoleColor.White; Console.WriteLine(); //n.policy.PrintStats(); n = n.children.Where(ch => !ch.done).OrderByDescending(c => c.tries).First(); } } private void InitPolicies(IGrammar grammar) { this.updateChain = new List>(); this.v.Clear(); this.tries.Clear(); rootNode = GetTreeNode(grammar.SentenceSymbol.ToString(), new ReadonlySequence("$")); treeDepth = 0; treeSize = 0; } private Sequence SampleSentence(IGrammar grammar, out bool success) { updateChain.Clear(); //var startPhrase = new Sequence("a*b+c*d+e*f+E"); var startPhrase = new Sequence(grammar.SentenceSymbol); return CompleteSentence(grammar, startPhrase, out success); } private Sequence CompleteSentence(IGrammar g, Sequence phrase, out bool success) { if (phrase.Length > maxLen) throw new ArgumentException(); if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException(); TreeNode parent = null; TreeNode n = rootNode; var curDepth = 0; while (!phrase.IsTerminal) { updateChain.Add(Tuple.Create(n, parent)); if (n.randomTries < randomTries) { n.randomTries++; treeDepth = Math.Max(treeDepth, curDepth); success = true; return g.CompleteSentenceRandomly(random, phrase, maxLen); } else { char nt = phrase.FirstNonTerminal; int maxLenOfReplacement = maxLen - (phrase.Length - 1); // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2 Debug.Assert(maxLenOfReplacement > 0); var alts = g.GetAlternatives(nt).Where(alt => g.MinPhraseLength(alt) <= maxLenOfReplacement); if (n.randomTries == randomTries && n.children == null) { // create a new node for each alternative n.children = new TreeNode[alts.Count()]; var i = 0; foreach (var alt in alts) { var newPhrase = new Sequence(phrase); newPhrase.ReplaceAt(newPhrase.FirstNonTerminalIndex, 1, alt); if (!newPhrase.IsTerminal) newPhrase = newPhrase.Subsequence(0, newPhrase.FirstNonTerminalIndex + 1); var treeNode = GetTreeNode(newPhrase.ToString(), new ReadonlySequence(alt)); treeNode.parents.Add(n); n.children[i++] = treeNode; } treeSize += n.children.Length; UpdateDone(n); // it could happend that we already finished all variations starting from the branch // stop if (n.done) { success = false; return phrase; } } //int selectedAltIdx = SelectRandom(random, n.children); // => select using eps-greedy int selectedAltIdx = SelectEpsGreedy(random, n.children); //int selectedAltIdx = SelectActionUCB1(random, n.children); Sequence selectedAlt = alts.ElementAt(selectedAltIdx); // replace nt with alt phrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, selectedAlt); curDepth++; // prepare for next iteration parent = n; n = n.children[selectedAltIdx]; //UpdateTD(parent, n, 0.0); } } // while updateChain.Add(Tuple.Create(n, parent)); // the last node is a leaf node (sentence is done), so we never need to visit this node again n.done = true; treeDepth = Math.Max(treeDepth, curDepth); success = true; return phrase; } //private void UpdateTD(TreeNode parent, TreeNode child, double reward) { // double alpha = 1.0; // var vParent = V(parent); // var vChild = V(child); // if (double.IsInfinity(vParent)) vParent = 0.0; // if (double.IsInfinity(vChild)) vChild = 0.0; // UpdateV(parent, (alpha * (reward + vChild - vParent))); //} private void DistributeReward(double reward) { // iterate in reverse order (bottom up) //updateChain.Reverse(); UpdateDone(updateChain.Last().Item1); //UpdateTD(updateChain.Last().Item2, updateChain.Last().Item1, reward); //return; BackPropReward(updateChain.Last().Item1, reward); /* foreach (var e in updateChain) { var node = e.Item1; //var parent = e.Item2; node.tries++; //if (node.children != null && node.children.All(c => c.done)) { // node.done = true; //} UpdateV(node, reward); // the reward for the parent is either the just recieved reward or the value of the best action so far //double value = 0.0; //if (parent != null) { // var doneChilds = parent.children.Where(ch => ch.done); // if (doneChilds.Any()) value = doneChilds.Select(ch => V(ch)).Max(); //} //if (value > reward) reward = value; }*/ } private void BackPropReward(TreeNode n, double reward) { n.tries++; UpdateV(n, reward); foreach (var p in n.parents) BackPropReward(p, reward); } private void UpdateDone(TreeNode n) { if (!n.done && n.children != null && n.children.All(c => c.done)) n.done = true; if (n.done) foreach (var p in n.parents) UpdateDone(p); } private Dictionary v; private Dictionary tries; private void UpdateV(TreeNode n, double reward) { var canonicalStr = problem.CanonicalRepresentation(n.ident); //var canonicalStr = n.ident; double stateV; if (!v.TryGetValue(canonicalStr, out stateV)) { v.Add(canonicalStr, reward); tries.Add(canonicalStr, 1); } else { //v[canonicalStr] = stateV + 0.005 * (reward - stateV); v[canonicalStr] = stateV + (1.0 / tries[canonicalStr]) * (reward - stateV); tries[canonicalStr]++; } } private double V(TreeNode n) { var canonicalStr = problem.CanonicalRepresentation(n.ident); //var canonicalStr = n.ident; double stateV; if (!tries.ContainsKey(canonicalStr)) return double.PositiveInfinity; if (!v.TryGetValue(canonicalStr, out stateV)) { return 0.0; } else { return stateV; } } private int SelectRandom(Random random, TreeNode[] children) { return children.Select((ch, i) => Tuple.Create(ch, i)).Where(p => !p.Item1.done).SelectRandom(random).Item2; } private int SelectEpsGreedy(Random random, TreeNode[] children) { if (random.NextDouble() < 0.2) { return SelectRandom(random, children); } else { var bestQ = double.NegativeInfinity; var bestChildIdx = new List(); for (int i = 0; i < children.Length; i++) { if (children[i].done) continue; // if (children[i].tries == 0) return i; var q = V(children[i]); if (q > bestQ) { bestQ = q; bestChildIdx.Clear(); bestChildIdx.Add(i); } else if (q == bestQ) { bestChildIdx.Add(i); } } Debug.Assert(bestChildIdx.Any()); return bestChildIdx.SelectRandom(random); } } private int SelectActionUCB1(Random random, TreeNode[] children) { int bestAction = -1; double bestQ = double.NegativeInfinity; int totalTries = children.Sum(ch => ch.tries); for (int a = 0; a < children.Length; a++) { var ch = children[a]; if (ch.done) continue; if (ch.tries == 0) return a; var q = V(ch) + Math.Sqrt((2 * Math.Log(totalTries)) / ch.tries); if (q > bestQ) { bestQ = q; bestAction = a; } } Debug.Assert(bestAction > -1); return bestAction; } private void RaiseSolutionEvaluated(string sentence, double quality) { var handler = SolutionEvaluated; if (handler != null) handler(sentence, quality); } private void RaiseFoundNewBestSolution(string sentence, double quality) { var handler = FoundNewBestSolution; if (handler != null) handler(sentence, quality); } } }