source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/UCBNormalPolicy.cs @ 11710

Last change on this file since 11710 was 11710, checked in by gkronber, 8 years ago

#2283: more bandit policies and tests

File size: 1.8 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using System.Threading.Tasks;
6
7namespace HeuristicLab.Algorithms.Bandits {
8  public class UCBNormalPolicy : BanditPolicy {
9    private readonly int[] tries;
10    private readonly double[] sumReward;
11    private readonly double[] sumSqrReward;
12    private int totalTries = 0;
13    public UCBNormalPolicy(int numActions)
14      : base(numActions) {
15      this.tries = new int[NumActions];
16      this.sumReward = new double[NumActions];
17      this.sumSqrReward = new double[NumActions];
18    }
19
20    private double V(int arm) {
21      var s = tries[arm];
22      return sumSqrReward[arm] / s - Math.Pow(sumReward[arm] / s, 2) + Math.Sqrt(2 * Math.Log(totalTries) / s);
23    }
24
25
26    public override int SelectAction() {
27      int bestAction = -1;
28      double bestQ = double.NegativeInfinity;
29      for (int i = 0; i < NumActions; i++) {
30        if (totalTries == 0 || tries[i] == 0 || tries[i] < Math.Ceiling(8 * Math.Log(totalTries))) return i;
31        var avgReward = sumReward[i] / tries[i];
32        var q = avgReward
33          + Math.Sqrt(16 * ((sumSqrReward[i] - tries[i] * Math.Pow(avgReward, 2)) / (tries[i] - 1)) * (Math.Log(totalTries - 1) / tries[i]));
34        if (q > bestQ) {
35          bestQ = q;
36          bestAction = i;
37        }
38      }
39      return bestAction;
40    }
41    public override void UpdateReward(int action, double reward) {
42      totalTries++;
43      tries[action]++;
44      sumReward[action] += reward;
45      sumSqrReward[action] += reward * reward;
46    }
47    public override void Reset() {
48      totalTries = 0;
49      Array.Clear(tries, 0, tries.Length);
50      Array.Clear(sumReward, 0, sumReward.Length);
51      Array.Clear(sumSqrReward, 0, sumSqrReward.Length);
52    }
53  }
54}
Note: See TracBrowser for help on using the repository browser.