source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/TemporalDifferenceTreeSearchSampler.cs @ 11747

Last change on this file since 11747 was 11747, checked in by gkronber, 7 years ago

#2283: implemented test problems for MCTS

File size: 7.8 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using HeuristicLab.Algorithms.Bandits;
7using HeuristicLab.Common;
8using HeuristicLab.Problems.GrammaticalOptimization;
9
10namespace HeuristicLab.Algorithms.GrammaticalOptimization {
11  // SARSA (fig. 6.9 in Sutton & Barto)
12  public class TemporalDifferenceTreeSearchSampler {
13    private class TreeNode {
14      public string ident;
15      public int randomTries;
16      public double q;
17      public int tries;
18      public TreeNode[] children;
19      public bool done = false;
20
21      public TreeNode(string id) {
22        this.ident = id;
23      }
24
25      public override string ToString() {
26        return string.Format("Node({0} tries: {1}, done: {2})", ident, tries, done);
27      }
28    }
29
30
31    public event Action<string, double> FoundNewBestSolution;
32    public event Action<string, double> SolutionEvaluated;
33
34    private readonly int maxLen;
35    private readonly IProblem problem;
36    private readonly Random random;
37    private readonly int randomTries;
38
39    private List<TreeNode> updateChain;
40    private TreeNode rootNode;
41
42    public int treeDepth;
43    public int treeSize;
44    private double bestQuality;
45
46
47    public TemporalDifferenceTreeSearchSampler(IProblem problem, int maxLen, Random random, int randomTries) {
48      this.maxLen = maxLen;
49      this.problem = problem;
50      this.random = random;
51      this.randomTries = randomTries;
52    }
53
54    public void Run(int maxIterations) {
55      InitPolicies(problem.Grammar);
56      for (int i = 0; !rootNode.done && i < maxIterations; i++) {
57        var sentence = SampleSentence(problem.Grammar).ToString();
58        var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen);
59        Debug.Assert(quality >= 0 && quality <= 1.0);
60        DistributeReward(quality);
61
62        RaiseSolutionEvaluated(sentence, quality);
63
64        if (quality > bestQuality) {
65          bestQuality = quality;
66          RaiseFoundNewBestSolution(sentence, quality);
67        }
68      }
69
70      // clean up
71      InitPolicies(problem.Grammar); GC.Collect();
72    }
73
74    public void PrintStats() {
75      var n = rootNode;
76      Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}, rootQ {3:F3}, bestQ {4:F3}", treeDepth, treeSize, n.tries, n.q, bestQuality);
77      while (n.children != null) {
78        Console.WriteLine("{0,-30}", n.ident);
79        double maxVForRow = n.children.Select(ch => ch.q).Max();
80        if (maxVForRow == 0) maxVForRow = 1.0;
81
82        for (int i = 0; i < n.children.Length; i++) {
83          var ch = n.children[i];
84          Console.ForegroundColor = ConsoleEx.ColorForValue(ch.q / maxVForRow);
85          Console.Write("{0,5}", ch.ident);
86        }
87        Console.WriteLine();
88        for (int i = 0; i < n.children.Length; i++) {
89          var ch = n.children[i];
90          Console.ForegroundColor = ConsoleEx.ColorForValue(ch.q / maxVForRow);
91          Console.Write("{0,5:F2}", ch.q * 10);
92        }
93        Console.WriteLine();
94        for (int i = 0; i < n.children.Length; i++) {
95          var ch = n.children[i];
96          Console.ForegroundColor = ConsoleEx.ColorForValue(ch.q / maxVForRow);
97          Console.Write("{0,5}", ch.done ? "X" : ch.tries.ToString());
98        }
99        Console.ForegroundColor = ConsoleColor.White;
100        Console.WriteLine();
101        //n.policy.PrintStats();
102        n = n.children.Where(ch => !ch.done).OrderByDescending(c => c.q).First();
103      }
104    }
105
106    private void InitPolicies(IGrammar grammar) {
107      this.updateChain = new List<TreeNode>();
108
109      rootNode = new TreeNode(grammar.SentenceSymbol.ToString());
110      treeDepth = 0;
111      treeSize = 0;
112    }
113
114    private Sequence SampleSentence(IGrammar grammar) {
115      updateChain.Clear();
116      var startPhrase = new Sequence(grammar.SentenceSymbol);
117      return CompleteSentence(grammar, startPhrase);
118    }
119
120    private Sequence CompleteSentence(IGrammar g, Sequence phrase) {
121      if (phrase.Length > maxLen) throw new ArgumentException();
122      if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException();
123      TreeNode n = rootNode;
124      var curDepth = 0;
125      while (!phrase.IsTerminal) {
126        updateChain.Add(n);
127
128        if (n.randomTries < randomTries) {
129          n.randomTries++;
130          treeDepth = Math.Max(treeDepth, curDepth);
131          return g.CompleteSentenceRandomly(random, phrase, maxLen);
132        } else {
133          char nt = phrase.FirstNonTerminal;
134
135          int maxLenOfReplacement = maxLen - (phrase.Length - 1); // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2
136          Debug.Assert(maxLenOfReplacement > 0);
137
138          var alts = g.GetAlternatives(nt).Where(alt => g.MinPhraseLength(alt) <= maxLenOfReplacement);
139
140          if (n.randomTries == randomTries && n.children == null) {
141            n.children = alts.Select(alt => new TreeNode(alt.ToString())).ToArray(); // create a new node for each alternative
142            treeSize += n.children.Length;
143          }
144          // => select using bandit policy
145          int selectedAltIdx = SelectEpsGreedy(random, n.children);
146          Sequence selectedAlt = alts.ElementAt(selectedAltIdx);
147
148          // replace nt with alt
149          phrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, selectedAlt);
150
151          curDepth++;
152
153          // prepare for next iteration
154          n = n.children[selectedAltIdx];
155        }
156      } // while
157
158      updateChain.Add(n);
159
160
161      // the last node is a leaf node (sentence is done), so we never need to visit this node again
162      n.done = true;
163
164      treeDepth = Math.Max(treeDepth, curDepth);
165      return phrase;
166    }
167
168
169    // eps-greedy
170    private int SelectEpsGreedy(Random random, TreeNode[] children) {
171      if (random.NextDouble() < 0.1) {
172
173        return children.Select((ch, i) => Tuple.Create(ch, i)).Where(p => !p.Item1.done).SelectRandom(random).Item2;
174      } else {
175        var bestQ = double.NegativeInfinity;
176        var bestChildIdx = new List<int>();
177        for (int i = 0; i < children.Length; i++) {
178          if (children[i].done) continue;
179          // if (children[i].tries == 0) return i;
180          var q = children[i].q;
181          if (q > bestQ) {
182            bestQ = q;
183            bestChildIdx.Clear();
184            bestChildIdx.Add(i);
185          } else if (q == bestQ) {
186            bestChildIdx.Add(i);
187          }
188        }
189        Debug.Assert(bestChildIdx.Any());
190        return bestChildIdx.SelectRandom(random);
191      }
192    }
193
194    private void DistributeReward(double reward) {
195      updateChain.Reverse();
196      foreach (var node in updateChain) {
197        if (node.children != null && node.children.All(c => c.done)) {
198          node.done = true;
199        }
200      }
201      updateChain.Reverse();
202
203      //const double alpha = 0.1;
204      const double gamma = 1;
205      double alpha;
206      foreach (var p in updateChain.Zip(updateChain.Skip(1), Tuple.Create)) {
207        var parent = p.Item1;
208        var child = p.Item2;
209        parent.tries++;
210        alpha = 1.0 / parent.tries;
211        //alpha = 0.01;
212        parent.q = parent.q + alpha * (0 + gamma * child.q - parent.q);
213      }
214      // reward is recieved only for the last action
215      var n = updateChain.Last();
216      n.tries++;
217      alpha = 1.0 / n.tries;
218      //alpha = 0.1;
219      n.q = n.q + alpha * reward;
220    }
221
222    private void RaiseSolutionEvaluated(string sentence, double quality) {
223      var handler = SolutionEvaluated;
224      if (handler != null) handler(sentence, quality);
225    }
226    private void RaiseFoundNewBestSolution(string sentence, double quality) {
227      var handler = FoundNewBestSolution;
228      if (handler != null) handler(sentence, quality);
229    }
230  }
231}
Note: See TracBrowser for help on using the repository browser.