source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCB1TunedPolicy.cs @ 11730

Last change on this file since 11730 was 11730, checked in by gkronber, 5 years ago

#2283: several major extensions for grammatical optimization

File size: 2.4 KB
RevLine 
[11710]1using System;
2using System.Collections.Generic;
[11727]3using System.Diagnostics;
[11710]4using System.Linq;
5using System.Text;
6using System.Threading.Tasks;
7
8namespace HeuristicLab.Algorithms.Bandits {
9  public class UCB1TunedPolicy : BanditPolicy {
10    private readonly int[] tries;
11    private readonly double[] sumReward;
12    private readonly double[] sumSqrReward;
13    private int totalTries = 0;
14    public UCB1TunedPolicy(int numActions)
15      : base(numActions) {
[11727]16      this.tries = new int[numActions];
17      this.sumReward = new double[numActions];
18      this.sumSqrReward = new double[numActions];
[11710]19    }
20
21    private double V(int arm) {
22      var s = tries[arm];
23      return sumSqrReward[arm] / s - Math.Pow(sumReward[arm] / s, 2) + Math.Sqrt(2 * Math.Log(totalTries) / s);
24    }
25
26
27    public override int SelectAction() {
[11727]28      Debug.Assert(Actions.Any());
[11710]29      int bestAction = -1;
30      double bestQ = double.NegativeInfinity;
[11727]31      foreach (var a in Actions) {
32        if (tries[a] == 0) return a;
33        var q = sumReward[a] / tries[a] + Math.Sqrt((Math.Log(totalTries) / tries[a]) * Math.Min(1.0 / 4, V(a))); // 1/4 is upper bound of bernoulli distributed variable
[11710]34        if (q > bestQ) {
35          bestQ = q;
[11727]36          bestAction = a;
[11710]37        }
38      }
39      return bestAction;
40    }
41    public override void UpdateReward(int action, double reward) {
[11727]42      Debug.Assert(Actions.Contains(action));
[11710]43      totalTries++;
44      tries[action]++;
45      sumReward[action] += reward;
46      sumSqrReward[action] += reward * reward;
47    }
[11727]48
49    public override void DisableAction(int action) {
50      base.DisableAction(action);
51      totalTries -= tries[action];
52      tries[action] = -1;
53      sumReward[action] = 0;
54      sumSqrReward[action] = 0;
55    }
56
[11710]57    public override void Reset() {
[11727]58      base.Reset();
[11710]59      totalTries = 0;
60      Array.Clear(tries, 0, tries.Length);
61      Array.Clear(sumReward, 0, sumReward.Length);
62      Array.Clear(sumSqrReward, 0, sumSqrReward.Length);
63    }
[11730]64    public override void PrintStats() {
65      for (int i = 0; i < sumReward.Length; i++) {
66        if (tries[i] >= 0) {
67          Console.Write("{0,5:F2}", sumReward[i] / tries[i]);
68        } else {
69          Console.Write("{0,5}", "");
70        }
71      }
72      Console.WriteLine();
73    }
74    public override string ToString() {
75      return "UCB1TunedPolicy";
76    }
[11710]77  }
78}
Note: See TracBrowser for help on using the repository browser.