Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/ContextualMctsSampler.cs @ 11795

Last change on this file since 11795 was 11793, checked in by gkronber, 9 years ago

#2283 fixed compile errors and refactoring

File size: 7.5 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using HeuristicLab.Algorithms.Bandits;
7using HeuristicLab.Problems.GrammaticalOptimization;
8
9namespace HeuristicLab.Algorithms.GrammaticalOptimization {
10  public class ContextualMctsSampler {
11    private class TreeNode {
12      public string ident;
13      public int randomTries;
14      public int policyTries;
15      public TreeNode[] children;
16      public bool done = false;
17
18      public TreeNode(string id) {
19        this.ident = id;
20      }
21
22      public override string ToString() {
23        return string.Format("Node({0} tries: {1}, done: {2})", ident, randomTries + policyTries, done);
24      }
25    }
26
27
28    public event Action<string, double> FoundNewBestSolution;
29    public event Action<string, double> SolutionEvaluated;
30
31    private readonly int maxLen;
32    private readonly IProblem problem;
33    private readonly Random random;
34    private readonly int randomTries;
35    private readonly Func<Random, int, IPolicy> policyFactory;
36    private readonly Dictionary<string, IPolicy> policyForState;
37
38    private List<Tuple<TreeNode, int, string>> updateChain;
39    private TreeNode rootNode;
40
41    public int treeDepth;
42    public int treeSize;
43
44    public ContextualMctsSampler(IProblem problem, int maxLen, Random random) :
45      this(problem, maxLen, random, 10, (rand, numActions) => new EpsGreedyPolicy(rand, numActions, 0.1)) {
46
47    }
48
49    public ContextualMctsSampler(IProblem problem, int maxLen, Random random, int randomTries, Func<Random, int, IPolicy> policyFactory) {
50      this.maxLen = maxLen;
51      this.problem = problem;
52      this.random = random;
53      this.randomTries = randomTries;
54      this.policyFactory = policyFactory;
55      this.policyForState = new Dictionary<string, IPolicy>();
56    }
57
58    public void Run(int maxIterations) {
59      double bestQuality = double.MinValue;
60      InitPolicies(problem.Grammar);
61      for (int i = 0; !rootNode.done && i < maxIterations; i++) {
62        var sentence = SampleSentence(problem.Grammar).ToString();
63        var quality = problem.Evaluate(sentence) / problem.GetBestKnownQuality(maxLen);
64        Debug.Assert(quality >= 0 && quality <= 1.0);
65        DistributeReward(quality);
66
67        RaiseSolutionEvaluated(sentence, quality);
68
69        if (quality > bestQuality) {
70          bestQuality = quality;
71          RaiseFoundNewBestSolution(sentence, quality);
72        }
73      }
74
75      // clean up
76      InitPolicies(problem.Grammar); GC.Collect();
77    }
78
79    public void PrintStats() {
80      var n = rootNode;
81      Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}", treeDepth, treeSize, rootNode.policyTries + rootNode.randomTries);
82      foreach (var p in policyForState.OrderBy(p => p.Key.Length).ThenBy(p => p.Key)) {
83        Console.Write("{0,10} {1,20}", p.Key, p.Value);
84        p.Value.PrintStats();
85      }
86      //while (n.policy != null) {
87      //  Console.WriteLine();
88      //  Console.WriteLine("{0,5}->{1,-50}", n.ident, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.ident))));
89      //  Console.WriteLine("{0,5}  {1,-50}", string.Empty, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.randomTries + ch.policyTries))));
90      //  //n.policy.PrintStats();
91      //  n = n.children.OrderByDescending(c => c.policyTries).First();
92      //}
93      Console.ReadLine();
94    }
95
96    private void InitPolicies(IGrammar grammar) {
97      this.updateChain = new List<Tuple<TreeNode, int, string>>();
98      policyForState.Clear();
99      rootNode = new TreeNode(grammar.SentenceSymbol.ToString());
100      treeDepth = 0;
101      treeSize = 0;
102    }
103
104    private Sequence SampleSentence(IGrammar grammar) {
105      updateChain.Clear();
106      var startPhrase = new Sequence(grammar.SentenceSymbol);
107      return CompleteSentence(grammar, startPhrase);
108    }
109
110    private Sequence CompleteSentence(IGrammar g, Sequence phrase) {
111      if (phrase.Length > maxLen) throw new ArgumentException();
112      if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException();
113      TreeNode n = rootNode;
114      bool done = phrase.IsTerminal;
115      int selectedAltIdx = -1;
116      var curDepth = 0;
117      while (!done) {
118        var terminalSequence = phrase.Subsequence(0, phrase.FirstNonTerminalIndex).ToString();
119
120        char nt = phrase.FirstNonTerminal;
121
122        int maxLenOfReplacement = maxLen - (phrase.Length - 1); // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2
123        Debug.Assert(maxLenOfReplacement > 0);
124
125        var alts = g.GetAlternatives(nt).Where(alt => g.MinPhraseLength(alt) <= maxLenOfReplacement);
126
127        if (n.randomTries < randomTries) {
128          n.randomTries++;
129
130          treeDepth = Math.Max(treeDepth, curDepth);
131
132          return g.CompleteSentenceRandomly(random, phrase, maxLen);
133        } else if (n.randomTries == randomTries && n.children == null) {
134          n.children = alts.Select(alt => new TreeNode(alt.ToString())).ToArray(); // create a new node for each alternative
135
136          treeSize += n.children.Length;
137        }
138
139        IPolicy policy = GetPolicyForTerminalSequence(terminalSequence, alts.Count());
140        n.policyTries++;
141        // => select using bandit policy
142        selectedAltIdx = policy.SelectAction();
143        if (selectedAltIdx == -1) {
144          // oops
145          n.done = true;
146          Console.Write("*");
147          return g.CompleteSentenceRandomly(random, phrase, maxLen);
148        }
149        Sequence selectedAlt = alts.ElementAt(selectedAltIdx);
150        updateChain.Add(Tuple.Create(n, selectedAltIdx, terminalSequence)); // context is terminal sequence so far
151
152        // replace nt with alt
153        phrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, selectedAlt);
154
155
156        curDepth++;
157
158        done = phrase.IsTerminal;
159        // prepare for next iteration
160        n = n.children[selectedAltIdx];
161        Debug.Assert(!n.done);
162      } // while
163
164      // the last node is a leaf node (sentence is done), so we never need to visit this node again
165      n.done = true;
166
167      treeDepth = Math.Max(treeDepth, curDepth);
168      return phrase;
169    }
170
171    private void DistributeReward(double reward) {
172      // iterate in reverse order (bottom up)
173      updateChain.Reverse();
174
175      foreach (var e in updateChain) {
176        var node = e.Item1;
177        var action = e.Item2;
178        var sequence = e.Item3;
179        var policy = GetPolicyForTerminalSequence(sequence, -1);
180        //policy.UpdateReward(action, reward / updateChain.Count);
181        policy.UpdateReward(action, reward);
182
183        if (node.children[action].done) policy.DisableAction(action);
184        if (node.children.All(c => c.done)) node.done = true;
185      }
186    }
187
188    private IPolicy GetPolicyForTerminalSequence(string sequence, int numAlts) {
189      IPolicy p;
190      sequence = problem.Hash(sequence);
191      if (!policyForState.TryGetValue(sequence, out p)) {
192        p = policyFactory(random, numAlts);
193        policyForState.Add(sequence, p);
194      }
195      return p;
196    }
197
198    private void RaiseSolutionEvaluated(string sentence, double quality) {
199      var handler = SolutionEvaluated;
200      if (handler != null) handler(sentence, quality);
201    }
202    private void RaiseFoundNewBestSolution(string sentence, double quality) {
203      var handler = FoundNewBestSolution;
204      if (handler != null) handler(sentence, quality);
205    }
206  }
207}
Note: See TracBrowser for help on using the repository browser.