using System; using System.Diagnostics; using System.Linq; using HeuristicLab.Common; namespace HeuristicLab.Algorithms.Bandits { public class GaussianThompsonSamplingPolicy : BanditPolicy { private readonly Random random; private readonly double[] sampleMean; private readonly double[] sampleM2; private readonly int[] tries; private bool compatibility; // assumes a Gaussian reward distribution with different means but the same variances for each action // the prior for the mean is also Gaussian with the following parameters private readonly double rewardVariance = 0.1; // we assume a known variance private readonly double priorMean = 0.5; private readonly double priorVariance = 1; public GaussianThompsonSamplingPolicy(Random random, int numActions, bool compatibility = false) : base(numActions) { this.random = random; this.sampleMean = new double[numActions]; this.sampleM2 = new double[numActions]; this.tries = new int[numActions]; this.compatibility = compatibility; } public override int SelectAction() { Debug.Assert(Actions.Any()); var maxTheta = double.NegativeInfinity; int bestAction = -1; foreach (var a in Actions) { if(tries[a] == -1) continue; // skip disabled actions double theta; if (compatibility) { if (tries[a] < 2) return a; var mu = sampleMean[a]; var variance = sampleM2[a] / tries[a]; var stdDev = Math.Sqrt(variance); theta = Rand.RandNormal(random) * stdDev + mu; } else { // calculate posterior mean and variance (for mean reward) // see Murphy 2007: Conjugate Bayesian analysis of the Gaussian distribution (http://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf) var posteriorVariance = 1.0 / (tries[a] / rewardVariance + 1.0 / priorVariance); var posteriorMean = posteriorVariance * (priorMean / priorVariance + tries[a] * sampleMean[a] / rewardVariance); // sample a mean from the posterior theta = Rand.RandNormal(random) * Math.Sqrt(posteriorVariance) + posteriorMean; // theta already represents the expected reward value => nothing else to do } if (theta > maxTheta) { maxTheta = theta; bestAction = a; } } Debug.Assert(Actions.Contains(bestAction)); return bestAction; } public override void UpdateReward(int action, double reward) { Debug.Assert(Actions.Contains(action)); tries[action]++; var delta = reward - sampleMean[action]; sampleMean[action] += delta / tries[action]; sampleM2[action] += sampleM2[action] + delta * (reward - sampleMean[action]); } public override void DisableAction(int action) { base.DisableAction(action); sampleMean[action] = 0; sampleM2[action] = 0; tries[action] = -1; } public override void Reset() { base.Reset(); Array.Clear(sampleMean, 0, sampleMean.Length); Array.Clear(sampleM2, 0, sampleM2.Length); Array.Clear(tries, 0, tries.Length); } public override void PrintStats() { for (int i = 0; i < sampleMean.Length; i++) { if (tries[i] >= 0) { Console.Write(" {0,5:F2} {1}", sampleMean[i] / tries[i], tries[i]); } else { Console.Write("{0,5}", ""); } } Console.WriteLine(); } public override string ToString() { return "GaussianThompsonSamplingPolicy"; } } }