source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization.Test/TestBanditPolicies.cs @ 11732

Last change on this file since 11732 was 11732, checked in by gkronber, 5 years ago

#2283: refactoring and bug fixes

File size: 17.9 KB
Line 
1using System;
2using System.Linq;
3using System.Collections.Generic;
4using System.Globalization;
5using HeuristicLab.Algorithms.Bandits;
6using HeuristicLab.Algorithms.Bandits.Models;
7using Microsoft.VisualStudio.TestTools.UnitTesting;
8
9namespace HeuristicLab.Problems.GrammaticalOptimization.Test {
10  [TestClass]
11  public class TestBanditPolicies {
12    [TestMethod]
13    public void ComparePoliciesForGaussianUnknownVarianceBandit() {
14      CultureInfo.DefaultThreadCurrentCulture = CultureInfo.InvariantCulture;
15      var randSeed = 31415;
16      var nArms = 20;
17
18      // Console.WriteLine("Threshold Ascent (20)"); TestPolicyGaussianUnknownVariance(randSeed, nArms, new ThresholdAscent(20, 0.01));
19      // Console.WriteLine("Threshold Ascent (100)"); TestPolicyGaussianUnknownVariance(randSeed, nArms, new ThresholdAscent(100, 0.01));
20      // Console.WriteLine("Threshold Ascent (500)"); TestPolicyGaussianUnknownVariance(randSeed, nArms, new ThresholdAscent(500, 0.01));
21      // Console.WriteLine("Threshold Ascent (1000)"); TestPolicyGaussianUnknownVariance(randSeed, nArms, new ThresholdAscent(1000, 0.01));
22      Console.WriteLine("Thompson (Gaussian fixed variance)"); TestPolicyGaussianUnknownVariance(randSeed, nArms, new GenericThompsonSamplingPolicy(new GaussianModel(0, 1, 1)));
23      Console.WriteLine("Thompson (Gaussian est variance)"); TestPolicyGaussianUnknownVariance(randSeed, nArms, new GenericThompsonSamplingPolicy(new GaussianModel(0, 1, 1, 0.1)));
24      Console.WriteLine("GaussianThompson (compat)"); TestPolicyGaussianUnknownVariance(randSeed, nArms, new GaussianThompsonSamplingPolicy(true));
25      Console.WriteLine("GaussianThompson"); TestPolicyGaussianUnknownVariance(randSeed, nArms, new GaussianThompsonSamplingPolicy());
26      Console.WriteLine("UCBNormal"); TestPolicyGaussianUnknownVariance(randSeed, nArms, new UCBNormalPolicy());
27      Console.WriteLine("Random"); TestPolicyGaussianUnknownVariance(randSeed, nArms, new RandomPolicy());
28
29    }
30
31
32    [TestMethod]
33    public void ComparePoliciesForBernoulliBandit() {
34      CultureInfo.DefaultThreadCurrentCulture = CultureInfo.InvariantCulture;
35      var randSeed = 31415;
36      var nArms = 20;
37      //Console.WriteLine("Exp3 (gamma=0.01)");
38      //TestPolicyBernoulli(globalRand, nArms, new Exp3Policy(new Random(seedForPolicy), nArms, 1));
39      //Console.WriteLine("Exp3 (gamma=0.05)");
40      //estPolicyBernoulli(globalRand, nArms, new Exp3Policy(new Random(seedForPolicy), nArms, 1));
41      Console.WriteLine("Thompson (Bernoulli)"); TestPolicyBernoulli(randSeed, nArms, new BernoulliThompsonSamplingPolicy());
42      Console.WriteLine("Generic Thompson (Bernoulli)"); TestPolicyBernoulli(randSeed, nArms, new GenericThompsonSamplingPolicy(new BernoulliModel()));
43      Console.WriteLine("Random");
44      TestPolicyBernoulli(randSeed, nArms, new RandomPolicy());
45      Console.WriteLine("UCB1");
46      TestPolicyBernoulli(randSeed, nArms, new UCB1Policy());
47      Console.WriteLine("UCB1Tuned");
48      TestPolicyBernoulli(randSeed, nArms, new UCB1TunedPolicy());
49      Console.WriteLine("UCB1Normal");
50      TestPolicyBernoulli(randSeed, nArms, new UCBNormalPolicy());
51      Console.WriteLine("Eps(0.01)");
52      TestPolicyBernoulli(randSeed, nArms, new EpsGreedyPolicy(0.01));
53      Console.WriteLine("Eps(0.05)");
54      TestPolicyBernoulli(randSeed, nArms, new EpsGreedyPolicy(0.05));
55      //Console.WriteLine("Eps(0.1)");
56      //TestPolicyBernoulli(randSeed, nArms, new EpsGreedyPolicy(0.1));
57      //Console.WriteLine("Eps(0.2)");
58      //TestPolicyBernoulli(randSeed, nArms, new EpsGreedyPolicy(0.2));
59      //Console.WriteLine("Eps(0.5)");
60      //TestPolicyBernoulli(randSeed, nArms, new EpsGreedyPolicy(0.5));
61      Console.WriteLine("UCT(0.1)"); TestPolicyBernoulli(randSeed, nArms, new UCTPolicy(0.1));
62      Console.WriteLine("UCT(0.5)"); TestPolicyBernoulli(randSeed, nArms, new UCTPolicy(0.5));
63      Console.WriteLine("UCT(1)  "); TestPolicyBernoulli(randSeed, nArms, new UCTPolicy(1));
64      Console.WriteLine("UCT(2)  "); TestPolicyBernoulli(randSeed, nArms, new UCTPolicy(2));
65      Console.WriteLine("UCT(5)  "); TestPolicyBernoulli(randSeed, nArms, new UCTPolicy(5));
66      Console.WriteLine("BoltzmannExploration(0.1)"); TestPolicyBernoulli(randSeed, nArms, new BoltzmannExplorationPolicy(0.1));
67      Console.WriteLine("BoltzmannExploration(0.5)"); TestPolicyBernoulli(randSeed, nArms, new BoltzmannExplorationPolicy(0.5));
68      Console.WriteLine("BoltzmannExploration(1)  "); TestPolicyBernoulli(randSeed, nArms, new BoltzmannExplorationPolicy(1));
69      Console.WriteLine("BoltzmannExploration(10) "); TestPolicyBernoulli(randSeed, nArms, new BoltzmannExplorationPolicy(10));
70      Console.WriteLine("BoltzmannExploration(100)"); TestPolicyBernoulli(randSeed, nArms, new BoltzmannExplorationPolicy(100));
71      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.01)"); TestPolicyBernoulli(randSeed, nArms, new ChernoffIntervalEstimationPolicy(0.01));
72      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.05)"); TestPolicyBernoulli(randSeed, nArms, new ChernoffIntervalEstimationPolicy(0.05));
73      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.1) "); TestPolicyBernoulli(randSeed, nArms, new ChernoffIntervalEstimationPolicy(0.1));
74
75      // not applicable to bernoulli rewards
76      //Console.WriteLine("ThresholdAscent(10, 0.01)  "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 10, 0.01));
77      //Console.WriteLine("ThresholdAscent(10, 0.05)  "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 10, 0.05));
78      //Console.WriteLine("ThresholdAscent(10, 0.1)   "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 10, 0.1));
79      //Console.WriteLine("ThresholdAscent(100, 0.01) "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 100, 0.01));
80      //Console.WriteLine("ThresholdAscent(100, 0.05) "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 100, 0.05));
81      //Console.WriteLine("ThresholdAscent(100, 0.1)  "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 100, 0.1));
82      //Console.WriteLine("ThresholdAscent(1000, 0.01)"); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.01));
83      //Console.WriteLine("ThresholdAscent(1000, 0.05)"); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.05));
84      //Console.WriteLine("ThresholdAscent(1000, 0.1) "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.1));
85    }
86
87    [TestMethod]
88    public void ComparePoliciesForGaussianBandit() {
89      CultureInfo.DefaultThreadCurrentCulture = CultureInfo.InvariantCulture;
90
91      var randSeed = 31415;
92      var nArms = 20;
93      Console.WriteLine("Thompson (Gaussian orig)"); TestPolicyGaussian(randSeed, nArms, new GaussianThompsonSamplingPolicy(true));
94      Console.WriteLine("Thompson (Gaussian new)"); TestPolicyGaussian(randSeed, nArms, new GaussianThompsonSamplingPolicy());
95      Console.WriteLine("Generic Thompson (Gaussian)"); TestPolicyGaussian(randSeed, nArms, new GenericThompsonSamplingPolicy(new GaussianModel(0.5, 1)));
96      /*
97      Console.WriteLine("Random"); TestPolicyNormal(randSeed, nArms, new RandomPolicy(new Random(seedForPolicy), nArms));
98      Console.WriteLine("UCB1"); TestPolicyNormal(randSeed, nArms, new UCB1Policy(nArms));
99      Console.WriteLine("UCB1Tuned"); TestPolicyNormal(randSeed, nArms, new UCB1TunedPolicy(nArms));
100      Console.WriteLine("UCB1Normal"); TestPolicyNormal(randSeed, nArms, new UCBNormalPolicy(nArms));
101      //Console.WriteLine("Exp3 (gamma=0.01)");
102      //TestPolicyNormal(randSeed, nArms, new Exp3Policy(new Random(seedForPolicy), nArms, 0.01));
103      //Console.WriteLine("Exp3 (gamma=0.05)");
104      //TestPolicyNormal(randSeed, nArms, new Exp3Policy(new Random(seedForPolicy), nArms, 0.05));
105      Console.WriteLine("Eps(0.01)"); TestPolicyNormal(randSeed, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.01));
106      Console.WriteLine("Eps(0.05)"); TestPolicyNormal(randSeed, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.05));
107      //Console.WriteLine("Eps(0.1)");
108      //TestPolicyNormal(randSeed, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.1));
109      //Console.WriteLine("Eps(0.2)");
110      //TestPolicyNormal(randSeed, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.2));
111      //Console.WriteLine("Eps(0.5)");
112      //TestPolicyNormal(randSeed, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.5));
113      Console.WriteLine("UCT(0.1)"); TestPolicyNormal(randSeed, nArms, new UCTPolicy(nArms, 0.1));
114      Console.WriteLine("UCT(0.5)"); TestPolicyNormal(randSeed, nArms, new UCTPolicy(nArms, 0.5));
115      Console.WriteLine("UCT(1)  "); TestPolicyNormal(randSeed, nArms, new UCTPolicy(nArms, 1));
116      Console.WriteLine("UCT(2)  "); TestPolicyNormal(randSeed, nArms, new UCTPolicy(nArms, 2));
117      Console.WriteLine("UCT(5)  "); TestPolicyNormal(randSeed, nArms, new UCTPolicy(nArms, 5));
118      Console.WriteLine("BoltzmannExploration(0.1)"); TestPolicyNormal(randSeed, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 0.1));
119      Console.WriteLine("BoltzmannExploration(0.5)"); TestPolicyNormal(randSeed, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 0.5));
120      Console.WriteLine("BoltzmannExploration(1)  "); TestPolicyNormal(randSeed, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 1));
121      Console.WriteLine("BoltzmannExploration(10) "); TestPolicyNormal(randSeed, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 10));
122      Console.WriteLine("BoltzmannExploration(100)"); TestPolicyNormal(randSeed, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 100));
123      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.01)"); TestPolicyNormal(randSeed, nArms, new ChernoffIntervalEstimationPolicy(nArms, 0.01));
124      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.05)"); TestPolicyNormal(randSeed, nArms, new ChernoffIntervalEstimationPolicy(nArms, 0.05));
125      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.1) "); TestPolicyNormal(randSeed, nArms, new ChernoffIntervalEstimationPolicy(nArms, 0.1));
126      Console.WriteLine("ThresholdAscent(10,0.01)  "); TestPolicyNormal(randSeed, nArms, new ThresholdAscentPolicy(nArms, 10, 0.01));
127      Console.WriteLine("ThresholdAscent(10,0.05)  "); TestPolicyNormal(randSeed, nArms, new ThresholdAscentPolicy(nArms, 10, 0.05));
128      Console.WriteLine("ThresholdAscent(10,0.1)   "); TestPolicyNormal(randSeed, nArms, new ThresholdAscentPolicy(nArms, 10, 0.1));
129      Console.WriteLine("ThresholdAscent(100,0.01) "); TestPolicyNormal(randSeed, nArms, new ThresholdAscentPolicy(nArms, 100, 0.01));
130      Console.WriteLine("ThresholdAscent(100,0.05) "); TestPolicyNormal(randSeed, nArms, new ThresholdAscentPolicy(nArms, 100, 0.05));
131      Console.WriteLine("ThresholdAscent(100,0.1)  "); TestPolicyNormal(randSeed, nArms, new ThresholdAscentPolicy(nArms, 100, 0.1));
132      Console.WriteLine("ThresholdAscent(1000,0.01)"); TestPolicyNormal(randSeed, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.01));
133      Console.WriteLine("ThresholdAscent(1000,0.05)"); TestPolicyNormal(randSeed, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.05));
134      Console.WriteLine("ThresholdAscent(1000,0.1) "); TestPolicyNormal(randSeed, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.1));
135       */
136    }
137
138    [TestMethod]
139    public void ComparePoliciesForGaussianMixtureBandit() {
140      CultureInfo.DefaultThreadCurrentCulture = CultureInfo.InvariantCulture;
141      var randSeed = 31415;
142      var nArms = 20;
143      Console.WriteLine("Thompson (Gaussian orig)"); TestPolicyGaussianMixture(randSeed, nArms, new GaussianThompsonSamplingPolicy(true));
144      Console.WriteLine("Thompson (Gaussian new)"); TestPolicyGaussianMixture(randSeed, nArms, new GaussianThompsonSamplingPolicy());
145      Console.WriteLine("Generic Thompson (Gaussian)"); TestPolicyGaussianMixture(randSeed, nArms, new GenericThompsonSamplingPolicy(new GaussianModel(0.5, 1)));
146
147      /*
148      Console.WriteLine("Random"); TestPolicyGaussianMixture(randSeed, nArms, new RandomPolicy(new Random(seedForPolicy), nArms));
149      Console.WriteLine("UCB1"); TestPolicyGaussianMixture(randSeed, nArms, new UCB1Policy(nArms));
150      Console.WriteLine("UCB1Tuned "); TestPolicyGaussianMixture(randSeed, nArms, new UCB1TunedPolicy(nArms));
151      Console.WriteLine("UCB1Normal"); TestPolicyGaussianMixture(randSeed, nArms, new UCBNormalPolicy(nArms));
152      Console.WriteLine("Eps(0.01) "); TestPolicyGaussianMixture(randSeed, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.01));
153      Console.WriteLine("Eps(0.05) "); TestPolicyGaussianMixture(randSeed, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.05));
154      Console.WriteLine("UCT(1)  "); TestPolicyGaussianMixture(randSeed, nArms, new UCTPolicy(nArms, 1));
155      Console.WriteLine("UCT(2)  "); TestPolicyGaussianMixture(randSeed, nArms, new UCTPolicy(nArms, 2));
156      Console.WriteLine("UCT(5)  "); TestPolicyGaussianMixture(randSeed, nArms, new UCTPolicy(nArms, 5));
157      Console.WriteLine("BoltzmannExploration(1)  "); TestPolicyGaussianMixture(randSeed, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 1));
158      Console.WriteLine("BoltzmannExploration(10) "); TestPolicyGaussianMixture(randSeed, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 10));
159      Console.WriteLine("BoltzmannExploration(100)"); TestPolicyGaussianMixture(randSeed, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 100));
160
161      Console.WriteLine("ThresholdAscent(10,0.01)  "); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(nArms, 10, 0.01));
162      Console.WriteLine("ThresholdAscent(100,0.01) "); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(nArms, 100, 0.01));
163      Console.WriteLine("ThresholdAscent(1000,0.01)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.01));
164      Console.WriteLine("ThresholdAscent(10000,0.01)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(nArms, 10000, 0.01));
165       */
166    }
167
168
169    private void TestPolicyBernoulli(int randSeed, int nArms, IPolicy policy) {
170      TestPolicy(randSeed, nArms, policy, (banditRandom, nActions) => new BernoulliBandit(banditRandom, nActions));
171    }
172    private void TestPolicyGaussian(int randSeed, int nArms, IPolicy policy) {
173      TestPolicy(randSeed, nArms, policy, (banditRandom, nActions) => new TruncatedNormalBandit(banditRandom, nActions));
174    }
175    private void TestPolicyGaussianMixture(int randSeed, int nArms, IPolicy policy) {
176      TestPolicy(randSeed, nArms, policy, (banditRandom, nActions) => new GaussianMixtureBandit(banditRandom, nActions));
177    }
178    private void TestPolicyGaussianUnknownVariance(int randSeed, int nArms, IPolicy policy) {
179      TestPolicy(randSeed, nArms, policy, (banditRandom, nActions) => new GaussianBandit(banditRandom, nActions));
180    }
181
182
183    private void TestPolicy(int randSeed, int nArms, IPolicy policy, Func<Random, int, IBandit> banditFactory) {
184      var maxIt = 1E5;
185      var reps = 10; // independent runs
186      var regretForIteration = new Dictionary<int, List<double>>();
187      var numberOfPullsOfSuboptimalArmsForExp = new Dictionary<int, double>();
188      var numberOfPullsOfSuboptimalArmsForMax = new Dictionary<int, double>();
189      var globalRandom = new Random(randSeed);
190      var banditRandom = new Random(globalRandom.Next()); // bandits must produce the same rewards for each test
191      var policyRandom = new Random(globalRandom.Next());
192
193      // calculate statistics
194      for (int r = 0; r < reps; r++) {
195        var nextLogStep = 1;
196        var b = banditFactory(banditRandom, nArms);
197        var totalRegret = 0.0;
198        var totalPullsOfSuboptimalArmsExp = 0.0;
199        var totalPullsOfSuboptimalArmsMax = 0.0;
200        var actionInfos = Enumerable.Range(0, nArms).Select(_ => policy.CreateActionInfo()).ToArray();
201        for (int i = 0; i <= maxIt; i++) {
202          var selectedAction = policy.SelectAction(policyRandom, actionInfos);
203          var reward = b.Pull(selectedAction);
204          actionInfos[selectedAction].UpdateReward(reward);
205
206          // collect stats
207          if (selectedAction != b.OptimalExpectedRewardArm) totalPullsOfSuboptimalArmsExp++;
208          if (selectedAction != b.OptimalMaximalRewardArm) totalPullsOfSuboptimalArmsMax++;
209          totalRegret += b.OptimalExpectedReward - reward;
210
211          if (i == nextLogStep) {
212            nextLogStep *= 2;
213            if (!regretForIteration.ContainsKey(i)) {
214              regretForIteration.Add(i, new List<double>());
215            }
216            regretForIteration[i].Add(totalRegret / i);
217
218            if (!numberOfPullsOfSuboptimalArmsForExp.ContainsKey(i)) {
219              numberOfPullsOfSuboptimalArmsForExp.Add(i, 0.0);
220            }
221            numberOfPullsOfSuboptimalArmsForExp[i] += totalPullsOfSuboptimalArmsExp;
222
223            if (!numberOfPullsOfSuboptimalArmsForMax.ContainsKey(i)) {
224              numberOfPullsOfSuboptimalArmsForMax.Add(i, 0.0);
225            }
226            numberOfPullsOfSuboptimalArmsForMax[i] += totalPullsOfSuboptimalArmsMax;
227          }
228        }
229      }
230      // print
231      foreach (var p in regretForIteration.Keys.OrderBy(k => k)) {
232        Console.WriteLine("iter {0,8} regret avg {1,7:F5} min {2,7:F5} max {3,7:F5} suboptimal pulls (exp) {4,7:F2} suboptimal pulls (max) {5,7:F2}",
233          p,
234          regretForIteration[p].Average(),
235          regretForIteration[p].Min(),
236          regretForIteration[p].Max(),
237          numberOfPullsOfSuboptimalArmsForExp[p] / (double)reps,
238          numberOfPullsOfSuboptimalArmsForMax[p] / (double)reps
239          );
240      }
241    }
242
243  }
244}
Note: See TracBrowser for help on using the repository browser.