Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/BernoulliThompsonSamplingPolicy.cs @ 11728

Last change on this file since 11728 was 11727, checked in by gkronber, 9 years ago

#2283: worked on grammatical optimization problem solvers (simple MCTS done)

File size: 1.6 KB
RevLine 
[11727]1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using System.Threading.Tasks;
7using HeuristicLab.Common;
8
9namespace HeuristicLab.Algorithms.Bandits {
10  public class BernoulliThompsonSamplingPolicy : BanditPolicy {
11    private readonly Random random;
12    private readonly int[] success;
13    private readonly int[] failure;
14
15    // parameters of beta prior distribution
16    private readonly double alpha = 1.0;
17    private readonly double beta = 1.0;
18
19    public BernoulliThompsonSamplingPolicy(Random random, int numActions)
20      : base(numActions) {
21      this.random = random;
22      this.success = new int[numActions];
23      this.failure = new int[numActions];
24    }
25
26    public override int SelectAction() {
27      Debug.Assert(Actions.Any());
28      var maxTheta = double.NegativeInfinity;
29      int bestAction = -1;
30      foreach (var a in Actions) {
31        var theta = Rand.BetaRand(random, success[a] + alpha, failure[a] + beta);
32        if (theta > maxTheta) {
33          maxTheta = theta;
34          bestAction = a;
35        }
36      }
37      return bestAction;
38    }
39
40    public override void UpdateReward(int action, double reward) {
41      Debug.Assert(Actions.Contains(action));
42
43      if (reward > 0) success[action]++;
44      else failure[action]++;
45    }
46
47    public override void DisableAction(int action) {
48      base.DisableAction(action);
49      success[action] = -1;
50    }
51
52    public override void Reset() {
53      base.Reset();
54      Array.Clear(success, 0, success.Length);
55      Array.Clear(failure, 0, failure.Length);
56    }
57  }
58}
Note: See TracBrowser for help on using the repository browser.