using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Text; using System.Threading.Tasks; using HeuristicLab.Common; namespace HeuristicLab.Algorithms.Bandits.Models { public class BernoulliModel : IModel { private readonly int numActions; private readonly int[] success; private readonly int[] failure; // parameters of beta prior distribution private readonly double alpha; private readonly double beta; public BernoulliModel(int numActions, double alpha = 1.0, double beta = 1.0) { this.numActions = numActions; this.success = new int[numActions]; this.failure = new int[numActions]; this.alpha = alpha; this.beta = beta; } public double[] SampleExpectedRewards(Random random) { // sample bernoulli mean from beta prior var theta = new double[numActions]; for (int a = 0; a < numActions; a++) { if (success[a] == -1) theta[a] = 0.0; else { theta[a] = Rand.BetaRand(random, success[a] + alpha, failure[a] + beta); } } // no need to sample we know the exact expected value // the expected value of a bernoulli variable is just theta return theta.Select(t => t).ToArray(); } public void Update(int action, double reward) { const double EPSILON = 1E-6; Debug.Assert(Math.Abs(reward - 0.0) < EPSILON || Math.Abs(reward - 1.0) < EPSILON); if (Math.Abs(reward - 1.0) < EPSILON) { success[action]++; } else { failure[action]++; } } public void Disable(int action) { success[action] = -1; } public void Reset() { Array.Clear(success, 0, numActions); Array.Clear(failure, 0, numActions); } public void PrintStats() { for (int i = 0; i < numActions; i++) { Console.Write("{0:F2} ", success[i] / (double)failure[i]); } } } }