Changeset 11730 for branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies
- Timestamp:
- 01/02/15 16:08:21 (9 years ago)
- Location:
- branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies
- Files:
-
- 5 added
- 9 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/BanditPolicy.cs
r11727 r11730 28 28 Actions = Enumerable.Range(0, numInitialActions).ToArray(); 29 29 } 30 31 public abstract void PrintStats(); 30 32 } 31 33 } -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/BernoulliThompsonSamplingPolicy.cs
r11727 r11730 55 55 Array.Clear(failure, 0, failure.Length); 56 56 } 57 58 public override void PrintStats() { 59 for (int i = 0; i < success.Length; i++) { 60 if (success[i] >= 0) { 61 Console.Write("{0,5:F2}", success[i] / failure[i]); 62 } else { 63 Console.Write("{0,5}", ""); 64 } 65 } 66 Console.WriteLine(); 67 } 68 69 public override string ToString() { 70 return "BernoulliThompsonSamplingPolicy"; 71 } 57 72 } 58 73 } -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/EpsGreedyPolicy.cs
r11727 r11730 27 27 if (random.NextDouble() > eps) { 28 28 // select best 29 var maxReward= double.NegativeInfinity;29 var bestQ = double.NegativeInfinity; 30 30 int bestAction = -1; 31 31 foreach (var a in Actions) { 32 32 if (tries[a] == 0) return a; 33 var avgReward= sumReward[a] / tries[a];34 if ( maxReward < avgReward) {35 maxReward = avgReward;33 var q = sumReward[a] / tries[a]; 34 if (bestQ < q) { 35 bestQ = q; 36 36 bestAction = a; 37 37 } … … 65 65 Array.Clear(sumReward, 0, sumReward.Length); 66 66 } 67 public override void PrintStats() { 68 for (int i = 0; i < sumReward.Length; i++) { 69 if (tries[i] >= 0) { 70 Console.Write(" {0,5:F2} {1}", sumReward[i] / tries[i], tries[i]); 71 } else { 72 Console.Write("-", ""); 73 } 74 } 75 Console.WriteLine(); 76 } 77 public override string ToString() { 78 return string.Format("EpsGreedyPolicy({0:F2})", eps); 79 } 67 80 } 68 81 } -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/Exp3Policy.cs
r11727 r11730 52 52 foreach (var a in Actions) w[a] = 1.0; 53 53 } 54 public override void PrintStats() { 55 for (int i = 0; i < w.Length; i++) { 56 if (w[i] > 0) { 57 Console.Write("{0,5:F2}", w[i]); 58 } else { 59 Console.Write("{0,5}", ""); 60 } 61 } 62 Console.WriteLine(); 63 } 64 public override string ToString() { 65 return "Exp3Policy"; 66 } 54 67 } 55 68 } -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/GaussianThompsonSamplingPolicy.cs
r11727 r11730 5 5 6 6 namespace HeuristicLab.Algorithms.Bandits { 7 7 8 public class GaussianThompsonSamplingPolicy : BanditPolicy { 8 9 private readonly Random random; 9 private readonly double[] s umRewards;10 private readonly double[] s umSqrRewards;10 private readonly double[] sampleMean; 11 private readonly double[] sampleM2; 11 12 private readonly int[] tries; 12 public GaussianThompsonSamplingPolicy(Random random, int numActions) 13 private bool compatibility; 14 15 // assumes a Gaussian reward distribution with different means but the same variances for each action 16 // the prior for the mean is also Gaussian with the following parameters 17 private readonly double rewardVariance = 0.1; // we assume a known variance 18 19 private readonly double priorMean = 0.5; 20 private readonly double priorVariance = 1; 21 22 23 public GaussianThompsonSamplingPolicy(Random random, int numActions, bool compatibility = false) 13 24 : base(numActions) { 14 25 this.random = random; 15 this.s umRewards= new double[numActions];16 this.s umSqrRewards= new double[numActions];26 this.sampleMean = new double[numActions]; 27 this.sampleM2 = new double[numActions]; 17 28 this.tries = new int[numActions]; 29 this.compatibility = compatibility; 18 30 } 19 31 … … 24 36 int bestAction = -1; 25 37 foreach (var a in Actions) { 26 if (tries[a] == 0) return a; 27 var mu = sumRewards[a] / tries[a]; 28 var stdDev = Math.Sqrt(sumSqrRewards[a] / tries[a] - Math.Pow(mu, 2)); 29 var theta = Rand.RandNormal(random) * stdDev + mu; 38 if(tries[a] == -1) continue; // skip disabled actions 39 double theta; 40 if (compatibility) { 41 if (tries[a] < 2) return a; 42 var mu = sampleMean[a]; 43 var variance = sampleM2[a] / tries[a]; 44 var stdDev = Math.Sqrt(variance); 45 theta = Rand.RandNormal(random) * stdDev + mu; 46 } else { 47 // calculate posterior mean and variance (for mean reward) 48 49 // see Murphy 2007: Conjugate Bayesian analysis of the Gaussian distribution (http://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf) 50 var posteriorVariance = 1.0 / (tries[a] / rewardVariance + 1.0 / priorVariance); 51 var posteriorMean = posteriorVariance * (priorMean / priorVariance + tries[a] * sampleMean[a] / rewardVariance); 52 53 // sample a mean from the posterior 54 theta = Rand.RandNormal(random) * Math.Sqrt(posteriorVariance) + posteriorMean; 55 56 // theta already represents the expected reward value => nothing else to do 57 } 30 58 if (theta > maxTheta) { 31 59 maxTheta = theta; … … 33 61 } 34 62 } 63 Debug.Assert(Actions.Contains(bestAction)); 35 64 return bestAction; 36 65 } … … 38 67 public override void UpdateReward(int action, double reward) { 39 68 Debug.Assert(Actions.Contains(action)); 40 41 sumRewards[action] += reward;42 sumSqrRewards[action] += reward * reward;43 69 tries[action]++; 70 var delta = reward - sampleMean[action]; 71 sampleMean[action] += delta / tries[action]; 72 sampleM2[action] += sampleM2[action] + delta * (reward - sampleMean[action]); 44 73 } 45 74 46 75 public override void DisableAction(int action) { 47 76 base.DisableAction(action); 48 s umRewards[action] = 0;49 s umSqrRewards[action] = 0;77 sampleMean[action] = 0; 78 sampleM2[action] = 0; 50 79 tries[action] = -1; 51 80 } … … 53 82 public override void Reset() { 54 83 base.Reset(); 55 Array.Clear(s umRewards, 0, sumRewards.Length);56 Array.Clear(s umSqrRewards, 0, sumSqrRewards.Length);84 Array.Clear(sampleMean, 0, sampleMean.Length); 85 Array.Clear(sampleM2, 0, sampleM2.Length); 57 86 Array.Clear(tries, 0, tries.Length); 87 } 88 89 public override void PrintStats() { 90 for (int i = 0; i < sampleMean.Length; i++) { 91 if (tries[i] >= 0) { 92 Console.Write(" {0,5:F2} {1}", sampleMean[i] / tries[i], tries[i]); 93 } else { 94 Console.Write("{0,5}", ""); 95 } 96 } 97 Console.WriteLine(); 98 } 99 public override string ToString() { 100 return "GaussianThompsonSamplingPolicy"; 58 101 } 59 102 } -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/RandomPolicy.cs
r11727 r11730 23 23 // do nothing 24 24 } 25 25 public override void PrintStats() { 26 Console.WriteLine("Random"); 27 } 28 public override string ToString() { 29 return "RandomPolicy"; 30 } 26 31 } 27 32 } -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCB1Policy.cs
r11727 r11730 50 50 Array.Clear(sumReward, 0, sumReward.Length); 51 51 } 52 public override void PrintStats() { 53 for (int i = 0; i < sumReward.Length; i++) { 54 if (tries[i] >= 0) { 55 Console.Write("{0,5:F2}", sumReward[i] / tries[i]); 56 } else { 57 Console.Write("{0,5}", ""); 58 } 59 } 60 Console.WriteLine(); 61 } 62 public override string ToString() { 63 return "UCB1Policy"; 64 } 52 65 } 53 66 } -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCB1TunedPolicy.cs
r11727 r11730 62 62 Array.Clear(sumSqrReward, 0, sumSqrReward.Length); 63 63 } 64 public override void PrintStats() { 65 for (int i = 0; i < sumReward.Length; i++) { 66 if (tries[i] >= 0) { 67 Console.Write("{0,5:F2}", sumReward[i] / tries[i]); 68 } else { 69 Console.Write("{0,5}", ""); 70 } 71 } 72 Console.WriteLine(); 73 } 74 public override string ToString() { 75 return "UCB1TunedPolicy"; 76 } 64 77 } 65 78 } -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCBNormalPolicy.cs
r11727 r11730 24 24 double bestQ = double.NegativeInfinity; 25 25 foreach (var a in Actions) { 26 if (totalTries == 0 || tries[a] == 0 || tries[a] <Math.Ceiling(8 * Math.Log(totalTries))) return a;26 if (totalTries <= 1 || tries[a] <= 1 || tries[a] <= Math.Ceiling(8 * Math.Log(totalTries))) return a; 27 27 var avgReward = sumReward[a] / tries[a]; 28 var estVariance = 16 * ((sumSqrReward[a] - tries[a] * Math.Pow(avgReward, 2)) / (tries[a] - 1)) * (Math.Log(totalTries - 1) / tries[a]); 29 if (estVariance < 0) estVariance = 0; // numerical problems 28 30 var q = avgReward 29 + Math.Sqrt( 16 * ((sumSqrReward[a] - tries[a] * Math.Pow(avgReward, 2)) / (tries[a] - 1)) * (Math.Log(totalTries - 1) / tries[a]));31 + Math.Sqrt(estVariance); 30 32 if (q > bestQ) { 31 33 bestQ = q; … … 33 35 } 34 36 } 37 Debug.Assert(Actions.Contains(bestAction)); 35 38 return bestAction; 36 39 } … … 58 61 Array.Clear(sumSqrReward, 0, sumSqrReward.Length); 59 62 } 63 public override void PrintStats() { 64 for (int i = 0; i < sumReward.Length; i++) { 65 if (tries[i] >= 0) { 66 Console.Write("{0,5:F2}", sumReward[i] / tries[i]); 67 } else { 68 Console.Write("{0,5}", ""); 69 } 70 } 71 Console.WriteLine(); 72 } 73 public override string ToString() { 74 return "UCBNormalPolicy"; 75 } 60 76 } 61 77 }
Note: See TracChangeset
for help on using the changeset viewer.