Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
01/07/15 09:21:46 (9 years ago)
Author:
gkronber
Message:

#2283: refactoring and bug fixes

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCB1Policy.cs

    r11730 r11732  
    77
    88namespace HeuristicLab.Algorithms.Bandits {
    9   public class UCB1Policy : BanditPolicy {
    10     private readonly int[] tries;
    11     private readonly double[] sumReward;
    12     private int totalTries = 0;
    13     public UCB1Policy(int numActions)
    14       : base(numActions) {
    15       this.tries = new int[numActions];
    16       this.sumReward = new double[numActions];
    17     }
    18 
    19     public override int SelectAction() {
     9  public class UCB1Policy : IPolicy {
     10    public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) {
     11      var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>().ToArray(); // TODO: performance
    2012      int bestAction = -1;
    2113      double bestQ = double.NegativeInfinity;
    22       foreach (var a in Actions) {
    23         if (tries[a] == 0) return a;
    24         var q = sumReward[a] / tries[a] + Math.Sqrt((2 * Math.Log(totalTries)) / tries[a]);
     14      int totalTries = myActionInfos.Where(a => !a.Disabled).Sum(a => a.Tries);
     15
     16      for (int a = 0; a < myActionInfos.Length; a++) {
     17        if (myActionInfos[a].Disabled) continue;
     18        if (myActionInfos[a].Tries == 0) return a;
     19        var q = myActionInfos[a].SumReward / myActionInfos[a].Tries + Math.Sqrt((2 * Math.Log(totalTries)) / myActionInfos[a].Tries);
    2520        if (q > bestQ) {
    2621          bestQ = q;
     
    2823        }
    2924      }
     25      Debug.Assert(bestAction > -1);
    3026      return bestAction;
    3127    }
    32     public override void UpdateReward(int action, double reward) {
    33       Debug.Assert(Actions.Contains(action));
    34       totalTries++;
    35       tries[action]++;
    36       sumReward[action] += reward;
    37     }
    3828
    39     public override void DisableAction(int action) {
    40       base.DisableAction(action);
    41       totalTries -= tries[action];
    42       tries[action] = -1;
    43       sumReward[action] = 0;
    44     }
    45 
    46     public override void Reset() {
    47       base.Reset();
    48       totalTries = 0;
    49       Array.Clear(tries, 0, tries.Length);
    50       Array.Clear(sumReward, 0, sumReward.Length);
    51     }
    52     public override void PrintStats() {
    53       for (int i = 0; i < sumReward.Length; i++) {
    54         if (tries[i] >= 0) {
    55           Console.Write("{0,5:F2}", sumReward[i] / tries[i]);
    56         } else {
    57           Console.Write("{0,5}", "");
    58         }
    59       }
    60       Console.WriteLine();
     29    public IPolicyActionInfo CreateActionInfo() {
     30      return new DefaultPolicyActionInfo();
    6131    }
    6232    public override string ToString() {
Note: See TracChangeset for help on using the changeset viewer.