[11727] | 1 | using System;
|
---|
[11732] | 2 | using System.Collections.Generic;
|
---|
[11727] | 3 | using System.Diagnostics;
|
---|
| 4 | using System.Linq;
|
---|
| 5 | using HeuristicLab.Common;
|
---|
| 6 |
|
---|
| 7 | namespace HeuristicLab.Algorithms.Bandits {
|
---|
[11732] | 8 |
|
---|
| 9 | public class GaussianThompsonSamplingPolicy : IPolicy {
|
---|
[11730] | 10 | private bool compatibility;
|
---|
| 11 |
|
---|
| 12 | // assumes a Gaussian reward distribution with different means but the same variances for each action
|
---|
| 13 | // the prior for the mean is also Gaussian with the following parameters
|
---|
| 14 | private readonly double rewardVariance = 0.1; // we assume a known variance
|
---|
| 15 |
|
---|
| 16 | private readonly double priorMean = 0.5;
|
---|
| 17 | private readonly double priorVariance = 1;
|
---|
| 18 |
|
---|
| 19 |
|
---|
[11732] | 20 | public GaussianThompsonSamplingPolicy(bool compatibility = false) {
|
---|
[11730] | 21 | this.compatibility = compatibility;
|
---|
[11727] | 22 | }
|
---|
| 23 |
|
---|
[11732] | 24 | public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) {
|
---|
| 25 | var myActionInfos = actionInfos.OfType<MeanAndVariancePolicyActionInfo>();
|
---|
| 26 | int bestAction = -1;
|
---|
| 27 | double bestQ = double.NegativeInfinity;
|
---|
[11727] | 28 |
|
---|
[11732] | 29 | int aIdx = -1;
|
---|
| 30 | foreach (var aInfo in myActionInfos) {
|
---|
| 31 | aIdx++;
|
---|
| 32 | if (aInfo.Disabled) continue;
|
---|
| 33 |
|
---|
| 34 | var tries = aInfo.Tries;
|
---|
| 35 | var sampleMean = aInfo.AvgReward;
|
---|
| 36 | var sampleVariance = aInfo.RewardVariance;
|
---|
| 37 |
|
---|
[11730] | 38 | double theta;
|
---|
| 39 | if (compatibility) {
|
---|
[11732] | 40 | if (tries < 2) return aIdx;
|
---|
| 41 | var mu = sampleMean;
|
---|
| 42 | var variance = sampleVariance;
|
---|
[11730] | 43 | var stdDev = Math.Sqrt(variance);
|
---|
| 44 | theta = Rand.RandNormal(random) * stdDev + mu;
|
---|
| 45 | } else {
|
---|
| 46 | // calculate posterior mean and variance (for mean reward)
|
---|
| 47 |
|
---|
| 48 | // see Murphy 2007: Conjugate Bayesian analysis of the Gaussian distribution (http://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf)
|
---|
[11732] | 49 | var posteriorVariance = 1.0 / (tries / rewardVariance + 1.0 / priorVariance);
|
---|
| 50 | var posteriorMean = posteriorVariance * (priorMean / priorVariance + tries * sampleMean / rewardVariance);
|
---|
[11730] | 51 |
|
---|
| 52 | // sample a mean from the posterior
|
---|
| 53 | theta = Rand.RandNormal(random) * Math.Sqrt(posteriorVariance) + posteriorMean;
|
---|
| 54 |
|
---|
| 55 | // theta already represents the expected reward value => nothing else to do
|
---|
| 56 | }
|
---|
[11732] | 57 |
|
---|
| 58 | if (theta > bestQ) {
|
---|
| 59 | bestQ = theta;
|
---|
| 60 | bestAction = aIdx;
|
---|
[11727] | 61 | }
|
---|
| 62 | }
|
---|
[11732] | 63 | Debug.Assert(bestAction > -1);
|
---|
[11727] | 64 | return bestAction;
|
---|
| 65 | }
|
---|
| 66 |
|
---|
[11732] | 67 | public IPolicyActionInfo CreateActionInfo() {
|
---|
| 68 | return new MeanAndVariancePolicyActionInfo();
|
---|
[11727] | 69 | }
|
---|
| 70 |
|
---|
| 71 |
|
---|
[11732] | 72 | //public override void UpdateReward(int action, double reward) {
|
---|
| 73 | // Debug.Assert(Actions.Contains(action));
|
---|
| 74 | // tries[action]++;
|
---|
| 75 | // var delta = reward - sampleMean[action];
|
---|
| 76 | // sampleMean[action] += delta / tries[action];
|
---|
| 77 | // sampleM2[action] += sampleM2[action] + delta * (reward - sampleMean[action]);
|
---|
| 78 | //}
|
---|
[11730] | 79 |
|
---|
| 80 | public override string ToString() {
|
---|
| 81 | return "GaussianThompsonSamplingPolicy";
|
---|
| 82 | }
|
---|
[11727] | 83 | }
|
---|
| 84 | }
|
---|