using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Text; using HeuristicLab.Algorithms.Bandits; using HeuristicLab.Common; using HeuristicLab.Problems.GrammaticalOptimization; namespace HeuristicLab.Algorithms.GrammaticalOptimization { public class AlternativesSampler { public event Action FoundNewBestSolution; public event Action SolutionEvaluated; private readonly int maxLen; private readonly Random random; private readonly IProblem problem; private readonly IBanditPolicy policy; public AlternativesSampler(IProblem problem, IBanditPolicy policy, int maxLen) { this.problem = problem; this.maxLen = maxLen; this.random = new Random(); this.policy = policy; } public void Run(int maxIterations) { double bestQuality = double.MinValue; InitPolicies(problem.Grammar); for (int i = 0; i < maxIterations; i++) { var sentence = SampleSentence(problem.Grammar).ToString(); var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen); DistributeReward(quality); RaiseSolutionEvaluated(sentence, quality); if (quality > bestQuality) { bestQuality = quality; RaiseFoundNewBestSolution(sentence, quality); } } } private Dictionary ntActionInfos; private List> updateChain; private void InitPolicies(IGrammar grammar) { this.ntActionInfos = new Dictionary(); this.updateChain = new List>(); foreach (var nt in grammar.NonTerminalSymbols) { ntActionInfos.Add(nt, grammar.GetAlternatives(nt).Select(_ => policy.CreateActionInfo()).ToArray()); } } private Sequence SampleSentence(IGrammar grammar) { updateChain.Clear(); return CompleteSentence(grammar, new Sequence(grammar.SentenceSymbol)); } public Sequence CompleteSentence(IGrammar g, Sequence phrase) { if (phrase.Length > maxLen) throw new ArgumentException(); if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException(); bool done = phrase.IsTerminal; // terminal phrase means we are done while (!done) { char nt = phrase.FirstNonTerminal; int maxLenOfReplacement = maxLen - (phrase.Length - 1); // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2 Debug.Assert(maxLenOfReplacement > 0); var alts = g.GetAlternatives(nt); Sequence selectedAlt; // if the choice is restricted then one of the allowed alternatives is selected randomly if (alts.Any(alt => g.MinPhraseLength(alt) > maxLenOfReplacement)) { var allowedAlts = alts.Where(alt => g.MinPhraseLength(alt) <= maxLenOfReplacement); Debug.Assert(allowedAlts.Any()); // replace nt with random alternative selectedAlt = allowedAlts.SelectRandom(random); } else { // all alts are allowed => select using bandit policy var selectedAltIdx = policy.SelectAction(random, ntActionInfos[nt]); selectedAlt = alts.ElementAt(selectedAltIdx); updateChain.Add(Tuple.Create(nt, selectedAltIdx)); } // replace nt with alt phrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, selectedAlt); done = phrase.IsTerminal; // terminal phrase means we are done } return phrase; } private void DistributeReward(double reward) { foreach (var e in updateChain) { var nt = e.Item1; var action = e.Item2; ntActionInfos[nt][action].UpdateReward(reward); } } private void RaiseSolutionEvaluated(string sentence, double quality) { var handler = SolutionEvaluated; if (handler != null) handler(sentence, quality); } private void RaiseFoundNewBestSolution(string sentence, double quality) { var handler = FoundNewBestSolution; if (handler != null) handler(sentence, quality); } } }