source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization.Test/TestBanditPolicies.cs @ 11730

Last change on this file since 11730 was 11730, checked in by gkronber, 5 years ago

#2283: several major extensions for grammatical optimization

File size: 16.8 KB
Line 
1using System;
2using System.Linq;
3using System.Collections.Generic;
4using System.Globalization;
5using HeuristicLab.Algorithms.Bandits;
6using HeuristicLab.Algorithms.Bandits.Models;
7using Microsoft.VisualStudio.TestTools.UnitTesting;
8
9namespace HeuristicLab.Problems.GrammaticalOptimization.Test {
10  [TestClass]
11  public class TestBanditPolicies {
12
13
14    [TestMethod]
15    public void ComparePoliciesForBernoulliBandit() {
16      CultureInfo.DefaultThreadCurrentCulture = CultureInfo.InvariantCulture;
17
18      var globalRand = new Random(31415);
19      var seedForPolicy = globalRand.Next();
20      var nArms = 20;
21      //Console.WriteLine("Exp3 (gamma=0.01)");
22      //TestPolicyBernoulli(globalRand, nArms, new Exp3Policy(new Random(seedForPolicy), nArms, 1));
23      //Console.WriteLine("Exp3 (gamma=0.05)");
24      //estPolicyBernoulli(globalRand, nArms, new Exp3Policy(new Random(seedForPolicy), nArms, 1));
25      Console.WriteLine("Thompson (Bernoulli)"); TestPolicyBernoulli(globalRand, nArms, new BernoulliThompsonSamplingPolicy(new Random(seedForPolicy), nArms));
26      Console.WriteLine("Generic Thompson (Bernoulli)"); TestPolicyBernoulli(globalRand, nArms, new GenericThompsonSamplingPolicy(new Random(seedForPolicy), nArms, new BernoulliModel(nArms)));
27      Console.WriteLine("Random");
28      TestPolicyBernoulli(globalRand, nArms, new RandomPolicy(new Random(seedForPolicy), nArms));
29      Console.WriteLine("UCB1");
30      TestPolicyBernoulli(globalRand, nArms, new UCB1Policy(nArms));
31      Console.WriteLine("UCB1Tuned");
32      TestPolicyBernoulli(globalRand, nArms, new UCB1TunedPolicy(nArms));
33      Console.WriteLine("UCB1Normal");
34      TestPolicyBernoulli(globalRand, nArms, new UCBNormalPolicy(nArms));
35      Console.WriteLine("Eps(0.01)");
36      TestPolicyBernoulli(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.01));
37      Console.WriteLine("Eps(0.05)");
38      TestPolicyBernoulli(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.05));
39      //Console.WriteLine("Eps(0.1)");
40      //TestPolicyBernoulli(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.1));
41      //Console.WriteLine("Eps(0.2)");
42      //TestPolicyBernoulli(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.2));
43      //Console.WriteLine("Eps(0.5)");
44      //TestPolicyBernoulli(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.5));
45      Console.WriteLine("UCT(0.1)"); TestPolicyBernoulli(globalRand, nArms, new UCTPolicy(nArms, 0.1));
46      Console.WriteLine("UCT(0.5)"); TestPolicyBernoulli(globalRand, nArms, new UCTPolicy(nArms, 0.5));
47      Console.WriteLine("UCT(1)  "); TestPolicyBernoulli(globalRand, nArms, new UCTPolicy(nArms, 1));
48      Console.WriteLine("UCT(2)  "); TestPolicyBernoulli(globalRand, nArms, new UCTPolicy(nArms, 2));
49      Console.WriteLine("UCT(5)  "); TestPolicyBernoulli(globalRand, nArms, new UCTPolicy(nArms, 5));
50      Console.WriteLine("BoltzmannExploration(0.1)"); TestPolicyBernoulli(globalRand, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 0.1));
51      Console.WriteLine("BoltzmannExploration(0.5)"); TestPolicyBernoulli(globalRand, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 0.5));
52      Console.WriteLine("BoltzmannExploration(1)  "); TestPolicyBernoulli(globalRand, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 1));
53      Console.WriteLine("BoltzmannExploration(10) "); TestPolicyBernoulli(globalRand, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 10));
54      Console.WriteLine("BoltzmannExploration(100)"); TestPolicyBernoulli(globalRand, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 100));
55      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.01)"); TestPolicyBernoulli(globalRand, nArms, new ChernoffIntervalEstimationPolicy(nArms, 0.01));
56      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.05)"); TestPolicyBernoulli(globalRand, nArms, new ChernoffIntervalEstimationPolicy(nArms, 0.05));
57      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.1) "); TestPolicyBernoulli(globalRand, nArms, new ChernoffIntervalEstimationPolicy(nArms, 0.1));
58
59      // not applicable to bernoulli rewards
60      //Console.WriteLine("ThresholdAscent(10, 0.01)  "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 10, 0.01));
61      //Console.WriteLine("ThresholdAscent(10, 0.05)  "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 10, 0.05));
62      //Console.WriteLine("ThresholdAscent(10, 0.1)   "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 10, 0.1));
63      //Console.WriteLine("ThresholdAscent(100, 0.01) "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 100, 0.01));
64      //Console.WriteLine("ThresholdAscent(100, 0.05) "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 100, 0.05));
65      //Console.WriteLine("ThresholdAscent(100, 0.1)  "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 100, 0.1));
66      //Console.WriteLine("ThresholdAscent(1000, 0.01)"); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.01));
67      //Console.WriteLine("ThresholdAscent(1000, 0.05)"); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.05));
68      //Console.WriteLine("ThresholdAscent(1000, 0.1) "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.1));
69    }
70
71    [TestMethod]
72    public void ComparePoliciesForNormalBandit() {
73      CultureInfo.DefaultThreadCurrentCulture = CultureInfo.InvariantCulture;
74
75      var globalRand = new Random(31415);
76      var seedForPolicy = globalRand.Next();
77      var nArms = 20;
78      Console.WriteLine("Thompson (Gaussian orig)"); TestPolicyNormal(globalRand, nArms, new GaussianThompsonSamplingPolicy(new Random(seedForPolicy), nArms, true));
79      Console.WriteLine("Thompson (Gaussian new)"); TestPolicyNormal(globalRand, nArms, new GaussianThompsonSamplingPolicy(new Random(seedForPolicy), nArms));
80      Console.WriteLine("Generic Thompson (Gaussian)"); TestPolicyNormal(globalRand, nArms, new GenericThompsonSamplingPolicy(new Random(seedForPolicy), nArms, new GaussianModel(nArms, 0.5, 1)));
81      /*
82      Console.WriteLine("Random"); TestPolicyNormal(globalRand, nArms, new RandomPolicy(new Random(seedForPolicy), nArms));
83      Console.WriteLine("UCB1"); TestPolicyNormal(globalRand, nArms, new UCB1Policy(nArms));
84      Console.WriteLine("UCB1Tuned"); TestPolicyNormal(globalRand, nArms, new UCB1TunedPolicy(nArms));
85      Console.WriteLine("UCB1Normal"); TestPolicyNormal(globalRand, nArms, new UCBNormalPolicy(nArms));
86      //Console.WriteLine("Exp3 (gamma=0.01)");
87      //TestPolicyNormal(globalRand, nArms, new Exp3Policy(new Random(seedForPolicy), nArms, 0.01));
88      //Console.WriteLine("Exp3 (gamma=0.05)");
89      //TestPolicyNormal(globalRand, nArms, new Exp3Policy(new Random(seedForPolicy), nArms, 0.05));
90      Console.WriteLine("Eps(0.01)"); TestPolicyNormal(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.01));
91      Console.WriteLine("Eps(0.05)"); TestPolicyNormal(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.05));
92      //Console.WriteLine("Eps(0.1)");
93      //TestPolicyNormal(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.1));
94      //Console.WriteLine("Eps(0.2)");
95      //TestPolicyNormal(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.2));
96      //Console.WriteLine("Eps(0.5)");
97      //TestPolicyNormal(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.5));
98      Console.WriteLine("UCT(0.1)"); TestPolicyNormal(globalRand, nArms, new UCTPolicy(nArms, 0.1));
99      Console.WriteLine("UCT(0.5)"); TestPolicyNormal(globalRand, nArms, new UCTPolicy(nArms, 0.5));
100      Console.WriteLine("UCT(1)  "); TestPolicyNormal(globalRand, nArms, new UCTPolicy(nArms, 1));
101      Console.WriteLine("UCT(2)  "); TestPolicyNormal(globalRand, nArms, new UCTPolicy(nArms, 2));
102      Console.WriteLine("UCT(5)  "); TestPolicyNormal(globalRand, nArms, new UCTPolicy(nArms, 5));
103      Console.WriteLine("BoltzmannExploration(0.1)"); TestPolicyNormal(globalRand, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 0.1));
104      Console.WriteLine("BoltzmannExploration(0.5)"); TestPolicyNormal(globalRand, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 0.5));
105      Console.WriteLine("BoltzmannExploration(1)  "); TestPolicyNormal(globalRand, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 1));
106      Console.WriteLine("BoltzmannExploration(10) "); TestPolicyNormal(globalRand, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 10));
107      Console.WriteLine("BoltzmannExploration(100)"); TestPolicyNormal(globalRand, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 100));
108      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.01)"); TestPolicyNormal(globalRand, nArms, new ChernoffIntervalEstimationPolicy(nArms, 0.01));
109      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.05)"); TestPolicyNormal(globalRand, nArms, new ChernoffIntervalEstimationPolicy(nArms, 0.05));
110      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.1) "); TestPolicyNormal(globalRand, nArms, new ChernoffIntervalEstimationPolicy(nArms, 0.1));
111      Console.WriteLine("ThresholdAscent(10,0.01)  "); TestPolicyNormal(globalRand, nArms, new ThresholdAscentPolicy(nArms, 10, 0.01));
112      Console.WriteLine("ThresholdAscent(10,0.05)  "); TestPolicyNormal(globalRand, nArms, new ThresholdAscentPolicy(nArms, 10, 0.05));
113      Console.WriteLine("ThresholdAscent(10,0.1)   "); TestPolicyNormal(globalRand, nArms, new ThresholdAscentPolicy(nArms, 10, 0.1));
114      Console.WriteLine("ThresholdAscent(100,0.01) "); TestPolicyNormal(globalRand, nArms, new ThresholdAscentPolicy(nArms, 100, 0.01));
115      Console.WriteLine("ThresholdAscent(100,0.05) "); TestPolicyNormal(globalRand, nArms, new ThresholdAscentPolicy(nArms, 100, 0.05));
116      Console.WriteLine("ThresholdAscent(100,0.1)  "); TestPolicyNormal(globalRand, nArms, new ThresholdAscentPolicy(nArms, 100, 0.1));
117      Console.WriteLine("ThresholdAscent(1000,0.01)"); TestPolicyNormal(globalRand, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.01));
118      Console.WriteLine("ThresholdAscent(1000,0.05)"); TestPolicyNormal(globalRand, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.05));
119      Console.WriteLine("ThresholdAscent(1000,0.1) "); TestPolicyNormal(globalRand, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.1));
120       */
121    }
122
123    [TestMethod]
124    public void ComparePoliciesForGaussianMixtureBandit() {
125      CultureInfo.DefaultThreadCurrentCulture = CultureInfo.InvariantCulture;
126
127      var globalRand = new Random(31415);
128      var seedForPolicy = globalRand.Next();
129      var nArms = 20;
130      Console.WriteLine("Thompson (Gaussian orig)"); TestPolicyGaussianMixture(globalRand, nArms, new GaussianThompsonSamplingPolicy(new Random(seedForPolicy), nArms, true));
131      Console.WriteLine("Thompson (Gaussian new)"); TestPolicyGaussianMixture(globalRand, nArms, new GaussianThompsonSamplingPolicy(new Random(seedForPolicy), nArms));
132      Console.WriteLine("Generic Thompson (Gaussian)"); TestPolicyGaussianMixture(globalRand, nArms, new GenericThompsonSamplingPolicy(new Random(seedForPolicy), nArms, new GaussianModel(nArms, 0.5, 1)));
133
134      /*
135      Console.WriteLine("Random"); TestPolicyGaussianMixture(globalRand, nArms, new RandomPolicy(new Random(seedForPolicy), nArms));
136      Console.WriteLine("UCB1"); TestPolicyGaussianMixture(globalRand, nArms, new UCB1Policy(nArms));
137      Console.WriteLine("UCB1Tuned "); TestPolicyGaussianMixture(globalRand, nArms, new UCB1TunedPolicy(nArms));
138      Console.WriteLine("UCB1Normal"); TestPolicyGaussianMixture(globalRand, nArms, new UCBNormalPolicy(nArms));
139      Console.WriteLine("Eps(0.01) "); TestPolicyGaussianMixture(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.01));
140      Console.WriteLine("Eps(0.05) "); TestPolicyGaussianMixture(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.05));
141      Console.WriteLine("UCT(1)  "); TestPolicyGaussianMixture(globalRand, nArms, new UCTPolicy(nArms, 1));
142      Console.WriteLine("UCT(2)  "); TestPolicyGaussianMixture(globalRand, nArms, new UCTPolicy(nArms, 2));
143      Console.WriteLine("UCT(5)  "); TestPolicyGaussianMixture(globalRand, nArms, new UCTPolicy(nArms, 5));
144      Console.WriteLine("BoltzmannExploration(1)  "); TestPolicyGaussianMixture(globalRand, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 1));
145      Console.WriteLine("BoltzmannExploration(10) "); TestPolicyGaussianMixture(globalRand, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 10));
146      Console.WriteLine("BoltzmannExploration(100)"); TestPolicyGaussianMixture(globalRand, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 100));
147
148      Console.WriteLine("ThresholdAscent(10,0.01)  "); TestPolicyGaussianMixture(globalRand, nArms, new ThresholdAscentPolicy(nArms, 10, 0.01));
149      Console.WriteLine("ThresholdAscent(100,0.01) "); TestPolicyGaussianMixture(globalRand, nArms, new ThresholdAscentPolicy(nArms, 100, 0.01));
150      Console.WriteLine("ThresholdAscent(1000,0.01)"); TestPolicyGaussianMixture(globalRand, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.01));
151      Console.WriteLine("ThresholdAscent(10000,0.01)"); TestPolicyGaussianMixture(globalRand, nArms, new ThresholdAscentPolicy(nArms, 10000, 0.01));
152       */
153    }
154
155
156    private void TestPolicyBernoulli(Random globalRand, int nArms, IPolicy policy) {
157      TestPolicy(globalRand, nArms, policy, (banditRandom, nActions) => new BernoulliBandit(banditRandom, nActions));
158    }
159    private void TestPolicyNormal(Random globalRand, int nArms, IPolicy policy) {
160      TestPolicy(globalRand, nArms, policy, (banditRandom, nActions) => new TruncatedNormalBandit(banditRandom, nActions));
161    }
162    private void TestPolicyGaussianMixture(Random globalRand, int nArms, IPolicy policy) {
163      TestPolicy(globalRand, nArms, policy, (banditRandom, nActions) => new GaussianMixtureBandit(banditRandom, nActions));
164    }
165
166
167    private void TestPolicy(Random globalRand, int nArms, IPolicy policy, Func<Random, int, IBandit> banditFactory) {
168      var maxIt = 1E5;
169      var reps = 30; // independent runs
170      var regretForIteration = new Dictionary<int, List<double>>();
171      var numberOfPullsOfSuboptimalArmsForExp = new Dictionary<int, double>();
172      var numberOfPullsOfSuboptimalArmsForMax = new Dictionary<int, double>();
173      // calculate statistics
174      for (int r = 0; r < reps; r++) {
175        var nextLogStep = 1;
176        var b = banditFactory(new Random(globalRand.Next()), nArms);
177        policy.Reset();
178        var totalRegret = 0.0;
179        var totalPullsOfSuboptimalArmsExp = 0.0;
180        var totalPullsOfSuboptimalArmsMax = 0.0;
181        for (int i = 0; i <= maxIt; i++) {
182          var selectedAction = policy.SelectAction();
183          var reward = b.Pull(selectedAction);
184          policy.UpdateReward(selectedAction, reward);
185
186          // collect stats
187          if (selectedAction != b.OptimalExpectedRewardArm) totalPullsOfSuboptimalArmsExp++;
188          if (selectedAction != b.OptimalMaximalRewardArm) totalPullsOfSuboptimalArmsMax++;
189          totalRegret += b.OptimalExpectedReward - reward;
190
191          if (i == nextLogStep) {
192            nextLogStep *= 2;
193            if (!regretForIteration.ContainsKey(i)) {
194              regretForIteration.Add(i, new List<double>());
195            }
196            regretForIteration[i].Add(totalRegret / i);
197
198            if (!numberOfPullsOfSuboptimalArmsForExp.ContainsKey(i)) {
199              numberOfPullsOfSuboptimalArmsForExp.Add(i, 0.0);
200            }
201            numberOfPullsOfSuboptimalArmsForExp[i] += totalPullsOfSuboptimalArmsExp;
202
203            if (!numberOfPullsOfSuboptimalArmsForMax.ContainsKey(i)) {
204              numberOfPullsOfSuboptimalArmsForMax.Add(i, 0.0);
205            }
206            numberOfPullsOfSuboptimalArmsForMax[i] += totalPullsOfSuboptimalArmsMax;
207          }
208        }
209      }
210      // print
211      foreach (var p in regretForIteration.Keys.OrderBy(k => k)) {
212        Console.WriteLine("iter {0,8} regret avg {1,7:F5} min {2,7:F5} max {3,7:F5} suboptimal pulls (exp) {4,7:F2} suboptimal pulls (max) {5,7:F2}",
213          p,
214          regretForIteration[p].Average(),
215          regretForIteration[p].Min(),
216          regretForIteration[p].Max(),
217          numberOfPullsOfSuboptimalArmsForExp[p] / (double)reps,
218          numberOfPullsOfSuboptimalArmsForMax[p] / (double)reps
219          );
220      }
221    }
222
223  }
224}
Note: See TracBrowser for help on using the repository browser.