Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/MctsSampler.cs @ 11728

Last change on this file since 11728 was 11727, checked in by gkronber, 9 years ago

#2283: worked on grammatical optimization problem solvers (simple MCTS done)

File size: 5.2 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using HeuristicLab.Algorithms.Bandits;
7using HeuristicLab.Problems.GrammaticalOptimization;
8
9namespace HeuristicLab.Algorithms.GrammaticalOptimization {
10  public class MctsSampler {
11    private class TreeNode {
12      public int randomTries;
13      public IPolicy policy;
14      public TreeNode[] children;
15      public bool done = false;
16
17      public override string ToString() {
18        return string.Format("Node(random-tries: {0}, done: {1}, policy: {2})", randomTries, done, policy);
19      }
20    }
21
22    public event Action<string, double> FoundNewBestSolution;
23    public event Action<string, double> SolutionEvaluated;
24
25    private readonly int maxLen;
26    private readonly IProblem problem;
27    private readonly Random random;
28    private readonly int randomTries;
29    private readonly Func<int, IPolicy> policyFactory;
30
31    private List<Tuple<TreeNode, int>> updateChain;
32    private TreeNode rootNode;
33
34    public MctsSampler(IProblem problem, int maxLen, Random random) :
35      this(problem, maxLen, random, 10, (numActions) => new EpsGreedyPolicy(random, numActions, 0.1)) {
36
37    }
38
39    public MctsSampler(IProblem problem, int maxLen, Random random, int randomTries, Func<int, IPolicy> policyFactory) {
40      this.maxLen = maxLen;
41      this.problem = problem;
42      this.random = random;
43      this.randomTries = randomTries;
44      this.policyFactory = policyFactory;
45    }
46
47    public void Run(int maxIterations) {
48      double bestQuality = double.MinValue;
49      InitPolicies();
50      for (int i = 0; !rootNode.done && i < maxIterations; i++) {
51        var sentence = SampleSentence(problem.Grammar);
52        var quality = problem.Evaluate(sentence) / problem.GetBestKnownQuality(maxLen);
53        Debug.Assert(quality >= 0 && quality <= 1.0);
54        DistributeReward(quality);
55
56        RaiseSolutionEvaluated(sentence, quality);
57
58        if (quality > bestQuality) {
59          bestQuality = quality;
60          RaiseFoundNewBestSolution(sentence, quality);
61        }
62      }
63    }
64
65    private void InitPolicies() {
66      this.updateChain = new List<Tuple<TreeNode, int>>();
67      rootNode = new TreeNode();
68    }
69
70    private string SampleSentence(IGrammar grammar) {
71      updateChain.Clear();
72      return CompleteSentence(grammar, grammar.SentenceSymbol.ToString());
73    }
74
75    public string CompleteSentence(IGrammar g, string phrase) {
76      if (phrase.Length > maxLen) throw new ArgumentException();
77      if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException();
78      TreeNode n = rootNode;
79      bool done = phrase.All(g.IsTerminal); // terminal phrase means we are done
80      int selectedAltIdx = -1;
81      while (!done) {
82        int ntIdx; char nt;
83        Grammar.FindFirstNonTerminal(g, phrase, out nt, out ntIdx);
84
85        int maxLenOfReplacement = maxLen - (phrase.Length - 1); // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2
86        Debug.Assert(maxLenOfReplacement > 0);
87
88        var alts = g.GetAlternatives(nt).Where(alt => g.MinPhraseLength(alt) <= maxLenOfReplacement);
89
90        if (n.randomTries < randomTries) {
91          n.randomTries++;
92          return g.CompleteSentenceRandomly(random, phrase, maxLen);
93        } else if (n.randomTries == randomTries && n.policy == null) {
94          n.policy = policyFactory(alts.Count());
95          n.children = alts.Select(_ => new TreeNode()).ToArray(); // create a new node for each alternative
96        }
97
98        // => select using bandit policy
99        selectedAltIdx = n.policy.SelectAction();
100        string selectedAlt = alts.ElementAt(selectedAltIdx);
101        // replace nt with alt
102        phrase = phrase.Remove(ntIdx, 1);
103        phrase = phrase.Insert(ntIdx, selectedAlt);
104
105        updateChain.Add(Tuple.Create(n, selectedAltIdx));
106
107        done = phrase.All(g.IsTerminal); // terminal phrase means we are done
108        if (!done) {
109          // prepare for next iteration
110          n = n.children[selectedAltIdx];
111          Debug.Assert(!n.done);
112        }
113      } // while
114
115      // the last node is a leaf node (sentence is done), so we never need to visit this node again
116      n.children[selectedAltIdx].done = true;
117
118      return phrase;
119    }
120
121    private void DistributeReward(double reward) {
122      // iterate in reverse order (bottom up)
123      updateChain.Reverse();
124
125      foreach (var e in updateChain) {
126        var node = e.Item1;
127        var policy = node.policy;
128        var action = e.Item2;
129        policy.UpdateReward(action, reward);
130
131        if (node.children[action].done) node.policy.DisableAction(action);
132        if (node.children.All(c => c.done)) node.done = true;
133      }
134    }
135
136    private void RaiseSolutionEvaluated(string sentence, double quality) {
137      var handler = SolutionEvaluated;
138      if (handler != null) handler(sentence, quality);
139    }
140    private void RaiseFoundNewBestSolution(string sentence, double quality) {
141      var handler = FoundNewBestSolution;
142      if (handler != null) handler(sentence, quality);
143    }
144  }
145}
Note: See TracBrowser for help on using the repository browser.