Ignore:
Timestamp:
01/07/15 09:21:46 (7 years ago)
Author:
gkronber
Message:

#2283: refactoring and bug fixes

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/MctsSampler.cs

    r11730 r11732  
    1313      public int randomTries;
    1414      public int policyTries;
    15       public IPolicy policy;
     15      public IPolicyActionInfo actionInfo;
    1616      public TreeNode[] children;
    1717      public bool done = false;
     
    2222
    2323      public override string ToString() {
    24         return string.Format("Node({0} tries: {1}, done: {2}, policy: {3})", ident, randomTries + policyTries, done, policy);
     24        return string.Format("Node({0} tries: {1}, done: {2}, policy: {3})", ident, randomTries + policyTries, done, actionInfo);
    2525      }
    2626    }
     
    3434    private readonly Random random;
    3535    private readonly int randomTries;
    36     private readonly Func<Random, int, IPolicy> policyFactory;
     36    private readonly IPolicy policy;
    3737
    38     private List<Tuple<TreeNode, int>> updateChain;
     38    private List<TreeNode> updateChain;
    3939    private TreeNode rootNode;
    4040
     
    4242    public int treeSize;
    4343
    44     public MctsSampler(IProblem problem, int maxLen, Random random) :
    45       this(problem, maxLen, random, 10, (rand, numActions) => new EpsGreedyPolicy(rand, numActions, 0.1)) {
     44    // public MctsSampler(IProblem problem, int maxLen, Random random) :
     45    //   this(problem, maxLen, random, 10, (rand, numActions) => new EpsGreedyPolicy(rand, numActions, 0.1)) {
     46    //
     47    // }
    4648
    47     }
    48 
    49     public MctsSampler(IProblem problem, int maxLen, Random random, int randomTries, Func<Random, int, IPolicy> policyFactory) {
     49    public MctsSampler(IProblem problem, int maxLen, Random random, int randomTries, IPolicy policy) {
    5050      this.maxLen = maxLen;
    5151      this.problem = problem;
    5252      this.random = random;
    5353      this.randomTries = randomTries;
    54       this.policyFactory = policyFactory;
     54      this.policy = policy;
    5555    }
    5656
     
    6060      for (int i = 0; !rootNode.done && i < maxIterations; i++) {
    6161        var sentence = SampleSentence(problem.Grammar).ToString();
    62         var quality = problem.Evaluate(sentence) / problem.GetBestKnownQuality(maxLen);
     62        var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen);
    6363        Debug.Assert(quality >= 0 && quality <= 1.0);
    6464        DistributeReward(quality);
     
    7979      var n = rootNode;
    8080      Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}", treeDepth, treeSize, rootNode.policyTries + rootNode.randomTries);
    81       while (n.policy != null) {
     81      while (n.children != null) {
    8282        Console.WriteLine();
    8383        Console.WriteLine("{0,5}->{1,-50}", n.ident, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.ident))));
     
    9090
    9191    private void InitPolicies(IGrammar grammar) {
    92       this.updateChain = new List<Tuple<TreeNode, int>>();
     92      this.updateChain = new List<TreeNode>();
    9393
    9494      rootNode = new TreeNode(grammar.SentenceSymbol.ToString());
     95      rootNode.actionInfo = policy.CreateActionInfo();
    9596      treeDepth = 0;
    9697      treeSize = 0;
     
    108109      TreeNode n = rootNode;
    109110      bool done = phrase.IsTerminal;
    110       int selectedAltIdx = -1;
    111111      var curDepth = 0;
    112112      while (!done) {
    113         char nt = phrase.FirstNonTerminal;
    114 
    115         int maxLenOfReplacement = maxLen - (phrase.Length - 1); // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2
    116         Debug.Assert(maxLenOfReplacement > 0);
    117 
    118         var alts = g.GetAlternatives(nt).Where(alt => g.MinPhraseLength(alt) <= maxLenOfReplacement);
     113        updateChain.Add(n);
    119114
    120115        if (n.randomTries < randomTries) {
    121116          n.randomTries++;
     117          treeDepth = Math.Max(treeDepth, curDepth);
     118          return g.CompleteSentenceRandomly(random, phrase, maxLen);
     119        } else {
     120          char nt = phrase.FirstNonTerminal;
    122121
    123           treeDepth = Math.Max(treeDepth, curDepth);
     122          int maxLenOfReplacement = maxLen - (phrase.Length - 1); // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2
     123          Debug.Assert(maxLenOfReplacement > 0);
    124124
    125           return g.CompleteSentenceRandomly(random, phrase, maxLen);
    126         } else if (n.randomTries == randomTries && n.policy == null) {
    127           n.policy = policyFactory(random, alts.Count());
    128           //n.children = alts.Select(alt => new TreeNode(alt.ToString())).ToArray(); // create a new node for each alternative
    129           n.children = alts.Select(alt => new TreeNode(string.Empty)).ToArray(); // create a new node for each alternative
     125          var alts = g.GetAlternatives(nt).Where(alt => g.MinPhraseLength(alt) <= maxLenOfReplacement);
    130126
    131           treeSize += n.children.Length;
    132         }
    133         n.policyTries++;
    134         // => select using bandit policy
    135         selectedAltIdx = n.policy.SelectAction();
    136         Sequence selectedAlt = alts.ElementAt(selectedAltIdx);
     127          if (n.randomTries == randomTries && n.children == null) {
     128            n.children = alts.Select(alt => new TreeNode(alt.ToString())).ToArray(); // create a new node for each alternative
     129            //n.children = alts.Select(alt => new TreeNode(string.Empty)).ToArray(); // create a new node for each alternative
     130            foreach (var ch in n.children) ch.actionInfo = policy.CreateActionInfo();
     131            treeSize += n.children.Length;
     132          }
     133          n.policyTries++;
     134          // => select using bandit policy
     135          int selectedAltIdx = policy.SelectAction(random, n.children.Select(c => c.actionInfo));
     136          Sequence selectedAlt = alts.ElementAt(selectedAltIdx);
    137137
    138         // replace nt with alt
    139         phrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, selectedAlt);
     138          // replace nt with alt
     139          phrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, selectedAlt);
    140140
    141         updateChain.Add(Tuple.Create(n, selectedAltIdx));
     141          curDepth++;
    142142
    143         curDepth++;
     143          done = phrase.IsTerminal;
    144144
    145         done = phrase.IsTerminal;
    146         if (!done) {
    147145          // prepare for next iteration
    148146          n = n.children[selectedAltIdx];
    149           Debug.Assert(!n.done);
    150147        }
    151148      } // while
    152149
     150      updateChain.Add(n);
     151
     152
    153153      // the last node is a leaf node (sentence is done), so we never need to visit this node again
    154       n.children[selectedAltIdx].done = true;
     154      n.done = true;
     155      n.actionInfo.Disable();
    155156
    156157      treeDepth = Math.Max(treeDepth, curDepth);
     
    163164
    164165      foreach (var e in updateChain) {
    165         var node = e.Item1;
    166         var policy = node.policy;
    167         var action = e.Item2;
    168         //policy.UpdateReward(action, reward / updateChain.Count);
    169         policy.UpdateReward(action, reward);
    170 
    171         if (node.children[action].done) node.policy.DisableAction(action);
    172         if (node.children.All(c => c.done)) node.done = true;
     166        var node = e;
     167        if (node.children != null && node.children.All(c => c.done)) {
     168          node.done = true;
     169          node.actionInfo.Disable();
     170        }
     171        if (!node.done) {
     172          node.actionInfo.UpdateReward(reward);
     173          //policy.UpdateReward(action, reward / updateChain.Count);
     174        }
    173175      }
    174176    }
Note: See TracChangeset for help on using the changeset viewer.