source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/MctsSampler.cs @ 11730

Last change on this file since 11730 was 11730, checked in by gkronber, 5 years ago

#2283: several major extensions for grammatical optimization

File size: 6.6 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using HeuristicLab.Algorithms.Bandits;
7using HeuristicLab.Problems.GrammaticalOptimization;
8
9namespace HeuristicLab.Algorithms.GrammaticalOptimization {
10  public class MctsSampler {
11    private class TreeNode {
12      public string ident;
13      public int randomTries;
14      public int policyTries;
15      public IPolicy policy;
16      public TreeNode[] children;
17      public bool done = false;
18
19      public TreeNode(string id) {
20        this.ident = id;
21      }
22
23      public override string ToString() {
24        return string.Format("Node({0} tries: {1}, done: {2}, policy: {3})", ident, randomTries + policyTries, done, policy);
25      }
26    }
27
28
29    public event Action<string, double> FoundNewBestSolution;
30    public event Action<string, double> SolutionEvaluated;
31
32    private readonly int maxLen;
33    private readonly IProblem problem;
34    private readonly Random random;
35    private readonly int randomTries;
36    private readonly Func<Random, int, IPolicy> policyFactory;
37
38    private List<Tuple<TreeNode, int>> updateChain;
39    private TreeNode rootNode;
40
41    public int treeDepth;
42    public int treeSize;
43
44    public MctsSampler(IProblem problem, int maxLen, Random random) :
45      this(problem, maxLen, random, 10, (rand, numActions) => new EpsGreedyPolicy(rand, numActions, 0.1)) {
46
47    }
48
49    public MctsSampler(IProblem problem, int maxLen, Random random, int randomTries, Func<Random, int, IPolicy> policyFactory) {
50      this.maxLen = maxLen;
51      this.problem = problem;
52      this.random = random;
53      this.randomTries = randomTries;
54      this.policyFactory = policyFactory;
55    }
56
57    public void Run(int maxIterations) {
58      double bestQuality = double.MinValue;
59      InitPolicies(problem.Grammar);
60      for (int i = 0; !rootNode.done && i < maxIterations; i++) {
61        var sentence = SampleSentence(problem.Grammar).ToString();
62        var quality = problem.Evaluate(sentence) / problem.GetBestKnownQuality(maxLen);
63        Debug.Assert(quality >= 0 && quality <= 1.0);
64        DistributeReward(quality);
65
66        RaiseSolutionEvaluated(sentence, quality);
67
68        if (quality > bestQuality) {
69          bestQuality = quality;
70          RaiseFoundNewBestSolution(sentence, quality);
71        }
72      }
73
74      // clean up
75      InitPolicies(problem.Grammar); GC.Collect();
76    }
77
78    public void PrintStats() {
79      var n = rootNode;
80      Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}", treeDepth, treeSize, rootNode.policyTries + rootNode.randomTries);
81      while (n.policy != null) {
82        Console.WriteLine();
83        Console.WriteLine("{0,5}->{1,-50}", n.ident, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.ident))));
84        Console.WriteLine("{0,5}  {1,-50}", string.Empty, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.randomTries + ch.policyTries))));
85        //n.policy.PrintStats();
86        n = n.children.OrderByDescending(c => c.policyTries).First();
87      }
88      Console.ReadLine();
89    }
90
91    private void InitPolicies(IGrammar grammar) {
92      this.updateChain = new List<Tuple<TreeNode, int>>();
93
94      rootNode = new TreeNode(grammar.SentenceSymbol.ToString());
95      treeDepth = 0;
96      treeSize = 0;
97    }
98
99    private Sequence SampleSentence(IGrammar grammar) {
100      updateChain.Clear();
101      var startPhrase = new Sequence(grammar.SentenceSymbol);
102      return CompleteSentence(grammar, startPhrase);
103    }
104
105    private Sequence CompleteSentence(IGrammar g, Sequence phrase) {
106      if (phrase.Length > maxLen) throw new ArgumentException();
107      if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException();
108      TreeNode n = rootNode;
109      bool done = phrase.IsTerminal;
110      int selectedAltIdx = -1;
111      var curDepth = 0;
112      while (!done) {
113        char nt = phrase.FirstNonTerminal;
114
115        int maxLenOfReplacement = maxLen - (phrase.Length - 1); // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2
116        Debug.Assert(maxLenOfReplacement > 0);
117
118        var alts = g.GetAlternatives(nt).Where(alt => g.MinPhraseLength(alt) <= maxLenOfReplacement);
119
120        if (n.randomTries < randomTries) {
121          n.randomTries++;
122
123          treeDepth = Math.Max(treeDepth, curDepth);
124
125          return g.CompleteSentenceRandomly(random, phrase, maxLen);
126        } else if (n.randomTries == randomTries && n.policy == null) {
127          n.policy = policyFactory(random, alts.Count());
128          //n.children = alts.Select(alt => new TreeNode(alt.ToString())).ToArray(); // create a new node for each alternative
129          n.children = alts.Select(alt => new TreeNode(string.Empty)).ToArray(); // create a new node for each alternative
130
131          treeSize += n.children.Length;
132        }
133        n.policyTries++;
134        // => select using bandit policy
135        selectedAltIdx = n.policy.SelectAction();
136        Sequence selectedAlt = alts.ElementAt(selectedAltIdx);
137
138        // replace nt with alt
139        phrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, selectedAlt);
140
141        updateChain.Add(Tuple.Create(n, selectedAltIdx));
142
143        curDepth++;
144
145        done = phrase.IsTerminal;
146        if (!done) {
147          // prepare for next iteration
148          n = n.children[selectedAltIdx];
149          Debug.Assert(!n.done);
150        }
151      } // while
152
153      // the last node is a leaf node (sentence is done), so we never need to visit this node again
154      n.children[selectedAltIdx].done = true;
155
156      treeDepth = Math.Max(treeDepth, curDepth);
157      return phrase;
158    }
159
160    private void DistributeReward(double reward) {
161      // iterate in reverse order (bottom up)
162      updateChain.Reverse();
163
164      foreach (var e in updateChain) {
165        var node = e.Item1;
166        var policy = node.policy;
167        var action = e.Item2;
168        //policy.UpdateReward(action, reward / updateChain.Count);
169        policy.UpdateReward(action, reward);
170
171        if (node.children[action].done) node.policy.DisableAction(action);
172        if (node.children.All(c => c.done)) node.done = true;
173      }
174    }
175
176    private void RaiseSolutionEvaluated(string sentence, double quality) {
177      var handler = SolutionEvaluated;
178      if (handler != null) handler(sentence, quality);
179    }
180    private void RaiseFoundNewBestSolution(string sentence, double quality) {
181      var handler = FoundNewBestSolution;
182      if (handler != null) handler(sentence, quality);
183    }
184  }
185}
Note: See TracBrowser for help on using the repository browser.