[11727] | 1 | using System;
|
---|
[11732] | 2 | using System.Collections.Generic;
|
---|
[11727] | 3 | using System.Diagnostics;
|
---|
| 4 | using System.Linq;
|
---|
| 5 | using HeuristicLab.Common;
|
---|
| 6 |
|
---|
[11742] | 7 | namespace HeuristicLab.Algorithms.Bandits.BanditPolicies {
|
---|
[11732] | 8 |
|
---|
[11742] | 9 | [Obsolete("Replaced by GenericThompsonSamplingPolicy(GaussianModel(0.5, 1.0, 0.1))")]
|
---|
| 10 | public class GaussianThompsonSamplingPolicy : IBanditPolicy {
|
---|
[11730] | 11 | private bool compatibility;
|
---|
| 12 |
|
---|
| 13 | // assumes a Gaussian reward distribution with different means but the same variances for each action
|
---|
| 14 | // the prior for the mean is also Gaussian with the following parameters
|
---|
| 15 | private readonly double rewardVariance = 0.1; // we assume a known variance
|
---|
| 16 |
|
---|
| 17 | private readonly double priorMean = 0.5;
|
---|
| 18 | private readonly double priorVariance = 1;
|
---|
| 19 |
|
---|
| 20 |
|
---|
[11732] | 21 | public GaussianThompsonSamplingPolicy(bool compatibility = false) {
|
---|
[11730] | 22 | this.compatibility = compatibility;
|
---|
[11727] | 23 | }
|
---|
| 24 |
|
---|
[11742] | 25 | public int SelectAction(Random random, IEnumerable<IBanditPolicyActionInfo> actionInfos) {
|
---|
[11732] | 26 | var myActionInfos = actionInfos.OfType<MeanAndVariancePolicyActionInfo>();
|
---|
| 27 | int bestAction = -1;
|
---|
| 28 | double bestQ = double.NegativeInfinity;
|
---|
[11727] | 29 |
|
---|
[11732] | 30 | int aIdx = -1;
|
---|
| 31 | foreach (var aInfo in myActionInfos) {
|
---|
| 32 | aIdx++;
|
---|
| 33 |
|
---|
| 34 | var tries = aInfo.Tries;
|
---|
| 35 | var sampleMean = aInfo.AvgReward;
|
---|
| 36 | var sampleVariance = aInfo.RewardVariance;
|
---|
| 37 |
|
---|
[11730] | 38 | double theta;
|
---|
| 39 | if (compatibility) {
|
---|
[11742] | 40 | // old code used for old experiments (preserved because it performed very well)
|
---|
[11732] | 41 | if (tries < 2) return aIdx;
|
---|
| 42 | var mu = sampleMean;
|
---|
| 43 | var variance = sampleVariance;
|
---|
[11730] | 44 | var stdDev = Math.Sqrt(variance);
|
---|
| 45 | theta = Rand.RandNormal(random) * stdDev + mu;
|
---|
| 46 | } else {
|
---|
| 47 | // calculate posterior mean and variance (for mean reward)
|
---|
| 48 |
|
---|
| 49 | // see Murphy 2007: Conjugate Bayesian analysis of the Gaussian distribution (http://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf)
|
---|
[11732] | 50 | var posteriorVariance = 1.0 / (tries / rewardVariance + 1.0 / priorVariance);
|
---|
| 51 | var posteriorMean = posteriorVariance * (priorMean / priorVariance + tries * sampleMean / rewardVariance);
|
---|
[11730] | 52 |
|
---|
| 53 | // sample a mean from the posterior
|
---|
| 54 | theta = Rand.RandNormal(random) * Math.Sqrt(posteriorVariance) + posteriorMean;
|
---|
| 55 |
|
---|
| 56 | // theta already represents the expected reward value => nothing else to do
|
---|
| 57 | }
|
---|
[11732] | 58 |
|
---|
[11974] | 59 | // very unlikely to be the same (and we don't care)
|
---|
[11732] | 60 | if (theta > bestQ) {
|
---|
| 61 | bestQ = theta;
|
---|
| 62 | bestAction = aIdx;
|
---|
[11727] | 63 | }
|
---|
| 64 | }
|
---|
[11732] | 65 | Debug.Assert(bestAction > -1);
|
---|
[11727] | 66 | return bestAction;
|
---|
| 67 | }
|
---|
| 68 |
|
---|
[11742] | 69 | public IBanditPolicyActionInfo CreateActionInfo() {
|
---|
[11732] | 70 | return new MeanAndVariancePolicyActionInfo();
|
---|
[11727] | 71 | }
|
---|
| 72 |
|
---|
| 73 |
|
---|
[11732] | 74 | //public override void UpdateReward(int action, double reward) {
|
---|
| 75 | // Debug.Assert(Actions.Contains(action));
|
---|
| 76 | // tries[action]++;
|
---|
| 77 | // var delta = reward - sampleMean[action];
|
---|
| 78 | // sampleMean[action] += delta / tries[action];
|
---|
| 79 | // sampleM2[action] += sampleM2[action] + delta * (reward - sampleMean[action]);
|
---|
| 80 | //}
|
---|
[11730] | 81 |
|
---|
| 82 | public override string ToString() {
|
---|
| 83 | return "GaussianThompsonSamplingPolicy";
|
---|
| 84 | }
|
---|
[11727] | 85 | }
|
---|
| 86 | }
|
---|