using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Text; using System.Threading.Tasks; using HeuristicLab.Common; namespace HeuristicLab.Algorithms.Bandits.BanditPolicies { public class BernoulliThompsonSamplingPolicy : IBanditPolicy { // parameters of beta prior distribution private readonly double alpha = 1.0; private readonly double beta = 1.0; public int SelectAction(Random random, IEnumerable actionInfos) { var myActionInfos = actionInfos.OfType(); int bestAction = -1; double maxTheta = double.NegativeInfinity; var aIdx = -1; foreach (var aInfo in myActionInfos) { aIdx++; if (aInfo.Disabled) continue; var theta = Rand.BetaRand(random, aInfo.NumSuccess + alpha, aInfo.NumFailure + beta); if (theta > maxTheta) { maxTheta = theta; bestAction = aIdx; } } Debug.Assert(bestAction > -1); return bestAction; } public IBanditPolicyActionInfo CreateActionInfo() { return new BernoulliPolicyActionInfo(); } public override string ToString() { return "BernoulliThompsonSamplingPolicy"; } } }