1 | using System;
|
---|
2 | using System.Collections.Generic;
|
---|
3 | using System.Diagnostics;
|
---|
4 | using System.Linq;
|
---|
5 | using System.Text;
|
---|
6 | using HeuristicLab.Algorithms.Bandits;
|
---|
7 | using HeuristicLab.Common;
|
---|
8 | using HeuristicLab.Problems.GrammaticalOptimization;
|
---|
9 |
|
---|
10 | namespace HeuristicLab.Algorithms.GrammaticalOptimization {
|
---|
11 | public class MctsContextualSampler {
|
---|
12 | private class TreeNode {
|
---|
13 | public string ident;
|
---|
14 | public ReadonlySequence alt;
|
---|
15 | public int randomTries;
|
---|
16 | public int tries;
|
---|
17 | public List<TreeNode> parents;
|
---|
18 | public TreeNode[] children;
|
---|
19 | public bool done = false;
|
---|
20 |
|
---|
21 | public TreeNode(string id, ReadonlySequence alt) {
|
---|
22 | this.ident = id;
|
---|
23 | this.alt = alt;
|
---|
24 | this.parents = new List<TreeNode>();
|
---|
25 | }
|
---|
26 |
|
---|
27 | public override string ToString() {
|
---|
28 | return string.Format("Node({0} tries: {1}, done: {2})", ident, tries, done);
|
---|
29 | }
|
---|
30 | }
|
---|
31 |
|
---|
32 | private Dictionary<string, TreeNode> treeNodes;
|
---|
33 | private TreeNode GetTreeNode(string id, ReadonlySequence alt) {
|
---|
34 | TreeNode n;
|
---|
35 | var canonicalId = problem.CanonicalRepresentation(id);
|
---|
36 | if (!treeNodes.TryGetValue(canonicalId, out n)) {
|
---|
37 | n = new TreeNode(canonicalId, alt);
|
---|
38 | tries.TryGetValue(canonicalId, out n.tries);
|
---|
39 | treeNodes[canonicalId] = n;
|
---|
40 | }
|
---|
41 | return n;
|
---|
42 | }
|
---|
43 |
|
---|
44 | public event Action<string, double> FoundNewBestSolution;
|
---|
45 | public event Action<string, double> SolutionEvaluated;
|
---|
46 |
|
---|
47 | private readonly int maxLen;
|
---|
48 | private readonly IProblem problem;
|
---|
49 | private readonly Random random;
|
---|
50 | private readonly int randomTries;
|
---|
51 |
|
---|
52 | private List<Tuple<TreeNode, TreeNode>> updateChain;
|
---|
53 | private TreeNode rootNode;
|
---|
54 |
|
---|
55 | public int treeDepth;
|
---|
56 | public int treeSize;
|
---|
57 | private double bestQuality;
|
---|
58 |
|
---|
59 | public MctsContextualSampler(IProblem problem, int maxLen, Random random, int randomTries) {
|
---|
60 | this.maxLen = maxLen;
|
---|
61 | this.problem = problem;
|
---|
62 | this.random = random;
|
---|
63 | this.randomTries = randomTries;
|
---|
64 | this.v = new Dictionary<string, double>(1000000);
|
---|
65 | this.tries = new Dictionary<string, int>(1000000);
|
---|
66 | treeNodes = new Dictionary<string, TreeNode>();
|
---|
67 | }
|
---|
68 |
|
---|
69 | public void Run(int maxIterations) {
|
---|
70 | bestQuality = double.MinValue;
|
---|
71 | InitPolicies(problem.Grammar);
|
---|
72 | for (int i = 0; !rootNode.done && i < maxIterations; i++) {
|
---|
73 | bool success;
|
---|
74 | var sentence = SampleSentence(problem.Grammar, out success).ToString();
|
---|
75 | if (success) {
|
---|
76 | var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen);
|
---|
77 | Debug.Assert(quality >= 0 && quality <= 1.0);
|
---|
78 | DistributeReward(quality);
|
---|
79 |
|
---|
80 | RaiseSolutionEvaluated(sentence, quality);
|
---|
81 |
|
---|
82 | if (quality > bestQuality) {
|
---|
83 | bestQuality = quality;
|
---|
84 | RaiseFoundNewBestSolution(sentence, quality);
|
---|
85 | }
|
---|
86 | }
|
---|
87 | }
|
---|
88 |
|
---|
89 | // clean up
|
---|
90 | InitPolicies(problem.Grammar); GC.Collect();
|
---|
91 | }
|
---|
92 |
|
---|
93 | public void PrintStats() {
|
---|
94 | var n = rootNode;
|
---|
95 | Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}, rootQ {3:F3}, bestQ {4:F3}", treeDepth, treeSize, n.tries, V(n), bestQuality);
|
---|
96 | while (n.children != null) {
|
---|
97 | Console.WriteLine("{0,-30}", n.ident);
|
---|
98 | double maxVForRow = n.children.Select(ch => Math.Min(1.0, Math.Max(0.0, V(ch)))).Max();
|
---|
99 | if (maxVForRow == 0) maxVForRow = 1.0;
|
---|
100 |
|
---|
101 | for (int i = 0; i < n.children.Length; i++) {
|
---|
102 | var ch = n.children[i];
|
---|
103 | Console.ForegroundColor = ConsoleEx.ColorForValue(Math.Min(1.0, V(ch)) / maxVForRow);
|
---|
104 | Console.Write("{0,5}", ch.alt);
|
---|
105 | }
|
---|
106 | Console.WriteLine();
|
---|
107 | for (int i = 0; i < n.children.Length; i++) {
|
---|
108 | var ch = n.children[i];
|
---|
109 | Console.ForegroundColor = ConsoleEx.ColorForValue(Math.Min(1.0, V(ch)) / maxVForRow);
|
---|
110 | Console.Write("{0,5:F2}", Math.Min(1.0, V(ch)) * 10);
|
---|
111 | }
|
---|
112 | Console.WriteLine();
|
---|
113 | for (int i = 0; i < n.children.Length; i++) {
|
---|
114 | var ch = n.children[i];
|
---|
115 | Console.ForegroundColor = ConsoleEx.ColorForValue(Math.Min(1.0, V(ch)) / maxVForRow);
|
---|
116 | Console.Write("{0,5}", ch.done ? "X" : ch.tries.ToString());
|
---|
117 | }
|
---|
118 | Console.ForegroundColor = ConsoleColor.White;
|
---|
119 | Console.WriteLine();
|
---|
120 | //n.policy.PrintStats();
|
---|
121 | n = n.children.Where(ch => !ch.done).OrderByDescending(c => c.tries).First();
|
---|
122 | }
|
---|
123 | }
|
---|
124 |
|
---|
125 |
|
---|
126 | private void InitPolicies(IGrammar grammar) {
|
---|
127 | this.updateChain = new List<Tuple<TreeNode, TreeNode>>();
|
---|
128 | this.v.Clear();
|
---|
129 | this.tries.Clear();
|
---|
130 |
|
---|
131 | rootNode = GetTreeNode(grammar.SentenceSymbol.ToString(), new ReadonlySequence("$"));
|
---|
132 | treeDepth = 0;
|
---|
133 | treeSize = 0;
|
---|
134 | }
|
---|
135 |
|
---|
136 | private Sequence SampleSentence(IGrammar grammar, out bool success) {
|
---|
137 | updateChain.Clear();
|
---|
138 | //var startPhrase = new Sequence("a*b+c*d+e*f+E");
|
---|
139 | var startPhrase = new Sequence(grammar.SentenceSymbol);
|
---|
140 | return CompleteSentence(grammar, startPhrase, out success);
|
---|
141 | }
|
---|
142 |
|
---|
143 | private Sequence CompleteSentence(IGrammar g, Sequence phrase, out bool success) {
|
---|
144 | if (phrase.Length > maxLen) throw new ArgumentException();
|
---|
145 | if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException();
|
---|
146 | TreeNode parent = null;
|
---|
147 | TreeNode n = rootNode;
|
---|
148 | var curDepth = 0;
|
---|
149 | while (!phrase.IsTerminal) {
|
---|
150 | updateChain.Add(Tuple.Create(n, parent));
|
---|
151 |
|
---|
152 | if (n.randomTries < randomTries) {
|
---|
153 | n.randomTries++;
|
---|
154 | treeDepth = Math.Max(treeDepth, curDepth);
|
---|
155 | success = true;
|
---|
156 | return g.CompleteSentenceRandomly(random, phrase, maxLen);
|
---|
157 | } else {
|
---|
158 | char nt = phrase.FirstNonTerminal;
|
---|
159 |
|
---|
160 | int maxLenOfReplacement = maxLen - (phrase.Length - 1); // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2
|
---|
161 | Debug.Assert(maxLenOfReplacement > 0);
|
---|
162 |
|
---|
163 | var alts = g.GetAlternatives(nt).Where(alt => g.MinPhraseLength(alt) <= maxLenOfReplacement);
|
---|
164 |
|
---|
165 | if (n.randomTries == randomTries && n.children == null) {
|
---|
166 | // create a new node for each alternative
|
---|
167 | n.children = new TreeNode[alts.Count()];
|
---|
168 | var i = 0;
|
---|
169 | foreach (var alt in alts) {
|
---|
170 | var newPhrase = new Sequence(phrase);
|
---|
171 | newPhrase.ReplaceAt(newPhrase.FirstNonTerminalIndex, 1, alt);
|
---|
172 | if (!newPhrase.IsTerminal) newPhrase = newPhrase.Subsequence(0, newPhrase.FirstNonTerminalIndex + 1);
|
---|
173 | var treeNode = GetTreeNode(newPhrase.ToString(), new ReadonlySequence(alt));
|
---|
174 | treeNode.parents.Add(n);
|
---|
175 | n.children[i++] = treeNode;
|
---|
176 | }
|
---|
177 | treeSize += n.children.Length;
|
---|
178 | UpdateDone(n);
|
---|
179 |
|
---|
180 | // it could happend that we already finished all variations starting from the branch
|
---|
181 | // stop
|
---|
182 | if (n.done) {
|
---|
183 | success = false;
|
---|
184 | return phrase;
|
---|
185 | }
|
---|
186 | }
|
---|
187 | //int selectedAltIdx = SelectRandom(random, n.children);
|
---|
188 |
|
---|
189 | // => select using eps-greedy
|
---|
190 | int selectedAltIdx = SelectEpsGreedy(random, n.children);
|
---|
191 |
|
---|
192 | //int selectedAltIdx = SelectActionUCB1(random, n.children);
|
---|
193 | Sequence selectedAlt = alts.ElementAt(selectedAltIdx);
|
---|
194 |
|
---|
195 | // replace nt with alt
|
---|
196 | phrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, selectedAlt);
|
---|
197 |
|
---|
198 | curDepth++;
|
---|
199 |
|
---|
200 |
|
---|
201 | // prepare for next iteration
|
---|
202 | parent = n;
|
---|
203 | n = n.children[selectedAltIdx];
|
---|
204 | //UpdateTD(parent, n, 0.0);
|
---|
205 | }
|
---|
206 | } // while
|
---|
207 |
|
---|
208 | updateChain.Add(Tuple.Create(n, parent));
|
---|
209 |
|
---|
210 | // the last node is a leaf node (sentence is done), so we never need to visit this node again
|
---|
211 | n.done = true;
|
---|
212 |
|
---|
213 |
|
---|
214 | treeDepth = Math.Max(treeDepth, curDepth);
|
---|
215 | success = true;
|
---|
216 | return phrase;
|
---|
217 | }
|
---|
218 |
|
---|
219 |
|
---|
220 | //private void UpdateTD(TreeNode parent, TreeNode child, double reward) {
|
---|
221 | // double alpha = 1.0;
|
---|
222 | // var vParent = V(parent);
|
---|
223 | // var vChild = V(child);
|
---|
224 | // if (double.IsInfinity(vParent)) vParent = 0.0;
|
---|
225 | // if (double.IsInfinity(vChild)) vChild = 0.0;
|
---|
226 | // UpdateV(parent, (alpha * (reward + vChild - vParent)));
|
---|
227 | //}
|
---|
228 |
|
---|
229 | private void DistributeReward(double reward) {
|
---|
230 |
|
---|
231 | // iterate in reverse order (bottom up)
|
---|
232 | //updateChain.Reverse();
|
---|
233 | UpdateDone(updateChain.Last().Item1);
|
---|
234 | //UpdateTD(updateChain.Last().Item2, updateChain.Last().Item1, reward);
|
---|
235 | //return;
|
---|
236 |
|
---|
237 | BackPropReward(updateChain.Last().Item1, reward);
|
---|
238 | /*
|
---|
239 | foreach (var e in updateChain) {
|
---|
240 | var node = e.Item1;
|
---|
241 | //var parent = e.Item2;
|
---|
242 | node.tries++;
|
---|
243 | //if (node.children != null && node.children.All(c => c.done)) {
|
---|
244 | // node.done = true;
|
---|
245 | //}
|
---|
246 | UpdateV(node, reward);
|
---|
247 |
|
---|
248 | // the reward for the parent is either the just recieved reward or the value of the best action so far
|
---|
249 | //double value = 0.0;
|
---|
250 | //if (parent != null) {
|
---|
251 | // var doneChilds = parent.children.Where(ch => ch.done);
|
---|
252 | // if (doneChilds.Any()) value = doneChilds.Select(ch => V(ch)).Max();
|
---|
253 | //}
|
---|
254 | //if (value > reward) reward = value;
|
---|
255 | }*/
|
---|
256 | }
|
---|
257 |
|
---|
258 | private void BackPropReward(TreeNode n, double reward) {
|
---|
259 | n.tries++;
|
---|
260 | UpdateV(n, reward);
|
---|
261 | foreach (var p in n.parents) BackPropReward(p, reward);
|
---|
262 | }
|
---|
263 |
|
---|
264 | private void UpdateDone(TreeNode n) {
|
---|
265 | if (!n.done && n.children != null && n.children.All(c => c.done)) n.done = true;
|
---|
266 | if (n.done) foreach (var p in n.parents) UpdateDone(p);
|
---|
267 | }
|
---|
268 |
|
---|
269 |
|
---|
270 | private Dictionary<string, double> v;
|
---|
271 | private Dictionary<string, int> tries;
|
---|
272 |
|
---|
273 | private void UpdateV(TreeNode n, double reward) {
|
---|
274 | var canonicalStr = problem.CanonicalRepresentation(n.ident);
|
---|
275 | //var canonicalStr = n.ident;
|
---|
276 | double stateV;
|
---|
277 |
|
---|
278 | if (!v.TryGetValue(canonicalStr, out stateV)) {
|
---|
279 | v.Add(canonicalStr, reward);
|
---|
280 | tries.Add(canonicalStr, 1);
|
---|
281 | } else {
|
---|
282 | //v[canonicalStr] = stateV + 0.005 * (reward - stateV);
|
---|
283 | v[canonicalStr] = stateV + (1.0 / tries[canonicalStr]) * (reward - stateV);
|
---|
284 | tries[canonicalStr]++;
|
---|
285 | }
|
---|
286 | }
|
---|
287 |
|
---|
288 | private double V(TreeNode n) {
|
---|
289 | var canonicalStr = problem.CanonicalRepresentation(n.ident);
|
---|
290 | //var canonicalStr = n.ident;
|
---|
291 | double stateV;
|
---|
292 | if (!tries.ContainsKey(canonicalStr)) return double.PositiveInfinity;
|
---|
293 | if (!v.TryGetValue(canonicalStr, out stateV)) {
|
---|
294 | return 0.0;
|
---|
295 | } else {
|
---|
296 | return stateV;
|
---|
297 | }
|
---|
298 | }
|
---|
299 |
|
---|
300 | private int SelectRandom(Random random, TreeNode[] children) {
|
---|
301 | return children.Select((ch, i) => Tuple.Create(ch, i)).Where(p => !p.Item1.done).SelectRandom(random).Item2;
|
---|
302 | }
|
---|
303 |
|
---|
304 | private int SelectEpsGreedy(Random random, TreeNode[] children) {
|
---|
305 | if (random.NextDouble() < 0.2) {
|
---|
306 | return SelectRandom(random, children);
|
---|
307 | } else {
|
---|
308 | var bestQ = double.NegativeInfinity;
|
---|
309 | var bestChildIdx = new List<int>();
|
---|
310 | for (int i = 0; i < children.Length; i++) {
|
---|
311 | if (children[i].done) continue;
|
---|
312 | // if (children[i].tries == 0) return i;
|
---|
313 | var q = V(children[i]);
|
---|
314 | if (q > bestQ) {
|
---|
315 | bestQ = q;
|
---|
316 | bestChildIdx.Clear();
|
---|
317 | bestChildIdx.Add(i);
|
---|
318 | } else if (q == bestQ) {
|
---|
319 | bestChildIdx.Add(i);
|
---|
320 | }
|
---|
321 | }
|
---|
322 | Debug.Assert(bestChildIdx.Any());
|
---|
323 | return bestChildIdx.SelectRandom(random);
|
---|
324 | }
|
---|
325 | }
|
---|
326 | private int SelectActionUCB1(Random random, TreeNode[] children) {
|
---|
327 | int bestAction = -1;
|
---|
328 | double bestQ = double.NegativeInfinity;
|
---|
329 | int totalTries = children.Sum(ch => ch.tries);
|
---|
330 |
|
---|
331 | for (int a = 0; a < children.Length; a++) {
|
---|
332 | var ch = children[a];
|
---|
333 | if (ch.done) continue;
|
---|
334 | if (ch.tries == 0) return a;
|
---|
335 | var q = V(ch) + Math.Sqrt((2 * Math.Log(totalTries)) / ch.tries);
|
---|
336 | if (q > bestQ) {
|
---|
337 | bestQ = q;
|
---|
338 | bestAction = a;
|
---|
339 | }
|
---|
340 | }
|
---|
341 | Debug.Assert(bestAction > -1);
|
---|
342 | return bestAction;
|
---|
343 | }
|
---|
344 |
|
---|
345 |
|
---|
346 |
|
---|
347 | private void RaiseSolutionEvaluated(string sentence, double quality) {
|
---|
348 | var handler = SolutionEvaluated;
|
---|
349 | if (handler != null) handler(sentence, quality);
|
---|
350 | }
|
---|
351 | private void RaiseFoundNewBestSolution(string sentence, double quality) {
|
---|
352 | var handler = FoundNewBestSolution;
|
---|
353 | if (handler != null) handler(sentence, quality);
|
---|
354 | }
|
---|
355 |
|
---|
356 |
|
---|
357 | }
|
---|
358 | }
|
---|