Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCB1TunedPolicy.cs @ 11711

Last change on this file since 11711 was 11711, checked in by gkronber, 9 years ago

#2283: folders for bandits and policies

File size: 1.7 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using System.Threading.Tasks;
6
7namespace HeuristicLab.Algorithms.Bandits {
8  public class UCB1TunedPolicy : BanditPolicy {
9    private readonly int[] tries;
10    private readonly double[] sumReward;
11    private readonly double[] sumSqrReward;
12    private int totalTries = 0;
13    public UCB1TunedPolicy(int numActions)
14      : base(numActions) {
15      this.tries = new int[NumActions];
16      this.sumReward = new double[NumActions];
17      this.sumSqrReward = new double[NumActions];
18    }
19
20    private double V(int arm) {
21      var s = tries[arm];
22      return sumSqrReward[arm] / s - Math.Pow(sumReward[arm] / s, 2) + Math.Sqrt(2 * Math.Log(totalTries) / s);
23    }
24
25
26    public override int SelectAction() {
27      int bestAction = -1;
28      double bestQ = double.NegativeInfinity;
29      for (int i = 0; i < NumActions; i++) {
30        if (tries[i] == 0) return i;
31        var q = sumReward[i] / tries[i] + Math.Sqrt((Math.Log(totalTries) / tries[i]) * Math.Min(1.0 / 4, V(i))); // 1/4 is upper bound of bernoulli distributed variable
32        if (q > bestQ) {
33          bestQ = q;
34          bestAction = i;
35        }
36      }
37      return bestAction;
38    }
39    public override void UpdateReward(int action, double reward) {
40      totalTries++;
41      tries[action]++;
42      sumReward[action] += reward;
43      sumSqrReward[action] += reward * reward;
44    }
45    public override void Reset() {
46      totalTries = 0;
47      Array.Clear(tries, 0, tries.Length);
48      Array.Clear(sumReward, 0, sumReward.Length);
49      Array.Clear(sumSqrReward, 0, sumSqrReward.Length);
50    }
51  }
52}
Note: See TracBrowser for help on using the repository browser.