source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/MctsSampler.cs @ 11745

Last change on this file since 11745 was 11745, checked in by gkronber, 7 years ago

#2283: worked on contextual MCTS

File size: 6.3 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using HeuristicLab.Algorithms.Bandits;
7using HeuristicLab.Problems.GrammaticalOptimization;
8
9namespace HeuristicLab.Algorithms.GrammaticalOptimization {
10  public class MctsSampler {
11    private class TreeNode {
12      public string ident;
13      public int randomTries;
14      public IBanditPolicyActionInfo actionInfo;
15      public TreeNode[] children;
16      public bool done = false;
17
18      public TreeNode(string id) {
19        this.ident = id;
20      }
21
22      public override string ToString() {
23        return string.Format("Node({0} tries: {1}, done: {2}, policy: {3})", ident, actionInfo.Tries, done, actionInfo);
24      }
25    }
26
27
28    public event Action<string, double> FoundNewBestSolution;
29    public event Action<string, double> SolutionEvaluated;
30
31    private readonly int maxLen;
32    private readonly IProblem problem;
33    private readonly Random random;
34    private readonly int randomTries;
35    private readonly IBanditPolicy policy;
36
37    private List<TreeNode> updateChain;
38    private TreeNode rootNode;
39
40    public int treeDepth;
41    public int treeSize;
42    private double bestQuality;
43
44    public MctsSampler(IProblem problem, int maxLen, Random random, int randomTries, IBanditPolicy policy) {
45      this.maxLen = maxLen;
46      this.problem = problem;
47      this.random = random;
48      this.randomTries = randomTries;
49      this.policy = policy;
50    }
51
52    public void Run(int maxIterations) {
53      bestQuality = double.MinValue;
54      InitPolicies(problem.Grammar);
55      for (int i = 0; !rootNode.done && i < maxIterations; i++) {
56        var sentence = SampleSentence(problem.Grammar).ToString();
57        var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen);
58        Debug.Assert(quality >= 0 && quality <= 1.0);
59        DistributeReward(quality);
60
61        RaiseSolutionEvaluated(sentence, quality);
62
63        if (quality > bestQuality) {
64          bestQuality = quality;
65          RaiseFoundNewBestSolution(sentence, quality);
66        }
67      }
68
69      // clean up
70      InitPolicies(problem.Grammar); GC.Collect();
71    }
72
73    public void PrintStats() {
74      var n = rootNode;
75      Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}, rootQ {3:F3}, bestQ {4:F3}", treeDepth, treeSize, n.actionInfo.Tries, n.actionInfo.Value, bestQuality);
76      while (n.children != null) {
77        Console.WriteLine();
78        Console.WriteLine("{0,5}->{1,-50}", n.ident, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.ident))));
79        Console.WriteLine("{0,5}  {1,-50}", string.Empty, string.Join(" ", n.children.Select(ch => string.Format("{0,4:F2}", ch.actionInfo.Value * 10))));
80        Console.WriteLine("{0,5}  {1,-50}", string.Empty, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.done ? "X" : ch.actionInfo.Tries.ToString()))));
81        //n.policy.PrintStats();
82        n = n.children.Where(ch => !ch.done).OrderByDescending(c => c.actionInfo.Value).First();
83      }
84    }
85
86    private void InitPolicies(IGrammar grammar) {
87      this.updateChain = new List<TreeNode>();
88
89      rootNode = new TreeNode(grammar.SentenceSymbol.ToString());
90      rootNode.actionInfo = policy.CreateActionInfo();
91      treeDepth = 0;
92      treeSize = 0;
93    }
94
95    private Sequence SampleSentence(IGrammar grammar) {
96      updateChain.Clear();
97      var startPhrase = new Sequence(grammar.SentenceSymbol);
98      return CompleteSentence(grammar, startPhrase);
99    }
100
101    private Sequence CompleteSentence(IGrammar g, Sequence phrase) {
102      if (phrase.Length > maxLen) throw new ArgumentException();
103      if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException();
104      TreeNode n = rootNode;
105      var curDepth = 0;
106      while (!phrase.IsTerminal) {
107        updateChain.Add(n);
108
109        if (n.randomTries < randomTries) {
110          n.randomTries++;
111          treeDepth = Math.Max(treeDepth, curDepth);
112          return g.CompleteSentenceRandomly(random, phrase, maxLen);
113        } else {
114          char nt = phrase.FirstNonTerminal;
115
116          int maxLenOfReplacement = maxLen - (phrase.Length - 1); // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2
117          Debug.Assert(maxLenOfReplacement > 0);
118
119          var alts = g.GetAlternatives(nt).Where(alt => g.MinPhraseLength(alt) <= maxLenOfReplacement);
120
121          if (n.randomTries == randomTries && n.children == null) {
122            n.children = alts.Select(alt => new TreeNode(alt.ToString())).ToArray(); // create a new node for each alternative
123            foreach (var ch in n.children) ch.actionInfo = policy.CreateActionInfo();
124            treeSize += n.children.Length;
125          }
126          // => select using bandit policy
127          int selectedAltIdx = policy.SelectAction(random, n.children.Select(c => c.actionInfo));
128          Sequence selectedAlt = alts.ElementAt(selectedAltIdx);
129
130          // replace nt with alt
131          phrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, selectedAlt);
132
133          curDepth++;
134
135          // prepare for next iteration
136          n = n.children[selectedAltIdx];
137        }
138      } // while
139
140      updateChain.Add(n);
141
142
143      // the last node is a leaf node (sentence is done), so we never need to visit this node again
144      n.done = true;
145
146      treeDepth = Math.Max(treeDepth, curDepth);
147      return phrase;
148    }
149
150    private void DistributeReward(double reward) {
151      // iterate in reverse order (bottom up)
152      updateChain.Reverse();
153
154      foreach (var e in updateChain) {
155        var node = e;
156        if (node.done) node.actionInfo.Disable();
157        if (node.children != null && node.children.All(c => c.done)) {
158          node.done = true;
159          node.actionInfo.Disable();
160        }
161        if (!node.done) {
162          node.actionInfo.UpdateReward(reward);
163        }
164      }
165    }
166
167    private void RaiseSolutionEvaluated(string sentence, double quality) {
168      var handler = SolutionEvaluated;
169      if (handler != null) handler(sentence, quality);
170    }
171    private void RaiseFoundNewBestSolution(string sentence, double quality) {
172      var handler = FoundNewBestSolution;
173      if (handler != null) handler(sentence, quality);
174    }
175  }
176}
Note: See TracBrowser for help on using the repository browser.