# source:branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/GaussianThompsonSamplingPolicy.cs@11730

Last change on this file since 11730 was 11730, checked in by gkronber, 5 years ago

#2283: several major extensions for grammatical optimization

File size: 3.6 KB
Line
1using System;
2using System.Diagnostics;
3using System.Linq;
4using HeuristicLab.Common;
5
6namespace HeuristicLab.Algorithms.Bandits {
7
8  public class GaussianThompsonSamplingPolicy : BanditPolicy {
9    private readonly Random random;
10    private readonly double[] sampleMean;
11    private readonly double[] sampleM2;
12    private readonly int[] tries;
13    private bool compatibility;
14
15    // assumes a Gaussian reward distribution with different means but the same variances for each action
16    // the prior for the mean is also Gaussian with the following parameters
17    private readonly double rewardVariance = 0.1; // we assume a known variance
18
19    private readonly double priorMean = 0.5;
20    private readonly double priorVariance = 1;
21
22
23    public GaussianThompsonSamplingPolicy(Random random, int numActions, bool compatibility = false)
24      : base(numActions) {
25      this.random = random;
26      this.sampleMean = new double[numActions];
27      this.sampleM2 = new double[numActions];
28      this.tries = new int[numActions];
29      this.compatibility = compatibility;
30    }
31
32
33    public override int SelectAction() {
34      Debug.Assert(Actions.Any());
35      var maxTheta = double.NegativeInfinity;
36      int bestAction = -1;
37      foreach (var a in Actions) {
38        if(tries[a] == -1) continue; // skip disabled actions
39        double theta;
40        if (compatibility) {
41          if (tries[a] < 2) return a;
42          var mu = sampleMean[a];
43          var variance = sampleM2[a] / tries[a];
44          var stdDev = Math.Sqrt(variance);
45          theta = Rand.RandNormal(random) * stdDev + mu;
46        } else {
47          // calculate posterior mean and variance (for mean reward)
48
49          // see Murphy 2007: Conjugate Bayesian analysis of the Gaussian distribution (http://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf)
50          var posteriorVariance = 1.0 / (tries[a] / rewardVariance + 1.0 / priorVariance);
51          var posteriorMean = posteriorVariance * (priorMean / priorVariance + tries[a] * sampleMean[a] / rewardVariance);
52
53          // sample a mean from the posterior
54          theta = Rand.RandNormal(random) * Math.Sqrt(posteriorVariance) + posteriorMean;
55
56          // theta already represents the expected reward value => nothing else to do
57        }
58        if (theta > maxTheta) {
59          maxTheta = theta;
60          bestAction = a;
61        }
62      }
63      Debug.Assert(Actions.Contains(bestAction));
64      return bestAction;
65    }
66
67    public override void UpdateReward(int action, double reward) {
68      Debug.Assert(Actions.Contains(action));
69      tries[action]++;
70      var delta = reward - sampleMean[action];
71      sampleMean[action] += delta / tries[action];
72      sampleM2[action] += sampleM2[action] + delta * (reward - sampleMean[action]);
73    }
74
75    public override void DisableAction(int action) {
76      base.DisableAction(action);
77      sampleMean[action] = 0;
78      sampleM2[action] = 0;
79      tries[action] = -1;
80    }
81
82    public override void Reset() {
83      base.Reset();
84      Array.Clear(sampleMean, 0, sampleMean.Length);
85      Array.Clear(sampleM2, 0, sampleM2.Length);
86      Array.Clear(tries, 0, tries.Length);
87    }
88
89    public override void PrintStats() {
90      for (int i = 0; i < sampleMean.Length; i++) {
91        if (tries[i] >= 0) {
92          Console.Write(" {0,5:F2} {1}", sampleMean[i] / tries[i], tries[i]);
93        } else {
94          Console.Write("{0,5}", "");
95        }
96      }
97      Console.WriteLine();
98    }
99    public override string ToString() {
100      return "GaussianThompsonSamplingPolicy";
101    }
102  }
103}
Note: See TracBrowser for help on using the repository browser.