using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Text; using System.Threading.Tasks; namespace HeuristicLab.Algorithms.Bandits.BanditPolicies { public class Exp3Policy : BanditPolicy { private readonly Random random; private readonly double gamma; private readonly double[] w; // TODO: debug (very large weights over time) public Exp3Policy(Random random, int numActions, double gamma) : base(numActions) { if (gamma < 0 || gamma > 1) throw new ArgumentException(); this.random = random; this.gamma = gamma; this.w = Enumerable.Repeat(1.0, numActions).ToArray(); } public override int SelectAction() { Debug.Assert(Actions.Any()); var numActions = Actions.Count(); var sumW = w.Sum(); var r = random.NextDouble(); var sumP = (1 - gamma) * w[0] / sumW + gamma / numActions; int i = 0; while (r > sumP) { i++; sumP += (1 - gamma) * w[i] / sumW + gamma / numActions; } Debug.Assert(i >= 0 && i < numActions); return i; } public override void UpdateReward(int action, double reward) { Debug.Assert(Actions.Contains(action)); var numActions = Actions.Count(); var p = (1 - gamma) * w[action] / w.Sum() + gamma / numActions; var estReward = reward / p; w[action] = w[action] * Math.Exp(gamma * estReward / numActions); } public override void DisableAction(int action) { base.DisableAction(action); w[action] = 0; } public override void Reset() { base.Reset(); foreach (var a in Actions) w[a] = 1.0; } public override void PrintStats() { for (int i = 0; i < w.Length; i++) { if (w[i] > 0) { Console.Write("{0,5:F2}", w[i]); } else { Console.Write("{0,5}", ""); } } Console.WriteLine(); } public override string ToString() { return "Exp3Policy"; } } }