Changeset 11732 for branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCB1Policy.cs
- Timestamp:
- 01/07/15 09:21:46 (9 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCB1Policy.cs
r11730 r11732 7 7 8 8 namespace HeuristicLab.Algorithms.Bandits { 9 public class UCB1Policy : BanditPolicy { 10 private readonly int[] tries; 11 private readonly double[] sumReward; 12 private int totalTries = 0; 13 public UCB1Policy(int numActions) 14 : base(numActions) { 15 this.tries = new int[numActions]; 16 this.sumReward = new double[numActions]; 17 } 18 19 public override int SelectAction() { 9 public class UCB1Policy : IPolicy { 10 public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) { 11 var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>().ToArray(); // TODO: performance 20 12 int bestAction = -1; 21 13 double bestQ = double.NegativeInfinity; 22 foreach (var a in Actions) { 23 if (tries[a] == 0) return a; 24 var q = sumReward[a] / tries[a] + Math.Sqrt((2 * Math.Log(totalTries)) / tries[a]); 14 int totalTries = myActionInfos.Where(a => !a.Disabled).Sum(a => a.Tries); 15 16 for (int a = 0; a < myActionInfos.Length; a++) { 17 if (myActionInfos[a].Disabled) continue; 18 if (myActionInfos[a].Tries == 0) return a; 19 var q = myActionInfos[a].SumReward / myActionInfos[a].Tries + Math.Sqrt((2 * Math.Log(totalTries)) / myActionInfos[a].Tries); 25 20 if (q > bestQ) { 26 21 bestQ = q; … … 28 23 } 29 24 } 25 Debug.Assert(bestAction > -1); 30 26 return bestAction; 31 27 } 32 public override void UpdateReward(int action, double reward) {33 Debug.Assert(Actions.Contains(action));34 totalTries++;35 tries[action]++;36 sumReward[action] += reward;37 }38 28 39 public override void DisableAction(int action) { 40 base.DisableAction(action); 41 totalTries -= tries[action]; 42 tries[action] = -1; 43 sumReward[action] = 0; 44 } 45 46 public override void Reset() { 47 base.Reset(); 48 totalTries = 0; 49 Array.Clear(tries, 0, tries.Length); 50 Array.Clear(sumReward, 0, sumReward.Length); 51 } 52 public override void PrintStats() { 53 for (int i = 0; i < sumReward.Length; i++) { 54 if (tries[i] >= 0) { 55 Console.Write("{0,5:F2}", sumReward[i] / tries[i]); 56 } else { 57 Console.Write("{0,5}", ""); 58 } 59 } 60 Console.WriteLine(); 29 public IPolicyActionInfo CreateActionInfo() { 30 return new DefaultPolicyActionInfo(); 61 31 } 62 32 public override string ToString() {
Note: See TracChangeset
for help on using the changeset viewer.