using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Text; using HeuristicLab.Algorithms.Bandits; using HeuristicLab.Common; using HeuristicLab.Problems.GrammaticalOptimization; namespace HeuristicLab.Algorithms.GrammaticalOptimization { // SARSA (fig. 6.9 in Sutton & Barto) public class TemporalDifferenceTreeSearchSampler { private class TreeNode { public string ident; public int randomTries; public double q; public int tries; public TreeNode[] children; public bool done = false; public TreeNode(string id) { this.ident = id; } public override string ToString() { return string.Format("Node({0} tries: {1}, done: {2})", ident, tries, done); } } public event Action FoundNewBestSolution; public event Action SolutionEvaluated; private readonly int maxLen; private readonly IProblem problem; private readonly Random random; private readonly int randomTries; private List updateChain; private TreeNode rootNode; public int treeDepth; public int treeSize; private double bestQuality; public TemporalDifferenceTreeSearchSampler(IProblem problem, int maxLen, Random random, int randomTries) { this.maxLen = maxLen; this.problem = problem; this.random = random; this.randomTries = randomTries; } public void Run(int maxIterations) { InitPolicies(problem.Grammar); for (int i = 0; !rootNode.done && i < maxIterations; i++) { var sentence = SampleSentence(problem.Grammar).ToString(); var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen); Debug.Assert(quality >= 0 && quality <= 1.0); DistributeReward(quality); RaiseSolutionEvaluated(sentence, quality); if (quality > bestQuality) { bestQuality = quality; RaiseFoundNewBestSolution(sentence, quality); } } // clean up InitPolicies(problem.Grammar); GC.Collect(); } public void PrintStats() { var n = rootNode; Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}, rootQ {3:F3}, bestQ {4:F3}", treeDepth, treeSize, n.tries, n.q, bestQuality); while (n.children != null) { Console.WriteLine("{0,-30}", n.ident); double maxVForRow = n.children.Select(ch => ch.q).Max(); if (maxVForRow == 0) maxVForRow = 1.0; for (int i = 0; i < n.children.Length; i++) { var ch = n.children[i]; Console.ForegroundColor = ConsoleEx.ColorForValue(ch.q / maxVForRow); Console.Write("{0,5}", ch.ident); } Console.WriteLine(); for (int i = 0; i < n.children.Length; i++) { var ch = n.children[i]; Console.ForegroundColor = ConsoleEx.ColorForValue(ch.q / maxVForRow); Console.Write("{0,5:F2}", ch.q * 10); } Console.WriteLine(); for (int i = 0; i < n.children.Length; i++) { var ch = n.children[i]; Console.ForegroundColor = ConsoleEx.ColorForValue(ch.q / maxVForRow); Console.Write("{0,5}", ch.done ? "X" : ch.tries.ToString()); } Console.ForegroundColor = ConsoleColor.White; Console.WriteLine(); //n.policy.PrintStats(); n = n.children.Where(ch => !ch.done).OrderByDescending(c => c.q).First(); } } private void InitPolicies(IGrammar grammar) { this.updateChain = new List(); rootNode = new TreeNode(grammar.SentenceSymbol.ToString()); treeDepth = 0; treeSize = 0; } private Sequence SampleSentence(IGrammar grammar) { updateChain.Clear(); var startPhrase = new Sequence(grammar.SentenceSymbol); return CompleteSentence(grammar, startPhrase); } private Sequence CompleteSentence(IGrammar g, Sequence phrase) { if (phrase.Length > maxLen) throw new ArgumentException(); if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException(); TreeNode n = rootNode; var curDepth = 0; while (!phrase.IsTerminal) { updateChain.Add(n); if (n.randomTries < randomTries) { n.randomTries++; treeDepth = Math.Max(treeDepth, curDepth); return g.CompleteSentenceRandomly(random, phrase, maxLen); } else { char nt = phrase.FirstNonTerminal; int maxLenOfReplacement = maxLen - (phrase.Length - 1); // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2 Debug.Assert(maxLenOfReplacement > 0); var alts = g.GetAlternatives(nt).Where(alt => g.MinPhraseLength(alt) <= maxLenOfReplacement); if (n.randomTries == randomTries && n.children == null) { n.children = alts.Select(alt => new TreeNode(alt.ToString())).ToArray(); // create a new node for each alternative treeSize += n.children.Length; } // => select using bandit policy int selectedAltIdx = SelectEpsGreedy(random, n.children); Sequence selectedAlt = alts.ElementAt(selectedAltIdx); // replace nt with alt phrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, selectedAlt); curDepth++; // prepare for next iteration n = n.children[selectedAltIdx]; } } // while updateChain.Add(n); // the last node is a leaf node (sentence is done), so we never need to visit this node again n.done = true; treeDepth = Math.Max(treeDepth, curDepth); return phrase; } // eps-greedy private int SelectEpsGreedy(Random random, TreeNode[] children) { if (random.NextDouble() < 0.1) { return children.Select((ch, i) => Tuple.Create(ch, i)).Where(p => !p.Item1.done).SelectRandom(random).Item2; } else { var bestQ = double.NegativeInfinity; var bestChildIdx = new List(); for (int i = 0; i < children.Length; i++) { if (children[i].done) continue; // if (children[i].tries == 0) return i; var q = children[i].q; if (q > bestQ) { bestQ = q; bestChildIdx.Clear(); bestChildIdx.Add(i); } else if (q == bestQ) { bestChildIdx.Add(i); } } Debug.Assert(bestChildIdx.Any()); return bestChildIdx.SelectRandom(random); } } private void DistributeReward(double reward) { updateChain.Reverse(); foreach (var node in updateChain) { if (node.children != null && node.children.All(c => c.done)) { node.done = true; } } updateChain.Reverse(); //const double alpha = 0.1; const double gamma = 1; double alpha; foreach (var p in updateChain.Zip(updateChain.Skip(1), Tuple.Create)) { var parent = p.Item1; var child = p.Item2; parent.tries++; alpha = 1.0 / parent.tries; //alpha = 0.01; parent.q = parent.q + alpha * (0 + gamma * child.q - parent.q); } // reward is recieved only for the last action var n = updateChain.Last(); n.tries++; alpha = 1.0 / n.tries; //alpha = 0.1; n.q = n.q + alpha * reward; } private void RaiseSolutionEvaluated(string sentence, double quality) { var handler = SolutionEvaluated; if (handler != null) handler(sentence, quality); } private void RaiseFoundNewBestSolution(string sentence, double quality) { var handler = FoundNewBestSolution; if (handler != null) handler(sentence, quality); } } }