Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/TemporalDifferenceTreeSearchSampler.cs @ 11745

Last change on this file since 11745 was 11744, checked in by gkronber, 10 years ago

#2283 worked on TD, and models for MCTS

File size: 7.1 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using HeuristicLab.Algorithms.Bandits;
7using HeuristicLab.Common;
8using HeuristicLab.Problems.GrammaticalOptimization;
9
10namespace HeuristicLab.Algorithms.GrammaticalOptimization {
11  // SARSA (fig. 6.9 in Sutton & Barto)
12  public class TemporalDifferenceTreeSearchSampler {
13    private class TreeNode {
14      public string ident;
15      public int randomTries;
16      public double q;
17      public int tries;
18      public TreeNode[] children;
19      public bool done = false;
20
21      public TreeNode(string id) {
22        this.ident = id;
23      }
24
25      public override string ToString() {
26        return string.Format("Node({0} tries: {1}, done: {2})", ident, tries, done);
27      }
28    }
29
30
31    public event Action<string, double> FoundNewBestSolution;
32    public event Action<string, double> SolutionEvaluated;
33
34    private readonly int maxLen;
35    private readonly IProblem problem;
36    private readonly Random random;
37    private readonly int randomTries;
38    private readonly IBanditPolicy policy;
39
40    private List<TreeNode> updateChain;
41    private TreeNode rootNode;
42
43    public int treeDepth;
44    public int treeSize;
45    private double bestQuality;
46
47
48    public TemporalDifferenceTreeSearchSampler(IProblem problem, int maxLen, Random random, int randomTries, IBanditPolicy policy) {
49      this.maxLen = maxLen;
50      this.problem = problem;
51      this.random = random;
52      this.randomTries = randomTries;
53      this.policy = policy;
54    }
55
56    public void Run(int maxIterations) {
57      InitPolicies(problem.Grammar);
58      for (int i = 0; !rootNode.done && i < maxIterations; i++) {
59        var sentence = SampleSentence(problem.Grammar).ToString();
60        var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen);
61        Debug.Assert(quality >= 0 && quality <= 1.0);
62        DistributeReward(quality);
63
64        RaiseSolutionEvaluated(sentence, quality);
65
66        if (quality > bestQuality) {
67          bestQuality = quality;
68          RaiseFoundNewBestSolution(sentence, quality);
69        }
70      }
71
72      // clean up
73      InitPolicies(problem.Grammar); GC.Collect();
74    }
75
76    public void PrintStats() {
77      var n = rootNode;
78      Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}, rootQ {3:F3}, bestQ {4:F3}", treeDepth, treeSize, n.tries, n.q, bestQuality);
79      while (n.children != null) {
80        Console.WriteLine();
81        Console.WriteLine("{0,5}->{1,-50}", n.ident, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.ident))));
82        Console.WriteLine("{0,5}  {1,-50}", string.Empty, string.Join(" ", n.children.Select(ch => string.Format("{0,4:F2}", ch.q * 10))));
83        Console.WriteLine("{0,5}  {1,-50}", string.Empty, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.done ? "X" : ch.tries.ToString()))));
84        //n.policy.PrintStats();
85        n = n.children.Where(ch => !ch.done).OrderByDescending(c => c.q).First();
86      }
87      //Console.ReadLine();
88    }
89
90    private void InitPolicies(IGrammar grammar) {
91      this.updateChain = new List<TreeNode>();
92
93      rootNode = new TreeNode(grammar.SentenceSymbol.ToString());
94      treeDepth = 0;
95      treeSize = 0;
96    }
97
98    private Sequence SampleSentence(IGrammar grammar) {
99      updateChain.Clear();
100      var startPhrase = new Sequence(grammar.SentenceSymbol);
101      return CompleteSentence(grammar, startPhrase);
102    }
103
104    private Sequence CompleteSentence(IGrammar g, Sequence phrase) {
105      if (phrase.Length > maxLen) throw new ArgumentException();
106      if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException();
107      TreeNode n = rootNode;
108      var curDepth = 0;
109      while (!phrase.IsTerminal) {
110        updateChain.Add(n);
111
112        if (n.randomTries < randomTries) {
113          n.randomTries++;
114          treeDepth = Math.Max(treeDepth, curDepth);
115          return g.CompleteSentenceRandomly(random, phrase, maxLen);
116        } else {
117          char nt = phrase.FirstNonTerminal;
118
119          int maxLenOfReplacement = maxLen - (phrase.Length - 1); // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2
120          Debug.Assert(maxLenOfReplacement > 0);
121
122          var alts = g.GetAlternatives(nt).Where(alt => g.MinPhraseLength(alt) <= maxLenOfReplacement);
123
124          if (n.randomTries == randomTries && n.children == null) {
125            n.children = alts.Select(alt => new TreeNode(alt.ToString())).ToArray(); // create a new node for each alternative
126            treeSize += n.children.Length;
127          }
128          // => select using bandit policy
129          int selectedAltIdx = SelectAction(random, n.children);
130          Sequence selectedAlt = alts.ElementAt(selectedAltIdx);
131
132          // replace nt with alt
133          phrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, selectedAlt);
134
135          curDepth++;
136
137          // prepare for next iteration
138          n = n.children[selectedAltIdx];
139        }
140      } // while
141
142      updateChain.Add(n);
143
144
145      // the last node is a leaf node (sentence is done), so we never need to visit this node again
146      n.done = true;
147
148      treeDepth = Math.Max(treeDepth, curDepth);
149      return phrase;
150    }
151
152
153    // eps-greedy
154    private int SelectAction(Random random, TreeNode[] children) {
155      if (random.NextDouble() < 0.1) {
156
157        return children.Select((ch, i) => Tuple.Create(ch, i)).Where(p => !p.Item1.done).SelectRandom(random).Item2;
158      } else {
159        var bestQ = double.NegativeInfinity;
160        var bestChildIdx = -1;
161        for (int i = 0; i < children.Length; i++) {
162          if (children[i].done) continue;
163          if (children[i].tries == 0) return i;
164          if (children[i].q > bestQ) {
165            bestQ = children[i].q;
166            bestChildIdx = i;
167          }
168        }
169        Debug.Assert(bestChildIdx > -1);
170        return bestChildIdx;
171      }
172    }
173
174    private void DistributeReward(double reward) {
175      const double alpha = 0.1;
176      const double gamma = 1;
177      // iterate in reverse order (bottom up)
178      updateChain.Reverse();
179      var nextQ = 0.0;
180      foreach (var e in updateChain) {
181        var node = e;
182        node.tries++;
183        if (node.children != null && node.children.All(c => c.done)) {
184          node.done = true;
185        }
186        // reward is recieved only for the last action
187        if (e == updateChain.First()) {
188          node.q = node.q + alpha * (reward + gamma * nextQ - node.q);
189          nextQ = node.q;
190        } else {
191          node.q = node.q + alpha * (0 + gamma * nextQ - node.q);
192          nextQ = node.q;
193        }
194      }
195    }
196
197    private void RaiseSolutionEvaluated(string sentence, double quality) {
198      var handler = SolutionEvaluated;
199      if (handler != null) handler(sentence, quality);
200    }
201    private void RaiseFoundNewBestSolution(string sentence, double quality) {
202      var handler = FoundNewBestSolution;
203      if (handler != null) handler(sentence, quality);
204    }
205  }
206}
Note: See TracBrowser for help on using the repository browser.