Changeset 11732 for branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Models/BernoulliModel.cs
- Timestamp:
- 01/07/15 09:21:46 (9 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Models/BernoulliModel.cs
r11730 r11732 9 9 namespace HeuristicLab.Algorithms.Bandits.Models { 10 10 public class BernoulliModel : IModel { 11 private readonly int numActions; 12 private readonly int[] success; 13 private readonly int[] failure; 11 private int success; 12 private int failure; 14 13 15 14 // parameters of beta prior distribution … … 17 16 private readonly double beta; 18 17 19 public BernoulliModel(int numActions, double alpha = 1.0, double beta = 1.0) { 20 this.numActions = numActions; 21 this.success = new int[numActions]; 22 this.failure = new int[numActions]; 18 public BernoulliModel(double alpha = 1.0, double beta = 1.0) { 23 19 this.alpha = alpha; 24 20 this.beta = beta; 25 21 } 26 22 27 28 public double[] SampleExpectedRewards(Random random) { 23 public double SampleExpectedReward(Random random) { 29 24 // sample bernoulli mean from beta prior 30 var theta = new double[numActions]; 31 for (int a = 0; a < numActions; a++) { 32 if (success[a] == -1) 33 theta[a] = 0.0; 34 else { 35 theta[a] = Rand.BetaRand(random, success[a] + alpha, failure[a] + beta); 36 } 37 } 38 39 // no need to sample we know the exact expected value 40 // the expected value of a bernoulli variable is just theta 41 return theta.Select(t => t).ToArray(); 25 return Rand.BetaRand(random, success + alpha, failure + beta); 42 26 } 43 27 44 public void Update(int action, double reward) { 45 const double EPSILON = 1E-6; 46 Debug.Assert(Math.Abs(reward - 0.0) < EPSILON || Math.Abs(reward - 1.0) < EPSILON); 47 if (Math.Abs(reward - 1.0) < EPSILON) { 48 success[action]++; 28 public void Update(double reward) { 29 Debug.Assert(reward.IsAlmost(1.0) || reward.IsAlmost(0.0)); 30 if (reward.IsAlmost(1.0)) { 31 success++; 49 32 } else { 50 failure [action]++;33 failure++; 51 34 } 52 35 } 53 36 54 public void Disable(int action) {55 success[action] = -1;56 }57 58 37 public void Reset() { 59 Array.Clear(success, 0, numActions);60 Array.Clear(failure, 0, numActions);38 success = 0; 39 failure = 0; 61 40 } 62 41 63 42 public void PrintStats() { 64 for (int i = 0; i < numActions; i++) { 65 Console.Write("{0:F2} ", success[i] / (double)failure[i]); 66 } 43 Console.Write("{0:F2} ", success / (double)failure); 44 } 45 46 public object Clone() { 47 return new BernoulliModel() { failure = this.failure, success = this.success }; 67 48 } 68 49 }
Note: See TracChangeset
for help on using the changeset viewer.