using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Text; using System.Threading.Tasks; using HeuristicLab.Common; namespace HeuristicLab.Algorithms.Bandits { public class GenericThompsonSamplingPolicy : BanditPolicy { private readonly Random random; private readonly IModel model; public GenericThompsonSamplingPolicy(Random random, int numActions, IModel model) : base(numActions) { this.random = random; this.model = model; } public override int SelectAction() { Debug.Assert(Actions.Any()); var maxR = double.NegativeInfinity; int bestAction = -1; var expRewards = model.SampleExpectedRewards(random); foreach (var a in Actions) { var r = expRewards[a]; if (r > maxR) { maxR = r; bestAction = a; } } return bestAction; } public override void UpdateReward(int action, double reward) { Debug.Assert(Actions.Contains(action)); model.Update(action, reward); } public override void DisableAction(int action) { base.DisableAction(action); model.Disable(action); } public override void Reset() { base.Reset(); model.Reset(); } public override void PrintStats() { model.PrintStats(); } public override string ToString() { return string.Format("GenericThompsonSamplingPolicy({0})", model); } } }