using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Text; using HeuristicLab.Algorithms.Bandits; using HeuristicLab.Common; using HeuristicLab.Problems.GrammaticalOptimization; namespace HeuristicLab.Algorithms.GrammaticalOptimization { public class AlternativesContextSampler { public event Action FoundNewBestSolution; public event Action SolutionEvaluated; private readonly int maxLen; private readonly IProblem problem; private readonly Random random; private readonly int contextLen; private readonly IPolicy policy; public AlternativesContextSampler(IProblem problem, Random random, int maxLen, int contextLen, IPolicy policy) { this.maxLen = maxLen; this.problem = problem; this.random = random; this.contextLen = contextLen; this.policy = policy; } public void Run(int maxIterations) { double bestQuality = double.MinValue; InitPolicies(problem.Grammar); for (int i = 0; i < maxIterations; i++) { var sentence = SampleSentence(problem.Grammar).ToString(); var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen); DistributeReward(quality); RaiseSolutionEvaluated(sentence, quality); if (quality > bestQuality) { bestQuality = quality; RaiseFoundNewBestSolution(sentence, quality); } } } private Dictionary contextActionInfos; private List> updateChain; private void InitPolicies(IGrammar grammar) { this.contextActionInfos = new Dictionary(); this.updateChain = new List>(); } private Sequence SampleSentence(IGrammar grammar) { updateChain.Clear(); return CompleteSentence(grammar, new Sequence(grammar.SentenceSymbol)); } public Sequence CompleteSentence(IGrammar g, Sequence phrase) { if (phrase.Length > maxLen) throw new ArgumentException(); if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException(); bool done = phrase.IsTerminal; // terminal phrase means we are done while (!done) { char nt = phrase.FirstNonTerminal; int maxLenOfReplacement = maxLen - (phrase.Length - 1); // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2 Debug.Assert(maxLenOfReplacement > 0); var alts = g.GetAlternatives(nt); Sequence selectedAlt; // if the choice is restricted then one of the allowed alternatives is selected randomly if (alts.Any(alt => g.MinPhraseLength(alt) > maxLenOfReplacement)) { var allowedAlts = alts.Where(alt => g.MinPhraseLength(alt) <= maxLenOfReplacement); Debug.Assert(allowedAlts.Any()); // replace nt with random alternative selectedAlt = allowedAlts.SelectRandom(random); } else { // all alts are allowed => select using bandit policy var ntIdx = phrase.FirstNonTerminalIndex; var startIdx = Math.Max(0, ntIdx - contextLen); var endIdx = Math.Min(startIdx + contextLen, ntIdx); var lft = phrase.Subsequence(startIdx, endIdx - startIdx + 1).ToString(); lft = problem.Hash(lft); if (!contextActionInfos.ContainsKey(lft)) { contextActionInfos.Add(lft, g.GetAlternatives(nt).Select(_ => policy.CreateActionInfo()).ToArray()); } var selectedAltIdx = policy.SelectAction(random, contextActionInfos[lft]); selectedAlt = alts.ElementAt(selectedAltIdx); updateChain.Add(Tuple.Create(lft, selectedAltIdx)); } // replace nt with alt phrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, selectedAlt); done = phrase.IsTerminal; // terminal phrase means we are done } return phrase; } private void DistributeReward(double reward) { foreach (var e in updateChain) { var lft = e.Item1; var action = e.Item2; contextActionInfos[lft][action].UpdateReward(reward); } } private void RaiseSolutionEvaluated(string sentence, double quality) { var handler = SolutionEvaluated; if (handler != null) handler(sentence, quality); } private void RaiseFoundNewBestSolution(string sentence, double quality) { var handler = FoundNewBestSolution; if (handler != null) handler(sentence, quality); } } }