Changeset 11732 for branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits
- Timestamp:
- 01/07/15 09:21:46 (10 years ago)
- Location:
- branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits
- Files:
-
- 8 added
- 17 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/HeuristicLab.Algorithms.Bandits.csproj
r11730 r11732 31 31 </PropertyGroup> 32 32 <ItemGroup> 33 <Reference Include="ALGLIB-3.7.0"> 34 <HintPath>..\..\..\trunk\sources\bin\ALGLIB-3.7.0.dll</HintPath> 35 </Reference> 33 36 <Reference Include="System" /> 34 37 <Reference Include="System.Core" /> … … 42 45 <Compile Include="BanditHelper.cs" /> 43 46 <Compile Include="Bandits\BernoulliBandit.cs" /> 47 <Compile Include="Bandits\GaussianBandit.cs" /> 44 48 <Compile Include="Bandits\GaussianMixtureBandit.cs" /> 45 49 <Compile Include="Bandits\IBandit.cs" /> 46 50 <Compile Include="Bandits\TruncatedNormalBandit.cs" /> 51 <Compile Include="OnlineMeanAndVarianceEstimator.cs" /> 52 <Compile Include="IPolicyActionInfo.cs" /> 47 53 <Compile Include="Models\BernoulliModel.cs" /> 48 54 <Compile Include="Models\GaussianModel.cs" /> 49 <Compile Include="Models\GaussianMixtureModel.cs" />50 55 <Compile Include="Models\IModel.cs" /> 51 <Compile Include="Policies\BanditPolicy.cs" /> 52 <Compile Include="Policies\BernoulliThompsonSamplingPolicy.cs" /> 53 <Compile Include="Policies\BoltzmannExplorationPolicy.cs" /> 54 <Compile Include="Policies\ChernoffIntervalEstimationPolicy.cs" /> 55 <Compile Include="Policies\GenericThompsonSamplingPolicy.cs" /> 56 <Compile Include="Policies\ThresholdAscentPolicy.cs" /> 57 <Compile Include="Policies\UCTPolicy.cs" /> 58 <Compile Include="Policies\GaussianThompsonSamplingPolicy.cs" /> 59 <Compile Include="Policies\Exp3Policy.cs" /> 60 <Compile Include="Policies\EpsGreedyPolicy.cs" /> 56 <Compile Include="Policies\BernoulliThompsonSamplingPolicy.cs"> 57 <SubType>Code</SubType> 58 </Compile> 59 <Compile Include="Policies\BoltzmannExplorationPolicy.cs"> 60 <SubType>Code</SubType> 61 </Compile> 62 <Compile Include="Policies\ChernoffIntervalEstimationPolicy.cs"> 63 <SubType>Code</SubType> 64 </Compile> 65 <Compile Include="Policies\BernoulliPolicyActionInfo.cs" /> 66 <Compile Include="Policies\ModelPolicyActionInfo.cs" /> 67 <Compile Include="Policies\EpsGreedyPolicy.cs"> 68 <SubType>Code</SubType> 69 </Compile> 70 <Compile Include="Policies\GaussianThompsonSamplingPolicy.cs"> 71 <SubType>Code</SubType> 72 </Compile> 73 <Compile Include="Policies\GenericThompsonSamplingPolicy.cs"> 74 <SubType>Code</SubType> 75 </Compile> 76 <Compile Include="Policies\MeanAndVariancePolicyActionInfo.cs" /> 77 <Compile Include="Policies\DefaultPolicyActionInfo.cs" /> 78 <Compile Include="Policies\EmptyPolicyActionInfo.cs" /> 61 79 <Compile Include="Policies\RandomPolicy.cs" /> 62 80 <Compile Include="Policies\UCB1Policy.cs" /> 63 <Compile Include="Policies\UCB1TunedPolicy.cs" />64 <Compile Include="Policies\UCBNormalPolicy.cs" />65 81 <Compile Include="IPolicy.cs" /> 82 <Compile Include="Policies\UCB1TunedPolicy.cs"> 83 <SubType>Code</SubType> 84 </Compile> 85 <Compile Include="Policies\UCBNormalPolicy.cs"> 86 <SubType>Code</SubType> 87 </Compile> 88 <Compile Include="Policies\UCTPolicy.cs"> 89 <SubType>Code</SubType> 90 </Compile> 66 91 <Compile Include="Properties\AssemblyInfo.cs" /> 67 92 </ItemGroup> -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/IPolicy.cs
r11730 r11732 8 8 // this interface represents a policy for reinforcement learning 9 9 public interface IPolicy { 10 IEnumerable<int> Actions { get; } 11 int SelectAction(); // action selection ... 12 void UpdateReward(int action, double reward); // ... and reward update are defined as usual 13 14 // policies must also support disabling of potential actions 15 // for instance if we know that an action in a state has a deterministic 16 // reward we need to sample it only once 17 // it is necessary to sample an action only once 18 void DisableAction(int action); 19 20 // reset causes the policy to be reinitialized to it's initial state (as after constructor-call) 21 void Reset(); 22 23 void PrintStats(); 10 int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos); 11 IPolicyActionInfo CreateActionInfo(); 24 12 } 25 13 } -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Models/BernoulliModel.cs
r11730 r11732 9 9 namespace HeuristicLab.Algorithms.Bandits.Models { 10 10 public class BernoulliModel : IModel { 11 private readonly int numActions; 12 private readonly int[] success; 13 private readonly int[] failure; 11 private int success; 12 private int failure; 14 13 15 14 // parameters of beta prior distribution … … 17 16 private readonly double beta; 18 17 19 public BernoulliModel(int numActions, double alpha = 1.0, double beta = 1.0) { 20 this.numActions = numActions; 21 this.success = new int[numActions]; 22 this.failure = new int[numActions]; 18 public BernoulliModel(double alpha = 1.0, double beta = 1.0) { 23 19 this.alpha = alpha; 24 20 this.beta = beta; 25 21 } 26 22 27 28 public double[] SampleExpectedRewards(Random random) { 23 public double SampleExpectedReward(Random random) { 29 24 // sample bernoulli mean from beta prior 30 var theta = new double[numActions]; 31 for (int a = 0; a < numActions; a++) { 32 if (success[a] == -1) 33 theta[a] = 0.0; 34 else { 35 theta[a] = Rand.BetaRand(random, success[a] + alpha, failure[a] + beta); 36 } 37 } 38 39 // no need to sample we know the exact expected value 40 // the expected value of a bernoulli variable is just theta 41 return theta.Select(t => t).ToArray(); 25 return Rand.BetaRand(random, success + alpha, failure + beta); 42 26 } 43 27 44 public void Update(int action, double reward) { 45 const double EPSILON = 1E-6; 46 Debug.Assert(Math.Abs(reward - 0.0) < EPSILON || Math.Abs(reward - 1.0) < EPSILON); 47 if (Math.Abs(reward - 1.0) < EPSILON) { 48 success[action]++; 28 public void Update(double reward) { 29 Debug.Assert(reward.IsAlmost(1.0) || reward.IsAlmost(0.0)); 30 if (reward.IsAlmost(1.0)) { 31 success++; 49 32 } else { 50 failure [action]++;33 failure++; 51 34 } 52 35 } 53 36 54 public void Disable(int action) {55 success[action] = -1;56 }57 58 37 public void Reset() { 59 Array.Clear(success, 0, numActions);60 Array.Clear(failure, 0, numActions);38 success = 0; 39 failure = 0; 61 40 } 62 41 63 42 public void PrintStats() { 64 for (int i = 0; i < numActions; i++) { 65 Console.Write("{0:F2} ", success[i] / (double)failure[i]); 66 } 43 Console.Write("{0:F2} ", success / (double)failure); 44 } 45 46 public object Clone() { 47 return new BernoulliModel() { failure = this.failure, success = this.success }; 67 48 } 68 49 } -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Models/GaussianModel.cs
r11730 r11732 1 1 using System; 2 using System.Collections.Generic;3 using System.Diagnostics;4 using System.Linq;5 using System.Text;6 using System.Threading.Tasks;7 2 using HeuristicLab.Common; 8 3 9 4 namespace HeuristicLab.Algorithms.Bandits.Models { 10 // bayesian estimation of a Gaussian with unknown mean and known variance 5 // bayesian estimation of a Gaussian with 6 // 1) unknown mean and known variance 7 // 2) unknown mean and unknown variance 11 8 public class GaussianModel : IModel { 12 private readonly int numActions; 13 private readonly int[] tries; 14 private readonly double[] sumRewards; 15 9 private OnlineMeanAndVarianceEstimator estimator = new OnlineMeanAndVarianceEstimator(); 16 10 17 11 // parameters of Gaussian prior for mean … … 19 13 private readonly double meanPriorVariance; 20 14 15 private readonly bool knownVariance; 21 16 private readonly double rewardVariance = 0.1; // assumed know reward variance 22 17 23 public GaussianModel(int numActions, double meanPriorMu, double meanPriorVariance) { 24 this.numActions = numActions; 25 this.tries = new int[numActions]; 26 this.sumRewards = new double[numActions]; 18 // parameters of Gamma prior for precision (= inverse variance) 19 private readonly int precisionPriorAlpha; 20 private readonly double precisionPriorBeta; 21 22 // non-informative prior 23 private const double priorK = 1.0; 24 25 // this constructor assumes the variance is known 26 public GaussianModel(double meanPriorMu, double meanPriorVariance, double rewardVariance = 0.1) { 27 27 this.meanPriorMu = meanPriorMu; 28 28 this.meanPriorVariance = meanPriorVariance; 29 30 this.knownVariance = true; 31 this.rewardVariance = rewardVariance; 32 } 33 34 // this constructor assumes the variance is also unknown 35 // uses Murphy 2007: Conjugate Bayesian analysis of the Gaussian distribution equation 85 - 89 36 public GaussianModel(double meanPriorMu, double meanPriorVariance, int precisionPriorAlpha, double precisionPriorBeta) { 37 this.meanPriorMu = meanPriorMu; 38 this.meanPriorVariance = meanPriorVariance; 39 40 this.knownVariance = false; 41 this.precisionPriorAlpha = precisionPriorAlpha; 42 this.precisionPriorBeta = precisionPriorBeta; 29 43 } 30 44 31 45 32 public double[] SampleExpectedRewards(Random random) { 46 public double SampleExpectedReward(Random random) { 47 if (knownVariance) { 48 return SampleExpectedRewardKnownVariance(random); 49 } else { 50 return SampleExpectedRewardUnknownVariance(random); 51 } 52 } 53 54 private double SampleExpectedRewardKnownVariance(Random random) { 33 55 // expected values for reward 34 var theta = new double[numActions];56 // calculate posterior mean and variance (for mean reward) 35 57 36 for (int a = 0; a < numActions; a++) { 37 if (tries[a] == -1) { 38 theta[a] = double.NegativeInfinity; // disabled action 39 } else { 40 // calculate posterior mean and variance (for mean reward) 58 // see Murphy 2007: Conjugate Bayesian analysis of the Gaussian distribution (http://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf) 59 var posteriorMeanVariance = 1.0 / (estimator.N / rewardVariance + 1.0 / meanPriorVariance); 60 var posteriorMeanMean = posteriorMeanVariance * (meanPriorMu / meanPriorVariance + estimator.Sum / rewardVariance); 41 61 42 // see Murphy 2007: Conjugate Bayesian analysis of the Gaussian distribution (http://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf) 43 var posteriorVariance = 1.0 / (tries[a] / rewardVariance + 1.0 / meanPriorVariance); 44 var posteriorMean = posteriorVariance * (meanPriorMu / meanPriorVariance + sumRewards[a] / rewardVariance); 62 // sample a mean from the posterior 63 var posteriorMeanSample = Rand.RandNormal(random) * Math.Sqrt(posteriorMeanVariance) + posteriorMeanMean; 64 // theta already represents the expected reward value => nothing else to do 65 return posteriorMeanSample; 45 66 46 // sample a mean from the posterior 47 theta[a] = Rand.RandNormal(random) * Math.Sqrt(posteriorVariance) + posteriorMean; 48 // theta already represents the expected reward value => nothing else to do 49 } 67 // return 0.99-quantile value 68 //return alglib.invnormaldistribution(0.99) * Math.Sqrt(rewardVariance + posteriorMeanVariance) + posteriorMeanMean; 69 } 70 71 // see Murphy 2007: Conjugate Bayesian analysis of the Gaussian distribution page 6 onwards (http://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf) 72 private double SampleExpectedRewardUnknownVariance(Random random) { 73 74 var posteriorMean = (priorK * meanPriorMu + estimator.Sum) / (priorK + estimator.N); 75 var posteriorK = priorK + estimator.N; 76 var posteriorAlpha = precisionPriorAlpha + estimator.N / 2.0; 77 double posteriorBeta; 78 if (estimator.N > 0) { 79 posteriorBeta = precisionPriorBeta + 0.5 * estimator.N * estimator.Variance + priorK * estimator.N * Math.Pow(estimator.Avg - meanPriorMu, 2) / (2.0 * (priorK + estimator.N)); 80 } else { 81 posteriorBeta = precisionPriorBeta; 50 82 } 51 83 84 // sample from the posterior marginal for mu (expected value) equ. 91 85 // p(µ|D) = T2αn (µ| µn, βn/(αnκn)) 86 87 // sample from Tk distribution : http://stats.stackexchange.com/a/70270 88 var t2alpha = alglib.invstudenttdistribution((int)(2 * posteriorAlpha), random.NextDouble()); 89 90 var theta = t2alpha * posteriorBeta / (posteriorAlpha * posteriorK) + posteriorMean; 52 91 return theta; 92 93 //return alglib.invnormaldistribution(random.NextDouble()) * + theta; 94 //return alglib.invstudenttdistribution((int)(2 * posteriorAlpha), 0.99) * (posteriorBeta*posteriorK + posteriorBeta) / (posteriorAlpha*posteriorK) + posteriorMean; 53 95 } 54 96 55 public void Update(int action, double reward) {56 sumRewards[action] += reward;57 tries[action]++;58 }59 97 60 public void Disable(int action) { 61 tries[action] = -1; 62 sumRewards[action] = 0.0; 98 public void Update(double reward) { 99 estimator.UpdateReward(reward); 63 100 } 64 101 65 102 public void Reset() { 66 Array.Clear(tries, 0, numActions); 67 Array.Clear(sumRewards, 0, numActions); 103 estimator.Reset(); 68 104 } 69 105 70 106 public void PrintStats() { 71 for (int i = 0; i < numActions; i++) { 72 Console.Write("{0:F2} ", sumRewards[i] / (double)tries[i]); 73 } 107 Console.Write("{0:F2} ", estimator.Avg); 108 } 109 110 public object Clone() { 111 if (knownVariance) 112 return new GaussianModel(meanPriorMu, meanPriorVariance, rewardVariance); 113 else 114 return new GaussianModel(meanPriorMu, meanPriorVariance, precisionPriorAlpha, precisionPriorBeta); 74 115 } 75 116 } -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Models/IModel.cs
r11730 r11732 6 6 7 7 namespace HeuristicLab.Algorithms.Bandits { 8 public interface IModel {9 double[] SampleExpectedRewards(Random random);10 void Update(int action, double reward);11 void Disable(int action);8 // represents a model for the reward distribution (of an action given a state) 9 public interface IModel : ICloneable { 10 double SampleExpectedReward(Random random); 11 void Update(double reward); 12 12 void Reset(); 13 13 void PrintStats(); -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/BanditPolicy.cs
r11730 r11732 7 7 8 8 namespace HeuristicLab.Algorithms.Bandits { 9 public abstract class BanditPolicy : IPolicy{9 public abstract class BanditPolicy<TPolicyActionInfo> : IPolicy<TPolicyActionInfo> where TPolicyActionInfo : IPolicyActionInfo { 10 10 public IEnumerable<int> Actions { get; private set; } 11 11 private readonly int numInitialActions; -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/BernoulliThompsonSamplingPolicy.cs
r11730 r11732 8 8 9 9 namespace HeuristicLab.Algorithms.Bandits { 10 public class BernoulliThompsonSamplingPolicy : BanditPolicy { 11 private readonly Random random; 12 private readonly int[] success; 13 private readonly int[] failure; 14 10 public class BernoulliThompsonSamplingPolicy : IPolicy { 15 11 // parameters of beta prior distribution 16 12 private readonly double alpha = 1.0; 17 13 private readonly double beta = 1.0; 18 14 19 public BernoulliThompsonSamplingPolicy(Random random, int numActions) 20 : base(numActions) { 21 this.random = random; 22 this.success = new int[numActions]; 23 this.failure = new int[numActions]; 24 } 15 public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) { 16 var myActionInfos = actionInfos.OfType<BernoulliPolicyActionInfo>(); // TODO: performance 17 int bestAction = -1; 18 double maxTheta = double.NegativeInfinity; 19 var aIdx = -1; 25 20 26 public override int SelectAction() { 27 Debug.Assert(Actions.Any()); 28 var maxTheta = double.NegativeInfinity; 29 int bestAction = -1; 30 foreach (var a in Actions) { 31 var theta = Rand.BetaRand(random, success[a] + alpha, failure[a] + beta); 21 foreach (var aInfo in myActionInfos) { 22 aIdx++; 23 if (aInfo.Disabled) continue; 24 var theta = Rand.BetaRand(random, aInfo.NumSuccess + alpha, aInfo.NumFailure + beta); 32 25 if (theta > maxTheta) { 33 26 maxTheta = theta; 34 bestAction = a ;27 bestAction = aIdx; 35 28 } 36 29 } 30 Debug.Assert(bestAction > -1); 37 31 return bestAction; 38 32 } 39 33 40 public override void UpdateReward(int action, double reward) { 41 Debug.Assert(Actions.Contains(action)); 42 43 if (reward > 0) success[action]++; 44 else failure[action]++; 34 public IPolicyActionInfo CreateActionInfo() { 35 return new BernoulliPolicyActionInfo(); 45 36 } 46 37 47 public override void DisableAction(int action) {48 base.DisableAction(action);49 success[action] = -1;50 }51 52 public override void Reset() {53 base.Reset();54 Array.Clear(success, 0, success.Length);55 Array.Clear(failure, 0, failure.Length);56 }57 58 public override void PrintStats() {59 for (int i = 0; i < success.Length; i++) {60 if (success[i] >= 0) {61 Console.Write("{0,5:F2}", success[i] / failure[i]);62 } else {63 Console.Write("{0,5}", "");64 }65 }66 Console.WriteLine();67 }68 38 69 39 public override string ToString() { -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/BoltzmannExplorationPolicy.cs
r11730 r11732 5 5 using System.Text; 6 6 using System.Threading.Tasks; 7 using HeuristicLab.Common; 7 8 8 9 namespace HeuristicLab.Algorithms.Bandits { 9 10 // also called softmax policy 10 public class BoltzmannExplorationPolicy : BanditPolicy { 11 private readonly Random random; 12 private readonly double eps; 13 private readonly int[] tries; 14 private readonly double[] sumReward; 11 public class BoltzmannExplorationPolicy : IPolicy { 15 12 private readonly double beta; 16 13 17 public BoltzmannExplorationPolicy(Random random, int numActions, double beta) 18 : base(numActions) { 14 public BoltzmannExplorationPolicy(double beta) { 19 15 if (beta < 0) throw new ArgumentException(); 20 this.random = random;21 16 this.beta = beta; 22 this.tries = new int[numActions]; 23 this.sumReward = new double[numActions]; 17 } 18 public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) { 19 Debug.Assert(actionInfos.Any()); 20 21 // select best 22 var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>().ToArray(); // TODO: performance 23 Debug.Assert(myActionInfos.Any(a => !a.Disabled)); 24 double[] w = new double[myActionInfos.Length]; 25 26 for (int a = 0; a < myActionInfos.Length; a++) { 27 if (myActionInfos[a].Disabled) { 28 w[a] = 0; continue; 29 } 30 if (myActionInfos[a].Tries == 0) return a; 31 var sumReward = myActionInfos[a].SumReward; 32 var tries = myActionInfos[a].Tries; 33 var avgReward = sumReward / tries; 34 w[a] = Math.Exp(beta * avgReward); 35 } 36 37 38 var bestAction = Enumerable.Range(0, w.Length).SampleProportional(random, w).First(); 39 Debug.Assert(bestAction >= 0); 40 Debug.Assert(bestAction < w.Length); 41 Debug.Assert(!myActionInfos[bestAction].Disabled); 42 return bestAction; 24 43 } 25 44 26 public override int SelectAction() { 27 Debug.Assert(Actions.Any()); 28 // select best 29 var maxReward = double.NegativeInfinity; 30 int bestAction = -1; 31 if (Actions.Any(a => tries[a] == 0)) 32 return Actions.First(a => tries[a] == 0); 33 34 var ts = Actions.Select(a => Math.Exp(beta * sumReward[a] / tries[a])); 35 var r = random.NextDouble() * ts.Sum(); 36 37 var agg = 0.0; 38 foreach (var p in Actions.Zip(ts, Tuple.Create)) { 39 agg += p.Item2; 40 if (agg >= r) return p.Item1; 41 } 42 throw new InvalidProgramException(); 43 } 44 public override void UpdateReward(int action, double reward) { 45 Debug.Assert(Actions.Contains(action)); 46 47 tries[action]++; 48 sumReward[action] += reward; 49 } 50 51 public override void DisableAction(int action) { 52 base.DisableAction(action); 53 sumReward[action] = 0; 54 tries[action] = -1; 55 } 56 57 public override void Reset() { 58 base.Reset(); 59 Array.Clear(tries, 0, tries.Length); 60 Array.Clear(sumReward, 0, sumReward.Length); 61 } 62 63 public override void PrintStats() { 64 for (int i = 0; i < sumReward.Length; i++) { 65 if (tries[i] >= 0) { 66 Console.Write("{0,5:F2}", sumReward[i] / tries[i]); 67 } else { 68 Console.Write("{0,5}", ""); 69 } 70 } 71 Console.WriteLine(); 45 public IPolicyActionInfo CreateActionInfo() { 46 return new DefaultPolicyActionInfo(); 72 47 } 73 48 -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/ChernoffIntervalEstimationPolicy.cs
r11730 r11732 10 10 International Conference, CP 2006, Nantes, France, September 25-29, 2006. pp 560-574 */ 11 11 12 public class ChernoffIntervalEstimationPolicy : BanditPolicy { 13 private readonly int[] tries; 14 private readonly double[] sumReward; 15 private int totalTries = 0; 12 public class ChernoffIntervalEstimationPolicy : IPolicy { 16 13 private readonly double delta; 17 14 18 public ChernoffIntervalEstimationPolicy(int numActions, double delta = 0.01) 19 : base(numActions) { 15 public ChernoffIntervalEstimationPolicy(double delta = 0.01) { 20 16 this.delta = delta; 21 this.tries = new int[numActions];22 this.sumReward = new double[numActions];23 17 } 24 25 public override int SelectAction() { 18 public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) { 19 Debug.Assert(actionInfos.Any()); 20 // select best 21 var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>().ToArray(); // TODO: performance 22 int k = myActionInfos.Length; 23 int totalTries = myActionInfos.Where(a => !a.Disabled).Sum(a => a.Tries); 26 24 int bestAction = -1; 27 25 double bestQ = double.NegativeInfinity; 28 double k = Actions.Count(); 29 Debug.Assert(k > 0); 30 foreach (var a in Actions) { 31 if (tries[a] == 0) return a; 26 for (int a = 0; a < myActionInfos.Length; a++) { 27 if (myActionInfos[a].Disabled) continue; 28 if (myActionInfos[a].Tries == 0) return a; 29 30 var sumReward = myActionInfos[a].SumReward; 31 var tries = myActionInfos[a].Tries; 32 33 var avgReward = sumReward / tries; 34 32 35 // page 5 of "A simple distribution-free appraoch to the max k-armed bandit problem" 33 36 // var alpha = Math.Log(2 * totalTries * k / delta); 34 37 double alpha = Math.Log(2) + Math.Log(totalTries) + Math.Log(k) - Math.Log(delta); // total tries is max tries in the original paper 35 double mu = sumReward[a] / tries[a]; 36 var q = mu + (alpha + Math.Sqrt(2 * tries[a] * mu * alpha + alpha * alpha)) / tries[a]; 38 var q = avgReward + (alpha + Math.Sqrt(2 * tries * avgReward * alpha + alpha * alpha)) / tries; 37 39 if (q > bestQ) { 38 40 bestQ = q; … … 40 42 } 41 43 } 44 Debug.Assert(bestAction >= 0); 42 45 return bestAction; 43 46 } 44 public override void UpdateReward(int action, double reward) { 45 Debug.Assert(Actions.Contains(action)); 46 totalTries++; 47 tries[action]++; 48 sumReward[action] += reward; 47 48 public IPolicyActionInfo CreateActionInfo() { 49 return new DefaultPolicyActionInfo(); 49 50 } 50 51 51 public override void DisableAction(int action) {52 base.DisableAction(action);53 totalTries -= tries[action];54 tries[action] = -1;55 sumReward[action] = 0;56 }57 58 public override void Reset() {59 base.Reset();60 totalTries = 0;61 Array.Clear(tries, 0, tries.Length);62 Array.Clear(sumReward, 0, sumReward.Length);63 }64 65 public override void PrintStats() {66 for (int i = 0; i < sumReward.Length; i++) {67 if (tries[i] >= 0) {68 Console.Write("{0,5:F2}", sumReward[i] / tries[i]);69 } else {70 Console.Write("{0,5}", "");71 }72 }73 Console.WriteLine();74 }75 52 public override string ToString() { 76 53 return string.Format("ChernoffIntervalEstimationPolicy({0:F2})", delta); -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/EpsGreedyPolicy.cs
r11730 r11732 7 7 8 8 namespace HeuristicLab.Algorithms.Bandits { 9 public class EpsGreedyPolicy : BanditPolicy { 10 private readonly Random random; 9 public class EpsGreedyPolicy : IPolicy { 11 10 private readonly double eps; 12 private readonly int[] tries;13 private readonly double[] sumReward;14 11 private readonly RandomPolicy randomPolicy; 15 12 16 public EpsGreedyPolicy(Random random, int numActions, double eps) 17 : base(numActions) { 18 this.random = random; 13 public EpsGreedyPolicy(double eps) { 19 14 this.eps = eps; 20 this.randomPolicy = new RandomPolicy(random, numActions); 21 this.tries = new int[numActions]; 22 this.sumReward = new double[numActions]; 15 this.randomPolicy = new RandomPolicy(); 23 16 } 24 25 public override int SelectAction() { 26 Debug.Assert(Actions.Any()); 17 public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) { 18 Debug.Assert(actionInfos.Any()); 27 19 if (random.NextDouble() > eps) { 28 20 // select best 29 var bestQ = double.NegativeInfinity;21 var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>(); 30 22 int bestAction = -1; 31 foreach (var a in Actions) { 32 if (tries[a] == 0) return a; 33 var q = sumReward[a] / tries[a]; 34 if (bestQ < q) { 23 double bestQ = double.NegativeInfinity; 24 int aIdx = -1; 25 foreach (var aInfo in myActionInfos) { 26 27 aIdx++; 28 if (aInfo.Disabled) continue; 29 if (aInfo.Tries == 0) return aIdx; 30 31 32 var avgReward = aInfo.SumReward / aInfo.Tries; 33 //var q = avgReward; 34 var q = aInfo.MaxReward; 35 if (q > bestQ) { 35 36 bestQ = q; 36 bestAction = a ;37 bestAction = aIdx; 37 38 } 38 39 } … … 41 42 } else { 42 43 // select random 43 return randomPolicy.SelectAction( );44 return randomPolicy.SelectAction(random, actionInfos); 44 45 } 45 46 } 46 public override void UpdateReward(int action, double reward) {47 Debug.Assert(Actions.Contains(action));48 47 49 randomPolicy.UpdateReward(action, reward); // does nothing 50 tries[action]++; 51 sumReward[action] += reward; 48 public IPolicyActionInfo CreateActionInfo() { 49 return new DefaultPolicyActionInfo(); 52 50 } 53 51 54 public override void DisableAction(int action) {55 base.DisableAction(action);56 randomPolicy.DisableAction(action);57 sumReward[action] = 0;58 tries[action] = -1;59 }60 52 61 public override void Reset() {62 base.Reset();63 randomPolicy.Reset();64 Array.Clear(tries, 0, tries.Length);65 Array.Clear(sumReward, 0, sumReward.Length);66 }67 public override void PrintStats() {68 for (int i = 0; i < sumReward.Length; i++) {69 if (tries[i] >= 0) {70 Console.Write(" {0,5:F2} {1}", sumReward[i] / tries[i], tries[i]);71 } else {72 Console.Write("-", "");73 }74 }75 Console.WriteLine();76 }77 53 public override string ToString() { 78 54 return string.Format("EpsGreedyPolicy({0:F2})", eps); -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/GaussianThompsonSamplingPolicy.cs
r11730 r11732 1 1 using System; 2 using System.Collections.Generic; 2 3 using System.Diagnostics; 3 4 using System.Linq; … … 5 6 6 7 namespace HeuristicLab.Algorithms.Bandits { 7 8 public class GaussianThompsonSamplingPolicy : BanditPolicy { 9 private readonly Random random; 10 private readonly double[] sampleMean; 11 private readonly double[] sampleM2; 12 private readonly int[] tries; 8 9 public class GaussianThompsonSamplingPolicy : IPolicy { 13 10 private bool compatibility; 14 11 … … 21 18 22 19 23 public GaussianThompsonSamplingPolicy(Random random, int numActions, bool compatibility = false) 24 : base(numActions) { 25 this.random = random; 26 this.sampleMean = new double[numActions]; 27 this.sampleM2 = new double[numActions]; 28 this.tries = new int[numActions]; 20 public GaussianThompsonSamplingPolicy(bool compatibility = false) { 29 21 this.compatibility = compatibility; 30 22 } 31 23 24 public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) { 25 var myActionInfos = actionInfos.OfType<MeanAndVariancePolicyActionInfo>(); 26 int bestAction = -1; 27 double bestQ = double.NegativeInfinity; 32 28 33 public override int SelectAction() { 34 Debug.Assert(Actions.Any()); 35 var maxTheta = double.NegativeInfinity; 36 int bestAction = -1; 37 foreach (var a in Actions) { 38 if(tries[a] == -1) continue; // skip disabled actions 29 int aIdx = -1; 30 foreach (var aInfo in myActionInfos) { 31 aIdx++; 32 if (aInfo.Disabled) continue; 33 34 var tries = aInfo.Tries; 35 var sampleMean = aInfo.AvgReward; 36 var sampleVariance = aInfo.RewardVariance; 37 39 38 double theta; 40 39 if (compatibility) { 41 if (tries [a] < 2) return a;42 var mu = sampleMean [a];43 var variance = sample M2[a] / tries[a];40 if (tries < 2) return aIdx; 41 var mu = sampleMean; 42 var variance = sampleVariance; 44 43 var stdDev = Math.Sqrt(variance); 45 44 theta = Rand.RandNormal(random) * stdDev + mu; … … 48 47 49 48 // see Murphy 2007: Conjugate Bayesian analysis of the Gaussian distribution (http://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf) 50 var posteriorVariance = 1.0 / (tries [a]/ rewardVariance + 1.0 / priorVariance);51 var posteriorMean = posteriorVariance * (priorMean / priorVariance + tries [a] * sampleMean[a]/ rewardVariance);49 var posteriorVariance = 1.0 / (tries / rewardVariance + 1.0 / priorVariance); 50 var posteriorMean = posteriorVariance * (priorMean / priorVariance + tries * sampleMean / rewardVariance); 52 51 53 52 // sample a mean from the posterior … … 56 55 // theta already represents the expected reward value => nothing else to do 57 56 } 58 if (theta > maxTheta) { 59 maxTheta = theta; 60 bestAction = a; 57 58 if (theta > bestQ) { 59 bestQ = theta; 60 bestAction = aIdx; 61 61 } 62 62 } 63 Debug.Assert( Actions.Contains(bestAction));63 Debug.Assert(bestAction > -1); 64 64 return bestAction; 65 65 } 66 66 67 public override void UpdateReward(int action, double reward) { 68 Debug.Assert(Actions.Contains(action)); 69 tries[action]++; 70 var delta = reward - sampleMean[action]; 71 sampleMean[action] += delta / tries[action]; 72 sampleM2[action] += sampleM2[action] + delta * (reward - sampleMean[action]); 67 public IPolicyActionInfo CreateActionInfo() { 68 return new MeanAndVariancePolicyActionInfo(); 73 69 } 74 70 75 public override void DisableAction(int action) {76 base.DisableAction(action);77 sampleMean[action] = 0;78 sampleM2[action] = 0;79 tries[action] = -1;80 }81 71 82 public override void Reset() { 83 base.Reset(); 84 Array.Clear(sampleMean, 0, sampleMean.Length); 85 Array.Clear(sampleM2, 0, sampleM2.Length); 86 Array.Clear(tries, 0, tries.Length); 87 } 72 //public override void UpdateReward(int action, double reward) { 73 // Debug.Assert(Actions.Contains(action)); 74 // tries[action]++; 75 // var delta = reward - sampleMean[action]; 76 // sampleMean[action] += delta / tries[action]; 77 // sampleM2[action] += sampleM2[action] + delta * (reward - sampleMean[action]); 78 //} 88 79 89 public override void PrintStats() {90 for (int i = 0; i < sampleMean.Length; i++) {91 if (tries[i] >= 0) {92 Console.Write(" {0,5:F2} {1}", sampleMean[i] / tries[i], tries[i]);93 } else {94 Console.Write("{0,5}", "");95 }96 }97 Console.WriteLine();98 }99 80 public override string ToString() { 100 81 return "GaussianThompsonSamplingPolicy"; -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/GenericThompsonSamplingPolicy.cs
r11730 r11732 8 8 9 9 namespace HeuristicLab.Algorithms.Bandits { 10 public class GenericThompsonSamplingPolicy : BanditPolicy { 11 private readonly Random random; 10 public class GenericThompsonSamplingPolicy : IPolicy { 12 11 private readonly IModel model; 13 12 14 public GenericThompsonSamplingPolicy(Random random, int numActions, IModel model) 15 : base(numActions) { 16 this.random = random; 13 public GenericThompsonSamplingPolicy(IModel model) { 17 14 this.model = model; 18 15 } 19 16 20 public override int SelectAction() { 21 Debug.Assert(Actions.Any()); 22 var maxR = double.NegativeInfinity; 17 public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) { 18 var myActionInfos = actionInfos.OfType<ModelPolicyActionInfo>(); 23 19 int bestAction = -1; 24 var expRewards = model.SampleExpectedRewards(random); 25 foreach (var a in Actions) { 26 var r = expRewards[a]; 27 if (r > maxR) { 28 maxR = r; 29 bestAction = a; 20 double bestQ = double.NegativeInfinity; 21 var aIdx = -1; 22 foreach (var aInfo in myActionInfos) { 23 aIdx++; 24 if (aInfo.Disabled) continue; 25 //if (aInfo.Tries == 0) return aIdx; 26 var q = aInfo.SampleExpectedReward(random); 27 if (q > bestQ) { 28 bestQ = q; 29 bestAction = aIdx; 30 30 } 31 31 } 32 Debug.Assert(bestAction > -1); 32 33 return bestAction; 33 34 } 34 35 35 public override void UpdateReward(int action, double reward) { 36 Debug.Assert(Actions.Contains(action)); 37 38 model.Update(action, reward); 39 } 40 41 public override void DisableAction(int action) { 42 base.DisableAction(action); 43 model.Disable(action); 44 } 45 46 public override void Reset() { 47 base.Reset(); 48 model.Reset(); 49 } 50 51 public override void PrintStats() { 52 model.PrintStats(); 36 public IPolicyActionInfo CreateActionInfo() { 37 return new ModelPolicyActionInfo((IModel)model.Clone()); 53 38 } 54 39 -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/RandomPolicy.cs
r11730 r11732 8 8 9 9 namespace HeuristicLab.Algorithms.Bandits { 10 public class RandomPolicy : BanditPolicy { 11 private readonly Random random; 10 public class RandomPolicy : IPolicy { 12 11 13 public RandomPolicy(Random random, int numActions)14 : base(numActions) {15 this.random = random;16 }17 18 public override int SelectAction() {19 Debug.Assert(Actions.Any());20 return Actions.SelectRandom(random);21 }22 public override void UpdateReward(int action, double reward) {23 // do nothing24 }25 public override void PrintStats() {26 Console.WriteLine("Random");27 }28 12 public override string ToString() { 29 13 return "RandomPolicy"; 30 14 } 15 16 public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) { 17 return actionInfos 18 .Select((a, i) => Tuple.Create(a, i)) 19 .Where(p => !p.Item1.Disabled) 20 .SelectRandom(random).Item2; 21 } 22 23 public IPolicyActionInfo CreateActionInfo() { 24 return new EmptyPolicyActionInfo(); 25 } 31 26 } 32 27 } -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCB1Policy.cs
r11730 r11732 7 7 8 8 namespace HeuristicLab.Algorithms.Bandits { 9 public class UCB1Policy : BanditPolicy { 10 private readonly int[] tries; 11 private readonly double[] sumReward; 12 private int totalTries = 0; 13 public UCB1Policy(int numActions) 14 : base(numActions) { 15 this.tries = new int[numActions]; 16 this.sumReward = new double[numActions]; 17 } 18 19 public override int SelectAction() { 9 public class UCB1Policy : IPolicy { 10 public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) { 11 var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>().ToArray(); // TODO: performance 20 12 int bestAction = -1; 21 13 double bestQ = double.NegativeInfinity; 22 foreach (var a in Actions) { 23 if (tries[a] == 0) return a; 24 var q = sumReward[a] / tries[a] + Math.Sqrt((2 * Math.Log(totalTries)) / tries[a]); 14 int totalTries = myActionInfos.Where(a => !a.Disabled).Sum(a => a.Tries); 15 16 for (int a = 0; a < myActionInfos.Length; a++) { 17 if (myActionInfos[a].Disabled) continue; 18 if (myActionInfos[a].Tries == 0) return a; 19 var q = myActionInfos[a].SumReward / myActionInfos[a].Tries + Math.Sqrt((2 * Math.Log(totalTries)) / myActionInfos[a].Tries); 25 20 if (q > bestQ) { 26 21 bestQ = q; … … 28 23 } 29 24 } 25 Debug.Assert(bestAction > -1); 30 26 return bestAction; 31 27 } 32 public override void UpdateReward(int action, double reward) {33 Debug.Assert(Actions.Contains(action));34 totalTries++;35 tries[action]++;36 sumReward[action] += reward;37 }38 28 39 public override void DisableAction(int action) { 40 base.DisableAction(action); 41 totalTries -= tries[action]; 42 tries[action] = -1; 43 sumReward[action] = 0; 44 } 45 46 public override void Reset() { 47 base.Reset(); 48 totalTries = 0; 49 Array.Clear(tries, 0, tries.Length); 50 Array.Clear(sumReward, 0, sumReward.Length); 51 } 52 public override void PrintStats() { 53 for (int i = 0; i < sumReward.Length; i++) { 54 if (tries[i] >= 0) { 55 Console.Write("{0,5:F2}", sumReward[i] / tries[i]); 56 } else { 57 Console.Write("{0,5}", ""); 58 } 59 } 60 Console.WriteLine(); 29 public IPolicyActionInfo CreateActionInfo() { 30 return new DefaultPolicyActionInfo(); 61 31 } 62 32 public override string ToString() { -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCB1TunedPolicy.cs
r11730 r11732 7 7 8 8 namespace HeuristicLab.Algorithms.Bandits { 9 public class UCB1TunedPolicy : BanditPolicy { 10 private readonly int[] tries; 11 private readonly double[] sumReward; 12 private readonly double[] sumSqrReward; 13 private int totalTries = 0; 14 public UCB1TunedPolicy(int numActions) 15 : base(numActions) { 16 this.tries = new int[numActions]; 17 this.sumReward = new double[numActions]; 18 this.sumSqrReward = new double[numActions]; 19 } 9 public class UCB1TunedPolicy : IPolicy { 20 10 21 private double V(int arm) { 22 var s = tries[arm]; 23 return sumSqrReward[arm] / s - Math.Pow(sumReward[arm] / s, 2) + Math.Sqrt(2 * Math.Log(totalTries) / s); 24 } 25 26 27 public override int SelectAction() { 28 Debug.Assert(Actions.Any()); 11 public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) { 12 var myActionInfos = actionInfos.OfType<MeanAndVariancePolicyActionInfo>().ToArray(); // TODO: performance 29 13 int bestAction = -1; 30 14 double bestQ = double.NegativeInfinity; 31 foreach (var a in Actions) { 32 if (tries[a] == 0) return a; 33 var q = sumReward[a] / tries[a] + Math.Sqrt((Math.Log(totalTries) / tries[a]) * Math.Min(1.0 / 4, V(a))); // 1/4 is upper bound of bernoulli distributed variable 15 int totalTries = myActionInfos.Where(a => !a.Disabled).Sum(a => a.Tries); 16 17 for (int a = 0; a < myActionInfos.Length; a++) { 18 if (myActionInfos[a].Disabled) continue; 19 if (myActionInfos[a].Tries == 0) return a; 20 21 var sumReward = myActionInfos[a].SumReward; 22 var tries = myActionInfos[a].Tries; 23 24 var avgReward = sumReward / tries; 25 var q = avgReward + Math.Sqrt((Math.Log(totalTries) / tries) * Math.Min(1.0 / 4, V(myActionInfos[a], totalTries))); // 1/4 is upper bound of bernoulli distributed variable 34 26 if (q > bestQ) { 35 27 bestQ = q; … … 37 29 } 38 30 } 31 Debug.Assert(bestAction > -1); 39 32 return bestAction; 40 33 } 41 public override void UpdateReward(int action, double reward) { 42 Debug.Assert(Actions.Contains(action)); 43 totalTries++; 44 tries[action]++; 45 sumReward[action] += reward; 46 sumSqrReward[action] += reward * reward; 34 35 public IPolicyActionInfo CreateActionInfo() { 36 return new MeanAndVariancePolicyActionInfo(); 47 37 } 48 38 49 public override void DisableAction(int action) { 50 base.DisableAction(action); 51 totalTries -= tries[action]; 52 tries[action] = -1; 53 sumReward[action] = 0; 54 sumSqrReward[action] = 0; 39 private double V(MeanAndVariancePolicyActionInfo actionInfo, int totalTries) { 40 var s = actionInfo.Tries; 41 return actionInfo.RewardVariance + Math.Sqrt(2 * Math.Log(totalTries) / s); 55 42 } 56 43 57 public override void Reset() {58 base.Reset();59 totalTries = 0;60 Array.Clear(tries, 0, tries.Length);61 Array.Clear(sumReward, 0, sumReward.Length);62 Array.Clear(sumSqrReward, 0, sumSqrReward.Length);63 }64 public override void PrintStats() {65 for (int i = 0; i < sumReward.Length; i++) {66 if (tries[i] >= 0) {67 Console.Write("{0,5:F2}", sumReward[i] / tries[i]);68 } else {69 Console.Write("{0,5}", "");70 }71 }72 Console.WriteLine();73 }74 44 public override string ToString() { 75 45 return "UCB1TunedPolicy"; -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCBNormalPolicy.cs
r11730 r11732 7 7 8 8 namespace HeuristicLab.Algorithms.Bandits { 9 public class UCBNormalPolicy : BanditPolicy { 10 private readonly int[] tries; 11 private readonly double[] sumReward; 12 private readonly double[] sumSqrReward; 13 private int totalTries = 0; 14 public UCBNormalPolicy(int numActions) 15 : base(numActions) { 16 this.tries = new int[numActions]; 17 this.sumReward = new double[numActions]; 18 this.sumSqrReward = new double[numActions]; 19 } 9 public class UCBNormalPolicy : IPolicy { 20 10 21 public override int SelectAction() {22 Debug.Assert(Actions.Any());11 public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) { 12 var myActionInfos = actionInfos.OfType<MeanAndVariancePolicyActionInfo>().ToArray(); // TODO: performance 23 13 int bestAction = -1; 24 14 double bestQ = double.NegativeInfinity; 25 foreach (var a in Actions) { 26 if (totalTries <= 1 || tries[a] <= 1 || tries[a] <= Math.Ceiling(8 * Math.Log(totalTries))) return a; 27 var avgReward = sumReward[a] / tries[a]; 28 var estVariance = 16 * ((sumSqrReward[a] - tries[a] * Math.Pow(avgReward, 2)) / (tries[a] - 1)) * (Math.Log(totalTries - 1) / tries[a]); 15 int totalTries = myActionInfos.Where(a => !a.Disabled).Sum(a => a.Tries); 16 17 for (int a = 0; a < myActionInfos.Length; a++) { 18 if (myActionInfos[a].Disabled) continue; 19 if (totalTries <= 1 || myActionInfos[a].Tries <= 1 || myActionInfos[a].Tries <= Math.Ceiling(8 * Math.Log(totalTries))) return a; 20 21 var tries = myActionInfos[a].Tries; 22 var avgReward = myActionInfos[a].AvgReward; 23 var rewardVariance = myActionInfos[a].RewardVariance; 24 var estVariance = 16 * rewardVariance * (Math.Log(totalTries - 1) / tries); 29 25 if (estVariance < 0) estVariance = 0; // numerical problems 30 26 var q = avgReward … … 35 31 } 36 32 } 37 Debug.Assert( Actions.Contains(bestAction));33 Debug.Assert(bestAction > -1); 38 34 return bestAction; 39 35 } 40 public override void UpdateReward(int action, double reward) { 41 Debug.Assert(Actions.Contains(action)); 42 totalTries++; 43 tries[action]++; 44 sumReward[action] += reward; 45 sumSqrReward[action] += reward * reward; 36 37 public IPolicyActionInfo CreateActionInfo() { 38 return new MeanAndVariancePolicyActionInfo(); 46 39 } 47 40 48 public override void DisableAction(int action) {49 base.DisableAction(action);50 totalTries -= tries[action];51 tries[action] = -1;52 sumReward[action] = 0;53 sumSqrReward[action] = 0;54 }55 56 public override void Reset() {57 base.Reset();58 totalTries = 0;59 Array.Clear(tries, 0, tries.Length);60 Array.Clear(sumReward, 0, sumReward.Length);61 Array.Clear(sumSqrReward, 0, sumSqrReward.Length);62 }63 public override void PrintStats() {64 for (int i = 0; i < sumReward.Length; i++) {65 if (tries[i] >= 0) {66 Console.Write("{0,5:F2}", sumReward[i] / tries[i]);67 } else {68 Console.Write("{0,5}", "");69 }70 }71 Console.WriteLine();72 }73 41 public override string ToString() { 74 42 return "UCBNormalPolicy"; -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCTPolicy.cs
r11730 r11732 8 8 namespace HeuristicLab.Algorithms.Bandits { 9 9 /* Kocsis et al. Bandit based Monte-Carlo Planning */ 10 public class UCTPolicy : BanditPolicy { 11 private readonly int[] tries; 12 private readonly double[] sumReward; 13 private int totalTries = 0; 10 public class UCTPolicy : IPolicy { 14 11 private readonly double c; 15 12 16 public UCTPolicy(int numActions, double c = 1.0) 17 : base(numActions) { 18 this.tries = new int[numActions]; 19 this.sumReward = new double[numActions]; 13 public UCTPolicy(double c = 1.0) { 20 14 this.c = c; 21 15 } 22 16 23 public override int SelectAction() { 17 18 public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) { 19 var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>().ToArray(); // TODO: performance 24 20 int bestAction = -1; 25 21 double bestQ = double.NegativeInfinity; 26 foreach (var a in Actions) { 27 if (tries[a] == 0) return a; 28 var q = sumReward[a] / tries[a] + 2 * c * Math.Sqrt(Math.Log(totalTries) / tries[a]); 22 int totalTries = myActionInfos.Where(a => !a.Disabled).Sum(a => a.Tries); 23 24 for (int a = 0; a < myActionInfos.Length; a++) { 25 if (myActionInfos[a].Disabled) continue; 26 if (myActionInfos[a].Tries == 0) return a; 27 var q = myActionInfos[a].SumReward / myActionInfos[a].Tries + 2 * c * Math.Sqrt(Math.Log(totalTries) / myActionInfos[a].Tries); 29 28 if (q > bestQ) { 30 29 bestQ = q; … … 32 31 } 33 32 } 33 Debug.Assert(bestAction > -1); 34 34 return bestAction; 35 35 } 36 public override void UpdateReward(int action, double reward) { 37 Debug.Assert(Actions.Contains(action)); 38 totalTries++; 39 tries[action]++; 40 sumReward[action] += reward; 36 37 public IPolicyActionInfo CreateActionInfo() { 38 return new DefaultPolicyActionInfo(); 41 39 } 42 40 43 public override void DisableAction(int action) {44 base.DisableAction(action);45 totalTries -= tries[action];46 tries[action] = -1;47 sumReward[action] = 0;48 }49 50 public override void Reset() {51 base.Reset();52 totalTries = 0;53 Array.Clear(tries, 0, tries.Length);54 Array.Clear(sumReward, 0, sumReward.Length);55 }56 public override void PrintStats() {57 for (int i = 0; i < sumReward.Length; i++) {58 if (tries[i] >= 0) {59 Console.Write("{0,5:F2}", sumReward[i] / tries[i]);60 } else {61 Console.Write("{0,5}", "");62 }63 }64 Console.WriteLine();65 }66 41 public override string ToString() { 67 42 return string.Format("UCTPolicy({0:F2})", c);
Note: See TracChangeset
for help on using the changeset viewer.