[11730] | 1 | using System;
|
---|
| 2 | using System.Collections.Generic;
|
---|
| 3 | using System.Diagnostics;
|
---|
| 4 | using System.Linq;
|
---|
| 5 | using System.Text;
|
---|
| 6 | using System.Threading.Tasks;
|
---|
| 7 | using HeuristicLab.Common;
|
---|
| 8 |
|
---|
| 9 | namespace HeuristicLab.Algorithms.Bandits.Models {
|
---|
| 10 | public class BernoulliModel : IModel {
|
---|
| 11 | private readonly int numActions;
|
---|
| 12 | private readonly int[] success;
|
---|
| 13 | private readonly int[] failure;
|
---|
| 14 |
|
---|
| 15 | // parameters of beta prior distribution
|
---|
| 16 | private readonly double alpha;
|
---|
| 17 | private readonly double beta;
|
---|
| 18 |
|
---|
| 19 | public BernoulliModel(int numActions, double alpha = 1.0, double beta = 1.0) {
|
---|
| 20 | this.numActions = numActions;
|
---|
| 21 | this.success = new int[numActions];
|
---|
| 22 | this.failure = new int[numActions];
|
---|
| 23 | this.alpha = alpha;
|
---|
| 24 | this.beta = beta;
|
---|
| 25 | }
|
---|
| 26 |
|
---|
| 27 |
|
---|
| 28 | public double[] SampleExpectedRewards(Random random) {
|
---|
| 29 | // sample bernoulli mean from beta prior
|
---|
| 30 | var theta = new double[numActions];
|
---|
| 31 | for (int a = 0; a < numActions; a++) {
|
---|
| 32 | if (success[a] == -1)
|
---|
| 33 | theta[a] = 0.0;
|
---|
| 34 | else {
|
---|
| 35 | theta[a] = Rand.BetaRand(random, success[a] + alpha, failure[a] + beta);
|
---|
| 36 | }
|
---|
| 37 | }
|
---|
| 38 |
|
---|
| 39 | // no need to sample we know the exact expected value
|
---|
| 40 | // the expected value of a bernoulli variable is just theta
|
---|
| 41 | return theta.Select(t => t).ToArray();
|
---|
| 42 | }
|
---|
| 43 |
|
---|
| 44 | public void Update(int action, double reward) {
|
---|
| 45 | const double EPSILON = 1E-6;
|
---|
| 46 | Debug.Assert(Math.Abs(reward - 0.0) < EPSILON || Math.Abs(reward - 1.0) < EPSILON);
|
---|
| 47 | if (Math.Abs(reward - 1.0) < EPSILON) {
|
---|
| 48 | success[action]++;
|
---|
| 49 | } else {
|
---|
| 50 | failure[action]++;
|
---|
| 51 | }
|
---|
| 52 | }
|
---|
| 53 |
|
---|
| 54 | public void Disable(int action) {
|
---|
| 55 | success[action] = -1;
|
---|
| 56 | }
|
---|
| 57 |
|
---|
| 58 | public void Reset() {
|
---|
| 59 | Array.Clear(success, 0, numActions);
|
---|
| 60 | Array.Clear(failure, 0, numActions);
|
---|
| 61 | }
|
---|
| 62 |
|
---|
| 63 | public void PrintStats() {
|
---|
| 64 | for (int i = 0; i < numActions; i++) {
|
---|
| 65 | Console.Write("{0:F2} ", success[i] / (double)failure[i]);
|
---|
| 66 | }
|
---|
| 67 | }
|
---|
| 68 | }
|
---|
| 69 | }
|
---|