source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization.Test/TestBanditPolicies.cs @ 11742

Last change on this file since 11742 was 11742, checked in by gkronber, 7 years ago

#2283 refactoring

File size: 18.9 KB
Line 
1using System;
2using System.Linq;
3using System.Collections.Generic;
4using System.Globalization;
5using HeuristicLab.Algorithms.Bandits;
6using HeuristicLab.Algorithms.Bandits.BanditPolicies;
7using HeuristicLab.Algorithms.Bandits.Models;
8using Microsoft.VisualStudio.TestTools.UnitTesting;
9
10namespace HeuristicLab.Problems.GrammaticalOptimization.Test {
11  [TestClass]
12  public class TestBanditPolicies {
13    [TestMethod]
14    public void ComparePoliciesForGaussianUnknownVarianceBandit() {
15      CultureInfo.DefaultThreadCurrentCulture = CultureInfo.InvariantCulture;
16      var randSeed = 31415;
17      var nArms = 20;
18
19      // ThresholdAscent only works for rewards in [0..1] so far
20
21      Console.WriteLine("Thompson (Gaussian est variance)"); TestPolicyGaussianUnknownVariance(randSeed, nArms, new GenericThompsonSamplingPolicy(new GaussianModel(0, 1, 1, 1)));
22      Console.WriteLine("Thompson (Gaussian fixed variance)"); TestPolicyGaussianUnknownVariance(randSeed, nArms, new GenericThompsonSamplingPolicy(new GaussianModel(0, 1, 0.1)));
23      Console.WriteLine("GaussianThompson (compat)"); TestPolicyGaussianUnknownVariance(randSeed, nArms, new GaussianThompsonSamplingPolicy(true));
24      Console.WriteLine("GaussianThompson"); TestPolicyGaussianUnknownVariance(randSeed, nArms, new GaussianThompsonSamplingPolicy());
25      Console.WriteLine("UCBNormal"); TestPolicyGaussianUnknownVariance(randSeed, nArms, new UCBNormalPolicy());
26      Console.WriteLine("Random"); TestPolicyGaussianUnknownVariance(randSeed, nArms, new RandomPolicy());
27
28    }
29
30
31    [TestMethod]
32    public void ComparePoliciesForBernoulliBandit() {
33      CultureInfo.DefaultThreadCurrentCulture = CultureInfo.InvariantCulture;
34      var randSeed = 31415;
35      var nArms = 20;
36      //Console.WriteLine("Exp3 (gamma=0.01)");
37      //TestPolicyBernoulli(globalRand, nArms, new Exp3Policy(new Random(seedForPolicy), nArms, 1));
38      //Console.WriteLine("Exp3 (gamma=0.05)");
39      //estPolicyBernoulli(globalRand, nArms, new Exp3Policy(new Random(seedForPolicy), nArms, 1));
40      Console.WriteLine("Thompson (Bernoulli)"); TestPolicyBernoulli(randSeed, nArms, new BernoulliThompsonSamplingPolicy());
41      Console.WriteLine("Generic Thompson (Bernoulli)"); TestPolicyBernoulli(randSeed, nArms, new GenericThompsonSamplingPolicy(new BernoulliModel()));
42      Console.WriteLine("Random");
43      TestPolicyBernoulli(randSeed, nArms, new RandomPolicy());
44      Console.WriteLine("UCB1");
45      TestPolicyBernoulli(randSeed, nArms, new UCB1Policy());
46      Console.WriteLine("UCB1Tuned");
47      TestPolicyBernoulli(randSeed, nArms, new UCB1TunedPolicy());
48      Console.WriteLine("UCB1Normal");
49      TestPolicyBernoulli(randSeed, nArms, new UCBNormalPolicy());
50      Console.WriteLine("Eps(0.01)");
51      TestPolicyBernoulli(randSeed, nArms, new EpsGreedyPolicy(0.01));
52      Console.WriteLine("Eps(0.05)");
53      TestPolicyBernoulli(randSeed, nArms, new EpsGreedyPolicy(0.05));
54      //Console.WriteLine("Eps(0.1)");
55      //TestPolicyBernoulli(randSeed, nArms, new EpsGreedyPolicy(0.1));
56      //Console.WriteLine("Eps(0.2)");
57      //TestPolicyBernoulli(randSeed, nArms, new EpsGreedyPolicy(0.2));
58      //Console.WriteLine("Eps(0.5)");
59      //TestPolicyBernoulli(randSeed, nArms, new EpsGreedyPolicy(0.5));
60      Console.WriteLine("UCT(0.1)"); TestPolicyBernoulli(randSeed, nArms, new UCTPolicy(0.1));
61      Console.WriteLine("UCT(0.5)"); TestPolicyBernoulli(randSeed, nArms, new UCTPolicy(0.5));
62      Console.WriteLine("UCT(1)  "); TestPolicyBernoulli(randSeed, nArms, new UCTPolicy(1));
63      Console.WriteLine("UCT(2)  "); TestPolicyBernoulli(randSeed, nArms, new UCTPolicy(2));
64      Console.WriteLine("UCT(5)  "); TestPolicyBernoulli(randSeed, nArms, new UCTPolicy(5));
65      Console.WriteLine("BoltzmannExploration(0.1)"); TestPolicyBernoulli(randSeed, nArms, new BoltzmannExplorationPolicy(0.1));
66      Console.WriteLine("BoltzmannExploration(0.5)"); TestPolicyBernoulli(randSeed, nArms, new BoltzmannExplorationPolicy(0.5));
67      Console.WriteLine("BoltzmannExploration(1)  "); TestPolicyBernoulli(randSeed, nArms, new BoltzmannExplorationPolicy(1));
68      Console.WriteLine("BoltzmannExploration(10) "); TestPolicyBernoulli(randSeed, nArms, new BoltzmannExplorationPolicy(10));
69      Console.WriteLine("BoltzmannExploration(100)"); TestPolicyBernoulli(randSeed, nArms, new BoltzmannExplorationPolicy(100));
70      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.01)"); TestPolicyBernoulli(randSeed, nArms, new ChernoffIntervalEstimationPolicy(0.01));
71      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.05)"); TestPolicyBernoulli(randSeed, nArms, new ChernoffIntervalEstimationPolicy(0.05));
72      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.1) "); TestPolicyBernoulli(randSeed, nArms, new ChernoffIntervalEstimationPolicy(0.1));
73
74      // not applicable to bernoulli rewards
75      //Console.WriteLine("ThresholdAscent(10, 0.01)  "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 10, 0.01));
76      //Console.WriteLine("ThresholdAscent(10, 0.05)  "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 10, 0.05));
77      //Console.WriteLine("ThresholdAscent(10, 0.1)   "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 10, 0.1));
78      //Console.WriteLine("ThresholdAscent(100, 0.01) "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 100, 0.01));
79      //Console.WriteLine("ThresholdAscent(100, 0.05) "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 100, 0.05));
80      //Console.WriteLine("ThresholdAscent(100, 0.1)  "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 100, 0.1));
81      //Console.WriteLine("ThresholdAscent(1000, 0.01)"); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.01));
82      //Console.WriteLine("ThresholdAscent(1000, 0.05)"); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.05));
83      //Console.WriteLine("ThresholdAscent(1000, 0.1) "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.1));
84    }
85
86    [TestMethod]
87    public void ComparePoliciesForGaussianBandit() {
88      CultureInfo.DefaultThreadCurrentCulture = CultureInfo.InvariantCulture;
89
90      var randSeed = 31415;
91      var nArms = 20;
92      Console.WriteLine("Threshold Ascent (20)"); TestPolicyGaussian(randSeed, nArms, new ThresholdAscentPolicy(20, 0.01));
93      Console.WriteLine("Threshold Ascent (100)"); TestPolicyGaussian(randSeed, nArms, new ThresholdAscentPolicy(100, 0.01));
94      Console.WriteLine("Threshold Ascent (500)"); TestPolicyGaussian(randSeed, nArms, new ThresholdAscentPolicy(500, 0.01));
95      Console.WriteLine("Threshold Ascent (1000)"); TestPolicyGaussian(randSeed, nArms, new ThresholdAscentPolicy(1000, 0.01));
96      Console.WriteLine("Generic Thompson (Gaussian fixed var)"); TestPolicyGaussian(randSeed, nArms, new GenericThompsonSamplingPolicy(new GaussianModel(0.5, 1)));
97      Console.WriteLine("Generic Thompson (Gaussian unknown var)"); TestPolicyGaussian(randSeed, nArms, new GenericThompsonSamplingPolicy(new GaussianModel(0.5, 1, 1, 1)));
98      Console.WriteLine("Thompson (Gaussian orig)"); TestPolicyGaussian(randSeed, nArms, new GaussianThompsonSamplingPolicy(true));
99      Console.WriteLine("Thompson (Gaussian new)"); TestPolicyGaussian(randSeed, nArms, new GaussianThompsonSamplingPolicy());
100
101      /*
102      Console.WriteLine("Random"); TestPolicyNormal(randSeed, nArms, new RandomPolicy(new Random(seedForPolicy), nArms));
103      Console.WriteLine("UCB1"); TestPolicyNormal(randSeed, nArms, new UCB1Policy(nArms));
104      Console.WriteLine("UCB1Tuned"); TestPolicyNormal(randSeed, nArms, new UCB1TunedPolicy(nArms));
105      Console.WriteLine("UCB1Normal"); TestPolicyNormal(randSeed, nArms, new UCBNormalPolicy(nArms));
106      //Console.WriteLine("Exp3 (gamma=0.01)");
107      //TestPolicyNormal(randSeed, nArms, new Exp3Policy(new Random(seedForPolicy), nArms, 0.01));
108      //Console.WriteLine("Exp3 (gamma=0.05)");
109      //TestPolicyNormal(randSeed, nArms, new Exp3Policy(new Random(seedForPolicy), nArms, 0.05));
110      Console.WriteLine("Eps(0.01)"); TestPolicyNormal(randSeed, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.01));
111      Console.WriteLine("Eps(0.05)"); TestPolicyNormal(randSeed, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.05));
112      //Console.WriteLine("Eps(0.1)");
113      //TestPolicyNormal(randSeed, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.1));
114      //Console.WriteLine("Eps(0.2)");
115      //TestPolicyNormal(randSeed, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.2));
116      //Console.WriteLine("Eps(0.5)");
117      //TestPolicyNormal(randSeed, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.5));
118      Console.WriteLine("UCT(0.1)"); TestPolicyNormal(randSeed, nArms, new UCTPolicy(nArms, 0.1));
119      Console.WriteLine("UCT(0.5)"); TestPolicyNormal(randSeed, nArms, new UCTPolicy(nArms, 0.5));
120      Console.WriteLine("UCT(1)  "); TestPolicyNormal(randSeed, nArms, new UCTPolicy(nArms, 1));
121      Console.WriteLine("UCT(2)  "); TestPolicyNormal(randSeed, nArms, new UCTPolicy(nArms, 2));
122      Console.WriteLine("UCT(5)  "); TestPolicyNormal(randSeed, nArms, new UCTPolicy(nArms, 5));
123      Console.WriteLine("BoltzmannExploration(0.1)"); TestPolicyNormal(randSeed, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 0.1));
124      Console.WriteLine("BoltzmannExploration(0.5)"); TestPolicyNormal(randSeed, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 0.5));
125      Console.WriteLine("BoltzmannExploration(1)  "); TestPolicyNormal(randSeed, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 1));
126      Console.WriteLine("BoltzmannExploration(10) "); TestPolicyNormal(randSeed, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 10));
127      Console.WriteLine("BoltzmannExploration(100)"); TestPolicyNormal(randSeed, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 100));
128      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.01)"); TestPolicyNormal(randSeed, nArms, new ChernoffIntervalEstimationPolicy(nArms, 0.01));
129      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.05)"); TestPolicyNormal(randSeed, nArms, new ChernoffIntervalEstimationPolicy(nArms, 0.05));
130      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.1) "); TestPolicyNormal(randSeed, nArms, new ChernoffIntervalEstimationPolicy(nArms, 0.1));     
131      Console.WriteLine("ThresholdAscent(10,0.01)  "); TestPolicyNormal(randSeed, nArms, new ThresholdAscentPolicy(nArms, 10, 0.01));
132      Console.WriteLine("ThresholdAscent(10,0.05)  "); TestPolicyNormal(randSeed, nArms, new ThresholdAscentPolicy(nArms, 10, 0.05));
133      Console.WriteLine("ThresholdAscent(10,0.1)   "); TestPolicyNormal(randSeed, nArms, new ThresholdAscentPolicy(nArms, 10, 0.1));
134      Console.WriteLine("ThresholdAscent(100,0.01) "); TestPolicyNormal(randSeed, nArms, new ThresholdAscentPolicy(nArms, 100, 0.01));
135      Console.WriteLine("ThresholdAscent(100,0.05) "); TestPolicyNormal(randSeed, nArms, new ThresholdAscentPolicy(nArms, 100, 0.05));
136      Console.WriteLine("ThresholdAscent(100,0.1)  "); TestPolicyNormal(randSeed, nArms, new ThresholdAscentPolicy(nArms, 100, 0.1));
137      Console.WriteLine("ThresholdAscent(1000,0.01)"); TestPolicyNormal(randSeed, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.01));
138      Console.WriteLine("ThresholdAscent(1000,0.05)"); TestPolicyNormal(randSeed, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.05));
139      Console.WriteLine("ThresholdAscent(1000,0.1) "); TestPolicyNormal(randSeed, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.1));
140       */
141    }
142
143    [TestMethod]
144    public void ComparePoliciesForGaussianMixtureBandit() {
145      CultureInfo.DefaultThreadCurrentCulture = CultureInfo.InvariantCulture;
146      var randSeed = 31415;
147      var nArms = 20;
148      Console.WriteLine("Threshold Ascent (20)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(20, 0.01));
149      Console.WriteLine("Threshold Ascent (100)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(100, 0.01));
150      Console.WriteLine("Threshold Ascent (500)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(500, 0.01));
151      Console.WriteLine("Threshold Ascent (1000)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(1000, 0.01));
152      Console.WriteLine("Thompson (Gaussian orig)"); TestPolicyGaussianMixture(randSeed, nArms, new GaussianThompsonSamplingPolicy(true));
153      Console.WriteLine("Thompson (Gaussian new)"); TestPolicyGaussianMixture(randSeed, nArms, new GaussianThompsonSamplingPolicy());
154      Console.WriteLine("Generic Thompson (Gaussian fixed variance)"); TestPolicyGaussianMixture(randSeed, nArms, new GenericThompsonSamplingPolicy(new GaussianModel(0.5, 1, 0.1)));
155      Console.WriteLine("Generic Thompson (Gaussian unknown variance)"); TestPolicyGaussianMixture(randSeed, nArms, new GenericThompsonSamplingPolicy(new GaussianModel(0.5, 1, 1, 1)));
156
157      /*
158      Console.WriteLine("Random"); TestPolicyGaussianMixture(randSeed, nArms, new RandomPolicy(new Random(seedForPolicy), nArms));
159      Console.WriteLine("UCB1"); TestPolicyGaussianMixture(randSeed, nArms, new UCB1Policy(nArms));
160      Console.WriteLine("UCB1Tuned "); TestPolicyGaussianMixture(randSeed, nArms, new UCB1TunedPolicy(nArms));
161      Console.WriteLine("UCB1Normal"); TestPolicyGaussianMixture(randSeed, nArms, new UCBNormalPolicy(nArms));
162      Console.WriteLine("Eps(0.01) "); TestPolicyGaussianMixture(randSeed, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.01));
163      Console.WriteLine("Eps(0.05) "); TestPolicyGaussianMixture(randSeed, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.05));
164      Console.WriteLine("UCT(1)  "); TestPolicyGaussianMixture(randSeed, nArms, new UCTPolicy(nArms, 1));
165      Console.WriteLine("UCT(2)  "); TestPolicyGaussianMixture(randSeed, nArms, new UCTPolicy(nArms, 2));
166      Console.WriteLine("UCT(5)  "); TestPolicyGaussianMixture(randSeed, nArms, new UCTPolicy(nArms, 5));
167      Console.WriteLine("BoltzmannExploration(1)  "); TestPolicyGaussianMixture(randSeed, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 1));
168      Console.WriteLine("BoltzmannExploration(10) "); TestPolicyGaussianMixture(randSeed, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 10));
169      Console.WriteLine("BoltzmannExploration(100)"); TestPolicyGaussianMixture(randSeed, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 100));
170
171      Console.WriteLine("ThresholdAscent(10,0.01)  "); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(nArms, 10, 0.01));
172      Console.WriteLine("ThresholdAscent(100,0.01) "); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(nArms, 100, 0.01));
173      Console.WriteLine("ThresholdAscent(1000,0.01)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.01));
174      Console.WriteLine("ThresholdAscent(10000,0.01)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(nArms, 10000, 0.01));
175       */
176    }
177
178
179    private void TestPolicyBernoulli(int randSeed, int nArms, IBanditPolicy policy) {
180      TestPolicy(randSeed, nArms, policy, (banditRandom, nActions) => new BernoulliBandit(banditRandom, nActions));
181    }
182    private void TestPolicyGaussian(int randSeed, int nArms, IBanditPolicy policy) {
183      TestPolicy(randSeed, nArms, policy, (banditRandom, nActions) => new TruncatedNormalBandit(banditRandom, nActions));
184    }
185    private void TestPolicyGaussianMixture(int randSeed, int nArms, IBanditPolicy policy) {
186      TestPolicy(randSeed, nArms, policy, (banditRandom, nActions) => new GaussianMixtureBandit(banditRandom, nActions));
187    }
188    private void TestPolicyGaussianUnknownVariance(int randSeed, int nArms, IBanditPolicy policy) {
189      TestPolicy(randSeed, nArms, policy, (banditRandom, nActions) => new GaussianBandit(banditRandom, nActions));
190    }
191
192
193    private void TestPolicy(int randSeed, int nArms, IBanditPolicy policy, Func<Random, int, IBandit> banditFactory) {
194      var maxIt = 1E5;
195      var reps = 10; // independent runs
196      var regretForIteration = new Dictionary<int, List<double>>();
197      var numberOfPullsOfSuboptimalArmsForExp = new Dictionary<int, double>();
198      var numberOfPullsOfSuboptimalArmsForMax = new Dictionary<int, double>();
199      var globalRandom = new Random(randSeed);
200      var banditRandom = new Random(globalRandom.Next()); // bandits must produce the same rewards for each test
201      var policyRandom = new Random(globalRandom.Next());
202
203      // calculate statistics
204      for (int r = 0; r < reps; r++) {
205        var nextLogStep = 1;
206        var b = banditFactory(banditRandom, nArms);
207        var totalRegret = 0.0;
208        var totalPullsOfSuboptimalArmsExp = 0.0;
209        var totalPullsOfSuboptimalArmsMax = 0.0;
210        var actionInfos = Enumerable.Range(0, nArms).Select(_ => policy.CreateActionInfo()).ToArray();
211        for (int i = 0; i <= maxIt; i++) {
212          var selectedAction = policy.SelectAction(policyRandom, actionInfos);
213          var reward = b.Pull(selectedAction);
214          actionInfos[selectedAction].UpdateReward(reward);
215
216          // collect stats
217          if (selectedAction != b.OptimalExpectedRewardArm) totalPullsOfSuboptimalArmsExp++;
218          if (selectedAction != b.OptimalMaximalRewardArm) totalPullsOfSuboptimalArmsMax++;
219          totalRegret += b.OptimalExpectedReward - reward;
220
221          if (i == nextLogStep) {
222            nextLogStep *= 2;
223            if (!regretForIteration.ContainsKey(i)) {
224              regretForIteration.Add(i, new List<double>());
225            }
226            regretForIteration[i].Add(totalRegret / i);
227
228            if (!numberOfPullsOfSuboptimalArmsForExp.ContainsKey(i)) {
229              numberOfPullsOfSuboptimalArmsForExp.Add(i, 0.0);
230            }
231            numberOfPullsOfSuboptimalArmsForExp[i] += totalPullsOfSuboptimalArmsExp;
232
233            if (!numberOfPullsOfSuboptimalArmsForMax.ContainsKey(i)) {
234              numberOfPullsOfSuboptimalArmsForMax.Add(i, 0.0);
235            }
236            numberOfPullsOfSuboptimalArmsForMax[i] += totalPullsOfSuboptimalArmsMax;
237          }
238        }
239      }
240      // print
241      foreach (var p in regretForIteration.Keys.OrderBy(k => k)) {
242        Console.WriteLine("iter {0,8} regret avg {1,7:F5} min {2,7:F5} max {3,7:F5} suboptimal pulls (exp) {4,7:F2} suboptimal pulls (max) {5,7:F2}",
243          p,
244          regretForIteration[p].Average(),
245          regretForIteration[p].Min(),
246          regretForIteration[p].Max(),
247          numberOfPullsOfSuboptimalArmsForExp[p] / (double)reps,
248          numberOfPullsOfSuboptimalArmsForMax[p] / (double)reps
249          );
250      }
251    }
252
253  }
254}
Note: See TracBrowser for help on using the repository browser.