using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Resources; using System.Runtime.InteropServices; using System.Text; using HeuristicLab.Algorithms.Bandits; using HeuristicLab.Algorithms.Bandits.BanditPolicies; using HeuristicLab.Algorithms.Bandits.GrammarPolicies; using HeuristicLab.Common; using HeuristicLab.Problems.GrammaticalOptimization; namespace HeuristicLab.Algorithms.GrammaticalOptimization { // a search procedure that uses a policy to generate sentences and updates the policy (online RL) // 1) Start with phrase = sentence symbol of grammar // 2) Repeat // a) generate derived phrases using left-canonical derivation and grammar rules // b) keep only the phrases which are allowed (sentence length limit) // c) if the set of phrases is empty restart with 1) // d) otherwise use policy to select one of the possible derived phrases as active phrase // the policy has the option to fail (for instance if all derived phrases are terminal and should not be visited again), in this case we restart at 1 // ... until phrase is terminal // 3) Collect reward and update policy (feedback: state of visited rewards from step 2) public class SequentialSearch : SolverBase { // only for storing states so that it is not necessary to allocate new state strings whenever we select a follow state using the policy private class TreeNode { public int randomTries; public string phrase; public Sequence alternative; public TreeNode[] children; public TreeNode(string phrase, Sequence alternative) { this.alternative = alternative; this.phrase = phrase; } } private readonly int maxLen; private readonly IProblem problem; private readonly Random random; private readonly int randomTries; private readonly IGrammarPolicy behaviourPolicy; private TreeNode rootNode; private int tries; private int maxSearchDepth; private string bestPhrase; private readonly List stateChain; public SequentialSearch(IProblem problem, int maxLen, Random random, int randomTries, IGrammarPolicy behaviourPolicy) { this.maxLen = maxLen; this.problem = problem; this.random = random; this.randomTries = randomTries; this.behaviourPolicy = behaviourPolicy; this.stateChain = new List(); } public bool StopRequested { get; set; } public override void Run(int maxIterations) { Reset(); for (int i = 0; !StopRequested && !Done() && i < maxIterations; i++) { var phrase = SampleSentence(problem.Grammar); // can fail on the last sentence if (phrase.IsTerminal) { var sentence = phrase.ToString(); tries++; var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen); if (double.IsNaN(quality)) quality = 0.0; Debug.Assert(quality >= 0 && quality <= 1.0); if (quality > bestQuality) { bestPhrase = sentence; } OnSolutionEvaluated(sentence, quality); DistributeReward(quality); } } } private Sequence SampleSentence(IGrammar grammar) { Sequence phrase; do { stateChain.Clear(); phrase = new Sequence(rootNode.phrase); } while (!Done() && !TryCompleteSentence(grammar, ref phrase)); return phrase; } private bool TryCompleteSentence(IGrammar g, ref Sequence phrase) { if (phrase.Length > maxLen) throw new ArgumentException(); if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException(); var curDepth = 0; var n = rootNode; stateChain.Add(n.phrase); while (!phrase.IsTerminal) { if (n.randomTries < randomTries) { n.randomTries++; maxSearchDepth = Math.Max(maxSearchDepth, curDepth); g.CompleteSentenceRandomly(random, phrase, maxLen); return true; } else { // => select using bandit policy // failure means we simply restart GenerateFollowStates(n); // creates child nodes for node n int selectedChildIdx; if (!behaviourPolicy.TrySelect(random, n.phrase, n.children.Select(ch => ch.phrase), out selectedChildIdx)) { return false; } phrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, n.children[selectedChildIdx].alternative); // prepare for next iteration n = n.children[selectedChildIdx]; stateChain.Add(n.phrase); curDepth++; } } // while maxSearchDepth = Math.Max(maxSearchDepth, curDepth); return true; } private IEnumerable GenerateFollowStates(TreeNode n) { // create children on the first visit if (n.children == null) { var g = problem.Grammar; // tree is only used for easily retrieving the follow-states of a state var phrase = new Sequence(n.phrase); char nt = phrase.FirstNonTerminal; int maxLenOfReplacement = maxLen - (phrase.Length - 1); // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2 Debug.Assert(maxLenOfReplacement > 0); var alts = g.GetAlternatives(nt).Where(alt => g.MinPhraseLength(alt) <= maxLenOfReplacement); var children = new TreeNode[alts.Count()]; int idx = 0; foreach (var alt in alts) { // var newPhrase = new Sequence(phrase); // clone // newPhrase.ReplaceAt(newPhrase.FirstNonTerminalIndex, 1, alt); // children[idx++] = new TreeNode(newPhrase.ToString(), alt); // since we are not using a sequence later on we might directly transform the current sequence to a string and replace there var phraseStr = phrase.ToString(); var sb = new StringBuilder(phraseStr); sb.Remove(phrase.FirstNonTerminalIndex, 1).Insert(phrase.FirstNonTerminalIndex, alt.ToString()); children[idx++] = new TreeNode(sb.ToString(), alt); } n.children = children; } return n.children.Select(ch => ch.phrase); } private void DistributeReward(double reward) { behaviourPolicy.UpdateReward(stateChain, reward); } private void Reset() { StopRequested = false; behaviourPolicy.Reset(); maxSearchDepth = 0; bestQuality = 0.0; tries = 0; rootNode = new TreeNode(problem.Grammar.SentenceSymbol.ToString(), new ReadonlySequence("$")); } public bool Done() { int selectedStateIdx; return !behaviourPolicy.TrySelect(random, rootNode.phrase, GenerateFollowStates(rootNode), out selectedStateIdx); } #region introspection public void PrintStats() { Console.WriteLine("depth: {0,5} tries: {1,5} best phrase {2,50} bestQ {3:F3}", maxSearchDepth, tries, bestPhrase, bestQuality); // use behaviour strategy to generate the currently prefered sentence var policy = behaviourPolicy; var n = rootNode; while (n != null) { var phrase = n.phrase; Console.ForegroundColor = ConsoleColor.White; Console.WriteLine("{0,-30}", phrase); var children = n.children; if (children == null || !children.Any()) break; var values = children.Select(ch => policy.GetValue(ch.phrase)); var maxValue = values.Max(); if (maxValue == 0) maxValue = 1.0; if (double.IsPositiveInfinity(maxValue)) maxValue = double.MaxValue; // write phrases foreach (var ch in children) { SetColorForValue(policy.GetValue(ch.phrase) / maxValue); Console.Write(" {0,-4}", ch.phrase.Substring(Math.Max(0, ch.phrase.Length - 3), Math.Min(3, ch.phrase.Length))); } Console.WriteLine(); // write values foreach (var ch in children) { SetColorForValue(policy.GetValue(ch.phrase) / maxValue); Console.Write(" {0:F2}", policy.GetValue(ch.phrase) * 10.0); } Console.WriteLine(); // write tries foreach (var ch in children) { SetColorForValue(policy.GetValue(ch.phrase) / maxValue); Console.Write(" {0,4}", policy.GetTries(ch.phrase)); } Console.WriteLine(); int selectedChildIdx; if (!policy.TrySelect(random, phrase, children.Select(ch => ch.phrase), out selectedChildIdx)) { break; } n = n.children[selectedChildIdx]; } Console.ForegroundColor = ConsoleColor.White; Console.WriteLine("-------------------"); } private void SetColorForValue(double v) { Console.ForegroundColor = ConsoleEx.ColorForValue(v); } #endregion } }