Ignore:
Timestamp:
01/07/15 09:21:46 (5 years ago)
Author:
gkronber
Message:

#2283: refactoring and bug fixes

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/BernoulliThompsonSamplingPolicy.cs

    r11730 r11732  
    88
    99namespace HeuristicLab.Algorithms.Bandits {
    10   public class BernoulliThompsonSamplingPolicy : BanditPolicy {
    11     private readonly Random random;
    12     private readonly int[] success;
    13     private readonly int[] failure;
    14 
     10  public class BernoulliThompsonSamplingPolicy : IPolicy {
    1511    // parameters of beta prior distribution
    1612    private readonly double alpha = 1.0;
    1713    private readonly double beta = 1.0;
    1814
    19     public BernoulliThompsonSamplingPolicy(Random random, int numActions)
    20       : base(numActions) {
    21       this.random = random;
    22       this.success = new int[numActions];
    23       this.failure = new int[numActions];
    24     }
     15    public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) {
     16      var myActionInfos = actionInfos.OfType<BernoulliPolicyActionInfo>(); // TODO: performance
     17      int bestAction = -1;
     18      double maxTheta = double.NegativeInfinity;
     19      var aIdx = -1;
    2520
    26     public override int SelectAction() {
    27       Debug.Assert(Actions.Any());
    28       var maxTheta = double.NegativeInfinity;
    29       int bestAction = -1;
    30       foreach (var a in Actions) {
    31         var theta = Rand.BetaRand(random, success[a] + alpha, failure[a] + beta);
     21      foreach (var aInfo in myActionInfos) {
     22        aIdx++;
     23        if (aInfo.Disabled) continue;
     24        var theta = Rand.BetaRand(random, aInfo.NumSuccess + alpha, aInfo.NumFailure + beta);
    3225        if (theta > maxTheta) {
    3326          maxTheta = theta;
    34           bestAction = a;
     27          bestAction = aIdx;
    3528        }
    3629      }
     30      Debug.Assert(bestAction > -1);
    3731      return bestAction;
    3832    }
    3933
    40     public override void UpdateReward(int action, double reward) {
    41       Debug.Assert(Actions.Contains(action));
    42 
    43       if (reward > 0) success[action]++;
    44       else failure[action]++;
     34    public IPolicyActionInfo CreateActionInfo() {
     35      return new BernoulliPolicyActionInfo();
    4536    }
    4637
    47     public override void DisableAction(int action) {
    48       base.DisableAction(action);
    49       success[action] = -1;
    50     }
    51 
    52     public override void Reset() {
    53       base.Reset();
    54       Array.Clear(success, 0, success.Length);
    55       Array.Clear(failure, 0, failure.Length);
    56     }
    57 
    58     public override void PrintStats() {
    59       for (int i = 0; i < success.Length; i++) {
    60         if (success[i] >= 0) {
    61           Console.Write("{0,5:F2}", success[i] / failure[i]);
    62         } else {
    63           Console.Write("{0,5}", "");
    64         }
    65       }
    66       Console.WriteLine();
    67     }
    6838
    6939    public override string ToString() {
Note: See TracChangeset for help on using the changeset viewer.