Changeset 12298 for branches/HeuristicLab.Problems.GrammaticalOptimization-gkr/HeuristicLab.Algorithms.GrammaticalOptimization/SequentialDecisionPolicies
- Timestamp:
- 04/10/15 16:12:08 (10 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/HeuristicLab.Problems.GrammaticalOptimization-gkr/HeuristicLab.Algorithms.GrammaticalOptimization/SequentialDecisionPolicies/GenericPolicy.cs
r12295 r12298 48 48 foreach (var afterState in afterStates) { 49 49 if (!Done(afterState)) { 50 if (GetTries(afterState) == 0) 51 activeAfterStates[idx] = double.PositiveInfinity; 52 else 53 activeAfterStates[idx] = GetValue(afterState); 50 activeAfterStates[idx] = CalculateValue(afterState); 54 51 actionIndexMap[idx] = originalIdx; 55 52 idx++; … … 58 55 } 59 56 57 60 58 //var eps = Math.Max(500.0 / (GetTries(curState) + 1), 0.01); 61 59 //var eps = 10.0 / Math.Sqrt(GetTries(curState) + 1); 62 var eps = 0. 2;60 var eps = 0.01; 63 61 selectedStateIdx = actionIndexMap[SelectEpsGreedy(random, activeAfterStates.Take(idx), eps)]; 64 62 63 UpdateValue(curState, afterStates); 64 65 65 return true; 66 66 } 67 68 private double CalculateValue(string chain) { 69 var features = problem.GetFeatures(chain); 70 var sum = 0.0; 71 foreach (var f in features) { 72 // if (GetTries(f.Id) == 0) 73 // sum = 0.0; 74 // else 75 sum += GetValue(f.Id) * f.Value; 76 } 77 return sum; 78 } 79 80 private void UpdateValue(string curChain, IEnumerable<string> alternatives) { 81 const double gamma = 1; 82 const double alpha = 0.01; 83 var maxNextQ = alternatives 84 .Select(CalculateValue).Max(); 85 86 var delta = gamma * maxNextQ - CalculateValue(curChain); 87 88 foreach (var f in problem.GetFeatures(curChain)) { 89 90 Q[f.Id] = GetValue(f.Id) + alpha * delta * f.Value; 91 } 92 } 93 94 private void UpdateLastValue(string terminalChain, double reward) { 95 const double alpha = 0.01; 96 var delta = reward - CalculateValue(terminalChain); 97 foreach (var f in problem.GetFeatures(terminalChain)) { 98 Q[f.Id] = GetValue(f.Id) + alpha * delta * f.Value; 99 } 100 } 101 67 102 68 103 private int SelectBoltzmann(Random random, IEnumerable<double> qs, double beta = 10) { … … 121 156 122 157 public void UpdateReward(IEnumerable<string> chainTrajectory, double reward) { 123 const double gamma = 0.95; 124 const double minAlpha = 0.01; 125 var reverseChains = chainTrajectory.Reverse(); 126 var terminalChain = reverseChains.First(); 127 128 var terminalState = CalcState(terminalChain); 129 T[terminalState] = GetTries(terminalChain) + 1; 130 double alpha = Math.Max(1.0 / GetTries(terminalChain), minAlpha); 131 Q[terminalState] = (1 - alpha) * GetValue(terminalChain) + alpha * reward; 132 133 foreach (var chain in reverseChains.Skip(1)) { 134 135 var maxNextQ = followStates[chain] 136 //.Where(s=>!Done(s)) 137 .Select(GetValue).Max(); 138 T[CalcState(chain)] = GetTries(chain) + 1; 139 140 alpha = Math.Max(1.0 / GetTries(chain), minAlpha); 141 Q[CalcState(chain)] = (1 - alpha) * GetValue(chain) + gamma * alpha * maxNextQ; // direct contribution is zero 142 } 158 // // only updates the last chain because we already update values after each step 159 // var reverseChains = chainTrajectory.Reverse(); 160 // var terminalChain = reverseChains.First(); 161 // 162 // UpdateValue(terminalChain, reward); 163 // 164 // foreach (var chain in reverseChains.Skip(1)) { 165 // 166 // var maxNextQ = followStates[chain] 167 // //.Where(s=>!Done(s)) 168 // .Select(GetValue).Max(); 169 // 170 // UpdateValue(chain, maxNextQ); 171 // } 172 var terminalChain = chainTrajectory.Last(); 173 UpdateLastValue(terminalChain, reward); 143 174 if (problem.Grammar.IsTerminal(terminalChain)) MarkAsDone(terminalChain); 144 175 } 176 145 177 146 178 public void Reset() { 147 179 Q.Clear(); 180 T.Clear(); 148 181 done.Clear(); 149 182 followStates.Clear(); … … 160 193 161 194 162 public int GetTries(string state) { 163 var s = CalcState(state); 164 if (T.ContainsKey(s)) return T[s]; 195 public int GetTries(string fId) { 196 if (T.ContainsKey(fId)) return T[fId]; 165 197 else return 0; 166 198 } 167 199 168 public double GetValue(string chain) {169 var s = CalcState(chain);170 if (Q.ContainsKey( s)) return Q[s];200 public double GetValue(string fId) { 201 // var s = CalcState(chain); 202 if (Q.ContainsKey(fId)) return Q[fId]; 171 203 else return 0.0; // TODO: check alternatives 172 204 } 173 205 174 private string CalcState(string chain) {175 var f = problem.GetFeatures(chain);176 // this policy only works for problems that return exactly one feature (the 'state')177 if (f.Skip(1).Any()) throw new ArgumentException();178 return f.First().Id;179 }206 // private string CalcState(string chain) { 207 // var f = problem.GetFeatures(chain); 208 // // this policy only works for problems that return exactly one feature (the 'state') 209 // if (f.Skip(1).Any()) throw new ArgumentException(); 210 // return f.First().Id; 211 // } 180 212 181 213 public void PrintStats() { 182 214 Console.WriteLine(Q.Values.Max()); 183 var topTries = Q.Keys.OrderByDescending(key => T[key]).Take(50); 184 var topQs = Q.Keys.Where(key => key.Contains(",")).OrderByDescending(key => Q[key]).Take(50); 185 foreach (var t in topTries.Zip(topQs, Tuple.Create)) { 186 var id1 = t.Item1; 187 var id2 = t.Item2; 188 Console.WriteLine("{0,30} {1,6} {2:N4} {3,30} {4,6} {5:N4}", id1, T[id1], Q[id1], id2, T[id2], Q[id2]); 189 } 190 215 // var topTries = Q.Keys.OrderByDescending(key => T[key]).Take(50); 216 // var topQs = Q.Keys/*.Where(key => key.Contains("E"))*/.OrderByDescending(key => Q[key]).Take(50); 217 // foreach (var t in topTries.Zip(topQs, Tuple.Create)) { 218 // var id1 = t.Item1; 219 // var id2 = t.Item2; 220 // Console.WriteLine("{0,30} {1,6} {2:N4} {3,30} {4,6} {5:N4}", id1, T[id1], Q[id1], id2, T[id2], Q[id2]); 221 // } 222 223 foreach (var option in new String[] 224 { 225 "a*b", "c*d", "a*b+c*d", "e*f", "a*b+c*d+e*f", 226 "a*b+a*b", "c*d+c*d", 227 "a*a", "a*b","a*c","a*d","a*e","a*f","a*g","a*h","a*i","a*j", 228 "a*b","c*d","e*f","a*c","a*f","a*i","a*i*g","c*f","c*f*j", 229 "b+c","a+c","b+d","a+d", 230 "a*b+c*d+e*f", "a*b+c*d+e*f+a", "a*b+c*d+e*f+b", "a*b+c*d+e*f+c", "a*b+c*d+e*f+d","a*b+c*d+e*f+e", "a*b+c*d+e*f+f", "a*b+c*d+e*f+g", "a*b+c*d+e*f+h", "a*b+c*d+e*f+i", "a*b+c*d+e*f+j", 231 "a*b+c*d+e*f+a*g*i+c*j*f" 232 }) { 233 Console.WriteLine("{0,-10} {1:N5}", option, CalculateValue(option)); 234 } 235 236 // var topQs = Q.Keys/*.Where(key => key.Contains("E"))*/.OrderByDescending(key => Math.Abs(Q[key])).Take(10); 237 // foreach (var t in topQs) { 238 // Console.WriteLine("{0,30} {1:N4}", t, Q[t]); 239 // } 191 240 } 192 241 }
Note: See TracChangeset
for help on using the changeset viewer.