source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization.Test/TestBanditPolicies.cs @ 11745

Last change on this file since 11745 was 11745, checked in by gkronber, 7 years ago

#2283: worked on contextual MCTS

File size: 19.1 KB
Line 
1using System;
2using System.Linq;
3using System.Collections.Generic;
4using System.Globalization;
5using HeuristicLab.Algorithms.Bandits;
6using HeuristicLab.Algorithms.Bandits.BanditPolicies;
7using HeuristicLab.Algorithms.Bandits.Models;
8using Microsoft.VisualStudio.TestTools.UnitTesting;
9
10namespace HeuristicLab.Problems.GrammaticalOptimization.Test {
11  [TestClass]
12  public class TestBanditPolicies {
13    [TestMethod]
14    public void ComparePoliciesForGaussianUnknownVarianceBandit() {
15      CultureInfo.DefaultThreadCurrentCulture = CultureInfo.InvariantCulture;
16      var randSeed = 31415;
17      var nArms = 20;
18
19      // ThresholdAscent only works for rewards in [0..1] so far
20
21      Console.WriteLine("Thompson (Gaussian est variance)"); TestPolicyGaussianUnknownVariance(randSeed, nArms, new GenericThompsonSamplingPolicy(new GaussianModel(0, 1, 1, 1)));
22      Console.WriteLine("Thompson (Gaussian fixed variance)"); TestPolicyGaussianUnknownVariance(randSeed, nArms, new GenericThompsonSamplingPolicy(new GaussianModel(0, 1, 0.1)));
23      Console.WriteLine("GaussianThompson (compat)"); TestPolicyGaussianUnknownVariance(randSeed, nArms, new GaussianThompsonSamplingPolicy(true));
24      Console.WriteLine("GaussianThompson"); TestPolicyGaussianUnknownVariance(randSeed, nArms, new GaussianThompsonSamplingPolicy());
25      Console.WriteLine("UCBNormal"); TestPolicyGaussianUnknownVariance(randSeed, nArms, new UCBNormalPolicy());
26      Console.WriteLine("Random"); TestPolicyGaussianUnknownVariance(randSeed, nArms, new RandomPolicy());
27
28    }
29
30
31    [TestMethod]
32    public void ComparePoliciesForBernoulliBandit() {
33      CultureInfo.DefaultThreadCurrentCulture = CultureInfo.InvariantCulture;
34      var randSeed = 31415;
35      var nArms = 20;
36      //Console.WriteLine("Exp3 (gamma=0.01)");
37      //TestPolicyBernoulli(globalRand, nArms, new Exp3Policy(new Random(seedForPolicy), nArms, 1));
38      //Console.WriteLine("Exp3 (gamma=0.05)");
39      //estPolicyBernoulli(globalRand, nArms, new Exp3Policy(new Random(seedForPolicy), nArms, 1));
40      Console.WriteLine("Thompson (Bernoulli)"); TestPolicyBernoulli(randSeed, nArms, new BernoulliThompsonSamplingPolicy());
41      Console.WriteLine("Generic Thompson (Bernoulli)"); TestPolicyBernoulli(randSeed, nArms, new GenericThompsonSamplingPolicy(new BernoulliModel()));
42      Console.WriteLine("Random");
43      TestPolicyBernoulli(randSeed, nArms, new RandomPolicy());
44      Console.WriteLine("UCB1");
45      TestPolicyBernoulli(randSeed, nArms, new UCB1Policy());
46      Console.WriteLine("UCB1Tuned");
47      TestPolicyBernoulli(randSeed, nArms, new UCB1TunedPolicy());
48      Console.WriteLine("UCB1Normal");
49      TestPolicyBernoulli(randSeed, nArms, new UCBNormalPolicy());
50      Console.WriteLine("Eps(0.01)");
51      TestPolicyBernoulli(randSeed, nArms, new EpsGreedyPolicy(0.01));
52      Console.WriteLine("Eps(0.05)");
53      TestPolicyBernoulli(randSeed, nArms, new EpsGreedyPolicy(0.05));
54      //Console.WriteLine("Eps(0.1)");
55      //TestPolicyBernoulli(randSeed, nArms, new EpsGreedyPolicy(0.1));
56      //Console.WriteLine("Eps(0.2)");
57      //TestPolicyBernoulli(randSeed, nArms, new EpsGreedyPolicy(0.2));
58      //Console.WriteLine("Eps(0.5)");
59      //TestPolicyBernoulli(randSeed, nArms, new EpsGreedyPolicy(0.5));
60      Console.WriteLine("UCT(0.1)"); TestPolicyBernoulli(randSeed, nArms, new UCTPolicy(0.1));
61      Console.WriteLine("UCT(0.5)"); TestPolicyBernoulli(randSeed, nArms, new UCTPolicy(0.5));
62      Console.WriteLine("UCT(1)  "); TestPolicyBernoulli(randSeed, nArms, new UCTPolicy(1));
63      Console.WriteLine("UCT(2)  "); TestPolicyBernoulli(randSeed, nArms, new UCTPolicy(2));
64      Console.WriteLine("UCT(5)  "); TestPolicyBernoulli(randSeed, nArms, new UCTPolicy(5));
65      Console.WriteLine("BoltzmannExploration(0.1)"); TestPolicyBernoulli(randSeed, nArms, new BoltzmannExplorationPolicy(0.1));
66      Console.WriteLine("BoltzmannExploration(0.5)"); TestPolicyBernoulli(randSeed, nArms, new BoltzmannExplorationPolicy(0.5));
67      Console.WriteLine("BoltzmannExploration(1)  "); TestPolicyBernoulli(randSeed, nArms, new BoltzmannExplorationPolicy(1));
68      Console.WriteLine("BoltzmannExploration(10) "); TestPolicyBernoulli(randSeed, nArms, new BoltzmannExplorationPolicy(10));
69      Console.WriteLine("BoltzmannExploration(100)"); TestPolicyBernoulli(randSeed, nArms, new BoltzmannExplorationPolicy(100));
70      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.01)"); TestPolicyBernoulli(randSeed, nArms, new ChernoffIntervalEstimationPolicy(0.01));
71      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.05)"); TestPolicyBernoulli(randSeed, nArms, new ChernoffIntervalEstimationPolicy(0.05));
72      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.1) "); TestPolicyBernoulli(randSeed, nArms, new ChernoffIntervalEstimationPolicy(0.1));
73
74      // not applicable to bernoulli rewards
75      //Console.WriteLine("ThresholdAscent(10, 0.01)  "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 10, 0.01));
76      //Console.WriteLine("ThresholdAscent(10, 0.05)  "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 10, 0.05));
77      //Console.WriteLine("ThresholdAscent(10, 0.1)   "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 10, 0.1));
78      //Console.WriteLine("ThresholdAscent(100, 0.01) "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 100, 0.01));
79      //Console.WriteLine("ThresholdAscent(100, 0.05) "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 100, 0.05));
80      //Console.WriteLine("ThresholdAscent(100, 0.1)  "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 100, 0.1));
81      //Console.WriteLine("ThresholdAscent(1000, 0.01)"); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.01));
82      //Console.WriteLine("ThresholdAscent(1000, 0.05)"); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.05));
83      //Console.WriteLine("ThresholdAscent(1000, 0.1) "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.1));
84    }
85
86    [TestMethod]
87    public void ComparePoliciesForGaussianBandit() {
88      CultureInfo.DefaultThreadCurrentCulture = CultureInfo.InvariantCulture;
89
90      var randSeed = 31415;
91      var nArms = 20;
92      Console.WriteLine("Threshold Ascent (20)"); TestPolicyGaussian(randSeed, nArms, new ThresholdAscentPolicy(20, 0.01));
93      Console.WriteLine("Threshold Ascent (100)"); TestPolicyGaussian(randSeed, nArms, new ThresholdAscentPolicy(100, 0.01));
94      Console.WriteLine("Threshold Ascent (500)"); TestPolicyGaussian(randSeed, nArms, new ThresholdAscentPolicy(500, 0.01));
95      Console.WriteLine("Threshold Ascent (1000)"); TestPolicyGaussian(randSeed, nArms, new ThresholdAscentPolicy(1000, 0.01));
96      Console.WriteLine("Generic Thompson (Gaussian fixed var)"); TestPolicyGaussian(randSeed, nArms, new GenericThompsonSamplingPolicy(new GaussianModel(0.5, 1)));
97      Console.WriteLine("Generic Thompson (Gaussian unknown var)"); TestPolicyGaussian(randSeed, nArms, new GenericThompsonSamplingPolicy(new GaussianModel(0.5, 1, 1, 1)));
98      Console.WriteLine("Thompson (Gaussian orig)"); TestPolicyGaussian(randSeed, nArms, new GaussianThompsonSamplingPolicy(true));
99      Console.WriteLine("Thompson (Gaussian new)"); TestPolicyGaussian(randSeed, nArms, new GaussianThompsonSamplingPolicy());
100
101      /*
102      Console.WriteLine("Random"); TestPolicyNormal(randSeed, nArms, new RandomPolicy(new Random(seedForPolicy), nArms));
103      Console.WriteLine("UCB1"); TestPolicyNormal(randSeed, nArms, new UCB1Policy(nArms));
104      Console.WriteLine("UCB1Tuned"); TestPolicyNormal(randSeed, nArms, new UCB1TunedPolicy(nArms));
105      Console.WriteLine("UCB1Normal"); TestPolicyNormal(randSeed, nArms, new UCBNormalPolicy(nArms));
106      //Console.WriteLine("Exp3 (gamma=0.01)");
107      //TestPolicyNormal(randSeed, nArms, new Exp3Policy(new Random(seedForPolicy), nArms, 0.01));
108      //Console.WriteLine("Exp3 (gamma=0.05)");
109      //TestPolicyNormal(randSeed, nArms, new Exp3Policy(new Random(seedForPolicy), nArms, 0.05));
110      Console.WriteLine("Eps(0.01)"); TestPolicyNormal(randSeed, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.01));
111      Console.WriteLine("Eps(0.05)"); TestPolicyNormal(randSeed, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.05));
112      //Console.WriteLine("Eps(0.1)");
113      //TestPolicyNormal(randSeed, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.1));
114      //Console.WriteLine("Eps(0.2)");
115      //TestPolicyNormal(randSeed, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.2));
116      //Console.WriteLine("Eps(0.5)");
117      //TestPolicyNormal(randSeed, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.5));
118      Console.WriteLine("UCT(0.1)"); TestPolicyNormal(randSeed, nArms, new UCTPolicy(nArms, 0.1));
119      Console.WriteLine("UCT(0.5)"); TestPolicyNormal(randSeed, nArms, new UCTPolicy(nArms, 0.5));
120      Console.WriteLine("UCT(1)  "); TestPolicyNormal(randSeed, nArms, new UCTPolicy(nArms, 1));
121      Console.WriteLine("UCT(2)  "); TestPolicyNormal(randSeed, nArms, new UCTPolicy(nArms, 2));
122      Console.WriteLine("UCT(5)  "); TestPolicyNormal(randSeed, nArms, new UCTPolicy(nArms, 5));
123      Console.WriteLine("BoltzmannExploration(0.1)"); TestPolicyNormal(randSeed, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 0.1));
124      Console.WriteLine("BoltzmannExploration(0.5)"); TestPolicyNormal(randSeed, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 0.5));
125      Console.WriteLine("BoltzmannExploration(1)  "); TestPolicyNormal(randSeed, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 1));
126      Console.WriteLine("BoltzmannExploration(10) "); TestPolicyNormal(randSeed, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 10));
127      Console.WriteLine("BoltzmannExploration(100)"); TestPolicyNormal(randSeed, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 100));
128      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.01)"); TestPolicyNormal(randSeed, nArms, new ChernoffIntervalEstimationPolicy(nArms, 0.01));
129      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.05)"); TestPolicyNormal(randSeed, nArms, new ChernoffIntervalEstimationPolicy(nArms, 0.05));
130      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.1) "); TestPolicyNormal(randSeed, nArms, new ChernoffIntervalEstimationPolicy(nArms, 0.1));     
131      Console.WriteLine("ThresholdAscent(10,0.01)  "); TestPolicyNormal(randSeed, nArms, new ThresholdAscentPolicy(nArms, 10, 0.01));
132      Console.WriteLine("ThresholdAscent(10,0.05)  "); TestPolicyNormal(randSeed, nArms, new ThresholdAscentPolicy(nArms, 10, 0.05));
133      Console.WriteLine("ThresholdAscent(10,0.1)   "); TestPolicyNormal(randSeed, nArms, new ThresholdAscentPolicy(nArms, 10, 0.1));
134      Console.WriteLine("ThresholdAscent(100,0.01) "); TestPolicyNormal(randSeed, nArms, new ThresholdAscentPolicy(nArms, 100, 0.01));
135      Console.WriteLine("ThresholdAscent(100,0.05) "); TestPolicyNormal(randSeed, nArms, new ThresholdAscentPolicy(nArms, 100, 0.05));
136      Console.WriteLine("ThresholdAscent(100,0.1)  "); TestPolicyNormal(randSeed, nArms, new ThresholdAscentPolicy(nArms, 100, 0.1));
137      Console.WriteLine("ThresholdAscent(1000,0.01)"); TestPolicyNormal(randSeed, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.01));
138      Console.WriteLine("ThresholdAscent(1000,0.05)"); TestPolicyNormal(randSeed, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.05));
139      Console.WriteLine("ThresholdAscent(1000,0.1) "); TestPolicyNormal(randSeed, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.1));
140       */
141    }
142
143    [TestMethod]
144    public void ComparePoliciesForGaussianMixtureBandit() {
145      CultureInfo.DefaultThreadCurrentCulture = CultureInfo.InvariantCulture;
146      var randSeed = 31415;
147      var nArms = 20;
148
149      Console.WriteLine("Generic Thompson (Gaussian Mixture)"); TestPolicyGaussianMixture(randSeed, nArms, new GenericThompsonSamplingPolicy(new GaussianMixtureModel()));
150      // Console.WriteLine("Threshold Ascent (20)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(20, 0.01));
151      // Console.WriteLine("Threshold Ascent (100)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(100, 0.01));
152      // Console.WriteLine("Threshold Ascent (500)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(500, 0.01));
153      // Console.WriteLine("Threshold Ascent (1000)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(1000, 0.01));
154      // Console.WriteLine("Thompson (Gaussian orig)"); TestPolicyGaussianMixture(randSeed, nArms, new GaussianThompsonSamplingPolicy(true));
155      // Console.WriteLine("Thompson (Gaussian new)"); TestPolicyGaussianMixture(randSeed, nArms, new GaussianThompsonSamplingPolicy());
156      // Console.WriteLine("Generic Thompson (Gaussian fixed variance)"); TestPolicyGaussianMixture(randSeed, nArms, new GenericThompsonSamplingPolicy(new GaussianModel(0.5, 1, 0.1)));
157      // Console.WriteLine("Generic Thompson (Gaussian unknown variance)"); TestPolicyGaussianMixture(randSeed, nArms, new GenericThompsonSamplingPolicy(new GaussianModel(0.5, 1, 1, 1)));
158
159      /*
160      Console.WriteLine("Random"); TestPolicyGaussianMixture(randSeed, nArms, new RandomPolicy(new Random(seedForPolicy), nArms));
161      Console.WriteLine("UCB1"); TestPolicyGaussianMixture(randSeed, nArms, new UCB1Policy(nArms));
162      Console.WriteLine("UCB1Tuned "); TestPolicyGaussianMixture(randSeed, nArms, new UCB1TunedPolicy(nArms));
163      Console.WriteLine("UCB1Normal"); TestPolicyGaussianMixture(randSeed, nArms, new UCBNormalPolicy(nArms));
164      Console.WriteLine("Eps(0.01) "); TestPolicyGaussianMixture(randSeed, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.01));
165      Console.WriteLine("Eps(0.05) "); TestPolicyGaussianMixture(randSeed, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.05));
166      Console.WriteLine("UCT(1)  "); TestPolicyGaussianMixture(randSeed, nArms, new UCTPolicy(nArms, 1));
167      Console.WriteLine("UCT(2)  "); TestPolicyGaussianMixture(randSeed, nArms, new UCTPolicy(nArms, 2));
168      Console.WriteLine("UCT(5)  "); TestPolicyGaussianMixture(randSeed, nArms, new UCTPolicy(nArms, 5));
169      Console.WriteLine("BoltzmannExploration(1)  "); TestPolicyGaussianMixture(randSeed, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 1));
170      Console.WriteLine("BoltzmannExploration(10) "); TestPolicyGaussianMixture(randSeed, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 10));
171      Console.WriteLine("BoltzmannExploration(100)"); TestPolicyGaussianMixture(randSeed, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 100));
172
173      Console.WriteLine("ThresholdAscent(10,0.01)  "); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(nArms, 10, 0.01));
174      Console.WriteLine("ThresholdAscent(100,0.01) "); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(nArms, 100, 0.01));
175      Console.WriteLine("ThresholdAscent(1000,0.01)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.01));
176      Console.WriteLine("ThresholdAscent(10000,0.01)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(nArms, 10000, 0.01));
177       */
178    }
179
180
181    private void TestPolicyBernoulli(int randSeed, int nArms, IBanditPolicy policy) {
182      TestPolicy(randSeed, nArms, policy, (banditRandom, nActions) => new BernoulliBandit(banditRandom, nActions));
183    }
184    private void TestPolicyGaussian(int randSeed, int nArms, IBanditPolicy policy) {
185      TestPolicy(randSeed, nArms, policy, (banditRandom, nActions) => new TruncatedNormalBandit(banditRandom, nActions));
186    }
187    private void TestPolicyGaussianMixture(int randSeed, int nArms, IBanditPolicy policy) {
188      TestPolicy(randSeed, nArms, policy, (banditRandom, nActions) => new GaussianMixtureBandit(banditRandom, nActions));
189    }
190    private void TestPolicyGaussianUnknownVariance(int randSeed, int nArms, IBanditPolicy policy) {
191      TestPolicy(randSeed, nArms, policy, (banditRandom, nActions) => new GaussianBandit(banditRandom, nActions));
192    }
193
194
195    private void TestPolicy(int randSeed, int nArms, IBanditPolicy policy, Func<Random, int, IBandit> banditFactory) {
196      var maxIt = 1E5;
197      var reps = 10; // independent runs
198      var regretForIteration = new Dictionary<int, List<double>>();
199      var numberOfPullsOfSuboptimalArmsForExp = new Dictionary<int, double>();
200      var numberOfPullsOfSuboptimalArmsForMax = new Dictionary<int, double>();
201      var globalRandom = new Random(randSeed);
202      var banditRandom = new Random(globalRandom.Next()); // bandits must produce the same rewards for each test
203      var policyRandom = new Random(globalRandom.Next());
204
205      // calculate statistics
206      for (int r = 0; r < reps; r++) {
207        var nextLogStep = 1;
208        var b = banditFactory(banditRandom, nArms);
209        var totalRegret = 0.0;
210        var totalPullsOfSuboptimalArmsExp = 0.0;
211        var totalPullsOfSuboptimalArmsMax = 0.0;
212        var actionInfos = Enumerable.Range(0, nArms).Select(_ => policy.CreateActionInfo()).ToArray();
213        for (int i = 0; i <= maxIt; i++) {
214          var selectedAction = policy.SelectAction(policyRandom, actionInfos);
215          var reward = b.Pull(selectedAction);
216          actionInfos[selectedAction].UpdateReward(reward);
217
218          // collect stats
219          if (selectedAction != b.OptimalExpectedRewardArm) totalPullsOfSuboptimalArmsExp++;
220          if (selectedAction != b.OptimalMaximalRewardArm) totalPullsOfSuboptimalArmsMax++;
221          totalRegret += b.OptimalExpectedReward - reward;
222
223          if (i == nextLogStep) {
224            nextLogStep *= 2;
225            if (!regretForIteration.ContainsKey(i)) {
226              regretForIteration.Add(i, new List<double>());
227            }
228            regretForIteration[i].Add(totalRegret / i);
229
230            if (!numberOfPullsOfSuboptimalArmsForExp.ContainsKey(i)) {
231              numberOfPullsOfSuboptimalArmsForExp.Add(i, 0.0);
232            }
233            numberOfPullsOfSuboptimalArmsForExp[i] += totalPullsOfSuboptimalArmsExp;
234
235            if (!numberOfPullsOfSuboptimalArmsForMax.ContainsKey(i)) {
236              numberOfPullsOfSuboptimalArmsForMax.Add(i, 0.0);
237            }
238            numberOfPullsOfSuboptimalArmsForMax[i] += totalPullsOfSuboptimalArmsMax;
239          }
240        }
241      }
242      // print
243      foreach (var p in regretForIteration.Keys.OrderBy(k => k)) {
244        Console.WriteLine("iter {0,8} regret avg {1,7:F5} min {2,7:F5} max {3,7:F5} suboptimal pulls (exp) {4,7:F2} suboptimal pulls (max) {5,7:F2}",
245          p,
246          regretForIteration[p].Average(),
247          regretForIteration[p].Min(),
248          regretForIteration[p].Max(),
249          numberOfPullsOfSuboptimalArmsForExp[p] / (double)reps,
250          numberOfPullsOfSuboptimalArmsForMax[p] / (double)reps
251          );
252      }
253    }
254
255  }
256}
Note: See TracBrowser for help on using the repository browser.