Changeset 12294 for branches/HeuristicLab.Problems.GrammaticalOptimization-gkr/HeuristicLab.Algorithms.GrammaticalOptimization
- Timestamp:
- 04/08/15 10:09:47 (10 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/HeuristicLab.Problems.GrammaticalOptimization-gkr/HeuristicLab.Algorithms.GrammaticalOptimization/SequentialDecisionPolicies/GenericPolicy.cs
r12291 r12294 11 11 // resampling is not prevented 12 12 public sealed class GenericPolicy : IGrammarPolicy { 13 private Dictionary<string, IBanditPolicyActionInfo> stateInfo; // stores the necessary information for bandit policies for each state 13 private Dictionary<string, double> Q; // stores the necessary information for bandit policies for each state 14 private Dictionary<string, int> T; // tries; 15 private Dictionary<string, List<string>> followStates; 14 16 private readonly IProblem problem; 15 private readonly IBanditPolicy banditPolicy;16 17 private readonly HashSet<string> done; // contains all visited chains 17 18 18 public GenericPolicy(IProblem problem , IBanditPolicy banditPolicy) {19 public GenericPolicy(IProblem problem) { 19 20 this.problem = problem; 20 this.banditPolicy = banditPolicy; 21 this.stateInfo = new Dictionary<string, IBanditPolicyActionInfo>(); 21 this.Q = new Dictionary<string, double>(); 22 this.T = new Dictionary<string, int>(); 23 this.followStates = new Dictionary<string, List<string>>(); 22 24 this.done = new HashSet<string>(); 23 25 } 24 26 25 private IBanditPolicyActionInfo[] activeAfterStates; // don't allocate each time27 private double[] activeAfterStates; // don't allocate each time 26 28 private int[] actionIndexMap; // don't allocate each time 27 29 … … 37 39 38 40 if (activeAfterStates == null || activeAfterStates.Length < afterStates.Count()) { 39 activeAfterStates = new IBanditPolicyActionInfo[afterStates.Count()];41 activeAfterStates = new double[afterStates.Count()]; 40 42 actionIndexMap = new int[afterStates.Count()]; 43 } 44 if (!followStates.ContainsKey(curState)) { 45 followStates[curState] = new List<string>(afterStates); 41 46 } 42 47 var idx = 0; int originalIdx = 0; 43 48 foreach (var afterState in afterStates) { 44 49 if (!Done(afterState)) { 45 activeAfterStates[idx] = Get StateInfo(afterState);50 activeAfterStates[idx] = GetValue(afterState); 46 51 actionIndexMap[idx] = originalIdx; 47 52 idx++; … … 50 55 } 51 56 52 selectedStateIdx = actionIndexMap[banditPolicy.SelectAction(random, activeAfterStates.Take(idx))]; 57 //var eps = Math.Max(500.0 / (GetTries(curState) + 1), 0.01); 58 //var eps = 10.0 / Math.Sqrt(GetTries(curState) + 1); 59 var eps = 0.2; 60 selectedStateIdx = actionIndexMap[SelectEpsGreedy(random, activeAfterStates.Take(idx), eps)]; 53 61 54 62 return true; 55 63 } 56 64 65 private int SelectBoltzmann(Random random, IEnumerable<double> qs, double beta = 10) { 66 // select best 57 67 68 // try any of the untries actions randomly 69 // for RoyalSequence it is much better to select the actions in the order of occurrence (all terminal alternatives first) 70 //if (myActionInfos.Any(aInfo => !aInfo.Disabled && aInfo.Tries == 0)) { 71 // return myActionInfos 72 // .Select((aInfo, idx) => new { aInfo, idx }) 73 // .Where(p => !p.aInfo.Disabled) 74 // .Where(p => p.aInfo.Tries == 0) 75 // .SelectRandom(random).idx; 76 //} 58 77 59 private IBanditPolicyActionInfo GetStateInfo(string state) { 60 var s = CalcState(state); 61 IBanditPolicyActionInfo info; 62 if (!stateInfo.TryGetValue(s, out info)) { 63 info = banditPolicy.CreateActionInfo(); 64 stateInfo[s] = info; 65 } 66 return info; 78 var w = from q in qs 79 select Math.Exp(beta * q); 80 81 var bestAction = Enumerable.Range(0, qs.Count()).SampleProportional(random, w); 82 Debug.Assert(bestAction >= 0); 83 return bestAction; 67 84 } 68 85 69 public void UpdateReward(IEnumerable<string> stateTrajectory, double reward) { 70 foreach (var state in stateTrajectory.Reverse()) { 71 GetStateInfo(state).UpdateReward(reward); 86 private int SelectEpsGreedy(Random random, IEnumerable<double> qs, double eps = 0.2) { 87 if (random.NextDouble() >= eps) { // eps == 0 should be equivalent to pure exploitation, eps == 1 is pure exploration 88 // select best 89 var bestActions = new List<int>(); 90 double bestQ = double.NegativeInfinity; 72 91 73 // actually only the last state can be terminal 74 if (problem.Grammar.IsTerminal(state)) { 75 MarkAsDone(state); 92 int aIdx = -1; 93 foreach (var q in qs) { 94 aIdx++; 95 96 if (q > bestQ) { 97 bestActions.Clear(); 98 bestActions.Add(aIdx); 99 bestQ = q; 100 } else if (q.IsAlmost(bestQ)) { 101 bestActions.Add(aIdx); 102 } 76 103 } 104 Debug.Assert(bestActions.Any()); 105 return bestActions.SelectRandom(random); 106 } else { 107 // select random 108 return SelectRandom(random, qs); 77 109 } 78 110 } 79 111 112 private int SelectRandom(Random random, IEnumerable<double> qs) { 113 return qs 114 .Select((aInfo, idx) => Tuple.Create(aInfo, idx)) 115 .SelectRandom(random).Item2; 116 } 117 118 119 public void UpdateReward(IEnumerable<string> chainTrajectory, double reward) { 120 const double gamma = 0.95; 121 const double minAlpha = 0.01; 122 var reverseChains = chainTrajectory.Reverse(); 123 var terminalChain = reverseChains.First(); 124 125 var terminalState = CalcState(terminalChain); 126 T[terminalState] = GetTries(terminalChain) + 1; 127 double alpha = Math.Max(1.0 / GetTries(terminalChain), minAlpha); 128 Q[terminalState] = (1 - alpha) * GetValue(terminalChain) + alpha * reward; 129 130 foreach (var chain in reverseChains.Skip(1)) { 131 132 var maxNextQ = followStates[chain] 133 //.Where(s=>!Done(s)) 134 .Select(GetValue).Max(); 135 T[CalcState(chain)] = GetTries(chain) + 1; 136 137 alpha = Math.Max(1.0 / GetTries(chain), minAlpha); 138 Q[CalcState(chain)] = (1 - alpha) * GetValue(chain) + gamma * alpha * maxNextQ; // direct contribution is zero 139 } 140 if (problem.Grammar.IsTerminal(terminalChain)) MarkAsDone(terminalChain); 141 } 142 80 143 public void Reset() { 81 stateInfo.Clear();144 Q.Clear(); 82 145 done.Clear(); 146 followStates.Clear(); 83 147 } 84 148 … … 95 159 public int GetTries(string state) { 96 160 var s = CalcState(state); 97 if ( stateInfo.ContainsKey(s)) return stateInfo[s].Tries;161 if (T.ContainsKey(s)) return T[s]; 98 162 else return 0; 99 163 } 100 164 101 public double GetValue(string state) {102 var s = CalcState( state);103 if ( stateInfo.ContainsKey(s)) return stateInfo[s].Value;165 public double GetValue(string chain) { 166 var s = CalcState(chain); 167 if (Q.ContainsKey(s)) return Q[s]; 104 168 else return 0.0; // TODO: check alternatives 105 169 } … … 111 175 return f.First().Id; 112 176 } 177 178 public void PrintStats() { 179 Console.WriteLine(Q.Values.Max()); 180 var topTries = Q.Keys.OrderByDescending(key => T[key]).Take(50); 181 var topQs = Q.Keys.Where(key=>key.Contains(",")).OrderByDescending(key => Q[key]).Take(50); 182 foreach (var t in topTries.Zip(topQs, Tuple.Create)) { 183 var id1 = t.Item1; 184 var id2 = t.Item2; 185 Console.WriteLine("{0,30} {1,6} {2:N4} {3,30} {4,6} {5:N4}", id1, T[id1], Q[id1], id2, T[id2], Q[id2]); 186 } 187 188 } 113 189 } 114 190 }
Note: See TracChangeset
for help on using the changeset viewer.