Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/MctsContextualSampler.cs @ 11747

Last change on this file since 11747 was 11747, checked in by gkronber, 9 years ago

#2283: implemented test problems for MCTS

File size: 12.5 KB
RevLine 
[11742]1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using HeuristicLab.Algorithms.Bandits;
[11745]7using HeuristicLab.Common;
[11742]8using HeuristicLab.Problems.GrammaticalOptimization;
9
10namespace HeuristicLab.Algorithms.GrammaticalOptimization {
11  public class MctsContextualSampler {
12    private class TreeNode {
[11745]13      public string ident;
14      public ReadonlySequence alt;
[11742]15      public int randomTries;
[11745]16      public int tries;
[11747]17      public List<TreeNode> parents;
[11742]18      public TreeNode[] children;
[11745]19      public bool done = false;
[11742]20
[11745]21      public TreeNode(string id, ReadonlySequence alt) {
22        this.ident = id;
[11742]23        this.alt = alt;
[11747]24        this.parents = new List<TreeNode>();
[11742]25      }
26
27      public override string ToString() {
[11745]28        return string.Format("Node({0} tries: {1}, done: {2})", ident, tries, done);
[11742]29      }
30    }
31
[11747]32    private Dictionary<string, TreeNode> treeNodes;
33    private TreeNode GetTreeNode(string id, ReadonlySequence alt) {
34      TreeNode n;
35      var canonicalId = problem.CanonicalRepresentation(id);
36      if (!treeNodes.TryGetValue(canonicalId, out n)) {
37        n = new TreeNode(canonicalId, alt);
38        tries.TryGetValue(canonicalId, out n.tries);
39        treeNodes[canonicalId] = n;
40      }
41      return n;
42    }
[11742]43
44    public event Action<string, double> FoundNewBestSolution;
45    public event Action<string, double> SolutionEvaluated;
46
47    private readonly int maxLen;
48    private readonly IProblem problem;
49    private readonly Random random;
50    private readonly int randomTries;
51
[11745]52    private List<Tuple<TreeNode, TreeNode>> updateChain;
[11742]53    private TreeNode rootNode;
54
55    public int treeDepth;
56    public int treeSize;
[11745]57    private double bestQuality;
[11742]58
[11745]59    public MctsContextualSampler(IProblem problem, int maxLen, Random random, int randomTries) {
[11742]60      this.maxLen = maxLen;
61      this.problem = problem;
62      this.random = random;
63      this.randomTries = randomTries;
[11745]64      this.v = new Dictionary<string, double>(1000000);
65      this.tries = new Dictionary<string, int>(1000000);
[11747]66      treeNodes = new Dictionary<string, TreeNode>();
[11742]67    }
68
69    public void Run(int maxIterations) {
[11745]70      bestQuality = double.MinValue;
[11742]71      InitPolicies(problem.Grammar);
[11745]72      for (int i = 0; !rootNode.done && i < maxIterations; i++) {
[11747]73        bool success;
74        var sentence = SampleSentence(problem.Grammar, out success).ToString();
75        if (success) {
76          var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen);
77          Debug.Assert(quality >= 0 && quality <= 1.0);
78          DistributeReward(quality);
[11742]79
[11747]80          RaiseSolutionEvaluated(sentence, quality);
[11742]81
[11747]82          if (quality > bestQuality) {
83            bestQuality = quality;
84            RaiseFoundNewBestSolution(sentence, quality);
85          }
[11742]86        }
87      }
88
89      // clean up
90      InitPolicies(problem.Grammar); GC.Collect();
91    }
92
93    public void PrintStats() {
94      var n = rootNode;
[11745]95      Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}, rootQ {3:F3}, bestQ {4:F3}", treeDepth, treeSize, n.tries, V(n), bestQuality);
[11742]96      while (n.children != null) {
[11747]97        Console.WriteLine("{0,-30}", n.ident);
98        double maxVForRow = n.children.Select(ch => Math.Min(1.0, Math.Max(0.0, V(ch)))).Max();
[11745]99        if (maxVForRow == 0) maxVForRow = 1.0;
100
101        for (int i = 0; i < n.children.Length; i++) {
102          var ch = n.children[i];
[11747]103          Console.ForegroundColor = ConsoleEx.ColorForValue(Math.Min(1.0, V(ch)) / maxVForRow);
[11745]104          Console.Write("{0,5}", ch.alt);
105        }
[11742]106        Console.WriteLine();
[11745]107        for (int i = 0; i < n.children.Length; i++) {
108          var ch = n.children[i];
[11747]109          Console.ForegroundColor = ConsoleEx.ColorForValue(Math.Min(1.0, V(ch)) / maxVForRow);
110          Console.Write("{0,5:F2}", Math.Min(1.0, V(ch)) * 10);
[11745]111        }
112        Console.WriteLine();
113        for (int i = 0; i < n.children.Length; i++) {
114          var ch = n.children[i];
[11747]115          Console.ForegroundColor = ConsoleEx.ColorForValue(Math.Min(1.0, V(ch)) / maxVForRow);
[11745]116          Console.Write("{0,5}", ch.done ? "X" : ch.tries.ToString());
117        }
118        Console.ForegroundColor = ConsoleColor.White;
119        Console.WriteLine();
[11742]120        //n.policy.PrintStats();
[11747]121        n = n.children.Where(ch => !ch.done).OrderByDescending(c => c.tries).First();
[11742]122      }
123    }
124
[11745]125
[11742]126    private void InitPolicies(IGrammar grammar) {
[11745]127      this.updateChain = new List<Tuple<TreeNode, TreeNode>>();
128      this.v.Clear();
129      this.tries.Clear();
[11742]130
[11747]131      rootNode = GetTreeNode(grammar.SentenceSymbol.ToString(), new ReadonlySequence("$"));
[11742]132      treeDepth = 0;
133      treeSize = 0;
134    }
135
[11747]136    private Sequence SampleSentence(IGrammar grammar, out bool success) {
[11742]137      updateChain.Clear();
[11745]138      //var startPhrase = new Sequence("a*b+c*d+e*f+E");
139      var startPhrase = new Sequence(grammar.SentenceSymbol);
[11747]140      return CompleteSentence(grammar, startPhrase, out success);
[11742]141    }
142
[11747]143    private Sequence CompleteSentence(IGrammar g, Sequence phrase, out bool success) {
[11742]144      if (phrase.Length > maxLen) throw new ArgumentException();
145      if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException();
146      TreeNode parent = null;
147      TreeNode n = rootNode;
148      var curDepth = 0;
[11745]149      while (!phrase.IsTerminal) {
150        updateChain.Add(Tuple.Create(n, parent));
[11742]151
152        if (n.randomTries < randomTries) {
153          n.randomTries++;
154          treeDepth = Math.Max(treeDepth, curDepth);
[11747]155          success = true;
[11742]156          return g.CompleteSentenceRandomly(random, phrase, maxLen);
157        } else {
158          char nt = phrase.FirstNonTerminal;
159
160          int maxLenOfReplacement = maxLen - (phrase.Length - 1); // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2
161          Debug.Assert(maxLenOfReplacement > 0);
162
163          var alts = g.GetAlternatives(nt).Where(alt => g.MinPhraseLength(alt) <= maxLenOfReplacement);
164
165          if (n.randomTries == randomTries && n.children == null) {
[11745]166            // create a new node for each alternative
[11742]167            n.children = new TreeNode[alts.Count()];
[11745]168            var i = 0;
[11742]169            foreach (var alt in alts) {
170              var newPhrase = new Sequence(phrase);
[11745]171              newPhrase.ReplaceAt(newPhrase.FirstNonTerminalIndex, 1, alt);
172              if (!newPhrase.IsTerminal) newPhrase = newPhrase.Subsequence(0, newPhrase.FirstNonTerminalIndex + 1);
[11747]173              var treeNode = GetTreeNode(newPhrase.ToString(), new ReadonlySequence(alt));
174              treeNode.parents.Add(n);
175              n.children[i++] = treeNode;
[11742]176            }
177            treeSize += n.children.Length;
[11747]178            UpdateDone(n);
179
180            // it could happend that we already finished all variations starting from the branch
181            // stop
182            if (n.done) {
183              success = false;
184              return phrase;
185            }
[11742]186          }
[11747]187          //int selectedAltIdx = SelectRandom(random, n.children);
188
[11745]189          // => select using eps-greedy
190          int selectedAltIdx = SelectEpsGreedy(random, n.children);
[11742]191
[11745]192          //int selectedAltIdx = SelectActionUCB1(random, n.children);
193          Sequence selectedAlt = alts.ElementAt(selectedAltIdx);
[11742]194
195          // replace nt with alt
196          phrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, selectedAlt);
197
198          curDepth++;
199
[11747]200
[11742]201          // prepare for next iteration
202          parent = n;
[11745]203          n = n.children[selectedAltIdx];
[11747]204          //UpdateTD(parent, n, 0.0);
[11742]205        }
206      } // while
207
[11745]208      updateChain.Add(Tuple.Create(n, parent));
[11742]209
[11745]210      // the last node is a leaf node (sentence is done), so we never need to visit this node again
211      n.done = true;
[11742]212
[11745]213
[11742]214      treeDepth = Math.Max(treeDepth, curDepth);
[11747]215      success = true;
[11742]216      return phrase;
217    }
218
[11747]219
220    //private void UpdateTD(TreeNode parent, TreeNode child, double reward) {
221    //  double alpha = 1.0;
222    //  var vParent = V(parent);
223    //  var vChild = V(child);
224    //  if (double.IsInfinity(vParent)) vParent = 0.0;
225    //  if (double.IsInfinity(vChild)) vChild = 0.0;
226    //  UpdateV(parent, (alpha * (reward + vChild - vParent)));
227    //}
228
[11742]229    private void DistributeReward(double reward) {
[11747]230
[11742]231      // iterate in reverse order (bottom up)
[11747]232      //updateChain.Reverse();
233      UpdateDone(updateChain.Last().Item1);
234      //UpdateTD(updateChain.Last().Item2, updateChain.Last().Item1, reward);
235      //return;
[11742]236
[11747]237      BackPropReward(updateChain.Last().Item1, reward);
238      /*
[11742]239      foreach (var e in updateChain) {
[11745]240        var node = e.Item1;
[11747]241        //var parent = e.Item2;
[11745]242        node.tries++;
[11747]243        //if (node.children != null && node.children.All(c => c.done)) {
244        //  node.done = true;
245        //}
[11745]246        UpdateV(node, reward);
247
248        // the reward for the parent is either the just recieved reward or the value of the best action so far
[11747]249        //double value = 0.0;
250        //if (parent != null) {
251        //  var doneChilds = parent.children.Where(ch => ch.done);
252        //  if (doneChilds.Any()) value = doneChilds.Select(ch => V(ch)).Max();
253        //}
[11745]254        //if (value > reward) reward = value;
[11747]255      }*/
[11742]256    }
257
[11747]258    private void BackPropReward(TreeNode n, double reward) {
259      n.tries++;
260      UpdateV(n, reward);
261      foreach (var p in n.parents) BackPropReward(p, reward);
262    }
263
264    private void UpdateDone(TreeNode n) {
265      if (!n.done && n.children != null && n.children.All(c => c.done)) n.done = true;
266      if (n.done) foreach (var p in n.parents) UpdateDone(p);
267    }
268
269
[11745]270    private Dictionary<string, double> v;
271    private Dictionary<string, int> tries;
272
273    private void UpdateV(TreeNode n, double reward) {
274      var canonicalStr = problem.CanonicalRepresentation(n.ident);
275      //var canonicalStr = n.ident;
276      double stateV;
277
278      if (!v.TryGetValue(canonicalStr, out  stateV)) {
279        v.Add(canonicalStr, reward);
280        tries.Add(canonicalStr, 1);
281      } else {
[11747]282        //v[canonicalStr] = stateV + 0.005 * (reward - stateV);
283        v[canonicalStr] = stateV + (1.0 / tries[canonicalStr]) * (reward - stateV);
[11745]284        tries[canonicalStr]++;
285      }
286    }
287
288    private double V(TreeNode n) {
289      var canonicalStr = problem.CanonicalRepresentation(n.ident);
290      //var canonicalStr = n.ident;
291      double stateV;
[11747]292      if (!tries.ContainsKey(canonicalStr)) return double.PositiveInfinity;
[11745]293      if (!v.TryGetValue(canonicalStr, out  stateV)) {
294        return 0.0;
295      } else {
296        return stateV;
297      }
298    }
299
[11747]300    private int SelectRandom(Random random, TreeNode[] children) {
301      return children.Select((ch, i) => Tuple.Create(ch, i)).Where(p => !p.Item1.done).SelectRandom(random).Item2;
302    }
303
[11745]304    private int SelectEpsGreedy(Random random, TreeNode[] children) {
305      if (random.NextDouble() < 0.2) {
[11747]306        return SelectRandom(random, children);
[11745]307      } else {
308        var bestQ = double.NegativeInfinity;
309        var bestChildIdx = new List<int>();
310        for (int i = 0; i < children.Length; i++) {
311          if (children[i].done) continue;
312          // if (children[i].tries == 0) return i;
313          var q = V(children[i]);
314          if (q > bestQ) {
315            bestQ = q;
316            bestChildIdx.Clear();
317            bestChildIdx.Add(i);
318          } else if (q == bestQ) {
319            bestChildIdx.Add(i);
320          }
321        }
322        Debug.Assert(bestChildIdx.Any());
323        return bestChildIdx.SelectRandom(random);
324      }
325    }
326    private int SelectActionUCB1(Random random, TreeNode[] children) {
327      int bestAction = -1;
328      double bestQ = double.NegativeInfinity;
329      int totalTries = children.Sum(ch => ch.tries);
330
331      for (int a = 0; a < children.Length; a++) {
332        var ch = children[a];
333        if (ch.done) continue;
334        if (ch.tries == 0) return a;
335        var q = V(ch) + Math.Sqrt((2 * Math.Log(totalTries)) / ch.tries);
336        if (q > bestQ) {
337          bestQ = q;
338          bestAction = a;
339        }
340      }
341      Debug.Assert(bestAction > -1);
342      return bestAction;
343    }
344
345
346
[11742]347    private void RaiseSolutionEvaluated(string sentence, double quality) {
348      var handler = SolutionEvaluated;
349      if (handler != null) handler(sentence, quality);
350    }
351    private void RaiseFoundNewBestSolution(string sentence, double quality) {
352      var handler = FoundNewBestSolution;
353      if (handler != null) handler(sentence, quality);
354    }
[11745]355
356
[11742]357  }
358}
Note: See TracBrowser for help on using the repository browser.