Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
12/21/14 09:19:54 (9 years ago)
Author:
gkronber
Message:

#2283: more bandit policies and tests

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization.Test/TestBanditPolicies.cs

    r11708 r11710  
    1010  public class TestBanditPolicies {
    1111    [TestMethod]
    12     public void ComparePolicies() {
     12    public void ComparePoliciesForBernoulliBandit() {
    1313      System.Threading.Thread.CurrentThread.CurrentCulture = CultureInfo.InvariantCulture;
    1414      var globalRand = new Random(31415);
     
    1616      var nArms = 10;
    1717      Console.WriteLine("Random");
    18       TestPolicy(globalRand, nArms, new RandomPolicy(new Random(seedForPolicy), 10));
     18      TestPolicyBernoulli(globalRand, nArms, new RandomPolicy(new Random(seedForPolicy), 10));
     19      Console.WriteLine("UCB1");
     20      TestPolicyBernoulli(globalRand, nArms, new UCB1Policy(10));
     21      Console.WriteLine("UCB1Tuned");
     22      TestPolicyBernoulli(globalRand, nArms, new UCB1TunedPolicy(10));
     23      Console.WriteLine("UCB1Normal");
     24      TestPolicyBernoulli(globalRand, nArms, new UCBNormalPolicy(10));
    1925      Console.WriteLine("Eps(0.01)");
    20       TestPolicy(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), 10, 0.01));
     26      TestPolicyBernoulli(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), 10, 0.01));
    2127      Console.WriteLine("Eps(0.05)");
    22       TestPolicy(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), 10, 0.05));
     28      TestPolicyBernoulli(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), 10, 0.05));
    2329      Console.WriteLine("Eps(0.1)");
    24       TestPolicy(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), 10, 0.1));
     30      TestPolicyBernoulli(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), 10, 0.1));
    2531      Console.WriteLine("Eps(0.2)");
    26       TestPolicy(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), 10, 0.2));
     32      TestPolicyBernoulli(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), 10, 0.2));
    2733      Console.WriteLine("Eps(0.5)");
    28       TestPolicy(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), 10, 0.5));
     34      TestPolicyBernoulli(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), 10, 0.5));
     35    }
     36    [TestMethod]
     37    public void ComparePoliciesForNormalBandit() {
     38      System.Threading.Thread.CurrentThread.CurrentCulture = CultureInfo.InvariantCulture;
     39      var globalRand = new Random(31415);
     40      var seedForPolicy = globalRand.Next();
     41      var nArms = 10;
     42      Console.WriteLine("Random");
     43      TestPolicyNormal(globalRand, nArms, new RandomPolicy(new Random(seedForPolicy), 10));
     44      Console.WriteLine("UCB1");
     45      TestPolicyNormal(globalRand, nArms, new UCB1Policy(10));
     46      Console.WriteLine("UCB1Tuned");
     47      TestPolicyNormal(globalRand, nArms, new UCB1TunedPolicy(10));
     48      Console.WriteLine("UCB1Normal");
     49      TestPolicyNormal(globalRand, nArms, new UCBNormalPolicy(10));
     50      Console.WriteLine("Eps(0.01)");
     51      TestPolicyNormal(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), 10, 0.01));
     52      Console.WriteLine("Eps(0.05)");
     53      TestPolicyNormal(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), 10, 0.05));
     54      Console.WriteLine("Eps(0.1)");
     55      TestPolicyNormal(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), 10, 0.1));
     56      Console.WriteLine("Eps(0.2)");
     57      TestPolicyNormal(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), 10, 0.2));
     58      Console.WriteLine("Eps(0.5)");
     59      TestPolicyNormal(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), 10, 0.5));
    2960    }
    3061
    31     private void TestPolicy(Random globalRand, int nArms, IPolicy policy) {
     62    private void TestPolicyBernoulli(Random globalRand, int nArms, IPolicy policy) {
    3263      var maxIt = 1E6;
    3364      var reps = 10; // 10 independent runs
     
    3667      for (int r = 0; r < reps; r++) {
    3768        var nextLogStep = 1;
    38         var b = new Bandit(new Random(globalRand.Next()), 10);
     69        var b = new BernoulliBandit(new Random(globalRand.Next()), 10);
    3970        policy.Reset();
    4071        var totalRegret = 0.0;
     
    5990      }
    6091    }
     92    private void TestPolicyNormal(Random globalRand, int nArms, IPolicy policy) {
     93      var maxIt = 1E6;
     94      var reps = 10; // 10 independent runs
     95      var avgRegretForIteration = new Dictionary<int, double>();
     96      // calculate statistics
     97      for (int r = 0; r < reps; r++) {
     98        var nextLogStep = 1;
     99        var b = new TruncatedNormalBandit(new Random(globalRand.Next()), 10);
     100        policy.Reset();
     101        var totalRegret = 0.0;
     102
     103        for (int i = 0; i <= maxIt; i++) {
     104          var selectedAction = policy.SelectAction();
     105          var reward = b.Pull(selectedAction);
     106          totalRegret += b.OptimalExpectedReward - reward;
     107          policy.UpdateReward(selectedAction, reward);
     108          if (i == nextLogStep) {
     109            nextLogStep *= 10;
     110            if (!avgRegretForIteration.ContainsKey(i)) {
     111              avgRegretForIteration.Add(i, 0.0);
     112            }
     113            avgRegretForIteration[i] += totalRegret / i;
     114          }
     115        }
     116      }
     117      // print
     118      foreach (var p in avgRegretForIteration.Keys.OrderBy(k => k)) {
     119        Console.WriteLine("{0} {1}", p, avgRegretForIteration[p] / reps); // print avg. of avg. regret
     120      }
     121    }
     122
    61123  }
    62124}
Note: See TracChangeset for help on using the changeset viewer.