[11742] | 1 | using System;
|
---|
| 2 | using System.Collections.Generic;
|
---|
| 3 | using System.Diagnostics;
|
---|
| 4 | using System.Linq;
|
---|
| 5 | using System.Text;
|
---|
| 6 | using HeuristicLab.Algorithms.Bandits;
|
---|
[11745] | 7 | using HeuristicLab.Common;
|
---|
[11742] | 8 | using HeuristicLab.Problems.GrammaticalOptimization;
|
---|
| 9 |
|
---|
| 10 | namespace HeuristicLab.Algorithms.GrammaticalOptimization {
|
---|
| 11 | public class MctsContextualSampler {
|
---|
| 12 | private class TreeNode {
|
---|
[11745] | 13 | public string ident;
|
---|
| 14 | public ReadonlySequence alt;
|
---|
[11742] | 15 | public int randomTries;
|
---|
[11745] | 16 | public int tries;
|
---|
[11747] | 17 | public List<TreeNode> parents;
|
---|
[11742] | 18 | public TreeNode[] children;
|
---|
[11745] | 19 | public bool done = false;
|
---|
[11742] | 20 |
|
---|
[11745] | 21 | public TreeNode(string id, ReadonlySequence alt) {
|
---|
| 22 | this.ident = id;
|
---|
[11742] | 23 | this.alt = alt;
|
---|
[11747] | 24 | this.parents = new List<TreeNode>();
|
---|
[11742] | 25 | }
|
---|
| 26 |
|
---|
| 27 | public override string ToString() {
|
---|
[11745] | 28 | return string.Format("Node({0} tries: {1}, done: {2})", ident, tries, done);
|
---|
[11742] | 29 | }
|
---|
| 30 | }
|
---|
| 31 |
|
---|
[11747] | 32 | private Dictionary<string, TreeNode> treeNodes;
|
---|
| 33 | private TreeNode GetTreeNode(string id, ReadonlySequence alt) {
|
---|
| 34 | TreeNode n;
|
---|
| 35 | var canonicalId = problem.CanonicalRepresentation(id);
|
---|
| 36 | if (!treeNodes.TryGetValue(canonicalId, out n)) {
|
---|
| 37 | n = new TreeNode(canonicalId, alt);
|
---|
| 38 | tries.TryGetValue(canonicalId, out n.tries);
|
---|
| 39 | treeNodes[canonicalId] = n;
|
---|
| 40 | }
|
---|
| 41 | return n;
|
---|
| 42 | }
|
---|
[11742] | 43 |
|
---|
| 44 | public event Action<string, double> FoundNewBestSolution;
|
---|
| 45 | public event Action<string, double> SolutionEvaluated;
|
---|
| 46 |
|
---|
| 47 | private readonly int maxLen;
|
---|
| 48 | private readonly IProblem problem;
|
---|
| 49 | private readonly Random random;
|
---|
| 50 | private readonly int randomTries;
|
---|
| 51 |
|
---|
[11745] | 52 | private List<Tuple<TreeNode, TreeNode>> updateChain;
|
---|
[11742] | 53 | private TreeNode rootNode;
|
---|
| 54 |
|
---|
| 55 | public int treeDepth;
|
---|
| 56 | public int treeSize;
|
---|
[11745] | 57 | private double bestQuality;
|
---|
[11742] | 58 |
|
---|
[11745] | 59 | public MctsContextualSampler(IProblem problem, int maxLen, Random random, int randomTries) {
|
---|
[11742] | 60 | this.maxLen = maxLen;
|
---|
| 61 | this.problem = problem;
|
---|
| 62 | this.random = random;
|
---|
| 63 | this.randomTries = randomTries;
|
---|
[11745] | 64 | this.v = new Dictionary<string, double>(1000000);
|
---|
| 65 | this.tries = new Dictionary<string, int>(1000000);
|
---|
[11747] | 66 | treeNodes = new Dictionary<string, TreeNode>();
|
---|
[11742] | 67 | }
|
---|
| 68 |
|
---|
| 69 | public void Run(int maxIterations) {
|
---|
[11745] | 70 | bestQuality = double.MinValue;
|
---|
[11742] | 71 | InitPolicies(problem.Grammar);
|
---|
[11745] | 72 | for (int i = 0; !rootNode.done && i < maxIterations; i++) {
|
---|
[11747] | 73 | bool success;
|
---|
| 74 | var sentence = SampleSentence(problem.Grammar, out success).ToString();
|
---|
| 75 | if (success) {
|
---|
| 76 | var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen);
|
---|
| 77 | Debug.Assert(quality >= 0 && quality <= 1.0);
|
---|
| 78 | DistributeReward(quality);
|
---|
[11742] | 79 |
|
---|
[11747] | 80 | RaiseSolutionEvaluated(sentence, quality);
|
---|
[11742] | 81 |
|
---|
[11747] | 82 | if (quality > bestQuality) {
|
---|
| 83 | bestQuality = quality;
|
---|
| 84 | RaiseFoundNewBestSolution(sentence, quality);
|
---|
| 85 | }
|
---|
[11742] | 86 | }
|
---|
| 87 | }
|
---|
| 88 |
|
---|
| 89 | // clean up
|
---|
| 90 | InitPolicies(problem.Grammar); GC.Collect();
|
---|
| 91 | }
|
---|
| 92 |
|
---|
| 93 | public void PrintStats() {
|
---|
| 94 | var n = rootNode;
|
---|
[11745] | 95 | Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}, rootQ {3:F3}, bestQ {4:F3}", treeDepth, treeSize, n.tries, V(n), bestQuality);
|
---|
[11742] | 96 | while (n.children != null) {
|
---|
[11747] | 97 | Console.WriteLine("{0,-30}", n.ident);
|
---|
| 98 | double maxVForRow = n.children.Select(ch => Math.Min(1.0, Math.Max(0.0, V(ch)))).Max();
|
---|
[11745] | 99 | if (maxVForRow == 0) maxVForRow = 1.0;
|
---|
| 100 |
|
---|
| 101 | for (int i = 0; i < n.children.Length; i++) {
|
---|
| 102 | var ch = n.children[i];
|
---|
[11747] | 103 | Console.ForegroundColor = ConsoleEx.ColorForValue(Math.Min(1.0, V(ch)) / maxVForRow);
|
---|
[11745] | 104 | Console.Write("{0,5}", ch.alt);
|
---|
| 105 | }
|
---|
[11742] | 106 | Console.WriteLine();
|
---|
[11745] | 107 | for (int i = 0; i < n.children.Length; i++) {
|
---|
| 108 | var ch = n.children[i];
|
---|
[11747] | 109 | Console.ForegroundColor = ConsoleEx.ColorForValue(Math.Min(1.0, V(ch)) / maxVForRow);
|
---|
| 110 | Console.Write("{0,5:F2}", Math.Min(1.0, V(ch)) * 10);
|
---|
[11745] | 111 | }
|
---|
| 112 | Console.WriteLine();
|
---|
| 113 | for (int i = 0; i < n.children.Length; i++) {
|
---|
| 114 | var ch = n.children[i];
|
---|
[11747] | 115 | Console.ForegroundColor = ConsoleEx.ColorForValue(Math.Min(1.0, V(ch)) / maxVForRow);
|
---|
[11745] | 116 | Console.Write("{0,5}", ch.done ? "X" : ch.tries.ToString());
|
---|
| 117 | }
|
---|
| 118 | Console.ForegroundColor = ConsoleColor.White;
|
---|
| 119 | Console.WriteLine();
|
---|
[11742] | 120 | //n.policy.PrintStats();
|
---|
[11747] | 121 | n = n.children.Where(ch => !ch.done).OrderByDescending(c => c.tries).First();
|
---|
[11742] | 122 | }
|
---|
| 123 | }
|
---|
| 124 |
|
---|
[11745] | 125 |
|
---|
[11742] | 126 | private void InitPolicies(IGrammar grammar) {
|
---|
[11745] | 127 | this.updateChain = new List<Tuple<TreeNode, TreeNode>>();
|
---|
| 128 | this.v.Clear();
|
---|
| 129 | this.tries.Clear();
|
---|
[11742] | 130 |
|
---|
[11747] | 131 | rootNode = GetTreeNode(grammar.SentenceSymbol.ToString(), new ReadonlySequence("$"));
|
---|
[11742] | 132 | treeDepth = 0;
|
---|
| 133 | treeSize = 0;
|
---|
| 134 | }
|
---|
| 135 |
|
---|
[11747] | 136 | private Sequence SampleSentence(IGrammar grammar, out bool success) {
|
---|
[11742] | 137 | updateChain.Clear();
|
---|
[11745] | 138 | //var startPhrase = new Sequence("a*b+c*d+e*f+E");
|
---|
| 139 | var startPhrase = new Sequence(grammar.SentenceSymbol);
|
---|
[11747] | 140 | return CompleteSentence(grammar, startPhrase, out success);
|
---|
[11742] | 141 | }
|
---|
| 142 |
|
---|
[11747] | 143 | private Sequence CompleteSentence(IGrammar g, Sequence phrase, out bool success) {
|
---|
[11742] | 144 | if (phrase.Length > maxLen) throw new ArgumentException();
|
---|
| 145 | if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException();
|
---|
| 146 | TreeNode parent = null;
|
---|
| 147 | TreeNode n = rootNode;
|
---|
| 148 | var curDepth = 0;
|
---|
[11745] | 149 | while (!phrase.IsTerminal) {
|
---|
| 150 | updateChain.Add(Tuple.Create(n, parent));
|
---|
[11742] | 151 |
|
---|
| 152 | if (n.randomTries < randomTries) {
|
---|
| 153 | n.randomTries++;
|
---|
| 154 | treeDepth = Math.Max(treeDepth, curDepth);
|
---|
[11747] | 155 | success = true;
|
---|
[11742] | 156 | return g.CompleteSentenceRandomly(random, phrase, maxLen);
|
---|
| 157 | } else {
|
---|
| 158 | char nt = phrase.FirstNonTerminal;
|
---|
| 159 |
|
---|
| 160 | int maxLenOfReplacement = maxLen - (phrase.Length - 1); // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2
|
---|
| 161 | Debug.Assert(maxLenOfReplacement > 0);
|
---|
| 162 |
|
---|
| 163 | var alts = g.GetAlternatives(nt).Where(alt => g.MinPhraseLength(alt) <= maxLenOfReplacement);
|
---|
| 164 |
|
---|
| 165 | if (n.randomTries == randomTries && n.children == null) {
|
---|
[11745] | 166 | // create a new node for each alternative
|
---|
[11742] | 167 | n.children = new TreeNode[alts.Count()];
|
---|
[11745] | 168 | var i = 0;
|
---|
[11742] | 169 | foreach (var alt in alts) {
|
---|
| 170 | var newPhrase = new Sequence(phrase);
|
---|
[11745] | 171 | newPhrase.ReplaceAt(newPhrase.FirstNonTerminalIndex, 1, alt);
|
---|
| 172 | if (!newPhrase.IsTerminal) newPhrase = newPhrase.Subsequence(0, newPhrase.FirstNonTerminalIndex + 1);
|
---|
[11747] | 173 | var treeNode = GetTreeNode(newPhrase.ToString(), new ReadonlySequence(alt));
|
---|
| 174 | treeNode.parents.Add(n);
|
---|
| 175 | n.children[i++] = treeNode;
|
---|
[11742] | 176 | }
|
---|
| 177 | treeSize += n.children.Length;
|
---|
[11747] | 178 | UpdateDone(n);
|
---|
| 179 |
|
---|
| 180 | // it could happend that we already finished all variations starting from the branch
|
---|
| 181 | // stop
|
---|
| 182 | if (n.done) {
|
---|
| 183 | success = false;
|
---|
| 184 | return phrase;
|
---|
| 185 | }
|
---|
[11742] | 186 | }
|
---|
[11747] | 187 | //int selectedAltIdx = SelectRandom(random, n.children);
|
---|
| 188 |
|
---|
[11745] | 189 | // => select using eps-greedy
|
---|
| 190 | int selectedAltIdx = SelectEpsGreedy(random, n.children);
|
---|
[11742] | 191 |
|
---|
[11745] | 192 | //int selectedAltIdx = SelectActionUCB1(random, n.children);
|
---|
| 193 | Sequence selectedAlt = alts.ElementAt(selectedAltIdx);
|
---|
[11742] | 194 |
|
---|
| 195 | // replace nt with alt
|
---|
| 196 | phrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, selectedAlt);
|
---|
| 197 |
|
---|
| 198 | curDepth++;
|
---|
| 199 |
|
---|
[11747] | 200 |
|
---|
[11742] | 201 | // prepare for next iteration
|
---|
| 202 | parent = n;
|
---|
[11745] | 203 | n = n.children[selectedAltIdx];
|
---|
[11747] | 204 | //UpdateTD(parent, n, 0.0);
|
---|
[11742] | 205 | }
|
---|
| 206 | } // while
|
---|
| 207 |
|
---|
[11745] | 208 | updateChain.Add(Tuple.Create(n, parent));
|
---|
[11742] | 209 |
|
---|
[11745] | 210 | // the last node is a leaf node (sentence is done), so we never need to visit this node again
|
---|
| 211 | n.done = true;
|
---|
[11742] | 212 |
|
---|
[11745] | 213 |
|
---|
[11742] | 214 | treeDepth = Math.Max(treeDepth, curDepth);
|
---|
[11747] | 215 | success = true;
|
---|
[11742] | 216 | return phrase;
|
---|
| 217 | }
|
---|
| 218 |
|
---|
[11747] | 219 |
|
---|
| 220 | //private void UpdateTD(TreeNode parent, TreeNode child, double reward) {
|
---|
| 221 | // double alpha = 1.0;
|
---|
| 222 | // var vParent = V(parent);
|
---|
| 223 | // var vChild = V(child);
|
---|
| 224 | // if (double.IsInfinity(vParent)) vParent = 0.0;
|
---|
| 225 | // if (double.IsInfinity(vChild)) vChild = 0.0;
|
---|
| 226 | // UpdateV(parent, (alpha * (reward + vChild - vParent)));
|
---|
| 227 | //}
|
---|
| 228 |
|
---|
[11742] | 229 | private void DistributeReward(double reward) {
|
---|
[11747] | 230 |
|
---|
[11742] | 231 | // iterate in reverse order (bottom up)
|
---|
[11747] | 232 | //updateChain.Reverse();
|
---|
| 233 | UpdateDone(updateChain.Last().Item1);
|
---|
| 234 | //UpdateTD(updateChain.Last().Item2, updateChain.Last().Item1, reward);
|
---|
| 235 | //return;
|
---|
[11742] | 236 |
|
---|
[11747] | 237 | BackPropReward(updateChain.Last().Item1, reward);
|
---|
| 238 | /*
|
---|
[11742] | 239 | foreach (var e in updateChain) {
|
---|
[11745] | 240 | var node = e.Item1;
|
---|
[11747] | 241 | //var parent = e.Item2;
|
---|
[11745] | 242 | node.tries++;
|
---|
[11747] | 243 | //if (node.children != null && node.children.All(c => c.done)) {
|
---|
| 244 | // node.done = true;
|
---|
| 245 | //}
|
---|
[11745] | 246 | UpdateV(node, reward);
|
---|
| 247 |
|
---|
| 248 | // the reward for the parent is either the just recieved reward or the value of the best action so far
|
---|
[11747] | 249 | //double value = 0.0;
|
---|
| 250 | //if (parent != null) {
|
---|
| 251 | // var doneChilds = parent.children.Where(ch => ch.done);
|
---|
| 252 | // if (doneChilds.Any()) value = doneChilds.Select(ch => V(ch)).Max();
|
---|
| 253 | //}
|
---|
[11745] | 254 | //if (value > reward) reward = value;
|
---|
[11747] | 255 | }*/
|
---|
[11742] | 256 | }
|
---|
| 257 |
|
---|
[11747] | 258 | private void BackPropReward(TreeNode n, double reward) {
|
---|
| 259 | n.tries++;
|
---|
| 260 | UpdateV(n, reward);
|
---|
| 261 | foreach (var p in n.parents) BackPropReward(p, reward);
|
---|
| 262 | }
|
---|
| 263 |
|
---|
| 264 | private void UpdateDone(TreeNode n) {
|
---|
| 265 | if (!n.done && n.children != null && n.children.All(c => c.done)) n.done = true;
|
---|
| 266 | if (n.done) foreach (var p in n.parents) UpdateDone(p);
|
---|
| 267 | }
|
---|
| 268 |
|
---|
| 269 |
|
---|
[11745] | 270 | private Dictionary<string, double> v;
|
---|
| 271 | private Dictionary<string, int> tries;
|
---|
| 272 |
|
---|
| 273 | private void UpdateV(TreeNode n, double reward) {
|
---|
| 274 | var canonicalStr = problem.CanonicalRepresentation(n.ident);
|
---|
| 275 | //var canonicalStr = n.ident;
|
---|
| 276 | double stateV;
|
---|
| 277 |
|
---|
| 278 | if (!v.TryGetValue(canonicalStr, out stateV)) {
|
---|
| 279 | v.Add(canonicalStr, reward);
|
---|
| 280 | tries.Add(canonicalStr, 1);
|
---|
| 281 | } else {
|
---|
[11747] | 282 | //v[canonicalStr] = stateV + 0.005 * (reward - stateV);
|
---|
| 283 | v[canonicalStr] = stateV + (1.0 / tries[canonicalStr]) * (reward - stateV);
|
---|
[11745] | 284 | tries[canonicalStr]++;
|
---|
| 285 | }
|
---|
| 286 | }
|
---|
| 287 |
|
---|
| 288 | private double V(TreeNode n) {
|
---|
| 289 | var canonicalStr = problem.CanonicalRepresentation(n.ident);
|
---|
| 290 | //var canonicalStr = n.ident;
|
---|
| 291 | double stateV;
|
---|
[11747] | 292 | if (!tries.ContainsKey(canonicalStr)) return double.PositiveInfinity;
|
---|
[11745] | 293 | if (!v.TryGetValue(canonicalStr, out stateV)) {
|
---|
| 294 | return 0.0;
|
---|
| 295 | } else {
|
---|
| 296 | return stateV;
|
---|
| 297 | }
|
---|
| 298 | }
|
---|
| 299 |
|
---|
[11747] | 300 | private int SelectRandom(Random random, TreeNode[] children) {
|
---|
| 301 | return children.Select((ch, i) => Tuple.Create(ch, i)).Where(p => !p.Item1.done).SelectRandom(random).Item2;
|
---|
| 302 | }
|
---|
| 303 |
|
---|
[11745] | 304 | private int SelectEpsGreedy(Random random, TreeNode[] children) {
|
---|
| 305 | if (random.NextDouble() < 0.2) {
|
---|
[11747] | 306 | return SelectRandom(random, children);
|
---|
[11745] | 307 | } else {
|
---|
| 308 | var bestQ = double.NegativeInfinity;
|
---|
| 309 | var bestChildIdx = new List<int>();
|
---|
| 310 | for (int i = 0; i < children.Length; i++) {
|
---|
| 311 | if (children[i].done) continue;
|
---|
| 312 | // if (children[i].tries == 0) return i;
|
---|
| 313 | var q = V(children[i]);
|
---|
| 314 | if (q > bestQ) {
|
---|
| 315 | bestQ = q;
|
---|
| 316 | bestChildIdx.Clear();
|
---|
| 317 | bestChildIdx.Add(i);
|
---|
| 318 | } else if (q == bestQ) {
|
---|
| 319 | bestChildIdx.Add(i);
|
---|
| 320 | }
|
---|
| 321 | }
|
---|
| 322 | Debug.Assert(bestChildIdx.Any());
|
---|
| 323 | return bestChildIdx.SelectRandom(random);
|
---|
| 324 | }
|
---|
| 325 | }
|
---|
| 326 | private int SelectActionUCB1(Random random, TreeNode[] children) {
|
---|
| 327 | int bestAction = -1;
|
---|
| 328 | double bestQ = double.NegativeInfinity;
|
---|
| 329 | int totalTries = children.Sum(ch => ch.tries);
|
---|
| 330 |
|
---|
| 331 | for (int a = 0; a < children.Length; a++) {
|
---|
| 332 | var ch = children[a];
|
---|
| 333 | if (ch.done) continue;
|
---|
| 334 | if (ch.tries == 0) return a;
|
---|
| 335 | var q = V(ch) + Math.Sqrt((2 * Math.Log(totalTries)) / ch.tries);
|
---|
| 336 | if (q > bestQ) {
|
---|
| 337 | bestQ = q;
|
---|
| 338 | bestAction = a;
|
---|
| 339 | }
|
---|
| 340 | }
|
---|
| 341 | Debug.Assert(bestAction > -1);
|
---|
| 342 | return bestAction;
|
---|
| 343 | }
|
---|
| 344 |
|
---|
| 345 |
|
---|
| 346 |
|
---|
[11742] | 347 | private void RaiseSolutionEvaluated(string sentence, double quality) {
|
---|
| 348 | var handler = SolutionEvaluated;
|
---|
| 349 | if (handler != null) handler(sentence, quality);
|
---|
| 350 | }
|
---|
| 351 | private void RaiseFoundNewBestSolution(string sentence, double quality) {
|
---|
| 352 | var handler = FoundNewBestSolution;
|
---|
| 353 | if (handler != null) handler(sentence, quality);
|
---|
| 354 | }
|
---|
[11745] | 355 |
|
---|
| 356 |
|
---|
[11742] | 357 | }
|
---|
| 358 | }
|
---|