# source:branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Models/BernoulliModel.cs@11730

Last change on this file since 11730 was 11730, checked in by gkronber, 5 years ago

#2283: several major extensions for grammatical optimization

File size: 2.0 KB
Line
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
7using HeuristicLab.Common;
8
9namespace HeuristicLab.Algorithms.Bandits.Models {
10  public class BernoulliModel : IModel {
11    private readonly int numActions;
12    private readonly int[] success;
13    private readonly int[] failure;
14
15    // parameters of beta prior distribution
16    private readonly double alpha;
17    private readonly double beta;
18
19    public BernoulliModel(int numActions, double alpha = 1.0, double beta = 1.0) {
20      this.numActions = numActions;
21      this.success = new int[numActions];
22      this.failure = new int[numActions];
23      this.alpha = alpha;
24      this.beta = beta;
25    }
26
27
28    public double[] SampleExpectedRewards(Random random) {
29      // sample bernoulli mean from beta prior
30      var theta = new double[numActions];
31      for (int a = 0; a < numActions; a++) {
32        if (success[a] == -1)
33          theta[a] = 0.0;
34        else {
35          theta[a] = Rand.BetaRand(random, success[a] + alpha, failure[a] + beta);
36        }
37      }
38
39      // no need to sample we know the exact expected value
40      // the expected value of a bernoulli variable is just theta
41      return theta.Select(t => t).ToArray();
42    }
43
44    public void Update(int action, double reward) {
45      const double EPSILON = 1E-6;
46      Debug.Assert(Math.Abs(reward - 0.0) < EPSILON || Math.Abs(reward - 1.0) < EPSILON);
47      if (Math.Abs(reward - 1.0) < EPSILON) {
48        success[action]++;
49      } else {
50        failure[action]++;
51      }
52    }
53
54    public void Disable(int action) {
55      success[action] = -1;
56    }
57
58    public void Reset() {
59      Array.Clear(success, 0, numActions);
60      Array.Clear(failure, 0, numActions);
61    }
62
63    public void PrintStats() {
64      for (int i = 0; i < numActions; i++) {
65        Console.Write("{0:F2} ", success[i] / (double)failure[i]);
66      }
67    }
68  }
69}
Note: See TracBrowser for help on using the repository browser.