using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; namespace HeuristicLab.Algorithms.Bandits { public class UCBNormalPolicy : BanditPolicy { private readonly int[] tries; private readonly double[] sumReward; private readonly double[] sumSqrReward; private int totalTries = 0; public UCBNormalPolicy(int numActions) : base(numActions) { this.tries = new int[NumActions]; this.sumReward = new double[NumActions]; this.sumSqrReward = new double[NumActions]; } private double V(int arm) { var s = tries[arm]; return sumSqrReward[arm] / s - Math.Pow(sumReward[arm] / s, 2) + Math.Sqrt(2 * Math.Log(totalTries) / s); } public override int SelectAction() { int bestAction = -1; double bestQ = double.NegativeInfinity; for (int i = 0; i < NumActions; i++) { if (totalTries == 0 || tries[i] == 0 || tries[i] < Math.Ceiling(8 * Math.Log(totalTries))) return i; var avgReward = sumReward[i] / tries[i]; var q = avgReward + Math.Sqrt(16 * ((sumSqrReward[i] - tries[i] * Math.Pow(avgReward, 2)) / (tries[i] - 1)) * (Math.Log(totalTries - 1) / tries[i])); if (q > bestQ) { bestQ = q; bestAction = i; } } return bestAction; } public override void UpdateReward(int action, double reward) { totalTries++; tries[action]++; sumReward[action] += reward; sumSqrReward[action] += reward * reward; } public override void Reset() { totalTries = 0; Array.Clear(tries, 0, tries.Length); Array.Clear(sumReward, 0, sumReward.Length); Array.Clear(sumSqrReward, 0, sumSqrReward.Length); } } }