Changeset 11732 for branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCB1TunedPolicy.cs
- Timestamp:
- 01/07/15 09:21:46 (9 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCB1TunedPolicy.cs
r11730 r11732 7 7 8 8 namespace HeuristicLab.Algorithms.Bandits { 9 public class UCB1TunedPolicy : BanditPolicy { 10 private readonly int[] tries; 11 private readonly double[] sumReward; 12 private readonly double[] sumSqrReward; 13 private int totalTries = 0; 14 public UCB1TunedPolicy(int numActions) 15 : base(numActions) { 16 this.tries = new int[numActions]; 17 this.sumReward = new double[numActions]; 18 this.sumSqrReward = new double[numActions]; 19 } 9 public class UCB1TunedPolicy : IPolicy { 20 10 21 private double V(int arm) { 22 var s = tries[arm]; 23 return sumSqrReward[arm] / s - Math.Pow(sumReward[arm] / s, 2) + Math.Sqrt(2 * Math.Log(totalTries) / s); 24 } 25 26 27 public override int SelectAction() { 28 Debug.Assert(Actions.Any()); 11 public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) { 12 var myActionInfos = actionInfos.OfType<MeanAndVariancePolicyActionInfo>().ToArray(); // TODO: performance 29 13 int bestAction = -1; 30 14 double bestQ = double.NegativeInfinity; 31 foreach (var a in Actions) { 32 if (tries[a] == 0) return a; 33 var q = sumReward[a] / tries[a] + Math.Sqrt((Math.Log(totalTries) / tries[a]) * Math.Min(1.0 / 4, V(a))); // 1/4 is upper bound of bernoulli distributed variable 15 int totalTries = myActionInfos.Where(a => !a.Disabled).Sum(a => a.Tries); 16 17 for (int a = 0; a < myActionInfos.Length; a++) { 18 if (myActionInfos[a].Disabled) continue; 19 if (myActionInfos[a].Tries == 0) return a; 20 21 var sumReward = myActionInfos[a].SumReward; 22 var tries = myActionInfos[a].Tries; 23 24 var avgReward = sumReward / tries; 25 var q = avgReward + Math.Sqrt((Math.Log(totalTries) / tries) * Math.Min(1.0 / 4, V(myActionInfos[a], totalTries))); // 1/4 is upper bound of bernoulli distributed variable 34 26 if (q > bestQ) { 35 27 bestQ = q; … … 37 29 } 38 30 } 31 Debug.Assert(bestAction > -1); 39 32 return bestAction; 40 33 } 41 public override void UpdateReward(int action, double reward) { 42 Debug.Assert(Actions.Contains(action)); 43 totalTries++; 44 tries[action]++; 45 sumReward[action] += reward; 46 sumSqrReward[action] += reward * reward; 34 35 public IPolicyActionInfo CreateActionInfo() { 36 return new MeanAndVariancePolicyActionInfo(); 47 37 } 48 38 49 public override void DisableAction(int action) { 50 base.DisableAction(action); 51 totalTries -= tries[action]; 52 tries[action] = -1; 53 sumReward[action] = 0; 54 sumSqrReward[action] = 0; 39 private double V(MeanAndVariancePolicyActionInfo actionInfo, int totalTries) { 40 var s = actionInfo.Tries; 41 return actionInfo.RewardVariance + Math.Sqrt(2 * Math.Log(totalTries) / s); 55 42 } 56 43 57 public override void Reset() {58 base.Reset();59 totalTries = 0;60 Array.Clear(tries, 0, tries.Length);61 Array.Clear(sumReward, 0, sumReward.Length);62 Array.Clear(sumSqrReward, 0, sumSqrReward.Length);63 }64 public override void PrintStats() {65 for (int i = 0; i < sumReward.Length; i++) {66 if (tries[i] >= 0) {67 Console.Write("{0,5:F2}", sumReward[i] / tries[i]);68 } else {69 Console.Write("{0,5}", "");70 }71 }72 Console.WriteLine();73 }74 44 public override string ToString() { 75 45 return "UCB1TunedPolicy";
Note: See TracChangeset
for help on using the changeset viewer.