using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Text; using System.Threading.Tasks; using HeuristicLab.Common; using HeuristicLab.Problems.GrammaticalOptimization; namespace HeuristicLab.Algorithms.Bandits.GrammarPolicies { // resampling is not prevented public sealed class GenericPolicy : IGrammarPolicy { private Dictionary Q; // stores the necessary information for bandit policies for each state private Dictionary T; // tries; private Dictionary> followStates; private readonly IProblem problem; private readonly HashSet done; // contains all visited chains public GenericPolicy(IProblem problem) { this.problem = problem; this.Q = new Dictionary(); this.T = new Dictionary(); this.followStates = new Dictionary>(); this.done = new HashSet(); } private double[] activeAfterStates; // don't allocate each time private int[] actionIndexMap; // don't allocate each time public bool TrySelect(System.Random random, string curState, IEnumerable afterStates, out int selectedStateIdx) { // fail if all states are done (corresponding state infos are disabled) if (afterStates.All(s => Done(s))) { // fail because all follow states have already been visited => also disable the current state (if we can be sure that it has been fully explored) MarkAsDone(curState); selectedStateIdx = -1; return false; } if (activeAfterStates == null || activeAfterStates.Length < afterStates.Count()) { activeAfterStates = new double[afterStates.Count()]; actionIndexMap = new int[afterStates.Count()]; } if (!followStates.ContainsKey(curState)) { followStates[curState] = new List(afterStates); } var idx = 0; int originalIdx = 0; foreach (var afterState in afterStates) { if (!Done(afterState)) { activeAfterStates[idx] = CalculateValue(afterState); actionIndexMap[idx] = originalIdx; idx++; } originalIdx++; } //var eps = Math.Max(500.0 / (GetTries(curState) + 1), 0.01); //var eps = 10.0 / Math.Sqrt(GetTries(curState) + 1); var eps = 0.01; selectedStateIdx = actionIndexMap[SelectEpsGreedy(random, activeAfterStates.Take(idx), eps)]; UpdateValue(curState, afterStates); return true; } private double CalculateValue(string chain) { var features = problem.GetFeatures(chain); var sum = 0.0; foreach (var f in features) { // if (GetTries(f.Id) == 0) // sum = 0.0; // else sum += GetValue(f.Id) * f.Value; } return sum; } private void UpdateValue(string curChain, IEnumerable alternatives) { const double gamma = 1; const double alpha = 0.01; var maxNextQ = alternatives .Select(CalculateValue).Max(); var delta = gamma * maxNextQ - CalculateValue(curChain); foreach (var f in problem.GetFeatures(curChain)) { Q[f.Id] = GetValue(f.Id) + alpha * delta * f.Value; } } private void UpdateLastValue(string terminalChain, double reward) { const double alpha = 0.01; var delta = reward - CalculateValue(terminalChain); foreach (var f in problem.GetFeatures(terminalChain)) { Q[f.Id] = GetValue(f.Id) + alpha * delta * f.Value; } } private int SelectBoltzmann(System.Random random, IEnumerable qs, double beta = 10) { // select best // try any of the untries actions randomly // for RoyalSequence it is much better to select the actions in the order of occurrence (all terminal alternatives first) //if (myActionInfos.Any(aInfo => !aInfo.Disabled && aInfo.Tries == 0)) { // return myActionInfos // .Select((aInfo, idx) => new { aInfo, idx }) // .Where(p => !p.aInfo.Disabled) // .Where(p => p.aInfo.Tries == 0) // .SelectRandom(random).idx; //} var w = from q in qs select Math.Exp(beta * q); var bestAction = Enumerable.Range(0, qs.Count()).SampleProportional(random, w); Debug.Assert(bestAction >= 0); return bestAction; } private int SelectEpsGreedy(System.Random random, IEnumerable qs, double eps = 0.2) { if (random.NextDouble() >= eps) { // eps == 0 should be equivalent to pure exploitation, eps == 1 is pure exploration // select best var bestActions = new List(); double bestQ = double.NegativeInfinity; int aIdx = -1; foreach (var q in qs) { aIdx++; if (q > bestQ) { bestActions.Clear(); bestActions.Add(aIdx); bestQ = q; } else if (HeuristicLab.Common.Extensions.IsAlmost(q,bestQ)) { bestActions.Add(aIdx); } } Debug.Assert(bestActions.Any()); return bestActions.SelectRandom(random); } else { // select random return SelectRandom(random, qs); } } private int SelectRandom(System.Random random, IEnumerable qs) { return qs .Select((aInfo, idx) => Tuple.Create(aInfo, idx)) .SelectRandom(random).Item2; } public void UpdateReward(IEnumerable chainTrajectory, double reward) { // // only updates the last chain because we already update values after each step // var reverseChains = chainTrajectory.Reverse(); // var terminalChain = reverseChains.First(); // // UpdateValue(terminalChain, reward); // // foreach (var chain in reverseChains.Skip(1)) { // // var maxNextQ = followStates[chain] // //.Where(s=>!Done(s)) // .Select(GetValue).Max(); // // UpdateValue(chain, maxNextQ); // } var terminalChain = chainTrajectory.Last(); UpdateLastValue(terminalChain, reward); if (problem.Grammar.IsTerminal(terminalChain)) MarkAsDone(terminalChain); } public void Reset() { Q.Clear(); T.Clear(); done.Clear(); followStates.Clear(); } private bool Done(string chain) { return done.Contains(chain); } private void MarkAsDone(string chain) { done.Add(chain); } public int GetTries(string fId) { if (T.ContainsKey(fId)) return T[fId]; else return 0; } public double GetValue(string fId) { // var s = CalcState(chain); if (Q.ContainsKey(fId)) return Q[fId]; else return 0.0; // TODO: check alternatives } // private string CalcState(string chain) { // var f = problem.GetFeatures(chain); // // this policy only works for problems that return exactly one feature (the 'state') // if (f.Skip(1).Any()) throw new ArgumentException(); // return f.First().Id; // } public void PrintStats() { Console.WriteLine(Q.Values.Max()); // var topTries = Q.Keys.OrderByDescending(key => T[key]).Take(50); // var topQs = Q.Keys/*.Where(key => key.Contains("E"))*/.OrderByDescending(key => Q[key]).Take(50); // foreach (var t in topTries.Zip(topQs, Tuple.Create)) { // var id1 = t.Item1; // var id2 = t.Item2; // Console.WriteLine("{0,30} {1,6} {2:N4} {3,30} {4,6} {5:N4}", id1, T[id1], Q[id1], id2, T[id2], Q[id2]); // } foreach (var option in new String[] { "a*b", "c*d", "a*b+c*d", "e*f", "a*b+c*d+e*f", "a*b+a*b", "c*d+c*d", "a*a", "a*b","a*c","a*d","a*e","a*f","a*g","a*h","a*i","a*j", "a*b","c*d","e*f","a*c","a*f","a*i","a*i*g","c*f","c*f*j", "b+c","a+c","b+d","a+d", "a*b+c*d+e*f", "a*b+c*d+e*f+a", "a*b+c*d+e*f+b", "a*b+c*d+e*f+c", "a*b+c*d+e*f+d","a*b+c*d+e*f+e", "a*b+c*d+e*f+f", "a*b+c*d+e*f+g", "a*b+c*d+e*f+h", "a*b+c*d+e*f+i", "a*b+c*d+e*f+j", "a*b+c*d+e*f+a*g*i+c*j*f" }) { Console.WriteLine("{0,-10} {1:N5}", option, CalculateValue(option)); } // var topQs = Q.Keys/*.Where(key => key.Contains("E"))*/.OrderByDescending(key => Math.Abs(Q[key])).Take(10); // foreach (var t in topQs) { // Console.WriteLine("{0,30} {1:N4}", t, Q[t]); // } } } }