Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCBNormalPolicy.cs @ 11730

Last change on this file since 11730 was 11730, checked in by gkronber, 9 years ago

#2283: several major extensions for grammatical optimization

File size: 2.5 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using System.Threading.Tasks;
7
8namespace HeuristicLab.Algorithms.Bandits {
9  public class UCBNormalPolicy : BanditPolicy {
10    private readonly int[] tries;
11    private readonly double[] sumReward;
12    private readonly double[] sumSqrReward;
13    private int totalTries = 0;
14    public UCBNormalPolicy(int numActions)
15      : base(numActions) {
16      this.tries = new int[numActions];
17      this.sumReward = new double[numActions];
18      this.sumSqrReward = new double[numActions];
19    }
20
21    public override int SelectAction() {
22      Debug.Assert(Actions.Any());
23      int bestAction = -1;
24      double bestQ = double.NegativeInfinity;
25      foreach (var a in Actions) {
26        if (totalTries <= 1 || tries[a] <= 1 || tries[a] <= Math.Ceiling(8 * Math.Log(totalTries))) return a;
27        var avgReward = sumReward[a] / tries[a];
28        var estVariance = 16 * ((sumSqrReward[a] - tries[a] * Math.Pow(avgReward, 2)) / (tries[a] - 1)) * (Math.Log(totalTries - 1) / tries[a]);
29        if (estVariance < 0) estVariance = 0; // numerical problems
30        var q = avgReward
31          + Math.Sqrt(estVariance);
32        if (q > bestQ) {
33          bestQ = q;
34          bestAction = a;
35        }
36      }
37      Debug.Assert(Actions.Contains(bestAction));
38      return bestAction;
39    }
40    public override void UpdateReward(int action, double reward) {
41      Debug.Assert(Actions.Contains(action));
42      totalTries++;
43      tries[action]++;
44      sumReward[action] += reward;
45      sumSqrReward[action] += reward * reward;
46    }
47
48    public override void DisableAction(int action) {
49      base.DisableAction(action);
50      totalTries -= tries[action];
51      tries[action] = -1;
52      sumReward[action] = 0;
53      sumSqrReward[action] = 0;
54    }
55
56    public override void Reset() {
57      base.Reset();
58      totalTries = 0;
59      Array.Clear(tries, 0, tries.Length);
60      Array.Clear(sumReward, 0, sumReward.Length);
61      Array.Clear(sumSqrReward, 0, sumSqrReward.Length);
62    }
63    public override void PrintStats() {
64      for (int i = 0; i < sumReward.Length; i++) {
65        if (tries[i] >= 0) {
66          Console.Write("{0,5:F2}", sumReward[i] / tries[i]);
67        } else {
68          Console.Write("{0,5}", "");
69        }
70      }
71      Console.WriteLine();
72    }
73    public override string ToString() {
74      return "UCBNormalPolicy";
75    }
76  }
77}
Note: See TracBrowser for help on using the repository browser.