Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
01/07/15 09:21:46 (8 years ago)
Author:
gkronber
Message:

#2283: refactoring and bug fixes

Location:
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits
Files:
8 added
17 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/HeuristicLab.Algorithms.Bandits.csproj

    r11730 r11732  
    3131  </PropertyGroup>
    3232  <ItemGroup>
     33    <Reference Include="ALGLIB-3.7.0">
     34      <HintPath>..\..\..\trunk\sources\bin\ALGLIB-3.7.0.dll</HintPath>
     35    </Reference>
    3336    <Reference Include="System" />
    3437    <Reference Include="System.Core" />
     
    4245    <Compile Include="BanditHelper.cs" />
    4346    <Compile Include="Bandits\BernoulliBandit.cs" />
     47    <Compile Include="Bandits\GaussianBandit.cs" />
    4448    <Compile Include="Bandits\GaussianMixtureBandit.cs" />
    4549    <Compile Include="Bandits\IBandit.cs" />
    4650    <Compile Include="Bandits\TruncatedNormalBandit.cs" />
     51    <Compile Include="OnlineMeanAndVarianceEstimator.cs" />
     52    <Compile Include="IPolicyActionInfo.cs" />
    4753    <Compile Include="Models\BernoulliModel.cs" />
    4854    <Compile Include="Models\GaussianModel.cs" />
    49     <Compile Include="Models\GaussianMixtureModel.cs" />
    5055    <Compile Include="Models\IModel.cs" />
    51     <Compile Include="Policies\BanditPolicy.cs" />
    52     <Compile Include="Policies\BernoulliThompsonSamplingPolicy.cs" />
    53     <Compile Include="Policies\BoltzmannExplorationPolicy.cs" />
    54     <Compile Include="Policies\ChernoffIntervalEstimationPolicy.cs" />
    55     <Compile Include="Policies\GenericThompsonSamplingPolicy.cs" />
    56     <Compile Include="Policies\ThresholdAscentPolicy.cs" />
    57     <Compile Include="Policies\UCTPolicy.cs" />
    58     <Compile Include="Policies\GaussianThompsonSamplingPolicy.cs" />
    59     <Compile Include="Policies\Exp3Policy.cs" />
    60     <Compile Include="Policies\EpsGreedyPolicy.cs" />
     56    <Compile Include="Policies\BernoulliThompsonSamplingPolicy.cs">
     57      <SubType>Code</SubType>
     58    </Compile>
     59    <Compile Include="Policies\BoltzmannExplorationPolicy.cs">
     60      <SubType>Code</SubType>
     61    </Compile>
     62    <Compile Include="Policies\ChernoffIntervalEstimationPolicy.cs">
     63      <SubType>Code</SubType>
     64    </Compile>
     65    <Compile Include="Policies\BernoulliPolicyActionInfo.cs" />
     66    <Compile Include="Policies\ModelPolicyActionInfo.cs" />
     67    <Compile Include="Policies\EpsGreedyPolicy.cs">
     68      <SubType>Code</SubType>
     69    </Compile>
     70    <Compile Include="Policies\GaussianThompsonSamplingPolicy.cs">
     71      <SubType>Code</SubType>
     72    </Compile>
     73    <Compile Include="Policies\GenericThompsonSamplingPolicy.cs">
     74      <SubType>Code</SubType>
     75    </Compile>
     76    <Compile Include="Policies\MeanAndVariancePolicyActionInfo.cs" />
     77    <Compile Include="Policies\DefaultPolicyActionInfo.cs" />
     78    <Compile Include="Policies\EmptyPolicyActionInfo.cs" />
    6179    <Compile Include="Policies\RandomPolicy.cs" />
    6280    <Compile Include="Policies\UCB1Policy.cs" />
    63     <Compile Include="Policies\UCB1TunedPolicy.cs" />
    64     <Compile Include="Policies\UCBNormalPolicy.cs" />
    6581    <Compile Include="IPolicy.cs" />
     82    <Compile Include="Policies\UCB1TunedPolicy.cs">
     83      <SubType>Code</SubType>
     84    </Compile>
     85    <Compile Include="Policies\UCBNormalPolicy.cs">
     86      <SubType>Code</SubType>
     87    </Compile>
     88    <Compile Include="Policies\UCTPolicy.cs">
     89      <SubType>Code</SubType>
     90    </Compile>
    6691    <Compile Include="Properties\AssemblyInfo.cs" />
    6792  </ItemGroup>
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/IPolicy.cs

    r11730 r11732  
    88  // this interface represents a policy for reinforcement learning
    99  public interface IPolicy {
    10     IEnumerable<int> Actions { get; }
    11     int SelectAction(); // action selection ...
    12     void UpdateReward(int action, double reward); // ... and reward update are defined as usual
    13 
    14     // policies must also support disabling of potential actions
    15     // for instance if we know that an action in a state has a deterministic
    16     // reward we need to sample it only once
    17     // it is necessary to sample an action only once
    18     void DisableAction(int action);
    19 
    20     // reset causes the policy to be reinitialized to it's initial state (as after constructor-call)
    21     void Reset();
    22 
    23     void PrintStats();
     10    int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos);
     11    IPolicyActionInfo CreateActionInfo();
    2412  }
    2513}
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Models/BernoulliModel.cs

    r11730 r11732  
    99namespace HeuristicLab.Algorithms.Bandits.Models {
    1010  public class BernoulliModel : IModel {
    11     private readonly int numActions;
    12     private readonly int[] success;
    13     private readonly int[] failure;
     11    private int success;
     12    private int failure;
    1413
    1514    // parameters of beta prior distribution
     
    1716    private readonly double beta;
    1817
    19     public BernoulliModel(int numActions, double alpha = 1.0, double beta = 1.0) {
    20       this.numActions = numActions;
    21       this.success = new int[numActions];
    22       this.failure = new int[numActions];
     18    public BernoulliModel(double alpha = 1.0, double beta = 1.0) {
    2319      this.alpha = alpha;
    2420      this.beta = beta;
    2521    }
    2622
    27 
    28     public double[] SampleExpectedRewards(Random random) {
     23    public double SampleExpectedReward(Random random) {
    2924      // sample bernoulli mean from beta prior
    30       var theta = new double[numActions];
    31       for (int a = 0; a < numActions; a++) {
    32         if (success[a] == -1)
    33           theta[a] = 0.0;
    34         else {
    35           theta[a] = Rand.BetaRand(random, success[a] + alpha, failure[a] + beta);
    36         }
    37       }
    38 
    39       // no need to sample we know the exact expected value
    40       // the expected value of a bernoulli variable is just theta
    41       return theta.Select(t => t).ToArray();
     25      return Rand.BetaRand(random, success + alpha, failure + beta);
    4226    }
    4327
    44     public void Update(int action, double reward) {
    45       const double EPSILON = 1E-6;
    46       Debug.Assert(Math.Abs(reward - 0.0) < EPSILON || Math.Abs(reward - 1.0) < EPSILON);
    47       if (Math.Abs(reward - 1.0) < EPSILON) {
    48         success[action]++;
     28    public void Update(double reward) {
     29      Debug.Assert(reward.IsAlmost(1.0) || reward.IsAlmost(0.0));
     30      if (reward.IsAlmost(1.0)) {
     31        success++;
    4932      } else {
    50         failure[action]++;
     33        failure++;
    5134      }
    5235    }
    5336
    54     public void Disable(int action) {
    55       success[action] = -1;
    56     }
    57 
    5837    public void Reset() {
    59       Array.Clear(success, 0, numActions);
    60       Array.Clear(failure, 0, numActions);
     38      success = 0;
     39      failure = 0;
    6140    }
    6241
    6342    public void PrintStats() {
    64       for (int i = 0; i < numActions; i++) {
    65         Console.Write("{0:F2} ", success[i] / (double)failure[i]);
    66       }
     43      Console.Write("{0:F2} ", success / (double)failure);
     44    }
     45
     46    public object Clone() {
     47      return new BernoulliModel() { failure = this.failure, success = this.success };
    6748    }
    6849  }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Models/GaussianModel.cs

    r11730 r11732  
    11using System;
    2 using System.Collections.Generic;
    3 using System.Diagnostics;
    4 using System.Linq;
    5 using System.Text;
    6 using System.Threading.Tasks;
    72using HeuristicLab.Common;
    83
    94namespace HeuristicLab.Algorithms.Bandits.Models {
    10   // bayesian estimation of a Gaussian with unknown mean and known variance
     5  // bayesian estimation of a Gaussian with
     6  // 1) unknown mean and known variance
     7  // 2) unknown mean and unknown variance
    118  public class GaussianModel : IModel {
    12     private readonly int numActions;
    13     private readonly int[] tries;
    14     private readonly double[] sumRewards;
    15 
     9    private OnlineMeanAndVarianceEstimator estimator = new OnlineMeanAndVarianceEstimator();
    1610
    1711    // parameters of Gaussian prior for mean
     
    1913    private readonly double meanPriorVariance;
    2014
     15    private readonly bool knownVariance;
    2116    private readonly double rewardVariance = 0.1; // assumed know reward variance
    2217
    23     public GaussianModel(int numActions, double meanPriorMu, double meanPriorVariance) {
    24       this.numActions = numActions;
    25       this.tries = new int[numActions];
    26       this.sumRewards = new double[numActions];
     18    // parameters of Gamma prior for precision (= inverse variance)
     19    private readonly int precisionPriorAlpha;
     20    private readonly double precisionPriorBeta;
     21
     22    // non-informative prior
     23    private const double priorK = 1.0;
     24
     25    // this constructor assumes the variance is known
     26    public GaussianModel(double meanPriorMu, double meanPriorVariance, double rewardVariance = 0.1) {
    2727      this.meanPriorMu = meanPriorMu;
    2828      this.meanPriorVariance = meanPriorVariance;
     29
     30      this.knownVariance = true;
     31      this.rewardVariance = rewardVariance;
     32    }
     33
     34    // this constructor assumes the variance is also unknown
     35    // uses Murphy 2007: Conjugate Bayesian analysis of the Gaussian distribution equation 85 - 89
     36    public GaussianModel(double meanPriorMu, double meanPriorVariance, int precisionPriorAlpha, double precisionPriorBeta) {
     37      this.meanPriorMu = meanPriorMu;
     38      this.meanPriorVariance = meanPriorVariance;
     39
     40      this.knownVariance = false;
     41      this.precisionPriorAlpha = precisionPriorAlpha;
     42      this.precisionPriorBeta = precisionPriorBeta;
    2943    }
    3044
    3145
    32     public double[] SampleExpectedRewards(Random random) {
     46    public double SampleExpectedReward(Random random) {
     47      if (knownVariance) {
     48        return SampleExpectedRewardKnownVariance(random);
     49      } else {
     50        return SampleExpectedRewardUnknownVariance(random);
     51      }
     52    }
     53
     54    private double SampleExpectedRewardKnownVariance(Random random) {
    3355      // expected values for reward
    34       var theta = new double[numActions];
     56      // calculate posterior mean and variance (for mean reward)
    3557
    36       for (int a = 0; a < numActions; a++) {
    37         if (tries[a] == -1) {
    38           theta[a] = double.NegativeInfinity; // disabled action
    39         } else {
    40           // calculate posterior mean and variance (for mean reward)
     58      // see Murphy 2007: Conjugate Bayesian analysis of the Gaussian distribution (http://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf)
     59      var posteriorMeanVariance = 1.0 / (estimator.N / rewardVariance + 1.0 / meanPriorVariance);
     60      var posteriorMeanMean = posteriorMeanVariance * (meanPriorMu / meanPriorVariance + estimator.Sum / rewardVariance);
    4161
    42           // see Murphy 2007: Conjugate Bayesian analysis of the Gaussian distribution (http://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf)
    43           var posteriorVariance = 1.0 / (tries[a] / rewardVariance + 1.0 / meanPriorVariance);
    44           var posteriorMean = posteriorVariance * (meanPriorMu / meanPriorVariance + sumRewards[a] / rewardVariance);
     62      // sample a mean from the posterior
     63      var posteriorMeanSample = Rand.RandNormal(random) * Math.Sqrt(posteriorMeanVariance) + posteriorMeanMean;
     64      // theta already represents the expected reward value => nothing else to do
     65      return posteriorMeanSample;
    4566
    46           // sample a mean from the posterior
    47           theta[a] = Rand.RandNormal(random) * Math.Sqrt(posteriorVariance) + posteriorMean;
    48           // theta already represents the expected reward value => nothing else to do
    49         }
     67      // return 0.99-quantile value
     68      //return alglib.invnormaldistribution(0.99) * Math.Sqrt(rewardVariance + posteriorMeanVariance) + posteriorMeanMean;
     69    }
     70
     71    // see Murphy 2007: Conjugate Bayesian analysis of the Gaussian distribution page 6 onwards (http://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf)
     72    private double SampleExpectedRewardUnknownVariance(Random random) {
     73
     74      var posteriorMean = (priorK * meanPriorMu + estimator.Sum) / (priorK + estimator.N);
     75      var posteriorK = priorK + estimator.N;
     76      var posteriorAlpha = precisionPriorAlpha + estimator.N / 2.0;
     77      double posteriorBeta;
     78      if (estimator.N > 0) {
     79        posteriorBeta = precisionPriorBeta + 0.5 * estimator.N * estimator.Variance + priorK * estimator.N * Math.Pow(estimator.Avg - meanPriorMu, 2) / (2.0 * (priorK + estimator.N));
     80      } else {
     81        posteriorBeta = precisionPriorBeta;
    5082      }
    5183
     84      // sample from the posterior marginal for mu (expected value) equ. 91
     85      // p(µ|D) = T2αn (µ| µn, βn/(αnκn))
     86
     87      // sample from Tk distribution : http://stats.stackexchange.com/a/70270
     88      var t2alpha = alglib.invstudenttdistribution((int)(2 * posteriorAlpha), random.NextDouble());
     89
     90      var theta = t2alpha * posteriorBeta / (posteriorAlpha * posteriorK) + posteriorMean;
    5291      return theta;
     92
     93      //return alglib.invnormaldistribution(random.NextDouble()) * + theta;
     94      //return alglib.invstudenttdistribution((int)(2 * posteriorAlpha), 0.99) * (posteriorBeta*posteriorK + posteriorBeta) / (posteriorAlpha*posteriorK) + posteriorMean;
    5395    }
    5496
    55     public void Update(int action, double reward) {
    56       sumRewards[action] += reward;
    57       tries[action]++;
    58     }
    5997
    60     public void Disable(int action) {
    61       tries[action] = -1;
    62       sumRewards[action] = 0.0;
     98    public void Update(double reward) {
     99      estimator.UpdateReward(reward);
    63100    }
    64101
    65102    public void Reset() {
    66       Array.Clear(tries, 0, numActions);
    67       Array.Clear(sumRewards, 0, numActions);
     103      estimator.Reset();
    68104    }
    69105
    70106    public void PrintStats() {
    71       for (int i = 0; i < numActions; i++) {
    72         Console.Write("{0:F2} ", sumRewards[i] / (double)tries[i]);
    73       }
     107      Console.Write("{0:F2} ", estimator.Avg);
     108    }
     109
     110    public object Clone() {
     111      if (knownVariance)
     112        return new GaussianModel(meanPriorMu, meanPriorVariance, rewardVariance);
     113      else
     114        return new GaussianModel(meanPriorMu, meanPriorVariance, precisionPriorAlpha, precisionPriorBeta);
    74115    }
    75116  }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Models/IModel.cs

    r11730 r11732  
    66
    77namespace HeuristicLab.Algorithms.Bandits {
    8   public interface IModel {
    9     double[] SampleExpectedRewards(Random random);
    10     void Update(int action, double reward);
    11     void Disable(int action);
     8  // represents a model for the reward distribution (of an action given a state)
     9  public interface IModel : ICloneable {
     10    double SampleExpectedReward(Random random);
     11    void Update(double reward);
    1212    void Reset();
    1313    void PrintStats();
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/BanditPolicy.cs

    r11730 r11732  
    77
    88namespace HeuristicLab.Algorithms.Bandits {
    9   public abstract class BanditPolicy : IPolicy {
     9  public abstract class BanditPolicy<TPolicyActionInfo> : IPolicy<TPolicyActionInfo> where TPolicyActionInfo : IPolicyActionInfo {
    1010    public IEnumerable<int> Actions { get; private set; }
    1111    private readonly int numInitialActions;
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/BernoulliThompsonSamplingPolicy.cs

    r11730 r11732  
    88
    99namespace HeuristicLab.Algorithms.Bandits {
    10   public class BernoulliThompsonSamplingPolicy : BanditPolicy {
    11     private readonly Random random;
    12     private readonly int[] success;
    13     private readonly int[] failure;
    14 
     10  public class BernoulliThompsonSamplingPolicy : IPolicy {
    1511    // parameters of beta prior distribution
    1612    private readonly double alpha = 1.0;
    1713    private readonly double beta = 1.0;
    1814
    19     public BernoulliThompsonSamplingPolicy(Random random, int numActions)
    20       : base(numActions) {
    21       this.random = random;
    22       this.success = new int[numActions];
    23       this.failure = new int[numActions];
    24     }
     15    public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) {
     16      var myActionInfos = actionInfos.OfType<BernoulliPolicyActionInfo>(); // TODO: performance
     17      int bestAction = -1;
     18      double maxTheta = double.NegativeInfinity;
     19      var aIdx = -1;
    2520
    26     public override int SelectAction() {
    27       Debug.Assert(Actions.Any());
    28       var maxTheta = double.NegativeInfinity;
    29       int bestAction = -1;
    30       foreach (var a in Actions) {
    31         var theta = Rand.BetaRand(random, success[a] + alpha, failure[a] + beta);
     21      foreach (var aInfo in myActionInfos) {
     22        aIdx++;
     23        if (aInfo.Disabled) continue;
     24        var theta = Rand.BetaRand(random, aInfo.NumSuccess + alpha, aInfo.NumFailure + beta);
    3225        if (theta > maxTheta) {
    3326          maxTheta = theta;
    34           bestAction = a;
     27          bestAction = aIdx;
    3528        }
    3629      }
     30      Debug.Assert(bestAction > -1);
    3731      return bestAction;
    3832    }
    3933
    40     public override void UpdateReward(int action, double reward) {
    41       Debug.Assert(Actions.Contains(action));
    42 
    43       if (reward > 0) success[action]++;
    44       else failure[action]++;
     34    public IPolicyActionInfo CreateActionInfo() {
     35      return new BernoulliPolicyActionInfo();
    4536    }
    4637
    47     public override void DisableAction(int action) {
    48       base.DisableAction(action);
    49       success[action] = -1;
    50     }
    51 
    52     public override void Reset() {
    53       base.Reset();
    54       Array.Clear(success, 0, success.Length);
    55       Array.Clear(failure, 0, failure.Length);
    56     }
    57 
    58     public override void PrintStats() {
    59       for (int i = 0; i < success.Length; i++) {
    60         if (success[i] >= 0) {
    61           Console.Write("{0,5:F2}", success[i] / failure[i]);
    62         } else {
    63           Console.Write("{0,5}", "");
    64         }
    65       }
    66       Console.WriteLine();
    67     }
    6838
    6939    public override string ToString() {
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/BoltzmannExplorationPolicy.cs

    r11730 r11732  
    55using System.Text;
    66using System.Threading.Tasks;
     7using HeuristicLab.Common;
    78
    89namespace HeuristicLab.Algorithms.Bandits {
    910  // also called softmax policy
    10   public class BoltzmannExplorationPolicy : BanditPolicy {
    11     private readonly Random random;
    12     private readonly double eps;
    13     private readonly int[] tries;
    14     private readonly double[] sumReward;
     11  public class BoltzmannExplorationPolicy : IPolicy {
    1512    private readonly double beta;
    1613
    17     public BoltzmannExplorationPolicy(Random random, int numActions, double beta)
    18       : base(numActions) {
     14    public BoltzmannExplorationPolicy(double beta) {
    1915      if (beta < 0) throw new ArgumentException();
    20       this.random = random;
    2116      this.beta = beta;
    22       this.tries = new int[numActions];
    23       this.sumReward = new double[numActions];
     17    }
     18    public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) {
     19      Debug.Assert(actionInfos.Any());
     20
     21      // select best
     22      var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>().ToArray(); // TODO: performance
     23      Debug.Assert(myActionInfos.Any(a => !a.Disabled));
     24      double[] w = new double[myActionInfos.Length];
     25
     26      for (int a = 0; a < myActionInfos.Length; a++) {
     27        if (myActionInfos[a].Disabled) {
     28          w[a] = 0; continue;
     29        }
     30        if (myActionInfos[a].Tries == 0) return a;
     31        var sumReward = myActionInfos[a].SumReward;
     32        var tries = myActionInfos[a].Tries;
     33        var avgReward = sumReward / tries;
     34        w[a] = Math.Exp(beta * avgReward);
     35      }
     36
     37
     38      var bestAction = Enumerable.Range(0, w.Length).SampleProportional(random, w).First();
     39      Debug.Assert(bestAction >= 0);
     40      Debug.Assert(bestAction < w.Length);
     41      Debug.Assert(!myActionInfos[bestAction].Disabled);
     42      return bestAction;
    2443    }
    2544
    26     public override int SelectAction() {
    27       Debug.Assert(Actions.Any());
    28       // select best
    29       var maxReward = double.NegativeInfinity;
    30       int bestAction = -1;
    31       if (Actions.Any(a => tries[a] == 0))
    32         return Actions.First(a => tries[a] == 0);
    33 
    34       var ts = Actions.Select(a => Math.Exp(beta * sumReward[a] / tries[a]));
    35       var r = random.NextDouble() * ts.Sum();
    36 
    37       var agg = 0.0;
    38       foreach (var p in Actions.Zip(ts, Tuple.Create)) {
    39         agg += p.Item2;
    40         if (agg >= r) return p.Item1;
    41       }
    42       throw new InvalidProgramException();
    43     }
    44     public override void UpdateReward(int action, double reward) {
    45       Debug.Assert(Actions.Contains(action));
    46 
    47       tries[action]++;
    48       sumReward[action] += reward;
    49     }
    50 
    51     public override void DisableAction(int action) {
    52       base.DisableAction(action);
    53       sumReward[action] = 0;
    54       tries[action] = -1;
    55     }
    56 
    57     public override void Reset() {
    58       base.Reset();
    59       Array.Clear(tries, 0, tries.Length);
    60       Array.Clear(sumReward, 0, sumReward.Length);
    61     }
    62 
    63     public override void PrintStats() {
    64       for (int i = 0; i < sumReward.Length; i++) {
    65         if (tries[i] >= 0) {
    66           Console.Write("{0,5:F2}", sumReward[i] / tries[i]);
    67         } else {
    68           Console.Write("{0,5}", "");
    69         }
    70       }
    71       Console.WriteLine();
     45    public IPolicyActionInfo CreateActionInfo() {
     46      return new DefaultPolicyActionInfo();
    7247    }
    7348
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/ChernoffIntervalEstimationPolicy.cs

    r11730 r11732  
    1010International Conference, CP 2006, Nantes, France, September 25-29, 2006. pp 560-574 */
    1111
    12   public class ChernoffIntervalEstimationPolicy : BanditPolicy {
    13     private readonly int[] tries;
    14     private readonly double[] sumReward;
    15     private int totalTries = 0;
     12  public class ChernoffIntervalEstimationPolicy : IPolicy {
    1613    private readonly double delta;
    1714
    18     public ChernoffIntervalEstimationPolicy(int numActions, double delta = 0.01)
    19       : base(numActions) {
     15    public ChernoffIntervalEstimationPolicy(double delta = 0.01) {
    2016      this.delta = delta;
    21       this.tries = new int[numActions];
    22       this.sumReward = new double[numActions];
    2317    }
    24 
    25     public override int SelectAction() {
     18    public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) {
     19      Debug.Assert(actionInfos.Any());
     20      // select best
     21      var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>().ToArray(); // TODO: performance
     22      int k = myActionInfos.Length;
     23      int totalTries = myActionInfos.Where(a => !a.Disabled).Sum(a => a.Tries);
    2624      int bestAction = -1;
    2725      double bestQ = double.NegativeInfinity;
    28       double k = Actions.Count();
    29       Debug.Assert(k > 0);
    30       foreach (var a in Actions) {
    31         if (tries[a] == 0) return a;
     26      for (int a = 0; a < myActionInfos.Length; a++) {
     27        if (myActionInfos[a].Disabled) continue;
     28        if (myActionInfos[a].Tries == 0) return a;
     29
     30        var sumReward = myActionInfos[a].SumReward;
     31        var tries = myActionInfos[a].Tries;
     32
     33        var avgReward = sumReward / tries;
     34
    3235        // page 5 of "A simple distribution-free appraoch to the max k-armed bandit problem"
    3336        // var alpha = Math.Log(2 * totalTries * k / delta);
    3437        double alpha = Math.Log(2) + Math.Log(totalTries) + Math.Log(k) - Math.Log(delta); // total tries is max tries in the original paper
    35         double mu = sumReward[a] / tries[a];
    36         var q = mu + (alpha + Math.Sqrt(2 * tries[a] * mu * alpha + alpha * alpha)) / tries[a];
     38        var q = avgReward + (alpha + Math.Sqrt(2 * tries * avgReward * alpha + alpha * alpha)) / tries;
    3739        if (q > bestQ) {
    3840          bestQ = q;
     
    4042        }
    4143      }
     44      Debug.Assert(bestAction >= 0);
    4245      return bestAction;
    4346    }
    44     public override void UpdateReward(int action, double reward) {
    45       Debug.Assert(Actions.Contains(action));
    46       totalTries++;
    47       tries[action]++;
    48       sumReward[action] += reward;
     47
     48    public IPolicyActionInfo CreateActionInfo() {
     49      return new DefaultPolicyActionInfo();
    4950    }
    5051
    51     public override void DisableAction(int action) {
    52       base.DisableAction(action);
    53       totalTries -= tries[action];
    54       tries[action] = -1;
    55       sumReward[action] = 0;
    56     }
    57 
    58     public override void Reset() {
    59       base.Reset();
    60       totalTries = 0;
    61       Array.Clear(tries, 0, tries.Length);
    62       Array.Clear(sumReward, 0, sumReward.Length);
    63     }
    64 
    65     public override void PrintStats() {
    66       for (int i = 0; i < sumReward.Length; i++) {
    67         if (tries[i] >= 0) {
    68           Console.Write("{0,5:F2}", sumReward[i] / tries[i]);
    69         } else {
    70           Console.Write("{0,5}", "");
    71         }
    72       }
    73       Console.WriteLine();
    74     }
    7552    public override string ToString() {
    7653      return string.Format("ChernoffIntervalEstimationPolicy({0:F2})", delta);
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/EpsGreedyPolicy.cs

    r11730 r11732  
    77
    88namespace HeuristicLab.Algorithms.Bandits {
    9   public class EpsGreedyPolicy : BanditPolicy {
    10     private readonly Random random;
     9  public class EpsGreedyPolicy : IPolicy {
    1110    private readonly double eps;
    12     private readonly int[] tries;
    13     private readonly double[] sumReward;
    1411    private readonly RandomPolicy randomPolicy;
    1512
    16     public EpsGreedyPolicy(Random random, int numActions, double eps)
    17       : base(numActions) {
    18       this.random = random;
     13    public EpsGreedyPolicy(double eps) {
    1914      this.eps = eps;
    20       this.randomPolicy = new RandomPolicy(random, numActions);
    21       this.tries = new int[numActions];
    22       this.sumReward = new double[numActions];
     15      this.randomPolicy = new RandomPolicy();
    2316    }
    24 
    25     public override int SelectAction() {
    26       Debug.Assert(Actions.Any());
     17    public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) {
     18      Debug.Assert(actionInfos.Any());
    2719      if (random.NextDouble() > eps) {
    2820        // select best
    29         var bestQ = double.NegativeInfinity;
     21        var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>();
    3022        int bestAction = -1;
    31         foreach (var a in Actions) {
    32           if (tries[a] == 0) return a;
    33           var q = sumReward[a] / tries[a];
    34           if (bestQ < q) {
     23        double bestQ = double.NegativeInfinity;
     24        int aIdx = -1;
     25        foreach (var aInfo in myActionInfos) {
     26
     27          aIdx++;
     28          if (aInfo.Disabled) continue;
     29          if (aInfo.Tries == 0) return aIdx;
     30
     31
     32          var avgReward = aInfo.SumReward / aInfo.Tries;         
     33          //var q = avgReward;
     34          var q = aInfo.MaxReward;
     35          if (q > bestQ) {
    3536            bestQ = q;
    36             bestAction = a;
     37            bestAction = aIdx;
    3738          }
    3839        }
     
    4142      } else {
    4243        // select random
    43         return randomPolicy.SelectAction();
     44        return randomPolicy.SelectAction(random, actionInfos);
    4445      }
    4546    }
    46     public override void UpdateReward(int action, double reward) {
    47       Debug.Assert(Actions.Contains(action));
    4847
    49       randomPolicy.UpdateReward(action, reward); // does nothing
    50       tries[action]++;
    51       sumReward[action] += reward;
     48    public IPolicyActionInfo CreateActionInfo() {
     49      return new DefaultPolicyActionInfo();
    5250    }
    5351
    54     public override void DisableAction(int action) {
    55       base.DisableAction(action);
    56       randomPolicy.DisableAction(action);
    57       sumReward[action] = 0;
    58       tries[action] = -1;
    59     }
    6052
    61     public override void Reset() {
    62       base.Reset();
    63       randomPolicy.Reset();
    64       Array.Clear(tries, 0, tries.Length);
    65       Array.Clear(sumReward, 0, sumReward.Length);
    66     }
    67     public override void PrintStats() {
    68       for (int i = 0; i < sumReward.Length; i++) {
    69         if (tries[i] >= 0) {
    70           Console.Write(" {0,5:F2} {1}", sumReward[i] / tries[i], tries[i]);
    71         } else {
    72           Console.Write("-", "");
    73         }
    74       }
    75       Console.WriteLine();
    76     }
    7753    public override string ToString() {
    7854      return string.Format("EpsGreedyPolicy({0:F2})", eps);
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/GaussianThompsonSamplingPolicy.cs

    r11730 r11732  
    11using System;
     2using System.Collections.Generic;
    23using System.Diagnostics;
    34using System.Linq;
     
    56
    67namespace HeuristicLab.Algorithms.Bandits {
    7  
    8   public class GaussianThompsonSamplingPolicy : BanditPolicy {
    9     private readonly Random random;
    10     private readonly double[] sampleMean;
    11     private readonly double[] sampleM2;
    12     private readonly int[] tries;
     8
     9  public class GaussianThompsonSamplingPolicy : IPolicy {
    1310    private bool compatibility;
    1411
     
    2118
    2219
    23     public GaussianThompsonSamplingPolicy(Random random, int numActions, bool compatibility = false)
    24       : base(numActions) {
    25       this.random = random;
    26       this.sampleMean = new double[numActions];
    27       this.sampleM2 = new double[numActions];
    28       this.tries = new int[numActions];
     20    public GaussianThompsonSamplingPolicy(bool compatibility = false) {
    2921      this.compatibility = compatibility;
    3022    }
    3123
     24    public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) {
     25      var myActionInfos = actionInfos.OfType<MeanAndVariancePolicyActionInfo>();
     26      int bestAction = -1;
     27      double bestQ = double.NegativeInfinity;
    3228
    33     public override int SelectAction() {
    34       Debug.Assert(Actions.Any());
    35       var maxTheta = double.NegativeInfinity;
    36       int bestAction = -1;
    37       foreach (var a in Actions) {
    38         if(tries[a] == -1) continue; // skip disabled actions
     29      int aIdx = -1;
     30      foreach (var aInfo in myActionInfos) {
     31        aIdx++;
     32        if (aInfo.Disabled) continue;
     33
     34        var tries = aInfo.Tries;
     35        var sampleMean = aInfo.AvgReward;
     36        var sampleVariance = aInfo.RewardVariance;
     37
    3938        double theta;
    4039        if (compatibility) {
    41           if (tries[a] < 2) return a;
    42           var mu = sampleMean[a];
    43           var variance = sampleM2[a] / tries[a];
     40          if (tries < 2) return aIdx;
     41          var mu = sampleMean;
     42          var variance = sampleVariance;
    4443          var stdDev = Math.Sqrt(variance);
    4544          theta = Rand.RandNormal(random) * stdDev + mu;
     
    4847
    4948          // see Murphy 2007: Conjugate Bayesian analysis of the Gaussian distribution (http://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf)
    50           var posteriorVariance = 1.0 / (tries[a] / rewardVariance + 1.0 / priorVariance);
    51           var posteriorMean = posteriorVariance * (priorMean / priorVariance + tries[a] * sampleMean[a] / rewardVariance);
     49          var posteriorVariance = 1.0 / (tries / rewardVariance + 1.0 / priorVariance);
     50          var posteriorMean = posteriorVariance * (priorMean / priorVariance + tries * sampleMean / rewardVariance);
    5251
    5352          // sample a mean from the posterior
     
    5655          // theta already represents the expected reward value => nothing else to do
    5756        }
    58         if (theta > maxTheta) {
    59           maxTheta = theta;
    60           bestAction = a;
     57
     58        if (theta > bestQ) {
     59          bestQ = theta;
     60          bestAction = aIdx;
    6161        }
    6262      }
    63       Debug.Assert(Actions.Contains(bestAction));
     63      Debug.Assert(bestAction > -1);
    6464      return bestAction;
    6565    }
    6666
    67     public override void UpdateReward(int action, double reward) {
    68       Debug.Assert(Actions.Contains(action));
    69       tries[action]++;
    70       var delta = reward - sampleMean[action];
    71       sampleMean[action] += delta / tries[action];
    72       sampleM2[action] += sampleM2[action] + delta * (reward - sampleMean[action]);
     67    public IPolicyActionInfo CreateActionInfo() {
     68      return new MeanAndVariancePolicyActionInfo();
    7369    }
    7470
    75     public override void DisableAction(int action) {
    76       base.DisableAction(action);
    77       sampleMean[action] = 0;
    78       sampleM2[action] = 0;
    79       tries[action] = -1;
    80     }
    8171
    82     public override void Reset() {
    83       base.Reset();
    84       Array.Clear(sampleMean, 0, sampleMean.Length);
    85       Array.Clear(sampleM2, 0, sampleM2.Length);
    86       Array.Clear(tries, 0, tries.Length);
    87     }
     72    //public override void UpdateReward(int action, double reward) {
     73    //  Debug.Assert(Actions.Contains(action));
     74    //  tries[action]++;
     75    //  var delta = reward - sampleMean[action];
     76    //  sampleMean[action] += delta / tries[action];
     77    //  sampleM2[action] += sampleM2[action] + delta * (reward - sampleMean[action]);
     78    //}
    8879
    89     public override void PrintStats() {
    90       for (int i = 0; i < sampleMean.Length; i++) {
    91         if (tries[i] >= 0) {
    92           Console.Write(" {0,5:F2} {1}", sampleMean[i] / tries[i], tries[i]);
    93         } else {
    94           Console.Write("{0,5}", "");
    95         }
    96       }
    97       Console.WriteLine();
    98     }
    9980    public override string ToString() {
    10081      return "GaussianThompsonSamplingPolicy";
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/GenericThompsonSamplingPolicy.cs

    r11730 r11732  
    88
    99namespace HeuristicLab.Algorithms.Bandits {
    10   public class GenericThompsonSamplingPolicy : BanditPolicy {
    11     private readonly Random random;
     10  public class GenericThompsonSamplingPolicy : IPolicy {
    1211    private readonly IModel model;
    1312
    14     public GenericThompsonSamplingPolicy(Random random, int numActions, IModel model)
    15       : base(numActions) {
    16       this.random = random;
     13    public GenericThompsonSamplingPolicy(IModel model) {
    1714      this.model = model;
    1815    }
    1916
    20     public override int SelectAction() {
    21       Debug.Assert(Actions.Any());
    22       var maxR = double.NegativeInfinity;
     17    public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) {
     18      var myActionInfos = actionInfos.OfType<ModelPolicyActionInfo>();
    2319      int bestAction = -1;
    24       var expRewards = model.SampleExpectedRewards(random);
    25       foreach (var a in Actions) {
    26         var r = expRewards[a];
    27         if (r > maxR) {
    28           maxR = r;
    29           bestAction = a;
     20      double bestQ = double.NegativeInfinity;
     21      var aIdx = -1;
     22      foreach (var aInfo in myActionInfos) {
     23        aIdx++;
     24        if (aInfo.Disabled) continue;
     25        //if (aInfo.Tries == 0) return aIdx;
     26        var q = aInfo.SampleExpectedReward(random);
     27        if (q > bestQ) {
     28          bestQ = q;
     29          bestAction = aIdx;
    3030        }
    3131      }
     32      Debug.Assert(bestAction > -1);
    3233      return bestAction;
    3334    }
    3435
    35     public override void UpdateReward(int action, double reward) {
    36       Debug.Assert(Actions.Contains(action));
    37 
    38       model.Update(action, reward);
    39     }
    40 
    41     public override void DisableAction(int action) {
    42       base.DisableAction(action);
    43       model.Disable(action);
    44     }
    45 
    46     public override void Reset() {
    47       base.Reset();
    48       model.Reset();
    49     }
    50 
    51     public override void PrintStats() {
    52       model.PrintStats();
     36    public IPolicyActionInfo CreateActionInfo() {
     37      return new ModelPolicyActionInfo((IModel)model.Clone());
    5338    }
    5439
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/RandomPolicy.cs

    r11730 r11732  
    88
    99namespace HeuristicLab.Algorithms.Bandits {
    10   public class RandomPolicy : BanditPolicy {
    11     private readonly Random random;
     10  public class RandomPolicy : IPolicy {
    1211
    13     public RandomPolicy(Random random, int numActions)
    14       : base(numActions) {
    15       this.random = random;
    16     }
    17 
    18     public override int SelectAction() {
    19       Debug.Assert(Actions.Any());
    20       return Actions.SelectRandom(random);
    21     }
    22     public override void UpdateReward(int action, double reward) {
    23       // do nothing
    24     }
    25     public override void PrintStats() {
    26       Console.WriteLine("Random");
    27     }
    2812    public override string ToString() {
    2913      return "RandomPolicy";
    3014    }
     15
     16    public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) {
     17      return actionInfos
     18        .Select((a, i) => Tuple.Create(a, i))
     19        .Where(p => !p.Item1.Disabled)
     20        .SelectRandom(random).Item2;
     21    }
     22
     23    public IPolicyActionInfo CreateActionInfo() {
     24      return new EmptyPolicyActionInfo();
     25    }
    3126  }
    3227}
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCB1Policy.cs

    r11730 r11732  
    77
    88namespace HeuristicLab.Algorithms.Bandits {
    9   public class UCB1Policy : BanditPolicy {
    10     private readonly int[] tries;
    11     private readonly double[] sumReward;
    12     private int totalTries = 0;
    13     public UCB1Policy(int numActions)
    14       : base(numActions) {
    15       this.tries = new int[numActions];
    16       this.sumReward = new double[numActions];
    17     }
    18 
    19     public override int SelectAction() {
     9  public class UCB1Policy : IPolicy {
     10    public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) {
     11      var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>().ToArray(); // TODO: performance
    2012      int bestAction = -1;
    2113      double bestQ = double.NegativeInfinity;
    22       foreach (var a in Actions) {
    23         if (tries[a] == 0) return a;
    24         var q = sumReward[a] / tries[a] + Math.Sqrt((2 * Math.Log(totalTries)) / tries[a]);
     14      int totalTries = myActionInfos.Where(a => !a.Disabled).Sum(a => a.Tries);
     15
     16      for (int a = 0; a < myActionInfos.Length; a++) {
     17        if (myActionInfos[a].Disabled) continue;
     18        if (myActionInfos[a].Tries == 0) return a;
     19        var q = myActionInfos[a].SumReward / myActionInfos[a].Tries + Math.Sqrt((2 * Math.Log(totalTries)) / myActionInfos[a].Tries);
    2520        if (q > bestQ) {
    2621          bestQ = q;
     
    2823        }
    2924      }
     25      Debug.Assert(bestAction > -1);
    3026      return bestAction;
    3127    }
    32     public override void UpdateReward(int action, double reward) {
    33       Debug.Assert(Actions.Contains(action));
    34       totalTries++;
    35       tries[action]++;
    36       sumReward[action] += reward;
    37     }
    3828
    39     public override void DisableAction(int action) {
    40       base.DisableAction(action);
    41       totalTries -= tries[action];
    42       tries[action] = -1;
    43       sumReward[action] = 0;
    44     }
    45 
    46     public override void Reset() {
    47       base.Reset();
    48       totalTries = 0;
    49       Array.Clear(tries, 0, tries.Length);
    50       Array.Clear(sumReward, 0, sumReward.Length);
    51     }
    52     public override void PrintStats() {
    53       for (int i = 0; i < sumReward.Length; i++) {
    54         if (tries[i] >= 0) {
    55           Console.Write("{0,5:F2}", sumReward[i] / tries[i]);
    56         } else {
    57           Console.Write("{0,5}", "");
    58         }
    59       }
    60       Console.WriteLine();
     29    public IPolicyActionInfo CreateActionInfo() {
     30      return new DefaultPolicyActionInfo();
    6131    }
    6232    public override string ToString() {
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCB1TunedPolicy.cs

    r11730 r11732  
    77
    88namespace HeuristicLab.Algorithms.Bandits {
    9   public class UCB1TunedPolicy : BanditPolicy {
    10     private readonly int[] tries;
    11     private readonly double[] sumReward;
    12     private readonly double[] sumSqrReward;
    13     private int totalTries = 0;
    14     public UCB1TunedPolicy(int numActions)
    15       : base(numActions) {
    16       this.tries = new int[numActions];
    17       this.sumReward = new double[numActions];
    18       this.sumSqrReward = new double[numActions];
    19     }
     9  public class UCB1TunedPolicy : IPolicy {
    2010
    21     private double V(int arm) {
    22       var s = tries[arm];
    23       return sumSqrReward[arm] / s - Math.Pow(sumReward[arm] / s, 2) + Math.Sqrt(2 * Math.Log(totalTries) / s);
    24     }
    25 
    26 
    27     public override int SelectAction() {
    28       Debug.Assert(Actions.Any());
     11    public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) {
     12      var myActionInfos = actionInfos.OfType<MeanAndVariancePolicyActionInfo>().ToArray(); // TODO: performance
    2913      int bestAction = -1;
    3014      double bestQ = double.NegativeInfinity;
    31       foreach (var a in Actions) {
    32         if (tries[a] == 0) return a;
    33         var q = sumReward[a] / tries[a] + Math.Sqrt((Math.Log(totalTries) / tries[a]) * Math.Min(1.0 / 4, V(a))); // 1/4 is upper bound of bernoulli distributed variable
     15      int totalTries = myActionInfos.Where(a => !a.Disabled).Sum(a => a.Tries);
     16
     17      for (int a = 0; a < myActionInfos.Length; a++) {
     18        if (myActionInfos[a].Disabled) continue;
     19        if (myActionInfos[a].Tries == 0) return a;
     20
     21        var sumReward = myActionInfos[a].SumReward;
     22        var tries = myActionInfos[a].Tries;
     23
     24        var avgReward = sumReward / tries;
     25        var q = avgReward + Math.Sqrt((Math.Log(totalTries) / tries) * Math.Min(1.0 / 4, V(myActionInfos[a], totalTries))); // 1/4 is upper bound of bernoulli distributed variable
    3426        if (q > bestQ) {
    3527          bestQ = q;
     
    3729        }
    3830      }
     31      Debug.Assert(bestAction > -1);
    3932      return bestAction;
    4033    }
    41     public override void UpdateReward(int action, double reward) {
    42       Debug.Assert(Actions.Contains(action));
    43       totalTries++;
    44       tries[action]++;
    45       sumReward[action] += reward;
    46       sumSqrReward[action] += reward * reward;
     34
     35    public IPolicyActionInfo CreateActionInfo() {
     36      return new MeanAndVariancePolicyActionInfo();
    4737    }
    4838
    49     public override void DisableAction(int action) {
    50       base.DisableAction(action);
    51       totalTries -= tries[action];
    52       tries[action] = -1;
    53       sumReward[action] = 0;
    54       sumSqrReward[action] = 0;
     39    private double V(MeanAndVariancePolicyActionInfo actionInfo, int totalTries) {
     40      var s = actionInfo.Tries;
     41      return actionInfo.RewardVariance + Math.Sqrt(2 * Math.Log(totalTries) / s);
    5542    }
    5643
    57     public override void Reset() {
    58       base.Reset();
    59       totalTries = 0;
    60       Array.Clear(tries, 0, tries.Length);
    61       Array.Clear(sumReward, 0, sumReward.Length);
    62       Array.Clear(sumSqrReward, 0, sumSqrReward.Length);
    63     }
    64     public override void PrintStats() {
    65       for (int i = 0; i < sumReward.Length; i++) {
    66         if (tries[i] >= 0) {
    67           Console.Write("{0,5:F2}", sumReward[i] / tries[i]);
    68         } else {
    69           Console.Write("{0,5}", "");
    70         }
    71       }
    72       Console.WriteLine();
    73     }
    7444    public override string ToString() {
    7545      return "UCB1TunedPolicy";
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCBNormalPolicy.cs

    r11730 r11732  
    77
    88namespace HeuristicLab.Algorithms.Bandits {
    9   public class UCBNormalPolicy : BanditPolicy {
    10     private readonly int[] tries;
    11     private readonly double[] sumReward;
    12     private readonly double[] sumSqrReward;
    13     private int totalTries = 0;
    14     public UCBNormalPolicy(int numActions)
    15       : base(numActions) {
    16       this.tries = new int[numActions];
    17       this.sumReward = new double[numActions];
    18       this.sumSqrReward = new double[numActions];
    19     }
     9  public class UCBNormalPolicy : IPolicy {
    2010
    21     public override int SelectAction() {
    22       Debug.Assert(Actions.Any());
     11    public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) {
     12      var myActionInfos = actionInfos.OfType<MeanAndVariancePolicyActionInfo>().ToArray(); // TODO: performance
    2313      int bestAction = -1;
    2414      double bestQ = double.NegativeInfinity;
    25       foreach (var a in Actions) {
    26         if (totalTries <= 1 || tries[a] <= 1 || tries[a] <= Math.Ceiling(8 * Math.Log(totalTries))) return a;
    27         var avgReward = sumReward[a] / tries[a];
    28         var estVariance = 16 * ((sumSqrReward[a] - tries[a] * Math.Pow(avgReward, 2)) / (tries[a] - 1)) * (Math.Log(totalTries - 1) / tries[a]);
     15      int totalTries = myActionInfos.Where(a => !a.Disabled).Sum(a => a.Tries);
     16
     17      for (int a = 0; a < myActionInfos.Length; a++) {
     18        if (myActionInfos[a].Disabled) continue;
     19        if (totalTries <= 1 || myActionInfos[a].Tries <= 1 || myActionInfos[a].Tries <= Math.Ceiling(8 * Math.Log(totalTries))) return a;
     20
     21        var tries = myActionInfos[a].Tries;
     22        var avgReward = myActionInfos[a].AvgReward;
     23        var rewardVariance = myActionInfos[a].RewardVariance;
     24        var estVariance = 16 * rewardVariance * (Math.Log(totalTries - 1) / tries);
    2925        if (estVariance < 0) estVariance = 0; // numerical problems
    3026        var q = avgReward
     
    3531        }
    3632      }
    37       Debug.Assert(Actions.Contains(bestAction));
     33      Debug.Assert(bestAction > -1);
    3834      return bestAction;
    3935    }
    40     public override void UpdateReward(int action, double reward) {
    41       Debug.Assert(Actions.Contains(action));
    42       totalTries++;
    43       tries[action]++;
    44       sumReward[action] += reward;
    45       sumSqrReward[action] += reward * reward;
     36
     37    public IPolicyActionInfo CreateActionInfo() {
     38      return new MeanAndVariancePolicyActionInfo();
    4639    }
    4740
    48     public override void DisableAction(int action) {
    49       base.DisableAction(action);
    50       totalTries -= tries[action];
    51       tries[action] = -1;
    52       sumReward[action] = 0;
    53       sumSqrReward[action] = 0;
    54     }
    55 
    56     public override void Reset() {
    57       base.Reset();
    58       totalTries = 0;
    59       Array.Clear(tries, 0, tries.Length);
    60       Array.Clear(sumReward, 0, sumReward.Length);
    61       Array.Clear(sumSqrReward, 0, sumSqrReward.Length);
    62     }
    63     public override void PrintStats() {
    64       for (int i = 0; i < sumReward.Length; i++) {
    65         if (tries[i] >= 0) {
    66           Console.Write("{0,5:F2}", sumReward[i] / tries[i]);
    67         } else {
    68           Console.Write("{0,5}", "");
    69         }
    70       }
    71       Console.WriteLine();
    72     }
    7341    public override string ToString() {
    7442      return "UCBNormalPolicy";
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCTPolicy.cs

    r11730 r11732  
    88namespace HeuristicLab.Algorithms.Bandits {
    99  /* Kocsis et al. Bandit based Monte-Carlo Planning */
    10   public class UCTPolicy : BanditPolicy {
    11     private readonly int[] tries;
    12     private readonly double[] sumReward;
    13     private int totalTries = 0;
     10  public class UCTPolicy : IPolicy {
    1411    private readonly double c;
    1512
    16     public UCTPolicy(int numActions, double c = 1.0)
    17       : base(numActions) {
    18       this.tries = new int[numActions];
    19       this.sumReward = new double[numActions];
     13    public UCTPolicy(double c = 1.0) {
    2014      this.c = c;
    2115    }
    2216
    23     public override int SelectAction() {
     17
     18    public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) {
     19      var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>().ToArray(); // TODO: performance
    2420      int bestAction = -1;
    2521      double bestQ = double.NegativeInfinity;
    26       foreach (var a in Actions) {
    27         if (tries[a] == 0) return a;
    28         var q = sumReward[a] / tries[a] + 2 * c * Math.Sqrt(Math.Log(totalTries) / tries[a]);
     22      int totalTries = myActionInfos.Where(a => !a.Disabled).Sum(a => a.Tries);
     23
     24      for (int a = 0; a < myActionInfos.Length; a++) {
     25        if (myActionInfos[a].Disabled) continue;
     26        if (myActionInfos[a].Tries == 0) return a;
     27        var q = myActionInfos[a].SumReward / myActionInfos[a].Tries + 2 * c * Math.Sqrt(Math.Log(totalTries) / myActionInfos[a].Tries);
    2928        if (q > bestQ) {
    3029          bestQ = q;
     
    3231        }
    3332      }
     33      Debug.Assert(bestAction > -1);
    3434      return bestAction;
    3535    }
    36     public override void UpdateReward(int action, double reward) {
    37       Debug.Assert(Actions.Contains(action));
    38       totalTries++;
    39       tries[action]++;
    40       sumReward[action] += reward;
     36
     37    public IPolicyActionInfo CreateActionInfo() {
     38      return new DefaultPolicyActionInfo();
    4139    }
    4240
    43     public override void DisableAction(int action) {
    44       base.DisableAction(action);
    45       totalTries -= tries[action];
    46       tries[action] = -1;
    47       sumReward[action] = 0;
    48     }
    49 
    50     public override void Reset() {
    51       base.Reset();
    52       totalTries = 0;
    53       Array.Clear(tries, 0, tries.Length);
    54       Array.Clear(sumReward, 0, sumReward.Length);
    55     }
    56     public override void PrintStats() {
    57       for (int i = 0; i < sumReward.Length; i++) {
    58         if (tries[i] >= 0) {
    59           Console.Write("{0,5:F2}", sumReward[i] / tries[i]);
    60         } else {
    61           Console.Write("{0,5}", "");
    62         }
    63       }
    64       Console.WriteLine();
    65     }
    6641    public override string ToString() {
    6742      return string.Format("UCTPolicy({0:F2})", c);
Note: See TracChangeset for help on using the changeset viewer.