Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
01/10/15 14:06:29 (9 years ago)
Author:
gkronber
Message:

#2283: worked on contextual MCTS

Location:
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/MctsContextualSampler.cs

    r11742 r11745  
    55using System.Text;
    66using HeuristicLab.Algorithms.Bandits;
     7using HeuristicLab.Common;
    78using HeuristicLab.Problems.GrammaticalOptimization;
    89
     
    1011  public class MctsContextualSampler {
    1112    private class TreeNode {
     13      public string ident;
     14      public ReadonlySequence alt;
    1215      public int randomTries;
    13       public int policyTries;
     16      public int tries;
    1417      public TreeNode[] children;
    15       public readonly ReadonlySequence phrase;
    16       public readonly ReadonlySequence alt;
    17 
    18       // phrase represents the phrase of the state and alt represents how the phrase has been reached from the parent state
    19       public TreeNode(ReadonlySequence phrase, ReadonlySequence alt) {
    20         this.phrase = phrase;
     18      public bool done = false;
     19
     20      public TreeNode(string id, ReadonlySequence alt) {
     21        this.ident = id;
    2122        this.alt = alt;
    2223      }
    2324
    2425      public override string ToString() {
    25         return string.Format("Node({0} tries: {1})", phrase, randomTries + policyTries);
     26        return string.Format("Node({0} tries: {1}, done: {2})", ident, tries, done);
    2627      }
    2728    }
     
    3536    private readonly Random random;
    3637    private readonly int randomTries;
    37     private readonly IGrammarPolicy policy;
    38 
    39     private List<Tuple<ReadonlySequence, ReadonlySequence, ReadonlySequence>> updateChain;
     38
     39    private List<Tuple<TreeNode, TreeNode>> updateChain;
    4040    private TreeNode rootNode;
    4141
    4242    public int treeDepth;
    4343    public int treeSize;
    44 
    45     // public MctsSampler(IProblem problem, int maxLen, Random random) :
    46     //   this(problem, maxLen, random, 10, (rand, numActions) => new EpsGreedyPolicy(rand, numActions, 0.1)) {
    47     //
    48     // }
    49 
    50     public MctsContextualSampler(IProblem problem, int maxLen, Random random, int randomTries, IGrammarPolicy policy) {
     44    private double bestQuality;
     45
     46    public MctsContextualSampler(IProblem problem, int maxLen, Random random, int randomTries) {
    5147      this.maxLen = maxLen;
    5248      this.problem = problem;
    5349      this.random = random;
    5450      this.randomTries = randomTries;
    55       this.policy = policy;
     51      this.v = new Dictionary<string, double>(1000000);
     52      this.tries = new Dictionary<string, int>(1000000);
    5653    }
    5754
    5855    public void Run(int maxIterations) {
    59       double bestQuality = double.MinValue;
     56      bestQuality = double.MinValue;
    6057      InitPolicies(problem.Grammar);
    61       for (int i = 0; !policy.Done(rootNode.phrase) && i < maxIterations; i++) {
     58      for (int i = 0; !rootNode.done && i < maxIterations; i++) {
    6259        var sentence = SampleSentence(problem.Grammar).ToString();
    6360        var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen);
     
    7976    public void PrintStats() {
    8077      var n = rootNode;
    81       Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}", treeDepth, treeSize, rootNode.policyTries + rootNode.randomTries);
     78      Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}, rootQ {3:F3}, bestQ {4:F3}", treeDepth, treeSize, n.tries, V(n), bestQuality);
    8279      while (n.children != null) {
     80        Console.WriteLine("{0}", n.ident);
     81        double maxVForRow = n.children.Select(ch => V(ch)).Max();
     82        if (maxVForRow == 0) maxVForRow = 1.0;
     83
     84        for (int i = 0; i < n.children.Length; i++) {
     85          var ch = n.children[i];
     86          Console.ForegroundColor = ConsoleEx.ColorForValue(V(ch) / maxVForRow);
     87          Console.Write("{0,5}", ch.alt);
     88        }
    8389        Console.WriteLine();
    84         Console.WriteLine("{0,5}->{1,-50}", n.alt, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.alt))));
    85         Console.WriteLine("{0,5}  {1,-50}", string.Empty, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.randomTries + ch.policyTries))));
     90        for (int i = 0; i < n.children.Length; i++) {
     91          var ch = n.children[i];
     92          Console.ForegroundColor = ConsoleEx.ColorForValue(V(ch) / maxVForRow);
     93          Console.Write("{0,5:F2}", V(ch) * 10);
     94        }
     95        Console.WriteLine();
     96        for (int i = 0; i < n.children.Length; i++) {
     97          var ch = n.children[i];
     98          Console.ForegroundColor = ConsoleEx.ColorForValue(V(ch) / maxVForRow);
     99          Console.Write("{0,5}", ch.done ? "X" : ch.tries.ToString());
     100        }
     101        Console.ForegroundColor = ConsoleColor.White;
     102        Console.WriteLine();
    86103        //n.policy.PrintStats();
    87         n = n.children.OrderByDescending(c => c.policyTries).First();
    88       }
    89       Console.ReadLine();
    90     }
     104        n = n.children.Where(ch => !ch.done).OrderByDescending(c => V(c)).First();
     105      }
     106    }
     107
    91108
    92109    private void InitPolicies(IGrammar grammar) {
    93       this.updateChain = new List<Tuple<ReadonlySequence, ReadonlySequence, ReadonlySequence>>();
    94 
    95       rootNode = new TreeNode(new ReadonlySequence(grammar.SentenceSymbol), new ReadonlySequence("$"));
     110      this.updateChain = new List<Tuple<TreeNode, TreeNode>>();
     111      this.v.Clear();
     112      this.tries.Clear();
     113
     114      rootNode = new TreeNode(grammar.SentenceSymbol.ToString(), new ReadonlySequence("$"));
    96115      treeDepth = 0;
    97116      treeSize = 0;
     
    100119    private Sequence SampleSentence(IGrammar grammar) {
    101120      updateChain.Clear();
    102       var startPhrase = new Sequence(rootNode.phrase);
     121      //var startPhrase = new Sequence("a*b+c*d+e*f+E");
     122      var startPhrase = new Sequence(grammar.SentenceSymbol);
    103123      return CompleteSentence(grammar, startPhrase);
    104124    }
     
    109129      TreeNode parent = null;
    110130      TreeNode n = rootNode;
    111       bool done = false;
    112131      var curDepth = 0;
    113       while (!done) {
    114         if (parent != null)
    115           updateChain.Add(Tuple.Create(parent.phrase, n.alt, n.phrase));
     132      while (!phrase.IsTerminal) {
     133        updateChain.Add(Tuple.Create(n, parent));
    116134
    117135        if (n.randomTries < randomTries) {
     
    128146
    129147          if (n.randomTries == randomTries && n.children == null) {
     148            // create a new node for each alternative
    130149            n.children = new TreeNode[alts.Count()];
    131             int cIdx = 0;
     150            var i = 0;
    132151            foreach (var alt in alts) {
    133152              var newPhrase = new Sequence(phrase);
    134               newPhrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, alt);
    135               n.children[cIdx++] = new TreeNode(new ReadonlySequence(newPhrase), new ReadonlySequence(alt));
     153              newPhrase.ReplaceAt(newPhrase.FirstNonTerminalIndex, 1, alt);
     154              if (!newPhrase.IsTerminal) newPhrase = newPhrase.Subsequence(0, newPhrase.FirstNonTerminalIndex + 1);
     155              n.children[i++] = new TreeNode(newPhrase.ToString(), new ReadonlySequence(alt));
    136156            }
    137157            treeSize += n.children.Length;
    138158          }
    139 
    140           n.policyTries++;
    141           // => select using bandit policy
    142           ReadonlySequence selectedAlt = policy.SelectAction(random, n.phrase, n.children.Select(c => c.alt));
     159          // => select using eps-greedy
     160          int selectedAltIdx = SelectEpsGreedy(random, n.children);
     161
     162          //int selectedAltIdx = SelectActionUCB1(random, n.children);
     163          Sequence selectedAlt = alts.ElementAt(selectedAltIdx);
    143164
    144165          // replace nt with alt
     
    147168          curDepth++;
    148169
    149           done = phrase.IsTerminal;
    150 
    151170          // prepare for next iteration
    152171          parent = n;
    153           n = n.children.Single(ch => ch.alt == selectedAlt); // TODO: perf
     172          n = n.children[selectedAltIdx];
    154173        }
    155174      } // while
    156175
    157       n.policyTries++;
    158       updateChain.Add(Tuple.Create(parent.phrase, n.alt, n.phrase));
     176      updateChain.Add(Tuple.Create(n, parent));
     177
     178      // the last node is a leaf node (sentence is done), so we never need to visit this node again
     179      n.done = true;
    159180
    160181
     
    168189
    169190      foreach (var e in updateChain) {
    170         var state = e.Item1;
    171         var action = e.Item2;
    172         var newState = e.Item3;
    173         policy.UpdateReward(state, action, reward, newState);
    174         //policy.UpdateReward(action, reward / updateChain.Count);
    175       }
    176     }
     191        var node = e.Item1;
     192        var parent = e.Item2;
     193        node.tries++;
     194        if (node.children != null && node.children.All(c => c.done)) {
     195          node.done = true;
     196        }
     197        UpdateV(node, reward);
     198
     199        // the reward for the parent is either the just recieved reward or the value of the best action so far
     200        double value = 0.0;
     201        if (parent != null) {
     202          var doneChilds = parent.children.Where(ch => ch.done);
     203          if (doneChilds.Any()) value = doneChilds.Select(ch => V(ch)).Max();
     204        }
     205        //if (value > reward) reward = value;
     206      }
     207    }
     208
     209    private Dictionary<string, double> v;
     210    private Dictionary<string, int> tries;
     211
     212    private void UpdateV(TreeNode n, double reward) {
     213      var canonicalStr = problem.CanonicalRepresentation(n.ident);
     214      //var canonicalStr = n.ident;
     215      double stateV;
     216
     217      if (!v.TryGetValue(canonicalStr, out  stateV)) {
     218        v.Add(canonicalStr, reward);
     219        tries.Add(canonicalStr, 1);
     220      } else {
     221        v[canonicalStr] = stateV + 0.005 * (reward - stateV);
     222        //v[canonicalStr] = stateV + (1.0 / tries[canonicalStr]) * (reward - stateV);
     223        tries[canonicalStr]++;
     224      }
     225    }
     226
     227    private double V(TreeNode n) {
     228      var canonicalStr = problem.CanonicalRepresentation(n.ident);
     229      //var canonicalStr = n.ident;
     230      double stateV;
     231
     232      if (!v.TryGetValue(canonicalStr, out  stateV)) {
     233        return 0.0;
     234      } else {
     235        return stateV;
     236      }
     237    }
     238
     239    private int SelectEpsGreedy(Random random, TreeNode[] children) {
     240      if (random.NextDouble() < 0.2) {
     241
     242        return children.Select((ch, i) => Tuple.Create(ch, i)).Where(p => !p.Item1.done).SelectRandom(random).Item2;
     243      } else {
     244        var bestQ = double.NegativeInfinity;
     245        var bestChildIdx = new List<int>();
     246        for (int i = 0; i < children.Length; i++) {
     247          if (children[i].done) continue;
     248          // if (children[i].tries == 0) return i;
     249          var q = V(children[i]);
     250          if (q > bestQ) {
     251            bestQ = q;
     252            bestChildIdx.Clear();
     253            bestChildIdx.Add(i);
     254          } else if (q == bestQ) {
     255            bestChildIdx.Add(i);
     256          }
     257        }
     258        Debug.Assert(bestChildIdx.Any());
     259        return bestChildIdx.SelectRandom(random);
     260      }
     261    }
     262    private int SelectActionUCB1(Random random, TreeNode[] children) {
     263      int bestAction = -1;
     264      double bestQ = double.NegativeInfinity;
     265      int totalTries = children.Sum(ch => ch.tries);
     266
     267      for (int a = 0; a < children.Length; a++) {
     268        var ch = children[a];
     269        if (ch.done) continue;
     270        if (ch.tries == 0) return a;
     271        var q = V(ch) + Math.Sqrt((2 * Math.Log(totalTries)) / ch.tries);
     272        if (q > bestQ) {
     273          bestQ = q;
     274          bestAction = a;
     275        }
     276      }
     277      Debug.Assert(bestAction > -1);
     278      return bestAction;
     279    }
     280
     281
    177282
    178283    private void RaiseSolutionEvaluated(string sentence, double quality) {
     
    184289      if (handler != null) handler(sentence, quality);
    185290    }
     291
     292
    186293  }
    187294}
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/MctsSampler.cs

    r11744 r11745  
    4141    public int treeSize;
    4242    private double bestQuality;
    43 
    44     // public MctsSampler(IProblem problem, int maxLen, Random random) :
    45     //   this(problem, maxLen, random, 10, (rand, numActions) => new EpsGreedyPolicy(rand, numActions, 0.1)) {
    46     //
    47     // }
    4843
    4944    public MctsSampler(IProblem problem, int maxLen, Random random, int randomTries, IBanditPolicy policy) {
Note: See TracChangeset for help on using the changeset viewer.