Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
01/12/15 21:23:01 (10 years ago)
Author:
gkronber
Message:

#2283: implemented test problems for MCTS

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/TemporalDifferenceTreeSearchSampler.cs

    r11744 r11747  
    3636    private readonly Random random;
    3737    private readonly int randomTries;
    38     private readonly IBanditPolicy policy;
    3938
    4039    private List<TreeNode> updateChain;
     
    4645
    4746
    48     public TemporalDifferenceTreeSearchSampler(IProblem problem, int maxLen, Random random, int randomTries, IBanditPolicy policy) {
     47    public TemporalDifferenceTreeSearchSampler(IProblem problem, int maxLen, Random random, int randomTries) {
    4948      this.maxLen = maxLen;
    5049      this.problem = problem;
    5150      this.random = random;
    5251      this.randomTries = randomTries;
    53       this.policy = policy;
    5452    }
    5553
     
    7876      Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}, rootQ {3:F3}, bestQ {4:F3}", treeDepth, treeSize, n.tries, n.q, bestQuality);
    7977      while (n.children != null) {
     78        Console.WriteLine("{0,-30}", n.ident);
     79        double maxVForRow = n.children.Select(ch => ch.q).Max();
     80        if (maxVForRow == 0) maxVForRow = 1.0;
     81
     82        for (int i = 0; i < n.children.Length; i++) {
     83          var ch = n.children[i];
     84          Console.ForegroundColor = ConsoleEx.ColorForValue(ch.q / maxVForRow);
     85          Console.Write("{0,5}", ch.ident);
     86        }
    8087        Console.WriteLine();
    81         Console.WriteLine("{0,5}->{1,-50}", n.ident, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.ident))));
    82         Console.WriteLine("{0,5}  {1,-50}", string.Empty, string.Join(" ", n.children.Select(ch => string.Format("{0,4:F2}", ch.q * 10))));
    83         Console.WriteLine("{0,5}  {1,-50}", string.Empty, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.done ? "X" : ch.tries.ToString()))));
     88        for (int i = 0; i < n.children.Length; i++) {
     89          var ch = n.children[i];
     90          Console.ForegroundColor = ConsoleEx.ColorForValue(ch.q / maxVForRow);
     91          Console.Write("{0,5:F2}", ch.q * 10);
     92        }
     93        Console.WriteLine();
     94        for (int i = 0; i < n.children.Length; i++) {
     95          var ch = n.children[i];
     96          Console.ForegroundColor = ConsoleEx.ColorForValue(ch.q / maxVForRow);
     97          Console.Write("{0,5}", ch.done ? "X" : ch.tries.ToString());
     98        }
     99        Console.ForegroundColor = ConsoleColor.White;
     100        Console.WriteLine();
    84101        //n.policy.PrintStats();
    85102        n = n.children.Where(ch => !ch.done).OrderByDescending(c => c.q).First();
    86103      }
    87       //Console.ReadLine();
    88104    }
    89105
     
    127143          }
    128144          // => select using bandit policy
    129           int selectedAltIdx = SelectAction(random, n.children);
     145          int selectedAltIdx = SelectEpsGreedy(random, n.children);
    130146          Sequence selectedAlt = alts.ElementAt(selectedAltIdx);
    131147
     
    152168
    153169    // eps-greedy
    154     private int SelectAction(Random random, TreeNode[] children) {
     170    private int SelectEpsGreedy(Random random, TreeNode[] children) {
    155171      if (random.NextDouble() < 0.1) {
    156172
     
    158174      } else {
    159175        var bestQ = double.NegativeInfinity;
    160         var bestChildIdx = -1;
     176        var bestChildIdx = new List<int>();
    161177        for (int i = 0; i < children.Length; i++) {
    162178          if (children[i].done) continue;
    163           if (children[i].tries == 0) return i;
    164           if (children[i].q > bestQ) {
    165             bestQ = children[i].q;
    166             bestChildIdx = i;
     179          // if (children[i].tries == 0) return i;
     180          var q = children[i].q;
     181          if (q > bestQ) {
     182            bestQ = q;
     183            bestChildIdx.Clear();
     184            bestChildIdx.Add(i);
     185          } else if (q == bestQ) {
     186            bestChildIdx.Add(i);
    167187          }
    168188        }
    169         Debug.Assert(bestChildIdx > -1);
    170         return bestChildIdx;
     189        Debug.Assert(bestChildIdx.Any());
     190        return bestChildIdx.SelectRandom(random);
    171191      }
    172192    }
    173193
    174194    private void DistributeReward(double reward) {
    175       const double alpha = 0.1;
    176       const double gamma = 1;
    177       // iterate in reverse order (bottom up)
    178195      updateChain.Reverse();
    179       var nextQ = 0.0;
    180       foreach (var e in updateChain) {
    181         var node = e;
    182         node.tries++;
     196      foreach (var node in updateChain) {
    183197        if (node.children != null && node.children.All(c => c.done)) {
    184198          node.done = true;
    185199        }
    186         // reward is recieved only for the last action
    187         if (e == updateChain.First()) {
    188           node.q = node.q + alpha * (reward + gamma * nextQ - node.q);
    189           nextQ = node.q;
    190         } else {
    191           node.q = node.q + alpha * (0 + gamma * nextQ - node.q);
    192           nextQ = node.q;
    193         }
    194       }
     200      }
     201      updateChain.Reverse();
     202
     203      //const double alpha = 0.1;
     204      const double gamma = 1;
     205      double alpha;
     206      foreach (var p in updateChain.Zip(updateChain.Skip(1), Tuple.Create)) {
     207        var parent = p.Item1;
     208        var child = p.Item2;
     209        parent.tries++;
     210        alpha = 1.0 / parent.tries;
     211        //alpha = 0.01;
     212        parent.q = parent.q + alpha * (0 + gamma * child.q - parent.q);
     213      }
     214      // reward is recieved only for the last action
     215      var n = updateChain.Last();
     216      n.tries++;
     217      alpha = 1.0 / n.tries;
     218      //alpha = 0.1;
     219      n.q = n.q + alpha * reward;
    195220    }
    196221
Note: See TracChangeset for help on using the changeset viewer.