Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/MctsSampler.cs @ 11747

Last change on this file since 11747 was 11747, checked in by gkronber, 9 years ago

#2283: implemented test problems for MCTS

File size: 7.4 KB
RevLine 
[11727]1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using HeuristicLab.Algorithms.Bandits;
[11747]7using HeuristicLab.Common;
[11727]8using HeuristicLab.Problems.GrammaticalOptimization;
9
10namespace HeuristicLab.Algorithms.GrammaticalOptimization {
11  public class MctsSampler {
12    private class TreeNode {
[11730]13      public string ident;
[11727]14      public int randomTries;
[11742]15      public IBanditPolicyActionInfo actionInfo;
[11747]16      public TreeNode parent;
[11727]17      public TreeNode[] children;
18      public bool done = false;
19
[11747]20      public TreeNode(string id, TreeNode parent) {
[11730]21        this.ident = id;
[11747]22        this.parent = parent;
[11730]23      }
24
[11727]25      public override string ToString() {
[11742]26        return string.Format("Node({0} tries: {1}, done: {2}, policy: {3})", ident, actionInfo.Tries, done, actionInfo);
[11727]27      }
28    }
29
[11730]30
[11727]31    public event Action<string, double> FoundNewBestSolution;
32    public event Action<string, double> SolutionEvaluated;
33
34    private readonly int maxLen;
35    private readonly IProblem problem;
36    private readonly Random random;
37    private readonly int randomTries;
[11742]38    private readonly IBanditPolicy policy;
[11727]39
[11747]40    private TreeNode lastNode; // the bottom node in one episode
[11727]41    private TreeNode rootNode;
42
[11730]43    public int treeDepth;
44    public int treeSize;
[11744]45    private double bestQuality;
[11730]46
[11742]47    public MctsSampler(IProblem problem, int maxLen, Random random, int randomTries, IBanditPolicy policy) {
[11727]48      this.maxLen = maxLen;
49      this.problem = problem;
50      this.random = random;
51      this.randomTries = randomTries;
[11732]52      this.policy = policy;
[11727]53    }
54
55    public void Run(int maxIterations) {
[11744]56      bestQuality = double.MinValue;
[11730]57      InitPolicies(problem.Grammar);
[11727]58      for (int i = 0; !rootNode.done && i < maxIterations; i++) {
[11730]59        var sentence = SampleSentence(problem.Grammar).ToString();
[11732]60        var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen);
[11727]61        Debug.Assert(quality >= 0 && quality <= 1.0);
62        DistributeReward(quality);
63
64        RaiseSolutionEvaluated(sentence, quality);
65
66        if (quality > bestQuality) {
67          bestQuality = quality;
68          RaiseFoundNewBestSolution(sentence, quality);
69        }
70      }
[11730]71
72      // clean up
73      InitPolicies(problem.Grammar); GC.Collect();
[11727]74    }
75
[11730]76    public void PrintStats() {
77      var n = rootNode;
[11744]78      Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}, rootQ {3:F3}, bestQ {4:F3}", treeDepth, treeSize, n.actionInfo.Tries, n.actionInfo.Value, bestQuality);
[11732]79      while (n.children != null) {
[11747]80        Console.WriteLine("{0,-30}", n.ident);
81        double maxVForRow = n.children.Select(ch => ch.actionInfo.Value).Max();
82        if (maxVForRow == 0) maxVForRow = 1.0;
83
84        for (int i = 0; i < n.children.Length; i++) {
85          var ch = n.children[i];
86          SetColorForChild(ch, maxVForRow);
87          Console.Write("{0,5}", ch.ident);
88        }
[11730]89        Console.WriteLine();
[11747]90        for (int i = 0; i < n.children.Length; i++) {
91          var ch = n.children[i];
92          SetColorForChild(ch, maxVForRow);
93          Console.Write("{0,5:F2}", ch.actionInfo.Value * 10);
94        }
95        Console.WriteLine();
96        for (int i = 0; i < n.children.Length; i++) {
97          var ch = n.children[i];
98          SetColorForChild(ch, maxVForRow);
99          Console.Write("{0,5}", ch.done ? "X" : ch.actionInfo.Tries.ToString());
100        }
101        Console.ForegroundColor = ConsoleColor.White;
102        Console.WriteLine();
[11730]103        //n.policy.PrintStats();
[11747]104        //n = n.children.Where(ch => !ch.done).OrderByDescending(c => c.actionInfo.Value).First();
105        n = n.children.Where(ch=>!ch.done).OrderByDescending(c => c.actionInfo.Value).First();
[11730]106      }
[11747]107      Console.WriteLine("-----------------------");
[11730]108    }
109
[11747]110    private void SetColorForChild(TreeNode ch, double maxVForRow) {
111      //if (ch.done) Console.ForegroundColor = ConsoleColor.White;
112      //else
113      Console.ForegroundColor = ConsoleEx.ColorForValue(ch.actionInfo.Value / maxVForRow);
114    }
115
[11730]116    private void InitPolicies(IGrammar grammar) {
117
[11747]118
119      rootNode = new TreeNode(grammar.SentenceSymbol.ToString(), null);
[11732]120      rootNode.actionInfo = policy.CreateActionInfo();
[11730]121      treeDepth = 0;
122      treeSize = 0;
[11727]123    }
124
[11730]125    private Sequence SampleSentence(IGrammar grammar) {
[11747]126      lastNode = null;
[11730]127      var startPhrase = new Sequence(grammar.SentenceSymbol);
[11747]128      //var startPhrase = new Sequence("a*b+c*d+e*f+E");
129
[11730]130      return CompleteSentence(grammar, startPhrase);
[11727]131    }
132
[11730]133    private Sequence CompleteSentence(IGrammar g, Sequence phrase) {
[11727]134      if (phrase.Length > maxLen) throw new ArgumentException();
135      if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException();
136      TreeNode n = rootNode;
[11730]137      var curDepth = 0;
[11742]138      while (!phrase.IsTerminal) {
[11727]139
140        if (n.randomTries < randomTries) {
141          n.randomTries++;
[11730]142          treeDepth = Math.Max(treeDepth, curDepth);
[11747]143          lastNode = n;
[11727]144          return g.CompleteSentenceRandomly(random, phrase, maxLen);
[11732]145        } else {
146          char nt = phrase.FirstNonTerminal;
[11730]147
[11732]148          int maxLenOfReplacement = maxLen - (phrase.Length - 1); // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2
149          Debug.Assert(maxLenOfReplacement > 0);
[11730]150
[11732]151          var alts = g.GetAlternatives(nt).Where(alt => g.MinPhraseLength(alt) <= maxLenOfReplacement);
[11727]152
[11732]153          if (n.randomTries == randomTries && n.children == null) {
[11747]154            n.children = alts.Select(alt => new TreeNode(alt.ToString(), n)).ToArray(); // create a new node for each alternative
[11732]155            foreach (var ch in n.children) ch.actionInfo = policy.CreateActionInfo();
156            treeSize += n.children.Length;
157          }
158          // => select using bandit policy
159          int selectedAltIdx = policy.SelectAction(random, n.children.Select(c => c.actionInfo));
160          Sequence selectedAlt = alts.ElementAt(selectedAltIdx);
[11727]161
[11732]162          // replace nt with alt
163          phrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, selectedAlt);
[11730]164
[11732]165          curDepth++;
166
[11727]167          // prepare for next iteration
168          n = n.children[selectedAltIdx];
169        }
170      } // while
171
[11747]172      lastNode = n;
[11732]173
174
[11727]175      // the last node is a leaf node (sentence is done), so we never need to visit this node again
[11732]176      n.done = true;
[11727]177
[11730]178      treeDepth = Math.Max(treeDepth, curDepth);
[11727]179      return phrase;
180    }
181
182    private void DistributeReward(double reward) {
183      // iterate in reverse order (bottom up)
184
[11747]185      var node = lastNode;
186      while (node != null) {
187        if (node.done) node.actionInfo.Disable(reward);
[11732]188        if (node.children != null && node.children.All(c => c.done)) {
189          node.done = true;
[11747]190          var bestActionValue = node.children.Select(c => c.actionInfo.Value).Max();
191          node.actionInfo.Disable(bestActionValue);
[11732]192        }
193        if (!node.done) {
194          node.actionInfo.UpdateReward(reward);
195        }
[11747]196        node = node.parent;
[11727]197      }
198    }
199
200    private void RaiseSolutionEvaluated(string sentence, double quality) {
201      var handler = SolutionEvaluated;
202      if (handler != null) handler(sentence, quality);
203    }
204    private void RaiseFoundNewBestSolution(string sentence, double quality) {
205      var handler = FoundNewBestSolution;
206      if (handler != null) handler(sentence, quality);
207    }
208  }
209}
Note: See TracBrowser for help on using the repository browser.