source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/MctsContextualSampler.cs @ 11745

Last change on this file since 11745 was 11745, checked in by gkronber, 7 years ago

#2283: worked on contextual MCTS

File size: 10.0 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using HeuristicLab.Algorithms.Bandits;
7using HeuristicLab.Common;
8using HeuristicLab.Problems.GrammaticalOptimization;
9
10namespace HeuristicLab.Algorithms.GrammaticalOptimization {
11  public class MctsContextualSampler {
12    private class TreeNode {
13      public string ident;
14      public ReadonlySequence alt;
15      public int randomTries;
16      public int tries;
17      public TreeNode[] children;
18      public bool done = false;
19
20      public TreeNode(string id, ReadonlySequence alt) {
21        this.ident = id;
22        this.alt = alt;
23      }
24
25      public override string ToString() {
26        return string.Format("Node({0} tries: {1}, done: {2})", ident, tries, done);
27      }
28    }
29
30
31    public event Action<string, double> FoundNewBestSolution;
32    public event Action<string, double> SolutionEvaluated;
33
34    private readonly int maxLen;
35    private readonly IProblem problem;
36    private readonly Random random;
37    private readonly int randomTries;
38
39    private List<Tuple<TreeNode, TreeNode>> updateChain;
40    private TreeNode rootNode;
41
42    public int treeDepth;
43    public int treeSize;
44    private double bestQuality;
45
46    public MctsContextualSampler(IProblem problem, int maxLen, Random random, int randomTries) {
47      this.maxLen = maxLen;
48      this.problem = problem;
49      this.random = random;
50      this.randomTries = randomTries;
51      this.v = new Dictionary<string, double>(1000000);
52      this.tries = new Dictionary<string, int>(1000000);
53    }
54
55    public void Run(int maxIterations) {
56      bestQuality = double.MinValue;
57      InitPolicies(problem.Grammar);
58      for (int i = 0; !rootNode.done && i < maxIterations; i++) {
59        var sentence = SampleSentence(problem.Grammar).ToString();
60        var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen);
61        Debug.Assert(quality >= 0 && quality <= 1.0);
62        DistributeReward(quality);
63
64        RaiseSolutionEvaluated(sentence, quality);
65
66        if (quality > bestQuality) {
67          bestQuality = quality;
68          RaiseFoundNewBestSolution(sentence, quality);
69        }
70      }
71
72      // clean up
73      InitPolicies(problem.Grammar); GC.Collect();
74    }
75
76    public void PrintStats() {
77      var n = rootNode;
78      Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}, rootQ {3:F3}, bestQ {4:F3}", treeDepth, treeSize, n.tries, V(n), bestQuality);
79      while (n.children != null) {
80        Console.WriteLine("{0}", n.ident);
81        double maxVForRow = n.children.Select(ch => V(ch)).Max();
82        if (maxVForRow == 0) maxVForRow = 1.0;
83
84        for (int i = 0; i < n.children.Length; i++) {
85          var ch = n.children[i];
86          Console.ForegroundColor = ConsoleEx.ColorForValue(V(ch) / maxVForRow);
87          Console.Write("{0,5}", ch.alt);
88        }
89        Console.WriteLine();
90        for (int i = 0; i < n.children.Length; i++) {
91          var ch = n.children[i];
92          Console.ForegroundColor = ConsoleEx.ColorForValue(V(ch) / maxVForRow);
93          Console.Write("{0,5:F2}", V(ch) * 10);
94        }
95        Console.WriteLine();
96        for (int i = 0; i < n.children.Length; i++) {
97          var ch = n.children[i];
98          Console.ForegroundColor = ConsoleEx.ColorForValue(V(ch) / maxVForRow);
99          Console.Write("{0,5}", ch.done ? "X" : ch.tries.ToString());
100        }
101        Console.ForegroundColor = ConsoleColor.White;
102        Console.WriteLine();
103        //n.policy.PrintStats();
104        n = n.children.Where(ch => !ch.done).OrderByDescending(c => V(c)).First();
105      }
106    }
107
108
109    private void InitPolicies(IGrammar grammar) {
110      this.updateChain = new List<Tuple<TreeNode, TreeNode>>();
111      this.v.Clear();
112      this.tries.Clear();
113
114      rootNode = new TreeNode(grammar.SentenceSymbol.ToString(), new ReadonlySequence("$"));
115      treeDepth = 0;
116      treeSize = 0;
117    }
118
119    private Sequence SampleSentence(IGrammar grammar) {
120      updateChain.Clear();
121      //var startPhrase = new Sequence("a*b+c*d+e*f+E");
122      var startPhrase = new Sequence(grammar.SentenceSymbol);
123      return CompleteSentence(grammar, startPhrase);
124    }
125
126    private Sequence CompleteSentence(IGrammar g, Sequence phrase) {
127      if (phrase.Length > maxLen) throw new ArgumentException();
128      if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException();
129      TreeNode parent = null;
130      TreeNode n = rootNode;
131      var curDepth = 0;
132      while (!phrase.IsTerminal) {
133        updateChain.Add(Tuple.Create(n, parent));
134
135        if (n.randomTries < randomTries) {
136          n.randomTries++;
137          treeDepth = Math.Max(treeDepth, curDepth);
138          return g.CompleteSentenceRandomly(random, phrase, maxLen);
139        } else {
140          char nt = phrase.FirstNonTerminal;
141
142          int maxLenOfReplacement = maxLen - (phrase.Length - 1); // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2
143          Debug.Assert(maxLenOfReplacement > 0);
144
145          var alts = g.GetAlternatives(nt).Where(alt => g.MinPhraseLength(alt) <= maxLenOfReplacement);
146
147          if (n.randomTries == randomTries && n.children == null) {
148            // create a new node for each alternative
149            n.children = new TreeNode[alts.Count()];
150            var i = 0;
151            foreach (var alt in alts) {
152              var newPhrase = new Sequence(phrase);
153              newPhrase.ReplaceAt(newPhrase.FirstNonTerminalIndex, 1, alt);
154              if (!newPhrase.IsTerminal) newPhrase = newPhrase.Subsequence(0, newPhrase.FirstNonTerminalIndex + 1);
155              n.children[i++] = new TreeNode(newPhrase.ToString(), new ReadonlySequence(alt));
156            }
157            treeSize += n.children.Length;
158          }
159          // => select using eps-greedy
160          int selectedAltIdx = SelectEpsGreedy(random, n.children);
161
162          //int selectedAltIdx = SelectActionUCB1(random, n.children);
163          Sequence selectedAlt = alts.ElementAt(selectedAltIdx);
164
165          // replace nt with alt
166          phrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, selectedAlt);
167
168          curDepth++;
169
170          // prepare for next iteration
171          parent = n;
172          n = n.children[selectedAltIdx];
173        }
174      } // while
175
176      updateChain.Add(Tuple.Create(n, parent));
177
178      // the last node is a leaf node (sentence is done), so we never need to visit this node again
179      n.done = true;
180
181
182      treeDepth = Math.Max(treeDepth, curDepth);
183      return phrase;
184    }
185
186    private void DistributeReward(double reward) {
187      // iterate in reverse order (bottom up)
188      updateChain.Reverse();
189
190      foreach (var e in updateChain) {
191        var node = e.Item1;
192        var parent = e.Item2;
193        node.tries++;
194        if (node.children != null && node.children.All(c => c.done)) {
195          node.done = true;
196        }
197        UpdateV(node, reward);
198
199        // the reward for the parent is either the just recieved reward or the value of the best action so far
200        double value = 0.0;
201        if (parent != null) {
202          var doneChilds = parent.children.Where(ch => ch.done);
203          if (doneChilds.Any()) value = doneChilds.Select(ch => V(ch)).Max();
204        }
205        //if (value > reward) reward = value;
206      }
207    }
208
209    private Dictionary<string, double> v;
210    private Dictionary<string, int> tries;
211
212    private void UpdateV(TreeNode n, double reward) {
213      var canonicalStr = problem.CanonicalRepresentation(n.ident);
214      //var canonicalStr = n.ident;
215      double stateV;
216
217      if (!v.TryGetValue(canonicalStr, out  stateV)) {
218        v.Add(canonicalStr, reward);
219        tries.Add(canonicalStr, 1);
220      } else {
221        v[canonicalStr] = stateV + 0.005 * (reward - stateV);
222        //v[canonicalStr] = stateV + (1.0 / tries[canonicalStr]) * (reward - stateV);
223        tries[canonicalStr]++;
224      }
225    }
226
227    private double V(TreeNode n) {
228      var canonicalStr = problem.CanonicalRepresentation(n.ident);
229      //var canonicalStr = n.ident;
230      double stateV;
231
232      if (!v.TryGetValue(canonicalStr, out  stateV)) {
233        return 0.0;
234      } else {
235        return stateV;
236      }
237    }
238
239    private int SelectEpsGreedy(Random random, TreeNode[] children) {
240      if (random.NextDouble() < 0.2) {
241
242        return children.Select((ch, i) => Tuple.Create(ch, i)).Where(p => !p.Item1.done).SelectRandom(random).Item2;
243      } else {
244        var bestQ = double.NegativeInfinity;
245        var bestChildIdx = new List<int>();
246        for (int i = 0; i < children.Length; i++) {
247          if (children[i].done) continue;
248          // if (children[i].tries == 0) return i;
249          var q = V(children[i]);
250          if (q > bestQ) {
251            bestQ = q;
252            bestChildIdx.Clear();
253            bestChildIdx.Add(i);
254          } else if (q == bestQ) {
255            bestChildIdx.Add(i);
256          }
257        }
258        Debug.Assert(bestChildIdx.Any());
259        return bestChildIdx.SelectRandom(random);
260      }
261    }
262    private int SelectActionUCB1(Random random, TreeNode[] children) {
263      int bestAction = -1;
264      double bestQ = double.NegativeInfinity;
265      int totalTries = children.Sum(ch => ch.tries);
266
267      for (int a = 0; a < children.Length; a++) {
268        var ch = children[a];
269        if (ch.done) continue;
270        if (ch.tries == 0) return a;
271        var q = V(ch) + Math.Sqrt((2 * Math.Log(totalTries)) / ch.tries);
272        if (q > bestQ) {
273          bestQ = q;
274          bestAction = a;
275        }
276      }
277      Debug.Assert(bestAction > -1);
278      return bestAction;
279    }
280
281
282
283    private void RaiseSolutionEvaluated(string sentence, double quality) {
284      var handler = SolutionEvaluated;
285      if (handler != null) handler(sentence, quality);
286    }
287    private void RaiseFoundNewBestSolution(string sentence, double quality) {
288      var handler = FoundNewBestSolution;
289      if (handler != null) handler(sentence, quality);
290    }
291
292
293  }
294}
Note: See TracBrowser for help on using the repository browser.