using System; using System.Diagnostics; using System.Linq; using HeuristicLab.Common; namespace HeuristicLab.Algorithms.Bandits { public class GaussianThompsonSamplingPolicy : BanditPolicy { private readonly Random random; private readonly double[] sumRewards; private readonly double[] sumSqrRewards; private readonly int[] tries; public GaussianThompsonSamplingPolicy(Random random, int numActions) : base(numActions) { this.random = random; this.sumRewards = new double[numActions]; this.sumSqrRewards = new double[numActions]; this.tries = new int[numActions]; } public override int SelectAction() { Debug.Assert(Actions.Any()); var maxTheta = double.NegativeInfinity; int bestAction = -1; foreach (var a in Actions) { if (tries[a] == 0) return a; var mu = sumRewards[a] / tries[a]; var stdDev = Math.Sqrt(sumSqrRewards[a] / tries[a] - Math.Pow(mu, 2)); var theta = Rand.RandNormal(random) * stdDev + mu; if (theta > maxTheta) { maxTheta = theta; bestAction = a; } } return bestAction; } public override void UpdateReward(int action, double reward) { Debug.Assert(Actions.Contains(action)); sumRewards[action] += reward; sumSqrRewards[action] += reward * reward; tries[action]++; } public override void DisableAction(int action) { base.DisableAction(action); sumRewards[action] = 0; sumSqrRewards[action] = 0; tries[action] = -1; } public override void Reset() { base.Reset(); Array.Clear(sumRewards, 0, sumRewards.Length); Array.Clear(sumSqrRewards, 0, sumSqrRewards.Length); Array.Clear(tries, 0, tries.Length); } } }