source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/MctsContextualSampler.cs @ 11747

Last change on this file since 11747 was 11747, checked in by gkronber, 7 years ago

#2283: implemented test problems for MCTS

File size: 12.5 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using HeuristicLab.Algorithms.Bandits;
7using HeuristicLab.Common;
8using HeuristicLab.Problems.GrammaticalOptimization;
9
10namespace HeuristicLab.Algorithms.GrammaticalOptimization {
11  public class MctsContextualSampler {
12    private class TreeNode {
13      public string ident;
14      public ReadonlySequence alt;
15      public int randomTries;
16      public int tries;
17      public List<TreeNode> parents;
18      public TreeNode[] children;
19      public bool done = false;
20
21      public TreeNode(string id, ReadonlySequence alt) {
22        this.ident = id;
23        this.alt = alt;
24        this.parents = new List<TreeNode>();
25      }
26
27      public override string ToString() {
28        return string.Format("Node({0} tries: {1}, done: {2})", ident, tries, done);
29      }
30    }
31
32    private Dictionary<string, TreeNode> treeNodes;
33    private TreeNode GetTreeNode(string id, ReadonlySequence alt) {
34      TreeNode n;
35      var canonicalId = problem.CanonicalRepresentation(id);
36      if (!treeNodes.TryGetValue(canonicalId, out n)) {
37        n = new TreeNode(canonicalId, alt);
38        tries.TryGetValue(canonicalId, out n.tries);
39        treeNodes[canonicalId] = n;
40      }
41      return n;
42    }
43
44    public event Action<string, double> FoundNewBestSolution;
45    public event Action<string, double> SolutionEvaluated;
46
47    private readonly int maxLen;
48    private readonly IProblem problem;
49    private readonly Random random;
50    private readonly int randomTries;
51
52    private List<Tuple<TreeNode, TreeNode>> updateChain;
53    private TreeNode rootNode;
54
55    public int treeDepth;
56    public int treeSize;
57    private double bestQuality;
58
59    public MctsContextualSampler(IProblem problem, int maxLen, Random random, int randomTries) {
60      this.maxLen = maxLen;
61      this.problem = problem;
62      this.random = random;
63      this.randomTries = randomTries;
64      this.v = new Dictionary<string, double>(1000000);
65      this.tries = new Dictionary<string, int>(1000000);
66      treeNodes = new Dictionary<string, TreeNode>();
67    }
68
69    public void Run(int maxIterations) {
70      bestQuality = double.MinValue;
71      InitPolicies(problem.Grammar);
72      for (int i = 0; !rootNode.done && i < maxIterations; i++) {
73        bool success;
74        var sentence = SampleSentence(problem.Grammar, out success).ToString();
75        if (success) {
76          var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen);
77          Debug.Assert(quality >= 0 && quality <= 1.0);
78          DistributeReward(quality);
79
80          RaiseSolutionEvaluated(sentence, quality);
81
82          if (quality > bestQuality) {
83            bestQuality = quality;
84            RaiseFoundNewBestSolution(sentence, quality);
85          }
86        }
87      }
88
89      // clean up
90      InitPolicies(problem.Grammar); GC.Collect();
91    }
92
93    public void PrintStats() {
94      var n = rootNode;
95      Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}, rootQ {3:F3}, bestQ {4:F3}", treeDepth, treeSize, n.tries, V(n), bestQuality);
96      while (n.children != null) {
97        Console.WriteLine("{0,-30}", n.ident);
98        double maxVForRow = n.children.Select(ch => Math.Min(1.0, Math.Max(0.0, V(ch)))).Max();
99        if (maxVForRow == 0) maxVForRow = 1.0;
100
101        for (int i = 0; i < n.children.Length; i++) {
102          var ch = n.children[i];
103          Console.ForegroundColor = ConsoleEx.ColorForValue(Math.Min(1.0, V(ch)) / maxVForRow);
104          Console.Write("{0,5}", ch.alt);
105        }
106        Console.WriteLine();
107        for (int i = 0; i < n.children.Length; i++) {
108          var ch = n.children[i];
109          Console.ForegroundColor = ConsoleEx.ColorForValue(Math.Min(1.0, V(ch)) / maxVForRow);
110          Console.Write("{0,5:F2}", Math.Min(1.0, V(ch)) * 10);
111        }
112        Console.WriteLine();
113        for (int i = 0; i < n.children.Length; i++) {
114          var ch = n.children[i];
115          Console.ForegroundColor = ConsoleEx.ColorForValue(Math.Min(1.0, V(ch)) / maxVForRow);
116          Console.Write("{0,5}", ch.done ? "X" : ch.tries.ToString());
117        }
118        Console.ForegroundColor = ConsoleColor.White;
119        Console.WriteLine();
120        //n.policy.PrintStats();
121        n = n.children.Where(ch => !ch.done).OrderByDescending(c => c.tries).First();
122      }
123    }
124
125
126    private void InitPolicies(IGrammar grammar) {
127      this.updateChain = new List<Tuple<TreeNode, TreeNode>>();
128      this.v.Clear();
129      this.tries.Clear();
130
131      rootNode = GetTreeNode(grammar.SentenceSymbol.ToString(), new ReadonlySequence("$"));
132      treeDepth = 0;
133      treeSize = 0;
134    }
135
136    private Sequence SampleSentence(IGrammar grammar, out bool success) {
137      updateChain.Clear();
138      //var startPhrase = new Sequence("a*b+c*d+e*f+E");
139      var startPhrase = new Sequence(grammar.SentenceSymbol);
140      return CompleteSentence(grammar, startPhrase, out success);
141    }
142
143    private Sequence CompleteSentence(IGrammar g, Sequence phrase, out bool success) {
144      if (phrase.Length > maxLen) throw new ArgumentException();
145      if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException();
146      TreeNode parent = null;
147      TreeNode n = rootNode;
148      var curDepth = 0;
149      while (!phrase.IsTerminal) {
150        updateChain.Add(Tuple.Create(n, parent));
151
152        if (n.randomTries < randomTries) {
153          n.randomTries++;
154          treeDepth = Math.Max(treeDepth, curDepth);
155          success = true;
156          return g.CompleteSentenceRandomly(random, phrase, maxLen);
157        } else {
158          char nt = phrase.FirstNonTerminal;
159
160          int maxLenOfReplacement = maxLen - (phrase.Length - 1); // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2
161          Debug.Assert(maxLenOfReplacement > 0);
162
163          var alts = g.GetAlternatives(nt).Where(alt => g.MinPhraseLength(alt) <= maxLenOfReplacement);
164
165          if (n.randomTries == randomTries && n.children == null) {
166            // create a new node for each alternative
167            n.children = new TreeNode[alts.Count()];
168            var i = 0;
169            foreach (var alt in alts) {
170              var newPhrase = new Sequence(phrase);
171              newPhrase.ReplaceAt(newPhrase.FirstNonTerminalIndex, 1, alt);
172              if (!newPhrase.IsTerminal) newPhrase = newPhrase.Subsequence(0, newPhrase.FirstNonTerminalIndex + 1);
173              var treeNode = GetTreeNode(newPhrase.ToString(), new ReadonlySequence(alt));
174              treeNode.parents.Add(n);
175              n.children[i++] = treeNode;
176            }
177            treeSize += n.children.Length;
178            UpdateDone(n);
179
180            // it could happend that we already finished all variations starting from the branch
181            // stop
182            if (n.done) {
183              success = false;
184              return phrase;
185            }
186          }
187          //int selectedAltIdx = SelectRandom(random, n.children);
188
189          // => select using eps-greedy
190          int selectedAltIdx = SelectEpsGreedy(random, n.children);
191
192          //int selectedAltIdx = SelectActionUCB1(random, n.children);
193          Sequence selectedAlt = alts.ElementAt(selectedAltIdx);
194
195          // replace nt with alt
196          phrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, selectedAlt);
197
198          curDepth++;
199
200
201          // prepare for next iteration
202          parent = n;
203          n = n.children[selectedAltIdx];
204          //UpdateTD(parent, n, 0.0);
205        }
206      } // while
207
208      updateChain.Add(Tuple.Create(n, parent));
209
210      // the last node is a leaf node (sentence is done), so we never need to visit this node again
211      n.done = true;
212
213
214      treeDepth = Math.Max(treeDepth, curDepth);
215      success = true;
216      return phrase;
217    }
218
219
220    //private void UpdateTD(TreeNode parent, TreeNode child, double reward) {
221    //  double alpha = 1.0;
222    //  var vParent = V(parent);
223    //  var vChild = V(child);
224    //  if (double.IsInfinity(vParent)) vParent = 0.0;
225    //  if (double.IsInfinity(vChild)) vChild = 0.0;
226    //  UpdateV(parent, (alpha * (reward + vChild - vParent)));
227    //}
228
229    private void DistributeReward(double reward) {
230
231      // iterate in reverse order (bottom up)
232      //updateChain.Reverse();
233      UpdateDone(updateChain.Last().Item1);
234      //UpdateTD(updateChain.Last().Item2, updateChain.Last().Item1, reward);
235      //return;
236
237      BackPropReward(updateChain.Last().Item1, reward);
238      /*
239      foreach (var e in updateChain) {
240        var node = e.Item1;
241        //var parent = e.Item2;
242        node.tries++;
243        //if (node.children != null && node.children.All(c => c.done)) {
244        //  node.done = true;
245        //}
246        UpdateV(node, reward);
247
248        // the reward for the parent is either the just recieved reward or the value of the best action so far
249        //double value = 0.0;
250        //if (parent != null) {
251        //  var doneChilds = parent.children.Where(ch => ch.done);
252        //  if (doneChilds.Any()) value = doneChilds.Select(ch => V(ch)).Max();
253        //}
254        //if (value > reward) reward = value;
255      }*/
256    }
257
258    private void BackPropReward(TreeNode n, double reward) {
259      n.tries++;
260      UpdateV(n, reward);
261      foreach (var p in n.parents) BackPropReward(p, reward);
262    }
263
264    private void UpdateDone(TreeNode n) {
265      if (!n.done && n.children != null && n.children.All(c => c.done)) n.done = true;
266      if (n.done) foreach (var p in n.parents) UpdateDone(p);
267    }
268
269
270    private Dictionary<string, double> v;
271    private Dictionary<string, int> tries;
272
273    private void UpdateV(TreeNode n, double reward) {
274      var canonicalStr = problem.CanonicalRepresentation(n.ident);
275      //var canonicalStr = n.ident;
276      double stateV;
277
278      if (!v.TryGetValue(canonicalStr, out  stateV)) {
279        v.Add(canonicalStr, reward);
280        tries.Add(canonicalStr, 1);
281      } else {
282        //v[canonicalStr] = stateV + 0.005 * (reward - stateV);
283        v[canonicalStr] = stateV + (1.0 / tries[canonicalStr]) * (reward - stateV);
284        tries[canonicalStr]++;
285      }
286    }
287
288    private double V(TreeNode n) {
289      var canonicalStr = problem.CanonicalRepresentation(n.ident);
290      //var canonicalStr = n.ident;
291      double stateV;
292      if (!tries.ContainsKey(canonicalStr)) return double.PositiveInfinity;
293      if (!v.TryGetValue(canonicalStr, out  stateV)) {
294        return 0.0;
295      } else {
296        return stateV;
297      }
298    }
299
300    private int SelectRandom(Random random, TreeNode[] children) {
301      return children.Select((ch, i) => Tuple.Create(ch, i)).Where(p => !p.Item1.done).SelectRandom(random).Item2;
302    }
303
304    private int SelectEpsGreedy(Random random, TreeNode[] children) {
305      if (random.NextDouble() < 0.2) {
306        return SelectRandom(random, children);
307      } else {
308        var bestQ = double.NegativeInfinity;
309        var bestChildIdx = new List<int>();
310        for (int i = 0; i < children.Length; i++) {
311          if (children[i].done) continue;
312          // if (children[i].tries == 0) return i;
313          var q = V(children[i]);
314          if (q > bestQ) {
315            bestQ = q;
316            bestChildIdx.Clear();
317            bestChildIdx.Add(i);
318          } else if (q == bestQ) {
319            bestChildIdx.Add(i);
320          }
321        }
322        Debug.Assert(bestChildIdx.Any());
323        return bestChildIdx.SelectRandom(random);
324      }
325    }
326    private int SelectActionUCB1(Random random, TreeNode[] children) {
327      int bestAction = -1;
328      double bestQ = double.NegativeInfinity;
329      int totalTries = children.Sum(ch => ch.tries);
330
331      for (int a = 0; a < children.Length; a++) {
332        var ch = children[a];
333        if (ch.done) continue;
334        if (ch.tries == 0) return a;
335        var q = V(ch) + Math.Sqrt((2 * Math.Log(totalTries)) / ch.tries);
336        if (q > bestQ) {
337          bestQ = q;
338          bestAction = a;
339        }
340      }
341      Debug.Assert(bestAction > -1);
342      return bestAction;
343    }
344
345
346
347    private void RaiseSolutionEvaluated(string sentence, double quality) {
348      var handler = SolutionEvaluated;
349      if (handler != null) handler(sentence, quality);
350    }
351    private void RaiseFoundNewBestSolution(string sentence, double quality) {
352      var handler = FoundNewBestSolution;
353      if (handler != null) handler(sentence, quality);
354    }
355
356
357  }
358}
Note: See TracBrowser for help on using the repository browser.