source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/Exp3Policy.cs @ 11742

Last change on this file since 11742 was 11742, checked in by gkronber, 6 years ago

#2283 refactoring

File size: 2.0 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using System.Threading.Tasks;
7
8namespace HeuristicLab.Algorithms.Bandits.BanditPolicies {
9  public class Exp3Policy : BanditPolicy {
10    private readonly Random random;
11    private readonly double gamma;
12    private readonly double[] w;
13    // TODO: debug (very large weights over time)
14    public Exp3Policy(Random random, int numActions, double gamma)
15      : base(numActions) {
16      if (gamma < 0 || gamma > 1) throw new ArgumentException();
17      this.random = random;
18      this.gamma = gamma;
19      this.w = Enumerable.Repeat(1.0, numActions).ToArray();
20    }
21
22    public override int SelectAction() {
23      Debug.Assert(Actions.Any());
24      var numActions = Actions.Count();
25      var sumW = w.Sum();
26      var r = random.NextDouble();
27      var sumP = (1 - gamma) * w[0] / sumW + gamma / numActions;
28      int i = 0;
29      while (r > sumP) {
30        i++;
31        sumP += (1 - gamma) * w[i] / sumW + gamma / numActions;
32      }
33      Debug.Assert(i >= 0 && i < numActions);
34      return i;
35    }
36
37    public override void UpdateReward(int action, double reward) {
38      Debug.Assert(Actions.Contains(action));
39      var numActions = Actions.Count();
40      var p = (1 - gamma) * w[action] / w.Sum() + gamma / numActions;
41      var estReward = reward / p;
42      w[action] = w[action] * Math.Exp(gamma * estReward / numActions);
43    }
44
45    public override void DisableAction(int action) {
46      base.DisableAction(action);
47      w[action] = 0;
48    }
49
50    public override void Reset() {
51      base.Reset();
52      foreach (var a in Actions) w[a] = 1.0;
53    }
54    public override void PrintStats() {
55      for (int i = 0; i < w.Length; i++) {
56        if (w[i] > 0) {
57          Console.Write("{0,5:F2}", w[i]);
58        } else {
59          Console.Write("{0,5}", "");
60        }
61      }
62      Console.WriteLine();
63    }
64    public override string ToString() {
65      return "Exp3Policy";
66    }
67  }
68}
Note: See TracBrowser for help on using the repository browser.