Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/GaussianThompsonSamplingPolicy.cs @ 11728

Last change on this file since 11728 was 11727, checked in by gkronber, 8 years ago

#2283: worked on grammatical optimization problem solvers (simple MCTS done)

File size: 1.8 KB
Line 
1using System;
2using System.Diagnostics;
3using System.Linq;
4using HeuristicLab.Common;
5
6namespace HeuristicLab.Algorithms.Bandits {
7  public class GaussianThompsonSamplingPolicy : BanditPolicy {
8    private readonly Random random;
9    private readonly double[] sumRewards;
10    private readonly double[] sumSqrRewards;
11    private readonly int[] tries;
12    public GaussianThompsonSamplingPolicy(Random random, int numActions)
13      : base(numActions) {
14      this.random = random;
15      this.sumRewards = new double[numActions];
16      this.sumSqrRewards = new double[numActions];
17      this.tries = new int[numActions];
18    }
19
20
21    public override int SelectAction() {
22      Debug.Assert(Actions.Any());
23      var maxTheta = double.NegativeInfinity;
24      int bestAction = -1;
25      foreach (var a in Actions) {
26        if (tries[a] == 0) return a;
27        var mu = sumRewards[a] / tries[a];
28        var stdDev = Math.Sqrt(sumSqrRewards[a] / tries[a] - Math.Pow(mu, 2));
29        var theta = Rand.RandNormal(random) * stdDev + mu;
30        if (theta > maxTheta) {
31          maxTheta = theta;
32          bestAction = a;
33        }
34      }
35      return bestAction;
36    }
37
38    public override void UpdateReward(int action, double reward) {
39      Debug.Assert(Actions.Contains(action));
40
41      sumRewards[action] += reward;
42      sumSqrRewards[action] += reward * reward;
43      tries[action]++;
44    }
45
46    public override void DisableAction(int action) {
47      base.DisableAction(action);
48      sumRewards[action] = 0;
49      sumSqrRewards[action] = 0;
50      tries[action] = -1;
51    }
52
53    public override void Reset() {
54      base.Reset();
55      Array.Clear(sumRewards, 0, sumRewards.Length);
56      Array.Clear(sumSqrRewards, 0, sumSqrRewards.Length);
57      Array.Clear(tries, 0, tries.Length);
58    }
59  }
60}
Note: See TracBrowser for help on using the repository browser.