source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/AlternativesContextSampler.cs @ 11730

Last change on this file since 11730 was 11730, checked in by gkronber, 5 years ago

#2283: several major extensions for grammatical optimization

File size: 4.5 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using HeuristicLab.Algorithms.Bandits;
7using HeuristicLab.Common;
8using HeuristicLab.Problems.GrammaticalOptimization;
9
10namespace HeuristicLab.Algorithms.GrammaticalOptimization {
11  public class AlternativesContextSampler {
12    public event Action<string, double> FoundNewBestSolution;
13    public event Action<string, double> SolutionEvaluated;
14
15    private readonly int maxLen;
16    private readonly IProblem problem;
17    private readonly Random random;
18    private readonly int contextLen;
19    private readonly Func<Random, int, IPolicy> policyFactory;
20
21    public AlternativesContextSampler(IProblem problem, Random random, int maxLen, int contextLen, Func<Random, int, IPolicy> policyFactory) {
22      this.maxLen = maxLen;
23      this.problem = problem;
24      this.random = random;
25      this.contextLen = contextLen;
26      this.policyFactory = policyFactory;
27    }
28
29    public void Run(int maxIterations) {
30      double bestQuality = double.MinValue;
31      InitPolicies(problem.Grammar);
32      for (int i = 0; i < maxIterations; i++) {
33        var sentence = SampleSentence(problem.Grammar).ToString();
34        var quality = problem.Evaluate(sentence) / problem.GetBestKnownQuality(maxLen);
35        DistributeReward(quality);
36
37        RaiseSolutionEvaluated(sentence, quality);
38
39        if (quality > bestQuality) {
40          bestQuality = quality;
41          RaiseFoundNewBestSolution(sentence, quality);
42        }
43      }
44    }
45
46
47    private Dictionary<string, IPolicy> ntPolicy;
48    private List<Tuple<string, int>> updateChain;
49
50    private void InitPolicies(IGrammar grammar) {
51      this.ntPolicy = new Dictionary<string, IPolicy>();
52      this.updateChain = new List<Tuple<string, int>>();
53    }
54
55    private Sequence SampleSentence(IGrammar grammar) {
56      updateChain.Clear();
57      return CompleteSentence(grammar, new Sequence(grammar.SentenceSymbol));
58    }
59
60    public Sequence CompleteSentence(IGrammar g, Sequence phrase) {
61      if (phrase.Length > maxLen) throw new ArgumentException();
62      if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException();
63      bool done = phrase.IsTerminal; // terminal phrase means we are done
64      while (!done) {
65        char nt = phrase.FirstNonTerminal;
66
67        int maxLenOfReplacement = maxLen - (phrase.Length - 1); // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2
68        Debug.Assert(maxLenOfReplacement > 0);
69
70        var alts = g.GetAlternatives(nt);
71        Sequence selectedAlt;
72        // if the choice is restricted then one of the allowed alternatives is selected randomly
73        if (alts.Any(alt => g.MinPhraseLength(alt) > maxLenOfReplacement)) {
74          var allowedAlts = alts.Where(alt => g.MinPhraseLength(alt) <= maxLenOfReplacement);
75          Debug.Assert(allowedAlts.Any());
76          // replace nt with random alternative
77          selectedAlt = allowedAlts.SelectRandom(random);
78        } else {
79          // all alts are allowed => select using bandit policy
80          var ntIdx = phrase.FirstNonTerminalIndex;
81          var startIdx = Math.Max(0, ntIdx - contextLen);
82          var endIdx = Math.Min(startIdx + contextLen, ntIdx);
83          var lft = phrase.Subsequence(startIdx, endIdx - startIdx + 1).ToString();
84          lft = problem.Hash(lft);
85          if (!ntPolicy.ContainsKey(lft)) {
86            ntPolicy.Add(lft, policyFactory(random, g.GetAlternatives(nt).Count()));
87          }
88          var selectedAltIdx = ntPolicy[lft].SelectAction();
89          selectedAlt = alts.ElementAt(selectedAltIdx);
90          updateChain.Add(Tuple.Create(lft, selectedAltIdx));
91        }
92
93        // replace nt with alt
94        phrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, selectedAlt);
95
96        done = phrase.IsTerminal; // terminal phrase means we are done
97      }
98      return phrase;
99    }
100
101    private void DistributeReward(double reward) {
102      foreach (var e in updateChain) {
103        var lft = e.Item1;
104        var action = e.Item2;
105        ntPolicy[lft].UpdateReward(action, reward);
106      }
107    }
108
109    private void RaiseSolutionEvaluated(string sentence, double quality) {
110      var handler = SolutionEvaluated;
111      if (handler != null) handler(sentence, quality);
112    }
113    private void RaiseFoundNewBestSolution(string sentence, double quality) {
114      var handler = FoundNewBestSolution;
115      if (handler != null) handler(sentence, quality);
116    }
117  }
118}
Note: See TracBrowser for help on using the repository browser.