[11727] | 1 | using System;
|
---|
| 2 | using System.Diagnostics;
|
---|
| 3 | using System.Linq;
|
---|
| 4 | using HeuristicLab.Common;
|
---|
| 5 |
|
---|
| 6 | namespace HeuristicLab.Algorithms.Bandits {
|
---|
[11730] | 7 |
|
---|
[11727] | 8 | public class GaussianThompsonSamplingPolicy : BanditPolicy {
|
---|
| 9 | private readonly Random random;
|
---|
[11730] | 10 | private readonly double[] sampleMean;
|
---|
| 11 | private readonly double[] sampleM2;
|
---|
[11727] | 12 | private readonly int[] tries;
|
---|
[11730] | 13 | private bool compatibility;
|
---|
| 14 |
|
---|
| 15 | // assumes a Gaussian reward distribution with different means but the same variances for each action
|
---|
| 16 | // the prior for the mean is also Gaussian with the following parameters
|
---|
| 17 | private readonly double rewardVariance = 0.1; // we assume a known variance
|
---|
| 18 |
|
---|
| 19 | private readonly double priorMean = 0.5;
|
---|
| 20 | private readonly double priorVariance = 1;
|
---|
| 21 |
|
---|
| 22 |
|
---|
| 23 | public GaussianThompsonSamplingPolicy(Random random, int numActions, bool compatibility = false)
|
---|
[11727] | 24 | : base(numActions) {
|
---|
| 25 | this.random = random;
|
---|
[11730] | 26 | this.sampleMean = new double[numActions];
|
---|
| 27 | this.sampleM2 = new double[numActions];
|
---|
[11727] | 28 | this.tries = new int[numActions];
|
---|
[11730] | 29 | this.compatibility = compatibility;
|
---|
[11727] | 30 | }
|
---|
| 31 |
|
---|
| 32 |
|
---|
| 33 | public override int SelectAction() {
|
---|
| 34 | Debug.Assert(Actions.Any());
|
---|
| 35 | var maxTheta = double.NegativeInfinity;
|
---|
| 36 | int bestAction = -1;
|
---|
| 37 | foreach (var a in Actions) {
|
---|
[11730] | 38 | if(tries[a] == -1) continue; // skip disabled actions
|
---|
| 39 | double theta;
|
---|
| 40 | if (compatibility) {
|
---|
| 41 | if (tries[a] < 2) return a;
|
---|
| 42 | var mu = sampleMean[a];
|
---|
| 43 | var variance = sampleM2[a] / tries[a];
|
---|
| 44 | var stdDev = Math.Sqrt(variance);
|
---|
| 45 | theta = Rand.RandNormal(random) * stdDev + mu;
|
---|
| 46 | } else {
|
---|
| 47 | // calculate posterior mean and variance (for mean reward)
|
---|
| 48 |
|
---|
| 49 | // see Murphy 2007: Conjugate Bayesian analysis of the Gaussian distribution (http://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf)
|
---|
| 50 | var posteriorVariance = 1.0 / (tries[a] / rewardVariance + 1.0 / priorVariance);
|
---|
| 51 | var posteriorMean = posteriorVariance * (priorMean / priorVariance + tries[a] * sampleMean[a] / rewardVariance);
|
---|
| 52 |
|
---|
| 53 | // sample a mean from the posterior
|
---|
| 54 | theta = Rand.RandNormal(random) * Math.Sqrt(posteriorVariance) + posteriorMean;
|
---|
| 55 |
|
---|
| 56 | // theta already represents the expected reward value => nothing else to do
|
---|
| 57 | }
|
---|
[11727] | 58 | if (theta > maxTheta) {
|
---|
| 59 | maxTheta = theta;
|
---|
| 60 | bestAction = a;
|
---|
| 61 | }
|
---|
| 62 | }
|
---|
[11730] | 63 | Debug.Assert(Actions.Contains(bestAction));
|
---|
[11727] | 64 | return bestAction;
|
---|
| 65 | }
|
---|
| 66 |
|
---|
| 67 | public override void UpdateReward(int action, double reward) {
|
---|
| 68 | Debug.Assert(Actions.Contains(action));
|
---|
| 69 | tries[action]++;
|
---|
[11730] | 70 | var delta = reward - sampleMean[action];
|
---|
| 71 | sampleMean[action] += delta / tries[action];
|
---|
| 72 | sampleM2[action] += sampleM2[action] + delta * (reward - sampleMean[action]);
|
---|
[11727] | 73 | }
|
---|
| 74 |
|
---|
| 75 | public override void DisableAction(int action) {
|
---|
| 76 | base.DisableAction(action);
|
---|
[11730] | 77 | sampleMean[action] = 0;
|
---|
| 78 | sampleM2[action] = 0;
|
---|
[11727] | 79 | tries[action] = -1;
|
---|
| 80 | }
|
---|
| 81 |
|
---|
| 82 | public override void Reset() {
|
---|
| 83 | base.Reset();
|
---|
[11730] | 84 | Array.Clear(sampleMean, 0, sampleMean.Length);
|
---|
| 85 | Array.Clear(sampleM2, 0, sampleM2.Length);
|
---|
[11727] | 86 | Array.Clear(tries, 0, tries.Length);
|
---|
| 87 | }
|
---|
[11730] | 88 |
|
---|
| 89 | public override void PrintStats() {
|
---|
| 90 | for (int i = 0; i < sampleMean.Length; i++) {
|
---|
| 91 | if (tries[i] >= 0) {
|
---|
| 92 | Console.Write(" {0,5:F2} {1}", sampleMean[i] / tries[i], tries[i]);
|
---|
| 93 | } else {
|
---|
| 94 | Console.Write("{0,5}", "");
|
---|
| 95 | }
|
---|
| 96 | }
|
---|
| 97 | Console.WriteLine();
|
---|
| 98 | }
|
---|
| 99 | public override string ToString() {
|
---|
| 100 | return "GaussianThompsonSamplingPolicy";
|
---|
| 101 | }
|
---|
[11727] | 102 | }
|
---|
| 103 | }
|
---|