1 | using System;
|
---|
2 | using System.Collections.Generic;
|
---|
3 | using System.Diagnostics;
|
---|
4 | using System.Linq;
|
---|
5 | using System.Text;
|
---|
6 | using HeuristicLab.Grammars;
|
---|
7 |
|
---|
8 | namespace CodeGenerator {
|
---|
9 | public class MonteCarloTreeSearchCodeGen {
|
---|
10 |
|
---|
11 | private string solverTemplate = @"
|
---|
12 | namespace ?PROBLEMNAME? {
|
---|
13 | public class SearchTreeNode {
|
---|
14 | public int tries;
|
---|
15 | public double sumQuality = 0.0;
|
---|
16 | public double bestQuality = double.NegativeInfinity;
|
---|
17 | public bool done;
|
---|
18 | public SearchTreeNode[] children;
|
---|
19 | // only for debugging
|
---|
20 | public double[] Ucb {
|
---|
21 | get {
|
---|
22 | return (from c in children
|
---|
23 | select ?IDENT?MonteCarloTreeSearchSolver.UCB(this, c)
|
---|
24 | ).ToArray();
|
---|
25 | }
|
---|
26 | }
|
---|
27 | public SearchTreeNode() {
|
---|
28 | }
|
---|
29 | }
|
---|
30 |
|
---|
31 | public sealed class ?IDENT?MonteCarloTreeSearchSolver {
|
---|
32 | private int maxDepth = 20;
|
---|
33 |
|
---|
34 | private readonly ?IDENT?Problem problem;
|
---|
35 | private readonly Random random;
|
---|
36 | private SearchTreeNode searchTree = new SearchTreeNode();
|
---|
37 |
|
---|
38 | private Tree SampleTree(int maxDepth) {
|
---|
39 | var extensionsStack = new Stack<Tuple<Tree, int, int, int>>(); // the unfinished tree, the state, the index of the extension point and the maximal depth of a tree inserted at that point
|
---|
40 | var t = new Tree(-1, new Tree[1]);
|
---|
41 | extensionsStack.Push(Tuple.Create(t, 0, 0, maxDepth));
|
---|
42 | SampleTree(searchTree, extensionsStack);
|
---|
43 | return t.subtrees[0];
|
---|
44 | }
|
---|
45 |
|
---|
46 | private void SampleTree(SearchTreeNode searchTree, Stack<Tuple<Tree, int, int, int>> extensionPoints) {
|
---|
47 | const int RANDOM_TRIES = 1000;
|
---|
48 | if(extensionPoints.Count == 0) {
|
---|
49 | searchTree.done = true;
|
---|
50 | return; // nothing to do
|
---|
51 | }
|
---|
52 | var extensionPoint = extensionPoints.Pop();
|
---|
53 | Tree parent = extensionPoint.Item1;
|
---|
54 | int state = extensionPoint.Item2;
|
---|
55 | int subtreeIdx = extensionPoint.Item3;
|
---|
56 | int maxDepth = extensionPoint.Item4;
|
---|
57 | Debug.Assert(maxDepth >= 1);
|
---|
58 | Tree t = null;
|
---|
59 | if(searchTree.tries < RANDOM_TRIES || Grammar.subtreeCount[state] == 0) {
|
---|
60 | t = SampleTreeRandom(state);
|
---|
61 | if(Grammar.subtreeCount[state] == 0) {
|
---|
62 | // when we produced a terminal continue filling up all other empty points
|
---|
63 | Debug.Assert(searchTree.children == null || searchTree.children.Length == 1);
|
---|
64 | if(searchTree.children == null)
|
---|
65 | searchTree.children = new SearchTreeNode[] { new SearchTreeNode() } ;
|
---|
66 | SampleTree(searchTree.children[0], extensionPoints);
|
---|
67 | if(searchTree.children[0].done) searchTree.done = true;
|
---|
68 | } else {
|
---|
69 | // fill up all remaining slots randomly
|
---|
70 | foreach(var p in extensionPoints) {
|
---|
71 | var pParent = p.Item1;
|
---|
72 | var pState = p.Item2;
|
---|
73 | var pIdx = p.Item3;
|
---|
74 | pParent.subtrees[pIdx] = SampleTreeRandom(pState);
|
---|
75 | }
|
---|
76 | }
|
---|
77 | } else {
|
---|
78 | if(Grammar.subtreeCount[state] == 1) {
|
---|
79 | if(searchTree.children == null) {
|
---|
80 | int nChildren = Grammar.transition[state].Length;
|
---|
81 | searchTree.children = new SearchTreeNode[nChildren];
|
---|
82 | }
|
---|
83 | Debug.Assert(searchTree.children.Length == Grammar.transition[state].Length);
|
---|
84 | Debug.Assert(searchTree.tries - RANDOM_TRIES == searchTree.children.Where(c=>c!=null).Sum(c=>c.tries));
|
---|
85 | var altIdx = SelectAlternative(searchTree);
|
---|
86 | t = new Tree(altIdx, new Tree[1]);
|
---|
87 | extensionPoints.Push(Tuple.Create(t, Grammar.transition[state][altIdx], 0, maxDepth - 1));
|
---|
88 | SampleTree(searchTree.children[altIdx], extensionPoints);
|
---|
89 | } else {
|
---|
90 | // multiple subtrees
|
---|
91 | var subtrees = new Tree[Grammar.subtreeCount[state]];
|
---|
92 | t = new Tree(-1, subtrees);
|
---|
93 | for(int i = subtrees.Length - 1; i >= 0; i--) {
|
---|
94 | extensionPoints.Push(Tuple.Create(t, Grammar.transition[state][i], i, maxDepth - 1));
|
---|
95 | }
|
---|
96 | SampleTree(searchTree, extensionPoints);
|
---|
97 | }
|
---|
98 | }
|
---|
99 | Debug.Assert(parent.subtrees[subtreeIdx] == null);
|
---|
100 | parent.subtrees[subtreeIdx] = t;
|
---|
101 | }
|
---|
102 |
|
---|
103 | private int SelectAlternative(SearchTreeNode searchTree) {
|
---|
104 | // any alternative not yet explored?
|
---|
105 | var altIdx = Array.FindIndex(searchTree.children, (e) => e == null);
|
---|
106 | if(altIdx >= 0) {
|
---|
107 | searchTree.children[altIdx] = new SearchTreeNode();
|
---|
108 | return altIdx;
|
---|
109 | } else {
|
---|
110 | altIdx = Array.FindIndex(searchTree.children, (e) => !e.done && e.tries < 1000);
|
---|
111 | if(altIdx >= 0) return altIdx;
|
---|
112 | // select the least sampled alternative
|
---|
113 | //altIdx = 0;
|
---|
114 | //int minSamples = searchTree.children[altIdx].tries;
|
---|
115 | //for(int idx = 1; idx < searchTree.children.Length; idx++) {
|
---|
116 | // if(!searchTree.children[idx].done && searchTree.children[idx].tries < minSamples) {
|
---|
117 | // minSamples = searchTree.children[idx].tries;
|
---|
118 | // altIdx = idx;
|
---|
119 | // }
|
---|
120 | //}
|
---|
121 | // select the alternative with the largest average
|
---|
122 | altIdx = 0;
|
---|
123 | double bestAverage = UCB(searchTree, searchTree.children[altIdx]);
|
---|
124 | for(int idx = 1; idx < searchTree.children.Length; idx++) {
|
---|
125 | if (!searchTree.children[idx].done && UCB(searchTree, searchTree.children[idx]) > UCB(searchTree, searchTree.children[altIdx])) {
|
---|
126 | altIdx = idx;
|
---|
127 | }
|
---|
128 | }
|
---|
129 |
|
---|
130 | searchTree.done = searchTree.children.All(c=>c.done);
|
---|
131 | return altIdx;
|
---|
132 | }
|
---|
133 | }
|
---|
134 |
|
---|
135 | public static double UCB(SearchTreeNode parent, SearchTreeNode n) {
|
---|
136 | Debug.Assert(parent.tries >= n.tries);
|
---|
137 | Debug.Assert(n.tries > 0);
|
---|
138 | return n.sumQuality / n.tries + Math.Sqrt((10 * Math.Log(parent.tries)) / n.tries ); // constant is dependent fitness function values
|
---|
139 | }
|
---|
140 |
|
---|
141 | private void UpdateSearchTree(Tree t, double quality) {
|
---|
142 | var trees = new Stack<Tree>();
|
---|
143 | trees.Push(t);
|
---|
144 | UpdateSearchTree(searchTree, trees, quality);
|
---|
145 | }
|
---|
146 |
|
---|
147 | private void UpdateSearchTree(SearchTreeNode searchTree, Stack<Tree> trees, double quality) {
|
---|
148 | if(trees.Count == 0 || searchTree == null) return;
|
---|
149 | var t = trees.Pop();
|
---|
150 | if(t.altIdx == -1) {
|
---|
151 | // for trees with multiple sub-trees
|
---|
152 | for(int idx = t.subtrees.Length - 1 ; idx >= 0; idx--) {
|
---|
153 | trees.Push(t.subtrees[idx]);
|
---|
154 | }
|
---|
155 | UpdateSearchTree(searchTree, trees, quality);
|
---|
156 | } else {
|
---|
157 | searchTree.sumQuality += quality;
|
---|
158 | searchTree.tries++;
|
---|
159 | if(quality > searchTree.bestQuality)
|
---|
160 | searchTree.bestQuality = quality;
|
---|
161 | if(t.subtrees != null) {
|
---|
162 | Debug.Assert(t.subtrees.Length == 1);
|
---|
163 | if(searchTree.children != null) {
|
---|
164 | trees.Push(t.subtrees[0]);
|
---|
165 | UpdateSearchTree(searchTree.children[t.altIdx], trees, quality);
|
---|
166 | }
|
---|
167 | } else {
|
---|
168 | if(searchTree.children != null) {
|
---|
169 | Debug.Assert(searchTree.children.Length == 1);
|
---|
170 | UpdateSearchTree(searchTree.children[0], trees, quality);
|
---|
171 | }
|
---|
172 | }
|
---|
173 | }
|
---|
174 | }
|
---|
175 |
|
---|
176 | // same as in random search solver (could reuse random search)
|
---|
177 | private Tree SampleTreeRandom(int state) {
|
---|
178 | return SampleTreeRandom(state, 5);
|
---|
179 | }
|
---|
180 | private Tree SampleTreeRandom(int state, int maxDepth) {
|
---|
181 | Tree t = null;
|
---|
182 |
|
---|
183 | // terminals
|
---|
184 | if(Grammar.subtreeCount[state] == 0) {
|
---|
185 | t = CreateTerminalNode(state, random, problem);
|
---|
186 | } else {
|
---|
187 | // if the symbol has alternatives then we must choose one randomly (only one sub-tree in this case)
|
---|
188 | if(Grammar.subtreeCount[state] == 1) {
|
---|
189 | var targetStates = Grammar.transition[state];
|
---|
190 | var altIdx = SampleAlternative(random, state, maxDepth);
|
---|
191 | var alternative = SampleTreeRandom(targetStates[altIdx], maxDepth - 1);
|
---|
192 | t = new Tree(altIdx, new Tree[] { alternative });
|
---|
193 | } else {
|
---|
194 | // if the symbol contains only one sequence we must use create sub-trees for each symbol in the sequence
|
---|
195 | Tree[] subtrees = new Tree[Grammar.subtreeCount[state]];
|
---|
196 | for(int i = 0; i < Grammar.subtreeCount[state]; i++) {
|
---|
197 | subtrees[i] = SampleTreeRandom(Grammar.transition[state][i], maxDepth - 1);
|
---|
198 | }
|
---|
199 | t = new Tree(-1, subtrees); // alternative index is ignored
|
---|
200 | }
|
---|
201 | }
|
---|
202 | return t;
|
---|
203 | }
|
---|
204 |
|
---|
205 | private static Tree CreateTerminalNode(int state, Random random, ?IDENT?Problem problem) {
|
---|
206 | switch(state) {
|
---|
207 | ?CREATETERMINALNODECODE?
|
---|
208 | default: { throw new ArgumentException(""Unknown state index"" + state); }
|
---|
209 | }
|
---|
210 | }
|
---|
211 |
|
---|
212 | private int SampleAlternative(Random random, int state, int maxDepth) {
|
---|
213 | switch(state) {
|
---|
214 |
|
---|
215 | ?SAMPLEALTERNATIVECODE?
|
---|
216 |
|
---|
217 | default: throw new InvalidOperationException();
|
---|
218 | }
|
---|
219 | }
|
---|
220 |
|
---|
221 | public ?IDENT?MonteCarloTreeSearchSolver(?IDENT?Problem problem, string[] args) {
|
---|
222 | if(args.Length > 0 ) throw new ArgumentException(""Arguments ""+args.Aggregate("""", (str, s) => str+ "" "" + s)+"" are not supported "");
|
---|
223 | this.problem = problem;
|
---|
224 | this.random = new Random();
|
---|
225 | }
|
---|
226 |
|
---|
227 | public void Start() {
|
---|
228 | Console.ReadLine();
|
---|
229 | var bestF = ?MAXIMIZATION? ? double.NegativeInfinity : double.PositiveInfinity;
|
---|
230 | int n = 0;
|
---|
231 | long sumDepth = 0;
|
---|
232 | long sumSize = 0;
|
---|
233 | var sumF = 0.0;
|
---|
234 | var sw = new System.Diagnostics.Stopwatch();
|
---|
235 | sw.Start();
|
---|
236 | while (!searchTree.done) {
|
---|
237 |
|
---|
238 | int steps, depth;
|
---|
239 | var _t = SampleTree(maxDepth);
|
---|
240 | // _t.PrintTree(0); Console.WriteLine();
|
---|
241 |
|
---|
242 | // inefficient but don't care for now
|
---|
243 | steps = _t.GetSize();
|
---|
244 | depth = _t.GetDepth();
|
---|
245 | Debug.Assert(depth <= maxDepth);
|
---|
246 | var f = problem.Evaluate(_t);
|
---|
247 | if(?MAXIMIZATION?)
|
---|
248 | UpdateSearchTree(_t, f);
|
---|
249 | else
|
---|
250 | UpdateSearchTree(_t, -f);
|
---|
251 | n++;
|
---|
252 | sumSize += steps;
|
---|
253 | sumDepth += depth;
|
---|
254 | sumF += f;
|
---|
255 | if (problem.IsBetter(f, bestF)) {
|
---|
256 | bestF = f;
|
---|
257 | _t.PrintTree(0); Console.WriteLine();
|
---|
258 | Console.WriteLine(""{0}\t{1}\t(size={2}, depth={3})"", n, bestF, steps, depth);
|
---|
259 | }
|
---|
260 | if (n % 1000 == 0) {
|
---|
261 | sw.Stop();
|
---|
262 | Console.WriteLine(""{0}\tbest: {1:0.000}\t(avg: {2:0.000})\t(avg size: {3:0.0})\t(avg. depth: {4:0.0})\t({5:0.00} sols/ms)"", n, bestF, sumF/1000.0, sumSize/1000.0, sumDepth/1000.0, 1000.0 / sw.ElapsedMilliseconds);
|
---|
263 | sumSize = 0;
|
---|
264 | sumDepth = 0;
|
---|
265 | sumF = 0.0;
|
---|
266 | sw.Restart();
|
---|
267 | }
|
---|
268 | }
|
---|
269 | }
|
---|
270 | }
|
---|
271 | }";
|
---|
272 |
|
---|
273 | public void Generate(IGrammar grammar, IEnumerable<TerminalNode> terminals, bool maximization, SourceBuilder problemSourceCode) {
|
---|
274 | var solverSourceCode = new SourceBuilder();
|
---|
275 | solverSourceCode.Append(solverTemplate)
|
---|
276 | .Replace("?MAXIMIZATION?", maximization.ToString().ToLowerInvariant())
|
---|
277 | .Replace("?SAMPLEALTERNATIVECODE?", GenerateSampleAlternativeSource(grammar))
|
---|
278 | .Replace("?CREATETERMINALNODECODE?", GenerateCreateTerminalCode(grammar, terminals))
|
---|
279 | ;
|
---|
280 |
|
---|
281 | problemSourceCode.Append(solverSourceCode.ToString());
|
---|
282 | }
|
---|
283 |
|
---|
284 |
|
---|
285 |
|
---|
286 | private string GenerateSampleAlternativeSource(IGrammar grammar) {
|
---|
287 | Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
|
---|
288 | var sb = new SourceBuilder();
|
---|
289 | int stateCount = 0;
|
---|
290 | foreach (var s in grammar.Symbols) {
|
---|
291 | sb.AppendFormat("case {0}: ", stateCount++);
|
---|
292 | if (grammar.IsTerminal(s)) {
|
---|
293 | // ignore
|
---|
294 | } else {
|
---|
295 | var terminalAltIndexes = grammar.GetAlternatives(s)
|
---|
296 | .Select((alt, idx) => new { alt, idx })
|
---|
297 | .Where((p) => p.alt.All(symb => grammar.IsTerminal(symb)))
|
---|
298 | .Select(p => p.idx);
|
---|
299 | var nonTerminalAltIndexes = grammar.GetAlternatives(s)
|
---|
300 | .Select((alt, idx) => new { alt, idx })
|
---|
301 | .Where((p) => p.alt.Any(symb => grammar.IsNonTerminal(symb)))
|
---|
302 | .Select(p => p.idx);
|
---|
303 | var hasTerminalAlts = terminalAltIndexes.Any();
|
---|
304 | var hasNonTerminalAlts = nonTerminalAltIndexes.Any();
|
---|
305 | if (hasTerminalAlts && hasNonTerminalAlts) {
|
---|
306 | sb.Append("if(maxDepth <= 1) {").BeginBlock();
|
---|
307 | GenerateReturnStatement(terminalAltIndexes, sb);
|
---|
308 | sb.Append("} else {");
|
---|
309 | GenerateReturnStatement(nonTerminalAltIndexes.Concat(terminalAltIndexes), sb);
|
---|
310 | sb.Append("}").EndBlock();
|
---|
311 | } else {
|
---|
312 | GenerateReturnStatement(grammar.NumberOfAlternatives(s), sb);
|
---|
313 | }
|
---|
314 | }
|
---|
315 | }
|
---|
316 | return sb.ToString();
|
---|
317 | }
|
---|
318 | private string GenerateCreateTerminalCode(IGrammar grammar, IEnumerable<TerminalNode> terminals) {
|
---|
319 | Debug.Assert(grammar.Symbols.First().Equals(grammar.StartSymbol));
|
---|
320 | var sb = new SourceBuilder();
|
---|
321 | var allSymbols = grammar.Symbols.ToList();
|
---|
322 | foreach (var s in grammar.Symbols) {
|
---|
323 | if (grammar.IsTerminal(s)) {
|
---|
324 | sb.AppendFormat("case {0}: {{", allSymbols.IndexOf(s)).BeginBlock();
|
---|
325 | sb.AppendFormat("var t = new {0}Tree();", s.Name).AppendLine();
|
---|
326 | var terminal = terminals.Single(t => t.Ident == s.Name);
|
---|
327 | foreach (var constr in terminal.Constraints) {
|
---|
328 | if (constr.Type == ConstraintNodeType.Set) {
|
---|
329 | throw new NotImplementedException("Support for terminal symbols with attributes is not yet implemented.");
|
---|
330 | // sb.Append("{").BeginBlock();
|
---|
331 | // sb.AppendFormat("var elements = problem.GetAllowed{0}_{1}().ToArray();", terminal.Ident, constr.Ident).AppendLine();
|
---|
332 | // sb.AppendFormat("t.{0} = elements[random.Next(elements.Length)]; ", constr.Ident).EndBlock();
|
---|
333 | // sb.AppendLine("}");
|
---|
334 | } else {
|
---|
335 | throw new NotSupportedException("The MTCS solver does not support RANGE constraints.");
|
---|
336 | }
|
---|
337 | }
|
---|
338 | sb.AppendLine("return t;").EndBlock();
|
---|
339 | sb.Append("}");
|
---|
340 | }
|
---|
341 | }
|
---|
342 | return sb.ToString();
|
---|
343 | }
|
---|
344 | private void GenerateReturnStatement(IEnumerable<int> idxs, SourceBuilder sb) {
|
---|
345 | if (idxs.Count() == 1) {
|
---|
346 | sb.AppendFormat("return {0};", idxs.Single()).AppendLine();
|
---|
347 | } else {
|
---|
348 | var idxStr = idxs.Aggregate(string.Empty, (str, idx) => str + idx + ", ");
|
---|
349 | sb.AppendFormat("return new int[] {{ {0} }}[random.Next({1})]; ", idxStr, idxs.Count()).AppendLine();
|
---|
350 | }
|
---|
351 | }
|
---|
352 |
|
---|
353 | private void GenerateReturnStatement(int nAlts, SourceBuilder sb) {
|
---|
354 | if (nAlts > 1) {
|
---|
355 | sb.AppendFormat("return random.Next({0});", nAlts).AppendLine();
|
---|
356 | } else if (nAlts == 1) {
|
---|
357 | sb.AppendLine("return 0; ");
|
---|
358 | } else {
|
---|
359 | sb.AppendLine("throw new InvalidProgramException();");
|
---|
360 | }
|
---|
361 | }
|
---|
362 | }
|
---|
363 | }
|
---|