using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Text; using System.Threading.Tasks; using HeuristicLab.Common; namespace HeuristicLab.Algorithms.Bandits { public class BernoulliThompsonSamplingPolicy : BanditPolicy { private readonly Random random; private readonly int[] success; private readonly int[] failure; // parameters of beta prior distribution private readonly double alpha = 1.0; private readonly double beta = 1.0; public BernoulliThompsonSamplingPolicy(Random random, int numActions) : base(numActions) { this.random = random; this.success = new int[numActions]; this.failure = new int[numActions]; } public override int SelectAction() { Debug.Assert(Actions.Any()); var maxTheta = double.NegativeInfinity; int bestAction = -1; foreach (var a in Actions) { var theta = Rand.BetaRand(random, success[a] + alpha, failure[a] + beta); if (theta > maxTheta) { maxTheta = theta; bestAction = a; } } return bestAction; } public override void UpdateReward(int action, double reward) { Debug.Assert(Actions.Contains(action)); if (reward > 0) success[action]++; else failure[action]++; } public override void DisableAction(int action) { base.DisableAction(action); success[action] = -1; } public override void Reset() { base.Reset(); Array.Clear(success, 0, success.Length); Array.Clear(failure, 0, failure.Length); } } }