Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization-gkr/HeuristicLab.Algorithms.GrammaticalOptimization/Solvers/SequentialSearch.cs @ 13398

Last change on this file since 13398 was 12893, checked in by gkronber, 9 years ago

#2283: experiments on grammatical optimization algorithms (maxreward instead of avg reward, ...)

File size: 9.5 KB
RevLine 
[11770]1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Resources;
6using System.Runtime.InteropServices;
7using System.Text;
[12893]8using System.Windows.Markup;
[11770]9using HeuristicLab.Algorithms.Bandits;
10using HeuristicLab.Algorithms.Bandits.BanditPolicies;
11using HeuristicLab.Algorithms.Bandits.GrammarPolicies;
12using HeuristicLab.Common;
13using HeuristicLab.Problems.GrammaticalOptimization;
14
15namespace HeuristicLab.Algorithms.GrammaticalOptimization {
16  // a search procedure that uses a policy to generate sentences and updates the policy (online RL)
17  // 1) Start with phrase = sentence symbol of grammar
18  // 2) Repeat
19  //    a) generate derived phrases using left-canonical derivation and grammar rules
20  //    b) keep only the phrases which are allowed (sentence length limit)
21  //    c) if the set of phrases is empty restart with 1)
22  //    d) otherwise use policy to select one of the possible derived phrases as active phrase
23  //       the policy has the option to fail (for instance if all derived phrases are terminal and should not be visited again), in this case we restart at 1
24  //    ... until phrase is terminal
25  // 3) Collect reward and update policy (feedback: state of visited rewards from step 2)
[11846]26  public class SequentialSearch : SolverBase {
[11793]27    // only for storing states so that it is not necessary to allocate new state strings whenever we select a follow state using the policy
28    private class TreeNode {
29      public int randomTries;
30      public string phrase;
31      public Sequence alternative;
32      public TreeNode[] children;
[11770]33
[11793]34      public TreeNode(string phrase, Sequence alternative) {
35        this.alternative = alternative;
36        this.phrase = phrase;
37      }
38    }
39
40
[11770]41    private readonly int maxLen;
42    private readonly IProblem problem;
[12893]43    private readonly System.Random random;
[11770]44    private readonly int randomTries;
45    private readonly IGrammarPolicy behaviourPolicy;
[11793]46    private TreeNode rootNode;
47
48    private int tries;
[11770]49    private int maxSearchDepth;
50
51    private string bestPhrase;
[11793]52    private readonly List<string> stateChain;
[11770]53
[12893]54    public SequentialSearch(IProblem problem, int maxLen, System.Random random, int randomTries, IGrammarPolicy behaviourPolicy) {
[11770]55      this.maxLen = maxLen;
56      this.problem = problem;
57      this.random = random;
58      this.randomTries = randomTries;
59      this.behaviourPolicy = behaviourPolicy;
[11793]60      this.stateChain = new List<string>();
[11770]61    }
62
[11977]63    public bool StopRequested {
64      get;
65      set;
66    }
67
[11846]68    public override void Run(int maxIterations) {
[11770]69      Reset();
70
[11977]71      for (int i = 0; !StopRequested && !Done() && i < maxIterations; i++) {
[11770]72        var phrase = SampleSentence(problem.Grammar);
73        // can fail on the last sentence
74        if (phrase.IsTerminal) {
75          var sentence = phrase.ToString();
76          tries++;
77          var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen);
[11806]78          if (double.IsNaN(quality)) quality = 0.0;
[11770]79          Debug.Assert(quality >= 0 && quality <= 1.0);
80
81          if (quality > bestQuality) {
82            bestPhrase = sentence;
83          }
[11846]84
85          OnSolutionEvaluated(sentence, quality);
86          DistributeReward(quality);
87
[11770]88        }
89      }
90    }
91
92
[11793]93    private Sequence SampleSentence(IGrammar grammar) {
94      Sequence phrase;
[11770]95      do {
96        stateChain.Clear();
[11793]97        phrase = new Sequence(rootNode.phrase);
[11770]98      } while (!Done() && !TryCompleteSentence(grammar, ref phrase));
99      return phrase;
100    }
101
[11793]102    private bool TryCompleteSentence(IGrammar g, ref Sequence phrase) {
[11770]103      if (phrase.Length > maxLen) throw new ArgumentException();
104      if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException();
105      var curDepth = 0;
[11793]106      var n = rootNode;
107      stateChain.Add(n.phrase);
[11770]108
109      while (!phrase.IsTerminal) {
[11799]110        if (n.randomTries < randomTries) {
111          n.randomTries++;
112          maxSearchDepth = Math.Max(maxSearchDepth, curDepth);
113          g.CompleteSentenceRandomly(random, phrase, maxLen);
114          return true;
115        } else {
116          // => select using bandit policy
117          // failure means we simply restart
118          GenerateFollowStates(n); // creates child nodes for node n
[11770]119
[12893]120
[11799]121          int selectedChildIdx;
122          if (!behaviourPolicy.TrySelect(random, n.phrase, n.children.Select(ch => ch.phrase), out selectedChildIdx)) {
123            return false;
124          }
[12893]125
[11799]126          phrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, n.children[selectedChildIdx].alternative);
[11793]127
[11799]128          // prepare for next iteration
129          n = n.children[selectedChildIdx];
130          stateChain.Add(n.phrase);
131          curDepth++;
[11793]132        }
[11770]133      } // while
134
135      maxSearchDepth = Math.Max(maxSearchDepth, curDepth);
136      return true;
137    }
138
139
[11793]140    private IEnumerable<string> GenerateFollowStates(TreeNode n) {
141      // create children on the first visit
142      if (n.children == null) {
143        var g = problem.Grammar;
144        // tree is only used for easily retrieving the follow-states of a state
145        var phrase = new Sequence(n.phrase);
[11770]146        char nt = phrase.FirstNonTerminal;
147
148        int maxLenOfReplacement = maxLen - (phrase.Length - 1);
149        // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2
150        Debug.Assert(maxLenOfReplacement > 0);
151
152        var alts = g.GetAlternatives(nt).Where(alt => g.MinPhraseLength(alt) <= maxLenOfReplacement);
153
[11793]154        var children = new TreeNode[alts.Count()];
[11770]155        int idx = 0;
156        foreach (var alt in alts) {
[11799]157          // var newPhrase = new Sequence(phrase); // clone
158          // newPhrase.ReplaceAt(newPhrase.FirstNonTerminalIndex, 1, alt);
159          // children[idx++] = new TreeNode(newPhrase.ToString(), alt);
160
161          // since we are not using a sequence later on we might directly transform the current sequence to a string and replace there
162          var phraseStr = phrase.ToString();
163          var sb = new StringBuilder(phraseStr);
164          sb.Remove(phrase.FirstNonTerminalIndex, 1).Insert(phrase.FirstNonTerminalIndex, alt.ToString());
165          children[idx++] = new TreeNode(sb.ToString(), alt);
[11770]166        }
[11793]167        n.children = children;
168      }
169      return n.children.Select(ch => ch.phrase);
[11770]170    }
171
[12893]172
173
[11770]174    private void DistributeReward(double reward) {
175      behaviourPolicy.UpdateReward(stateChain, reward);
176    }
177
178
[12893]179
[11770]180    private void Reset() {
[11977]181      StopRequested = false;
[11770]182      behaviourPolicy.Reset();
183      maxSearchDepth = 0;
184      bestQuality = 0.0;
185      tries = 0;
[12893]186      //rootNode = new TreeNode("a*b+c*d+e*f+E", new ReadonlySequence("$"));
[11793]187      rootNode = new TreeNode(problem.Grammar.SentenceSymbol.ToString(), new ReadonlySequence("$"));
[11770]188    }
189
190    public bool Done() {
[11793]191      int selectedStateIdx;
192      return !behaviourPolicy.TrySelect(random, rootNode.phrase, GenerateFollowStates(rootNode), out selectedStateIdx);
[11770]193    }
194
195    #region introspection
196    public void PrintStats() {
197      Console.WriteLine("depth: {0,5} tries: {1,5} best phrase {2,50} bestQ {3:F3}", maxSearchDepth, tries, bestPhrase, bestQuality);
198
[11793]199      // use behaviour strategy to generate the currently prefered sentence
[11770]200      var policy = behaviourPolicy;
[11793]201
202      var n = rootNode;
[12893]203      int lvl = 0;
[11793]204      while (n != null) {
205        var phrase = n.phrase;
[11770]206        Console.ForegroundColor = ConsoleColor.White;
[12893]207       
208        if (lvl++ > 10) return;
209
[11770]210        Console.WriteLine("{0,-30}", phrase);
[11793]211        var children = n.children;
212        if (children == null || !children.Any()) break;
[12893]213        var valuesEnumerable = children.Select(ch => policy.GetValue(ch.phrase));
214        double maxValue = valuesEnumerable.Where(v => !double.IsInfinity(v)).DefaultIfEmpty(0).Max();
215        maxValue = Math.Max(maxValue, 1.0);
[11770]216        // write phrases
[11793]217        foreach (var ch in children) {
[12893]218          //SetColorForValue(policy.GetValue(ch.phrase) / maxValue);
[11793]219          Console.Write(" {0,-4}", ch.phrase.Substring(Math.Max(0, ch.phrase.Length - 3), Math.Min(3, ch.phrase.Length)));
[11770]220        }
221        Console.WriteLine();
222
223        // write values
[11793]224        foreach (var ch in children) {
[12893]225          //SetColorForValue(policy.GetValue(ch.phrase) / maxValue);
[12876]226          if (!double.IsInfinity(policy.GetValue(ch.phrase)))
227            Console.Write(" {0:F2}", policy.GetValue(ch.phrase) * 10.0);
228          else
229            Console.Write(" Inf ");
[11770]230        }
231        Console.WriteLine();
232
233        // write tries
[11793]234        foreach (var ch in children) {
[12893]235          //SetColorForValue(policy.GetValue(ch.phrase) / maxValue);
[11793]236          Console.Write(" {0,4}", policy.GetTries(ch.phrase));
[11770]237        }
238        Console.WriteLine();
[12893]239        var triesArr = valuesEnumerable.ToArray();
240        //var selectedChildIdx = Array.IndexOf(triesArr, triesArr.Max());
241        var valuesArr = children.Select(ch => policy.GetValue(ch.phrase)).ToArray();
242        int selectedChildIdx = Enumerable.Range(0, children.Length).OrderByDescending(i => valuesArr[i]).ThenByDescending(i => triesArr[i]).First();
243
244        //int selectedChildIdx;
245        //if (!policy.TrySelect(random, phrase, children.Select(ch => ch.phrase), out selectedChildIdx)) {
246        //  break;
247        //}
[11793]248        n = n.children[selectedChildIdx];
[11770]249      }
250
251      Console.ForegroundColor = ConsoleColor.White;
252      Console.WriteLine("-------------------");
253    }
254
255    private void SetColorForValue(double v) {
256      Console.ForegroundColor = ConsoleEx.ColorForValue(v);
257    }
258    #endregion
[11977]259
[11770]260  }
261}
Note: See TracBrowser for help on using the repository browser.