1  using System;


2  using System.Collections.Generic;


3  using System.Diagnostics;


4  using System.Linq;


5  using System.Text;


6  using System.Threading.Tasks;


7  using HeuristicLab.Common;


8 


9  namespace HeuristicLab.Algorithms.Bandits.Models {


10  public class BernoulliModel : IModel {


11  private readonly int numActions;


12  private readonly int[] success;


13  private readonly int[] failure;


14 


15  // parameters of beta prior distribution


16  private readonly double alpha;


17  private readonly double beta;


18 


19  public BernoulliModel(int numActions, double alpha = 1.0, double beta = 1.0) {


20  this.numActions = numActions;


21  this.success = new int[numActions];


22  this.failure = new int[numActions];


23  this.alpha = alpha;


24  this.beta = beta;


25  }


26 


27 


28  public double[] SampleExpectedRewards(Random random) {


29  // sample bernoulli mean from beta prior


30  var theta = new double[numActions];


31  for (int a = 0; a < numActions; a++) {


32  if (success[a] == 1)


33  theta[a] = 0.0;


34  else {


35  theta[a] = Rand.BetaRand(random, success[a] + alpha, failure[a] + beta);


36  }


37  }


38 


39  // no need to sample we know the exact expected value


40  // the expected value of a bernoulli variable is just theta


41  return theta.Select(t => t).ToArray();


42  }


43 


44  public void Update(int action, double reward) {


45  const double EPSILON = 1E6;


46  Debug.Assert(Math.Abs(reward  0.0) < EPSILON  Math.Abs(reward  1.0) < EPSILON);


47  if (Math.Abs(reward  1.0) < EPSILON) {


48  success[action]++;


49  } else {


50  failure[action]++;


51  }


52  }


53 


54  public void Disable(int action) {


55  success[action] = 1;


56  }


57 


58  public void Reset() {


59  Array.Clear(success, 0, numActions);


60  Array.Clear(failure, 0, numActions);


61  }


62 


63  public void PrintStats() {


64  for (int i = 0; i < numActions; i++) {


65  Console.Write("{0:F2} ", success[i] / (double)failure[i]);


66  }


67  }


68  }


69  }

