using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Text; using System.Threading.Tasks; namespace HeuristicLab.Algorithms.Bandits { public class EpsGreedyPolicy : BanditPolicy { private readonly Random random; private readonly double eps; private readonly int[] tries; private readonly double[] sumReward; private readonly RandomPolicy randomPolicy; public EpsGreedyPolicy(Random random, int numActions, double eps) : base(numActions) { this.random = random; this.eps = eps; this.randomPolicy = new RandomPolicy(random, numActions); this.tries = new int[numActions]; this.sumReward = new double[numActions]; } public override int SelectAction() { Debug.Assert(Actions.Any()); if (random.NextDouble() > eps) { // select best var bestQ = double.NegativeInfinity; int bestAction = -1; foreach (var a in Actions) { if (tries[a] == 0) return a; var q = sumReward[a] / tries[a]; if (bestQ < q) { bestQ = q; bestAction = a; } } Debug.Assert(bestAction >= 0); return bestAction; } else { // select random return randomPolicy.SelectAction(); } } public override void UpdateReward(int action, double reward) { Debug.Assert(Actions.Contains(action)); randomPolicy.UpdateReward(action, reward); // does nothing tries[action]++; sumReward[action] += reward; } public override void DisableAction(int action) { base.DisableAction(action); randomPolicy.DisableAction(action); sumReward[action] = 0; tries[action] = -1; } public override void Reset() { base.Reset(); randomPolicy.Reset(); Array.Clear(tries, 0, tries.Length); Array.Clear(sumReward, 0, sumReward.Length); } public override void PrintStats() { for (int i = 0; i < sumReward.Length; i++) { if (tries[i] >= 0) { Console.Write(" {0,5:F2} {1}", sumReward[i] / tries[i], tries[i]); } else { Console.Write("-", ""); } } Console.WriteLine(); } public override string ToString() { return string.Format("EpsGreedyPolicy({0:F2})", eps); } } }