1 | using System;
|
---|
2 | using System.Collections.Generic;
|
---|
3 | using System.Diagnostics;
|
---|
4 | using System.Drawing;
|
---|
5 | using System.Linq;
|
---|
6 | using System.Net.Mail;
|
---|
7 | using System.Text;
|
---|
8 | using System.Threading;
|
---|
9 | using GraphVizWrapper;
|
---|
10 | using GraphVizWrapper.Commands;
|
---|
11 | using GraphVizWrapper.Queries;
|
---|
12 | using HeuristicLab.Algorithms.Bandits;
|
---|
13 | using HeuristicLab.Algorithms.GrammaticalOptimization;
|
---|
14 | using HeuristicLab.Algorithms.MonteCarloTreeSearch.Base;
|
---|
15 | using HeuristicLab.Algorithms.MonteCarloTreeSearch.Simulation;
|
---|
16 | using HeuristicLab.Common;
|
---|
17 | using HeuristicLab.Problems.GrammaticalOptimization;
|
---|
18 |
|
---|
19 | namespace HeuristicLab.Algorithms.MonteCarloTreeSearch
|
---|
20 | {
|
---|
21 | public class MonteCarloTreeSearch : SolverBase
|
---|
22 | {
|
---|
23 | protected readonly int maxLen;
|
---|
24 | protected readonly ISymbolicExpressionTreeProblem problem;
|
---|
25 | protected readonly IGrammar grammar;
|
---|
26 | protected readonly Random random;
|
---|
27 | protected readonly IBanditPolicy behaviourPolicy;
|
---|
28 | protected readonly ISimulation simulation;
|
---|
29 | protected TreeNode rootNode;
|
---|
30 | protected int goodSelections;
|
---|
31 |
|
---|
32 | public MonteCarloTreeSearch(ISymbolicExpressionTreeProblem problem, int maxLen, Random random, IBanditPolicy behaviourPolicy,
|
---|
33 | ISimulation simulationPolicy)
|
---|
34 | {
|
---|
35 | this.problem = problem;
|
---|
36 | this.grammar = problem.Grammar;
|
---|
37 | this.maxLen = maxLen;
|
---|
38 | this.random = random;
|
---|
39 | this.behaviourPolicy = behaviourPolicy;
|
---|
40 | this.simulation = simulationPolicy;
|
---|
41 |
|
---|
42 | this.problem.GenerateProblemSolutions(maxLen);
|
---|
43 | }
|
---|
44 |
|
---|
45 | public event Action<int, int> IterationFinished;
|
---|
46 |
|
---|
47 | protected virtual void OnIterationFinishedChanged(int value, int selections)
|
---|
48 | {
|
---|
49 | RaiseIterationFinishedChanged(value, selections);
|
---|
50 | }
|
---|
51 |
|
---|
52 | private void RaiseIterationFinishedChanged(int value, int selections)
|
---|
53 | {
|
---|
54 | var handler = IterationFinished;
|
---|
55 | if (handler != null) handler(value, selections);
|
---|
56 | }
|
---|
57 |
|
---|
58 | public bool StopRequested { get; set; }
|
---|
59 |
|
---|
60 | public int GoodSelections { get { return this.goodSelections; } }
|
---|
61 |
|
---|
62 | protected void CheckSelection(TreeNode node, int selections)
|
---|
63 | {
|
---|
64 | if (problem.IsParentOfProblemSolution(node.phrase, maxLen))
|
---|
65 | {
|
---|
66 | goodSelections++;
|
---|
67 | }
|
---|
68 | }
|
---|
69 |
|
---|
70 | public override void Run(int maxIterations)
|
---|
71 | {
|
---|
72 | Reset();
|
---|
73 | int selections = 0;
|
---|
74 | TreeNode currentNode;
|
---|
75 | string phrase;
|
---|
76 | string simulatedPhrase;
|
---|
77 | double quality;
|
---|
78 |
|
---|
79 | for (int i = 0; !StopRequested && i < maxIterations; i++)
|
---|
80 | {
|
---|
81 | currentNode = rootNode;
|
---|
82 |
|
---|
83 | while (!currentNode.IsLeaf())
|
---|
84 | {
|
---|
85 | int currentActionIndex = behaviourPolicy.SelectAction(random,
|
---|
86 | currentNode.GetChildActionInfos());
|
---|
87 | currentNode = currentNode.children[currentActionIndex];
|
---|
88 | selections++;
|
---|
89 | CheckSelection(currentNode, selections);
|
---|
90 | }
|
---|
91 |
|
---|
92 | phrase = currentNode.phrase;
|
---|
93 |
|
---|
94 | if (!grammar.IsTerminal(phrase) && phrase.Length <= maxLen)
|
---|
95 | {
|
---|
96 | ExpandTreeNode(currentNode);
|
---|
97 | currentNode =
|
---|
98 | currentNode.children[behaviourPolicy.SelectAction(random, currentNode.GetChildActionInfos())
|
---|
99 | ];
|
---|
100 | selections++;
|
---|
101 | CheckSelection(currentNode, selections);
|
---|
102 | }
|
---|
103 | if (currentNode.phrase.Length <= maxLen)
|
---|
104 | {
|
---|
105 | quality = simulation.Simulate(currentNode, out simulatedPhrase);
|
---|
106 | OnSolutionEvaluated(simulatedPhrase, quality);
|
---|
107 |
|
---|
108 | OnIterationFinishedChanged(goodSelections, selections);
|
---|
109 | Propagate(currentNode, quality);
|
---|
110 | }
|
---|
111 | }
|
---|
112 | }
|
---|
113 |
|
---|
114 | protected void ExpandTreeNode(TreeNode treeNode)
|
---|
115 | {
|
---|
116 | // create children on the first visit
|
---|
117 | Sequence newSequence;
|
---|
118 | TreeNode childNode;
|
---|
119 | if (treeNode.children == null)
|
---|
120 | {
|
---|
121 | List<TreeNode> newChildren = new List<TreeNode>();
|
---|
122 |
|
---|
123 | var phrase = new Sequence(treeNode.phrase);
|
---|
124 | // create subnodes for each nt-symbol in phrase
|
---|
125 | for (int i = 0; i < phrase.Length; i++)
|
---|
126 | {
|
---|
127 | char symbol = phrase[i];
|
---|
128 | if (grammar.IsNonTerminal(symbol))
|
---|
129 | {
|
---|
130 | // create subnode for each alternative of symbol
|
---|
131 | foreach (Sequence alternative in grammar.GetAlternatives(symbol))
|
---|
132 | {
|
---|
133 | newSequence = new Sequence(phrase);
|
---|
134 | newSequence.ReplaceAt(i, 1, alternative);
|
---|
135 | if (newSequence.Length <= maxLen)
|
---|
136 | {
|
---|
137 | childNode = new TreeNode(treeNode, newSequence.ToString(),
|
---|
138 | behaviourPolicy.CreateActionInfo(), (ushort) (treeNode.level + 1));
|
---|
139 | newChildren.Add(childNode);
|
---|
140 | }
|
---|
141 | }
|
---|
142 | }
|
---|
143 | }
|
---|
144 | treeNode.children = newChildren.ToArray();
|
---|
145 | }
|
---|
146 | }
|
---|
147 |
|
---|
148 | protected void Reset()
|
---|
149 | {
|
---|
150 | goodSelections = 0;
|
---|
151 | StopRequested = false;
|
---|
152 | bestQuality = 0.0;
|
---|
153 | rootNode = new TreeNode(null, grammar.SentenceSymbol.ToString(), behaviourPolicy.CreateActionInfo(), 0);
|
---|
154 | }
|
---|
155 |
|
---|
156 | protected void Propagate(TreeNode node, double quality)
|
---|
157 | {
|
---|
158 | var currentNode = node;
|
---|
159 | do
|
---|
160 | {
|
---|
161 | currentNode.actionInfo.UpdateReward(quality);
|
---|
162 | currentNode = currentNode.parent;
|
---|
163 | } while (currentNode != null);
|
---|
164 | }
|
---|
165 |
|
---|
166 | private void GetTreeInfosRek(TreeInfos treeInfos, TreeNode[] children)
|
---|
167 | {
|
---|
168 | treeInfos.TotalNodes += children.Length;
|
---|
169 | foreach (TreeNode child in children)
|
---|
170 | {
|
---|
171 | if (child.children != null)
|
---|
172 | {
|
---|
173 | treeInfos.ExpandedNodes++;
|
---|
174 | if (treeInfos.DeepestLevel <= child.level)
|
---|
175 | {
|
---|
176 | treeInfos.DeepestLevel = child.level + 1;
|
---|
177 | }
|
---|
178 | GetTreeInfosRek(treeInfos, child.children);
|
---|
179 | }
|
---|
180 | else
|
---|
181 | {
|
---|
182 | if (grammar.IsTerminal(child.phrase))
|
---|
183 | {
|
---|
184 | treeInfos.LeaveNodes++;
|
---|
185 | }
|
---|
186 | else
|
---|
187 | {
|
---|
188 | treeInfos.UnexpandedNodes++;
|
---|
189 | }
|
---|
190 | }
|
---|
191 | }
|
---|
192 | }
|
---|
193 |
|
---|
194 | public TreeInfos GetTreeInfos()
|
---|
195 | {
|
---|
196 | TreeInfos treeInfos = new TreeInfos();
|
---|
197 |
|
---|
198 | if (rootNode != null)
|
---|
199 | {
|
---|
200 | treeInfos.TotalNodes++;
|
---|
201 | if (rootNode.children != null)
|
---|
202 | {
|
---|
203 | treeInfos.ExpandedNodes++;
|
---|
204 | treeInfos.DeepestLevel = rootNode.level + 1;
|
---|
205 | GetTreeInfosRek(treeInfos, rootNode.children);
|
---|
206 | }
|
---|
207 | else
|
---|
208 | {
|
---|
209 | treeInfos.DeepestLevel = rootNode.level;
|
---|
210 | if (grammar.IsTerminal(rootNode.phrase))
|
---|
211 | {
|
---|
212 | treeInfos.LeaveNodes++;
|
---|
213 | }
|
---|
214 | else
|
---|
215 | {
|
---|
216 | treeInfos.UnexpandedNodes++;
|
---|
217 | }
|
---|
218 | }
|
---|
219 | }
|
---|
220 | return treeInfos;
|
---|
221 | }
|
---|
222 |
|
---|
223 | public byte[] GenerateSvg()
|
---|
224 | {
|
---|
225 | if (GetTreeInfos().TotalNodes < 1000)
|
---|
226 | {
|
---|
227 | IGetStartProcessQuery getStartProcessQuery = new GetStartProcessQuery();
|
---|
228 | IGetProcessStartInfoQuery getProcessStartInfoQuery = new GetProcessStartInfoQuery();
|
---|
229 | IRegisterLayoutPluginCommand registerLayoutPluginCommand =
|
---|
230 | new RegisterLayoutPluginCommand(getProcessStartInfoQuery, getStartProcessQuery);
|
---|
231 |
|
---|
232 | GraphGeneration wrapper = new GraphGeneration(getStartProcessQuery,
|
---|
233 | getProcessStartInfoQuery,
|
---|
234 | registerLayoutPluginCommand);
|
---|
235 | wrapper.GraphvizPath = @"../../../Graphviz2.38/bin";
|
---|
236 | StringBuilder dotFile = new StringBuilder("digraph {");
|
---|
237 | dotFile.AppendLine();
|
---|
238 | dotFile.AppendLine("splines=ortho;");
|
---|
239 | dotFile.AppendLine("concentrate=true;");
|
---|
240 | dotFile.AppendLine("ranksep=1.2;");
|
---|
241 |
|
---|
242 | List<TreeNode> toDoNodes = new List<TreeNode>();
|
---|
243 | if (rootNode != null)
|
---|
244 | {
|
---|
245 | toDoNodes.Add(rootNode);
|
---|
246 | // declare node
|
---|
247 | string hexColor = GetHexNodeColor(Color.White, Color.OrangeRed, rootNode.actionInfo.Value);
|
---|
248 |
|
---|
249 | dotFile.AppendLine(
|
---|
250 | string.Format("{0} [label=\"{1}\\n{2:0.00}/{3}\", style=filled, fillcolor=\"{4}\"]",
|
---|
251 | rootNode.GetHashCode(),
|
---|
252 | rootNode.phrase, rootNode.actionInfo.Value, rootNode.actionInfo.Tries, hexColor));
|
---|
253 | }
|
---|
254 |
|
---|
255 | // to put nodes on the same level in graph
|
---|
256 | Dictionary<int, List<TreeNode>> levelMap = new Dictionary<int, List<TreeNode>>();
|
---|
257 |
|
---|
258 | List<TreeNode> sameLevelNodes;
|
---|
259 |
|
---|
260 | while (toDoNodes.Any())
|
---|
261 | {
|
---|
262 | TreeNode currentNode = toDoNodes[0];
|
---|
263 | toDoNodes.RemoveAt(0);
|
---|
264 | // put currentNode into levelMap
|
---|
265 | if (levelMap.TryGetValue(currentNode.level, out sameLevelNodes))
|
---|
266 | {
|
---|
267 | sameLevelNodes.Add(currentNode);
|
---|
268 | }
|
---|
269 | else
|
---|
270 | {
|
---|
271 | sameLevelNodes = new List<TreeNode>();
|
---|
272 | sameLevelNodes.Add(currentNode);
|
---|
273 | levelMap.Add(currentNode.level, sameLevelNodes);
|
---|
274 | }
|
---|
275 |
|
---|
276 | // draw line from current node to all its children
|
---|
277 | if (currentNode.children != null)
|
---|
278 | {
|
---|
279 | foreach (TreeNode childNode in currentNode.children)
|
---|
280 | {
|
---|
281 | toDoNodes.Add(childNode);
|
---|
282 | // declare node
|
---|
283 |
|
---|
284 | string hexColor = GetHexNodeColor(Color.White, Color.OrangeRed, childNode.actionInfo.Value);
|
---|
285 | dotFile.AppendLine(
|
---|
286 | string.Format("{0} [label=\"{1}\\n{2:0.00}/{3}\", style=filled, fillcolor=\"{4}\"]",
|
---|
287 | childNode.GetHashCode(),
|
---|
288 | childNode.phrase, childNode.actionInfo.Value, childNode.actionInfo.Tries, hexColor));
|
---|
289 | // add edge
|
---|
290 | dotFile.AppendLine(string.Format("{0} -> {1}", currentNode.GetHashCode(),
|
---|
291 | childNode.GetHashCode()));
|
---|
292 | }
|
---|
293 | }
|
---|
294 | }
|
---|
295 |
|
---|
296 | // set same level ranks..
|
---|
297 | foreach (KeyValuePair<int, List<TreeNode>> entry in levelMap)
|
---|
298 | {
|
---|
299 | dotFile.Append("{rank = same;");
|
---|
300 | foreach (TreeNode node in entry.Value)
|
---|
301 | {
|
---|
302 | dotFile.Append(string.Format(" {0};", node.GetHashCode()));
|
---|
303 | }
|
---|
304 | dotFile.AppendLine("}");
|
---|
305 | }
|
---|
306 |
|
---|
307 | dotFile.Append(" }");
|
---|
308 | byte[] output = wrapper.GenerateGraph(dotFile.ToString(), Enums.GraphReturnType.Svg);
|
---|
309 | return output;
|
---|
310 | }
|
---|
311 | return null;
|
---|
312 | }
|
---|
313 |
|
---|
314 | protected String HexConverter(Color c)
|
---|
315 | {
|
---|
316 | return "#" + c.R.ToString("X2") + c.G.ToString("X2") + c.B.ToString("X2");
|
---|
317 | }
|
---|
318 |
|
---|
319 | protected String GetHexNodeColor(Color weakColor, Color strongColor, double quality)
|
---|
320 | {
|
---|
321 | // convert quality to value between 0 and 1
|
---|
322 | double bestKnownQuality = problem.BestKnownQuality(this.maxLen);
|
---|
323 | double q = quality / bestKnownQuality;
|
---|
324 |
|
---|
325 | // calculate difference between colors
|
---|
326 | byte rDiff = (byte)Math.Abs(weakColor.R - strongColor.R);
|
---|
327 | byte bDiff = (byte)Math.Abs(weakColor.B - strongColor.B);
|
---|
328 | byte gDiff = (byte)Math.Abs(weakColor.G - strongColor.G);
|
---|
329 |
|
---|
330 | byte newR = weakColor.R > strongColor.R
|
---|
331 | ? Convert.ToByte(weakColor.R - Math.Round(rDiff * q))
|
---|
332 | : Convert.ToByte(weakColor.R + Math.Round(rDiff * q));
|
---|
333 |
|
---|
334 | byte newB = weakColor.B > strongColor.B
|
---|
335 | ? Convert.ToByte(weakColor.B - Math.Round(bDiff * q))
|
---|
336 | : Convert.ToByte(weakColor.B + Math.Round(bDiff * q));
|
---|
337 |
|
---|
338 | byte newG = weakColor.G > strongColor.G
|
---|
339 | ? Convert.ToByte(weakColor.G - Math.Round(gDiff * q))
|
---|
340 | : Convert.ToByte(weakColor.G + Math.Round(gDiff * q));
|
---|
341 |
|
---|
342 | return HexConverter(Color.FromArgb(newR, newG, newB));
|
---|
343 | }
|
---|
344 |
|
---|
345 | public void FreeAll()
|
---|
346 | {
|
---|
347 | rootNode = null;
|
---|
348 | GC.Collect();
|
---|
349 | }
|
---|
350 | }
|
---|
351 | }
|
---|