Changeset 11747 for branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/MctsContextualSampler.cs
- Timestamp:
- 01/12/15 21:23:01 (9 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/MctsContextualSampler.cs
r11745 r11747 15 15 public int randomTries; 16 16 public int tries; 17 public List<TreeNode> parents; 17 18 public TreeNode[] children; 18 19 public bool done = false; … … 21 22 this.ident = id; 22 23 this.alt = alt; 24 this.parents = new List<TreeNode>(); 23 25 } 24 26 … … 28 30 } 29 31 32 private Dictionary<string, TreeNode> treeNodes; 33 private TreeNode GetTreeNode(string id, ReadonlySequence alt) { 34 TreeNode n; 35 var canonicalId = problem.CanonicalRepresentation(id); 36 if (!treeNodes.TryGetValue(canonicalId, out n)) { 37 n = new TreeNode(canonicalId, alt); 38 tries.TryGetValue(canonicalId, out n.tries); 39 treeNodes[canonicalId] = n; 40 } 41 return n; 42 } 30 43 31 44 public event Action<string, double> FoundNewBestSolution; … … 51 64 this.v = new Dictionary<string, double>(1000000); 52 65 this.tries = new Dictionary<string, int>(1000000); 66 treeNodes = new Dictionary<string, TreeNode>(); 53 67 } 54 68 … … 57 71 InitPolicies(problem.Grammar); 58 72 for (int i = 0; !rootNode.done && i < maxIterations; i++) { 59 var sentence = SampleSentence(problem.Grammar).ToString(); 60 var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen); 61 Debug.Assert(quality >= 0 && quality <= 1.0); 62 DistributeReward(quality); 63 64 RaiseSolutionEvaluated(sentence, quality); 65 66 if (quality > bestQuality) { 67 bestQuality = quality; 68 RaiseFoundNewBestSolution(sentence, quality); 73 bool success; 74 var sentence = SampleSentence(problem.Grammar, out success).ToString(); 75 if (success) { 76 var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen); 77 Debug.Assert(quality >= 0 && quality <= 1.0); 78 DistributeReward(quality); 79 80 RaiseSolutionEvaluated(sentence, quality); 81 82 if (quality > bestQuality) { 83 bestQuality = quality; 84 RaiseFoundNewBestSolution(sentence, quality); 85 } 69 86 } 70 87 } … … 78 95 Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}, rootQ {3:F3}, bestQ {4:F3}", treeDepth, treeSize, n.tries, V(n), bestQuality); 79 96 while (n.children != null) { 80 Console.WriteLine("{0 }", n.ident);81 double maxVForRow = n.children.Select(ch => V(ch)).Max();97 Console.WriteLine("{0,-30}", n.ident); 98 double maxVForRow = n.children.Select(ch => Math.Min(1.0, Math.Max(0.0, V(ch)))).Max(); 82 99 if (maxVForRow == 0) maxVForRow = 1.0; 83 100 84 101 for (int i = 0; i < n.children.Length; i++) { 85 102 var ch = n.children[i]; 86 Console.ForegroundColor = ConsoleEx.ColorForValue( V(ch) / maxVForRow);103 Console.ForegroundColor = ConsoleEx.ColorForValue(Math.Min(1.0, V(ch)) / maxVForRow); 87 104 Console.Write("{0,5}", ch.alt); 88 105 } … … 90 107 for (int i = 0; i < n.children.Length; i++) { 91 108 var ch = n.children[i]; 92 Console.ForegroundColor = ConsoleEx.ColorForValue( V(ch) / maxVForRow);93 Console.Write("{0,5:F2}", V(ch) * 10);109 Console.ForegroundColor = ConsoleEx.ColorForValue(Math.Min(1.0, V(ch)) / maxVForRow); 110 Console.Write("{0,5:F2}", Math.Min(1.0, V(ch)) * 10); 94 111 } 95 112 Console.WriteLine(); 96 113 for (int i = 0; i < n.children.Length; i++) { 97 114 var ch = n.children[i]; 98 Console.ForegroundColor = ConsoleEx.ColorForValue( V(ch) / maxVForRow);115 Console.ForegroundColor = ConsoleEx.ColorForValue(Math.Min(1.0, V(ch)) / maxVForRow); 99 116 Console.Write("{0,5}", ch.done ? "X" : ch.tries.ToString()); 100 117 } … … 102 119 Console.WriteLine(); 103 120 //n.policy.PrintStats(); 104 n = n.children.Where(ch => !ch.done).OrderByDescending(c => V(c)).First();121 n = n.children.Where(ch => !ch.done).OrderByDescending(c => c.tries).First(); 105 122 } 106 123 } … … 112 129 this.tries.Clear(); 113 130 114 rootNode = newTreeNode(grammar.SentenceSymbol.ToString(), new ReadonlySequence("$"));131 rootNode = GetTreeNode(grammar.SentenceSymbol.ToString(), new ReadonlySequence("$")); 115 132 treeDepth = 0; 116 133 treeSize = 0; 117 134 } 118 135 119 private Sequence SampleSentence(IGrammar grammar ) {136 private Sequence SampleSentence(IGrammar grammar, out bool success) { 120 137 updateChain.Clear(); 121 138 //var startPhrase = new Sequence("a*b+c*d+e*f+E"); 122 139 var startPhrase = new Sequence(grammar.SentenceSymbol); 123 return CompleteSentence(grammar, startPhrase );124 } 125 126 private Sequence CompleteSentence(IGrammar g, Sequence phrase ) {140 return CompleteSentence(grammar, startPhrase, out success); 141 } 142 143 private Sequence CompleteSentence(IGrammar g, Sequence phrase, out bool success) { 127 144 if (phrase.Length > maxLen) throw new ArgumentException(); 128 145 if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException(); … … 136 153 n.randomTries++; 137 154 treeDepth = Math.Max(treeDepth, curDepth); 155 success = true; 138 156 return g.CompleteSentenceRandomly(random, phrase, maxLen); 139 157 } else { … … 153 171 newPhrase.ReplaceAt(newPhrase.FirstNonTerminalIndex, 1, alt); 154 172 if (!newPhrase.IsTerminal) newPhrase = newPhrase.Subsequence(0, newPhrase.FirstNonTerminalIndex + 1); 155 n.children[i++] = new TreeNode(newPhrase.ToString(), new ReadonlySequence(alt)); 173 var treeNode = GetTreeNode(newPhrase.ToString(), new ReadonlySequence(alt)); 174 treeNode.parents.Add(n); 175 n.children[i++] = treeNode; 156 176 } 157 177 treeSize += n.children.Length; 178 UpdateDone(n); 179 180 // it could happend that we already finished all variations starting from the branch 181 // stop 182 if (n.done) { 183 success = false; 184 return phrase; 185 } 158 186 } 187 //int selectedAltIdx = SelectRandom(random, n.children); 188 159 189 // => select using eps-greedy 160 190 int selectedAltIdx = SelectEpsGreedy(random, n.children); … … 167 197 168 198 curDepth++; 199 169 200 170 201 // prepare for next iteration 171 202 parent = n; 172 203 n = n.children[selectedAltIdx]; 204 //UpdateTD(parent, n, 0.0); 173 205 } 174 206 } // while … … 181 213 182 214 treeDepth = Math.Max(treeDepth, curDepth); 215 success = true; 183 216 return phrase; 184 217 } 185 218 219 220 //private void UpdateTD(TreeNode parent, TreeNode child, double reward) { 221 // double alpha = 1.0; 222 // var vParent = V(parent); 223 // var vChild = V(child); 224 // if (double.IsInfinity(vParent)) vParent = 0.0; 225 // if (double.IsInfinity(vChild)) vChild = 0.0; 226 // UpdateV(parent, (alpha * (reward + vChild - vParent))); 227 //} 228 186 229 private void DistributeReward(double reward) { 230 187 231 // iterate in reverse order (bottom up) 188 updateChain.Reverse(); 189 232 //updateChain.Reverse(); 233 UpdateDone(updateChain.Last().Item1); 234 //UpdateTD(updateChain.Last().Item2, updateChain.Last().Item1, reward); 235 //return; 236 237 BackPropReward(updateChain.Last().Item1, reward); 238 /* 190 239 foreach (var e in updateChain) { 191 240 var node = e.Item1; 192 var parent = e.Item2;241 //var parent = e.Item2; 193 242 node.tries++; 194 if (node.children != null && node.children.All(c => c.done)) {195 node.done = true;196 }243 //if (node.children != null && node.children.All(c => c.done)) { 244 // node.done = true; 245 //} 197 246 UpdateV(node, reward); 198 247 199 248 // the reward for the parent is either the just recieved reward or the value of the best action so far 200 double value = 0.0;201 if (parent != null) {202 var doneChilds = parent.children.Where(ch => ch.done);203 if (doneChilds.Any()) value = doneChilds.Select(ch => V(ch)).Max();204 }249 //double value = 0.0; 250 //if (parent != null) { 251 // var doneChilds = parent.children.Where(ch => ch.done); 252 // if (doneChilds.Any()) value = doneChilds.Select(ch => V(ch)).Max(); 253 //} 205 254 //if (value > reward) reward = value; 206 } 207 } 255 }*/ 256 } 257 258 private void BackPropReward(TreeNode n, double reward) { 259 n.tries++; 260 UpdateV(n, reward); 261 foreach (var p in n.parents) BackPropReward(p, reward); 262 } 263 264 private void UpdateDone(TreeNode n) { 265 if (!n.done && n.children != null && n.children.All(c => c.done)) n.done = true; 266 if (n.done) foreach (var p in n.parents) UpdateDone(p); 267 } 268 208 269 209 270 private Dictionary<string, double> v; … … 219 280 tries.Add(canonicalStr, 1); 220 281 } else { 221 v[canonicalStr] = stateV + 0.005 * (reward - stateV);222 //v[canonicalStr] = stateV + (1.0 / tries[canonicalStr]) * (reward - stateV);282 //v[canonicalStr] = stateV + 0.005 * (reward - stateV); 283 v[canonicalStr] = stateV + (1.0 / tries[canonicalStr]) * (reward - stateV); 223 284 tries[canonicalStr]++; 224 285 } … … 229 290 //var canonicalStr = n.ident; 230 291 double stateV; 231 292 if (!tries.ContainsKey(canonicalStr)) return double.PositiveInfinity; 232 293 if (!v.TryGetValue(canonicalStr, out stateV)) { 233 294 return 0.0; … … 237 298 } 238 299 300 private int SelectRandom(Random random, TreeNode[] children) { 301 return children.Select((ch, i) => Tuple.Create(ch, i)).Where(p => !p.Item1.done).SelectRandom(random).Item2; 302 } 303 239 304 private int SelectEpsGreedy(Random random, TreeNode[] children) { 240 305 if (random.NextDouble() < 0.2) { 241 242 return children.Select((ch, i) => Tuple.Create(ch, i)).Where(p => !p.Item1.done).SelectRandom(random).Item2; 306 return SelectRandom(random, children); 243 307 } else { 244 308 var bestQ = double.NegativeInfinity;
Note: See TracChangeset
for help on using the changeset viewer.