Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
01/07/15 09:21:46 (9 years ago)
Author:
gkronber
Message:

#2283: refactoring and bug fixes

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCB1TunedPolicy.cs

    r11730 r11732  
    77
    88namespace HeuristicLab.Algorithms.Bandits {
    9   public class UCB1TunedPolicy : BanditPolicy {
    10     private readonly int[] tries;
    11     private readonly double[] sumReward;
    12     private readonly double[] sumSqrReward;
    13     private int totalTries = 0;
    14     public UCB1TunedPolicy(int numActions)
    15       : base(numActions) {
    16       this.tries = new int[numActions];
    17       this.sumReward = new double[numActions];
    18       this.sumSqrReward = new double[numActions];
    19     }
     9  public class UCB1TunedPolicy : IPolicy {
    2010
    21     private double V(int arm) {
    22       var s = tries[arm];
    23       return sumSqrReward[arm] / s - Math.Pow(sumReward[arm] / s, 2) + Math.Sqrt(2 * Math.Log(totalTries) / s);
    24     }
    25 
    26 
    27     public override int SelectAction() {
    28       Debug.Assert(Actions.Any());
     11    public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) {
     12      var myActionInfos = actionInfos.OfType<MeanAndVariancePolicyActionInfo>().ToArray(); // TODO: performance
    2913      int bestAction = -1;
    3014      double bestQ = double.NegativeInfinity;
    31       foreach (var a in Actions) {
    32         if (tries[a] == 0) return a;
    33         var q = sumReward[a] / tries[a] + Math.Sqrt((Math.Log(totalTries) / tries[a]) * Math.Min(1.0 / 4, V(a))); // 1/4 is upper bound of bernoulli distributed variable
     15      int totalTries = myActionInfos.Where(a => !a.Disabled).Sum(a => a.Tries);
     16
     17      for (int a = 0; a < myActionInfos.Length; a++) {
     18        if (myActionInfos[a].Disabled) continue;
     19        if (myActionInfos[a].Tries == 0) return a;
     20
     21        var sumReward = myActionInfos[a].SumReward;
     22        var tries = myActionInfos[a].Tries;
     23
     24        var avgReward = sumReward / tries;
     25        var q = avgReward + Math.Sqrt((Math.Log(totalTries) / tries) * Math.Min(1.0 / 4, V(myActionInfos[a], totalTries))); // 1/4 is upper bound of bernoulli distributed variable
    3426        if (q > bestQ) {
    3527          bestQ = q;
     
    3729        }
    3830      }
     31      Debug.Assert(bestAction > -1);
    3932      return bestAction;
    4033    }
    41     public override void UpdateReward(int action, double reward) {
    42       Debug.Assert(Actions.Contains(action));
    43       totalTries++;
    44       tries[action]++;
    45       sumReward[action] += reward;
    46       sumSqrReward[action] += reward * reward;
     34
     35    public IPolicyActionInfo CreateActionInfo() {
     36      return new MeanAndVariancePolicyActionInfo();
    4737    }
    4838
    49     public override void DisableAction(int action) {
    50       base.DisableAction(action);
    51       totalTries -= tries[action];
    52       tries[action] = -1;
    53       sumReward[action] = 0;
    54       sumSqrReward[action] = 0;
     39    private double V(MeanAndVariancePolicyActionInfo actionInfo, int totalTries) {
     40      var s = actionInfo.Tries;
     41      return actionInfo.RewardVariance + Math.Sqrt(2 * Math.Log(totalTries) / s);
    5542    }
    5643
    57     public override void Reset() {
    58       base.Reset();
    59       totalTries = 0;
    60       Array.Clear(tries, 0, tries.Length);
    61       Array.Clear(sumReward, 0, sumReward.Length);
    62       Array.Clear(sumSqrReward, 0, sumSqrReward.Length);
    63     }
    64     public override void PrintStats() {
    65       for (int i = 0; i < sumReward.Length; i++) {
    66         if (tries[i] >= 0) {
    67           Console.Write("{0,5:F2}", sumReward[i] / tries[i]);
    68         } else {
    69           Console.Write("{0,5}", "");
    70         }
    71       }
    72       Console.WriteLine();
    73     }
    7444    public override string ToString() {
    7545      return "UCB1TunedPolicy";
Note: See TracChangeset for help on using the changeset viewer.