using System; using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; using HeuristicLab.Common; using HeuristicLab.Problems.GrammaticalOptimization; namespace HeuristicLab.Algorithms.Bandits.GrammarPolicies { // stores: tries, avg reward and max reward for each state (base class for RandomPolicy and TDPolicy public abstract class GrammarPolicy : IGrammarPolicy { protected Dictionary avgReward; protected Dictionary tries; protected Dictionary maxReward; protected readonly bool useCanonicalState; protected readonly IProblem problem; protected GrammarPolicy(IProblem problem, bool useCanonicalState = false) { this.useCanonicalState = useCanonicalState; this.problem = problem; this.tries = new Dictionary(); this.avgReward = new Dictionary(); this.maxReward = new Dictionary(); } public abstract bool TrySelect(Random random, string curState, IEnumerable afterStates, out int selectedStateIdx); public virtual void UpdateReward(IEnumerable stateTrajectory, double reward) { foreach (var state in stateTrajectory) { var s = CanonicalState(state); if (!tries.ContainsKey(s)) tries.Add(s, 0); if (!avgReward.ContainsKey(s)) avgReward.Add(s, 0); if (!maxReward.ContainsKey(s)) maxReward.Add(s, 0); tries[s]++; double alpha = 1.0 / tries[s]; avgReward[s] += alpha * (reward - avgReward[s]); maxReward[s] = Math.Max(maxReward[s], reward); } } public virtual void Reset() { avgReward.Clear(); maxReward.Clear(); tries.Clear(); } public double AvgReward(string state) { var s = CanonicalState(state); if (avgReward.ContainsKey(s)) return avgReward[s]; else return 0.0; } public double MaxReward(string state) { var s = CanonicalState(state); if (maxReward.ContainsKey(s)) return maxReward[s]; else return 0.0; } public virtual int GetTries(string state) { var s = CanonicalState(state); if (tries.ContainsKey(s)) return tries[s]; else return 0; } public virtual double GetValue(string state) { return AvgReward(state); } protected string CanonicalState(string state) { if (useCanonicalState) return problem.CanonicalRepresentation(state); else return state; } } }