source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/BoltzmannExplorationPolicy.cs @ 11730

Last change on this file since 11730 was 11730, checked in by gkronber, 5 years ago

#2283: several major extensions for grammatical optimization

File size: 2.3 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using System.Threading.Tasks;
7
8namespace HeuristicLab.Algorithms.Bandits {
9  // also called softmax policy
10  public class BoltzmannExplorationPolicy : BanditPolicy {
11    private readonly Random random;
12    private readonly double eps;
13    private readonly int[] tries;
14    private readonly double[] sumReward;
15    private readonly double beta;
16
17    public BoltzmannExplorationPolicy(Random random, int numActions, double beta)
18      : base(numActions) {
19      if (beta < 0) throw new ArgumentException();
20      this.random = random;
21      this.beta = beta;
22      this.tries = new int[numActions];
23      this.sumReward = new double[numActions];
24    }
25
26    public override int SelectAction() {
27      Debug.Assert(Actions.Any());
28      // select best
29      var maxReward = double.NegativeInfinity;
30      int bestAction = -1;
31      if (Actions.Any(a => tries[a] == 0))
32        return Actions.First(a => tries[a] == 0);
33
34      var ts = Actions.Select(a => Math.Exp(beta * sumReward[a] / tries[a]));
35      var r = random.NextDouble() * ts.Sum();
36
37      var agg = 0.0;
38      foreach (var p in Actions.Zip(ts, Tuple.Create)) {
39        agg += p.Item2;
40        if (agg >= r) return p.Item1;
41      }
42      throw new InvalidProgramException();
43    }
44    public override void UpdateReward(int action, double reward) {
45      Debug.Assert(Actions.Contains(action));
46
47      tries[action]++;
48      sumReward[action] += reward;
49    }
50
51    public override void DisableAction(int action) {
52      base.DisableAction(action);
53      sumReward[action] = 0;
54      tries[action] = -1;
55    }
56
57    public override void Reset() {
58      base.Reset();
59      Array.Clear(tries, 0, tries.Length);
60      Array.Clear(sumReward, 0, sumReward.Length);
61    }
62
63    public override void PrintStats() {
64      for (int i = 0; i < sumReward.Length; i++) {
65        if (tries[i] >= 0) {
66          Console.Write("{0,5:F2}", sumReward[i] / tries[i]);
67        } else {
68          Console.Write("{0,5}", "");
69        }
70      }
71      Console.WriteLine();
72    }
73
74    public override string ToString() {
75      return string.Format("BoltzmannExplorationPolicy({0:F2})", beta);
76    }
77  }
78}
Note: See TracBrowser for help on using the repository browser.