source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/MctsContextualSampler.cs @ 11742

Last change on this file since 11742 was 11742, checked in by gkronber, 7 years ago

#2283 refactoring

File size: 6.8 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using HeuristicLab.Algorithms.Bandits;
7using HeuristicLab.Problems.GrammaticalOptimization;
8
9namespace HeuristicLab.Algorithms.GrammaticalOptimization {
10  public class MctsContextualSampler {
11    private class TreeNode {
12      public int randomTries;
13      public int policyTries;
14      public TreeNode[] children;
15      public readonly ReadonlySequence phrase;
16      public readonly ReadonlySequence alt;
17
18      // phrase represents the phrase of the state and alt represents how the phrase has been reached from the parent state
19      public TreeNode(ReadonlySequence phrase, ReadonlySequence alt) {
20        this.phrase = phrase;
21        this.alt = alt;
22      }
23
24      public override string ToString() {
25        return string.Format("Node({0} tries: {1})", phrase, randomTries + policyTries);
26      }
27    }
28
29
30    public event Action<string, double> FoundNewBestSolution;
31    public event Action<string, double> SolutionEvaluated;
32
33    private readonly int maxLen;
34    private readonly IProblem problem;
35    private readonly Random random;
36    private readonly int randomTries;
37    private readonly IGrammarPolicy policy;
38
39    private List<Tuple<ReadonlySequence, ReadonlySequence, ReadonlySequence>> updateChain;
40    private TreeNode rootNode;
41
42    public int treeDepth;
43    public int treeSize;
44
45    // public MctsSampler(IProblem problem, int maxLen, Random random) :
46    //   this(problem, maxLen, random, 10, (rand, numActions) => new EpsGreedyPolicy(rand, numActions, 0.1)) {
47    //
48    // }
49
50    public MctsContextualSampler(IProblem problem, int maxLen, Random random, int randomTries, IGrammarPolicy policy) {
51      this.maxLen = maxLen;
52      this.problem = problem;
53      this.random = random;
54      this.randomTries = randomTries;
55      this.policy = policy;
56    }
57
58    public void Run(int maxIterations) {
59      double bestQuality = double.MinValue;
60      InitPolicies(problem.Grammar);
61      for (int i = 0; !policy.Done(rootNode.phrase) && i < maxIterations; i++) {
62        var sentence = SampleSentence(problem.Grammar).ToString();
63        var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen);
64        Debug.Assert(quality >= 0 && quality <= 1.0);
65        DistributeReward(quality);
66
67        RaiseSolutionEvaluated(sentence, quality);
68
69        if (quality > bestQuality) {
70          bestQuality = quality;
71          RaiseFoundNewBestSolution(sentence, quality);
72        }
73      }
74
75      // clean up
76      InitPolicies(problem.Grammar); GC.Collect();
77    }
78
79    public void PrintStats() {
80      var n = rootNode;
81      Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}", treeDepth, treeSize, rootNode.policyTries + rootNode.randomTries);
82      while (n.children != null) {
83        Console.WriteLine();
84        Console.WriteLine("{0,5}->{1,-50}", n.alt, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.alt))));
85        Console.WriteLine("{0,5}  {1,-50}", string.Empty, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.randomTries + ch.policyTries))));
86        //n.policy.PrintStats();
87        n = n.children.OrderByDescending(c => c.policyTries).First();
88      }
89      Console.ReadLine();
90    }
91
92    private void InitPolicies(IGrammar grammar) {
93      this.updateChain = new List<Tuple<ReadonlySequence, ReadonlySequence, ReadonlySequence>>();
94
95      rootNode = new TreeNode(new ReadonlySequence(grammar.SentenceSymbol), new ReadonlySequence("$"));
96      treeDepth = 0;
97      treeSize = 0;
98    }
99
100    private Sequence SampleSentence(IGrammar grammar) {
101      updateChain.Clear();
102      var startPhrase = new Sequence(rootNode.phrase);
103      return CompleteSentence(grammar, startPhrase);
104    }
105
106    private Sequence CompleteSentence(IGrammar g, Sequence phrase) {
107      if (phrase.Length > maxLen) throw new ArgumentException();
108      if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException();
109      TreeNode parent = null;
110      TreeNode n = rootNode;
111      bool done = false;
112      var curDepth = 0;
113      while (!done) {
114        if (parent != null)
115          updateChain.Add(Tuple.Create(parent.phrase, n.alt, n.phrase));
116
117        if (n.randomTries < randomTries) {
118          n.randomTries++;
119          treeDepth = Math.Max(treeDepth, curDepth);
120          return g.CompleteSentenceRandomly(random, phrase, maxLen);
121        } else {
122          char nt = phrase.FirstNonTerminal;
123
124          int maxLenOfReplacement = maxLen - (phrase.Length - 1); // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2
125          Debug.Assert(maxLenOfReplacement > 0);
126
127          var alts = g.GetAlternatives(nt).Where(alt => g.MinPhraseLength(alt) <= maxLenOfReplacement);
128
129          if (n.randomTries == randomTries && n.children == null) {
130            n.children = new TreeNode[alts.Count()];
131            int cIdx = 0;
132            foreach (var alt in alts) {
133              var newPhrase = new Sequence(phrase);
134              newPhrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, alt);
135              n.children[cIdx++] = new TreeNode(new ReadonlySequence(newPhrase), new ReadonlySequence(alt));
136            }
137            treeSize += n.children.Length;
138          }
139
140          n.policyTries++;
141          // => select using bandit policy
142          ReadonlySequence selectedAlt = policy.SelectAction(random, n.phrase, n.children.Select(c => c.alt));
143
144          // replace nt with alt
145          phrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, selectedAlt);
146
147          curDepth++;
148
149          done = phrase.IsTerminal;
150
151          // prepare for next iteration
152          parent = n;
153          n = n.children.Single(ch => ch.alt == selectedAlt); // TODO: perf
154        }
155      } // while
156
157      n.policyTries++;
158      updateChain.Add(Tuple.Create(parent.phrase, n.alt, n.phrase));
159
160
161      treeDepth = Math.Max(treeDepth, curDepth);
162      return phrase;
163    }
164
165    private void DistributeReward(double reward) {
166      // iterate in reverse order (bottom up)
167      updateChain.Reverse();
168
169      foreach (var e in updateChain) {
170        var state = e.Item1;
171        var action = e.Item2;
172        var newState = e.Item3;
173        policy.UpdateReward(state, action, reward, newState);
174        //policy.UpdateReward(action, reward / updateChain.Count);
175      }
176    }
177
178    private void RaiseSolutionEvaluated(string sentence, double quality) {
179      var handler = SolutionEvaluated;
180      if (handler != null) handler(sentence, quality);
181    }
182    private void RaiseFoundNewBestSolution(string sentence, double quality) {
183      var handler = FoundNewBestSolution;
184      if (handler != null) handler(sentence, quality);
185    }
186  }
187}
Note: See TracBrowser for help on using the repository browser.