Changeset 11742 for branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/ThresholdAscentPolicy.cs
- Timestamp:
- 01/09/15 14:57:28 (10 years ago)
- Location:
- branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies
- Files:
-
- 1 edited
- 1 moved
Legend:
- Unmodified
- Added
- Removed
-
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/ThresholdAscentPolicy.cs
r11730 r11742 5 5 using System.Text; 6 6 using System.Threading.Tasks; 7 using HeuristicLab.Common; 7 8 8 namespace HeuristicLab.Algorithms.Bandits {9 namespace HeuristicLab.Algorithms.Bandits.BanditPolicies { 9 10 /* see: Streeter and Smith: A simple distribution-free approach to the max k-armed bandit problem, Proceedings of the 12th 10 11 International Conference, CP 2006, Nantes, France, September 25-29, 2006. pp 560-574 */ 11 12 12 public class ThresholdAscentPolicy : BanditPolicy {13 const int numBins = 101;14 const double binSize = 1.0 / (numBins - 1);13 public class ThresholdAscentPolicy : IBanditPolicy { 14 public const int numBins = 101; 15 public const double binSize = 1.0 / (numBins - 1); 15 16 16 // for each arm store the number of observed rewards for each bin of size delta 17 // for delta = 0.01 we have 101 bins 18 // the first bin is freq of rewards >= 0 // all 19 // the second bin is freq of rewards > 0 20 // the third bin is freq of rewards > 0.01 21 // the last bin is for rewards > 0.99 22 // 23 // (also see RewardBin function) 24 private readonly int[,] armRewardHistogram; // for performance reasons we store cumulative counts (freq of rewards > lower threshold) 17 private class ThresholdAscentActionInfo : IBanditPolicyActionInfo { 25 18 19 // for each arm store the number of observed rewards for each bin of size delta 20 // for delta = 0.01 we have 101 bins 21 // the first bin is freq of rewards >= 0 // all 22 // the second bin is freq of rewards > 0 23 // the third bin is freq of rewards > 0.01 24 // the last bin is for rewards > 0.99 25 // 26 // (also see RewardBin function) 27 public int[] rewardHistogram = new int[numBins]; // for performance reasons we store cumulative counts (freq of rewards > lower threshold) 28 public int Tries { get; private set; } 29 public int thresholdBin = 1; 30 public double Value { get { return rewardHistogram[thresholdBin] / (double)Tries; } } 26 31 27 private readonly int[] tries; 32 public bool Disabled { get { return Tries == -1; } } 33 34 public void UpdateReward(double reward) { 35 Tries++; 36 for (var idx = thresholdBin; idx <= RewardBin(reward); idx++) 37 rewardHistogram[idx]++; 38 } 39 40 public void Disable() { 41 Tries = -1; 42 } 43 44 public void Reset() { 45 Tries = 0; 46 thresholdBin = 1; 47 Array.Clear(rewardHistogram, 0, rewardHistogram.Length); 48 } 49 50 public void PrintStats() { 51 if (Tries >= 0) { 52 Console.Write("{0,6}", Tries); 53 } else { 54 Console.Write("{0,6}", ""); 55 } 56 } 57 58 // maps a reward value to it's bin 59 private static int RewardBin(double reward) { 60 Debug.Assert(reward >= 0 && reward <= 1.0); 61 // reward = 0 => 0 62 // ]0.00 .. 0.01] => 1 63 // ]0.01 .. 0.02] => 2 64 // ... 65 // ]0.99 .. 1.00] => 100 66 if (reward <= 0) return 0; 67 return (int)Math.Ceiling((reward / binSize)); 68 } 69 } 70 28 71 private readonly int s; 29 72 private readonly double delta; 30 73 31 private int totalTries = 0; 32 private int thresholdBin; // bin index of current threshold 33 private const double maxTries = 1E6; 34 35 public ThresholdAscentPolicy(int numActions, int s = 100, double delta = 0.05) 36 : base(numActions) { 37 this.thresholdBin = 1; // first bin to check is bin idx 1 == freq of rewards > 0 74 public ThresholdAscentPolicy(int s = 100, double delta = 0.05) { 38 75 this.s = s; 39 76 this.delta = delta; 40 this.armRewardHistogram = new int[numActions, numBins];41 this.tries = new int[numActions];42 77 } 43 78 44 // maps a reward value to it's bin 45 private static int RewardBin(double reward) { 46 Debug.Assert(reward >= 0 && reward <= 1.0); 47 // reward = 0 => 0 48 // ]0.00 .. 0.01] => 1 49 // ]0.01 .. 0.02] => 2 50 // ... 51 // ]0.99 .. 1.00] => 100 52 if (reward <= 0) return 0; 53 return (int)Math.Ceiling((reward / binSize)); 54 } 55 56 57 private double U(double mu, int n, int k) { 79 private double U(double mu, int totalTries, int n, int k) { 58 80 //var alpha = Math.Log(2.0 * totalTries * k / delta); 59 double alpha = Math.Log(2) + Math.Log( maxTries) + Math.Log(k) - Math.Log(delta); // totalTries is max iterations in original paper81 double alpha = Math.Log(2) + Math.Log(totalTries) + Math.Log(k) - Math.Log(delta); 60 82 return mu + (alpha + Math.Sqrt(2 * n * mu * alpha + alpha * alpha)) / n; 61 83 } 62 84 63 85 64 public override int SelectAction() { 65 Debug.Assert(Actions.Any()); 66 UpdateThreshold(); 86 public int SelectAction(Random random, IEnumerable<IBanditPolicyActionInfo> actionInfos) { 87 Debug.Assert(actionInfos.Any()); 88 var myActionInfos = actionInfos.OfType<ThresholdAscentActionInfo>(); 89 UpdateThreshold(myActionInfos); 90 67 91 int bestAction = -1; 68 92 double bestQ = double.NegativeInfinity; 69 int k = Actions.Count(); 70 foreach (var a in Actions) { 71 if (tries[a] == 0) return a; 72 double mu = armRewardHistogram[a, thresholdBin] / (double)tries[a]; // probability of rewards > T 73 double q = U(mu, tries[a], k); 93 int k = myActionInfos.Count(a => !a.Disabled); 94 var totalTries = myActionInfos.Where(a => !a.Disabled).Sum(a => a.Tries); 95 int aIdx = -1; 96 foreach (var aInfo in myActionInfos) { 97 aIdx++; 98 if (aInfo.Disabled) continue; 99 if (aInfo.Tries == 0) return aIdx; 100 double mu = aInfo.Value; // probability of rewards > T 101 double q = U(mu, totalTries, aInfo.Tries, k); // totalTries is max iterations in original paper 74 102 if (q > bestQ) { 75 103 bestQ = q; 76 bestAction = a ;104 bestAction = aIdx; 77 105 } 78 106 } 79 Debug.Assert( Actions.Contains(bestAction));107 Debug.Assert(bestAction > -1); 80 108 return bestAction; 81 109 } 82 110 83 private void UpdateThreshold() { 84 while (thresholdBin < (numBins - 1) && Actions.Sum(a => armRewardHistogram[a, thresholdBin]) >= s) { 111 112 private void UpdateThreshold(IEnumerable<ThresholdAscentActionInfo> actionInfos) { 113 var thresholdBin = 1; // first bin to check is bin idx 1 == freq of rewards > 0 114 while (thresholdBin < (numBins - 1) && actionInfos.Sum(a => a.rewardHistogram[thresholdBin]) >= s) { 85 115 thresholdBin++; 86 116 // Console.WriteLine("New threshold {0:F2}", T); 87 117 } 118 foreach (var aInfo in actionInfos) { 119 aInfo.thresholdBin = thresholdBin; 120 } 88 121 } 89 122 90 public override void UpdateReward(int action, double reward) { 91 Debug.Assert(Actions.Contains(action)); 92 totalTries++; 93 tries[action]++; 94 // efficiency: we can start at the current threshold bin because all bins below that are not accessed in select-action 95 for (var idx = thresholdBin; idx <= RewardBin(reward); idx++) 96 armRewardHistogram[action, idx]++; 123 124 public IBanditPolicyActionInfo CreateActionInfo() { 125 return new ThresholdAscentActionInfo(); 97 126 } 98 127 99 public override void DisableAction(int action) {100 base.DisableAction(action);101 totalTries -= tries[action];102 tries[action] = -1;103 }104 105 public override void Reset() {106 base.Reset();107 totalTries = 0;108 thresholdBin = 1;109 Array.Clear(tries, 0, tries.Length);110 Array.Clear(armRewardHistogram, 0, armRewardHistogram.Length);111 }112 113 public override void PrintStats() {114 for (int i = 0; i < tries.Length; i++) {115 if (tries[i] >= 0) {116 Console.Write("{0,6}", tries[i]);117 } else {118 Console.Write("{0,6}", "");119 }120 }121 Console.WriteLine();122 }123 128 public override string ToString() { 124 129 return string.Format("ThresholdAscentPolicy({0},{1:F2})", s, delta);
Note: See TracChangeset
for help on using the changeset viewer.