1 | using System;
|
---|
2 | using System.Collections.Generic;
|
---|
3 | using System.Drawing;
|
---|
4 | using System.Linq;
|
---|
5 | using System.Net.Mail;
|
---|
6 | using System.Text;
|
---|
7 | using System.Threading;
|
---|
8 | using GraphVizWrapper;
|
---|
9 | using GraphVizWrapper.Commands;
|
---|
10 | using GraphVizWrapper.Queries;
|
---|
11 | using HeuristicLab.Algorithms.Bandits;
|
---|
12 | using HeuristicLab.Algorithms.GrammaticalOptimization;
|
---|
13 | using HeuristicLab.Algorithms.MonteCarloTreeSearch.Base;
|
---|
14 | using HeuristicLab.Algorithms.MonteCarloTreeSearch.Simulation;
|
---|
15 | using HeuristicLab.Common;
|
---|
16 | using HeuristicLab.Problems.GrammaticalOptimization;
|
---|
17 |
|
---|
18 | namespace HeuristicLab.Algorithms.MonteCarloTreeSearch
|
---|
19 | {
|
---|
20 | public class MonteCarloTreeSearch : SolverBase
|
---|
21 | {
|
---|
22 | protected readonly int maxLen;
|
---|
23 | protected readonly IProblem problem;
|
---|
24 | protected readonly IGrammar grammar;
|
---|
25 | protected readonly Random random;
|
---|
26 | protected readonly IBanditPolicy behaviourPolicy;
|
---|
27 | protected readonly ISimulation simulation;
|
---|
28 | protected TreeNode rootNode;
|
---|
29 | protected bool isPaused = false;
|
---|
30 | protected object pauseLock = new object();
|
---|
31 |
|
---|
32 | public MonteCarloTreeSearch(IProblem problem, int maxLen, Random random, IBanditPolicy behaviourPolicy,
|
---|
33 | ISimulation simulationPolicy)
|
---|
34 | {
|
---|
35 | this.problem = problem;
|
---|
36 | this.grammar = problem.Grammar;
|
---|
37 | this.maxLen = maxLen;
|
---|
38 | this.random = random;
|
---|
39 | this.behaviourPolicy = behaviourPolicy;
|
---|
40 | this.simulation = simulationPolicy;
|
---|
41 | }
|
---|
42 |
|
---|
43 | public bool StopRequested { get; set; }
|
---|
44 |
|
---|
45 | public bool IsPaused
|
---|
46 | {
|
---|
47 | get { return this.isPaused; }
|
---|
48 | }
|
---|
49 |
|
---|
50 | public override void Run(int maxIterations)
|
---|
51 | {
|
---|
52 | Reset();
|
---|
53 | for (int i = 0; !StopRequested && i < maxIterations; i++)
|
---|
54 | {
|
---|
55 | lock (pauseLock)
|
---|
56 | {
|
---|
57 | if (isPaused)
|
---|
58 | {
|
---|
59 | Monitor.Wait(pauseLock);
|
---|
60 | }
|
---|
61 | }
|
---|
62 | TreeNode currentNode = rootNode;
|
---|
63 |
|
---|
64 | while (!currentNode.IsLeaf())
|
---|
65 | {
|
---|
66 | int currentActionIndex = behaviourPolicy.SelectAction(random,
|
---|
67 | currentNode.GetChildActionInfos());
|
---|
68 | currentNode = currentNode.children[currentActionIndex];
|
---|
69 | }
|
---|
70 |
|
---|
71 | string phrase = currentNode.phrase;
|
---|
72 |
|
---|
73 | if (!grammar.IsTerminal(phrase) && phrase.Length <= maxLen)
|
---|
74 | {
|
---|
75 | ExpandTreeNode(currentNode);
|
---|
76 | currentNode =
|
---|
77 | currentNode.children[behaviourPolicy.SelectAction(random, currentNode.GetChildActionInfos())
|
---|
78 | ];
|
---|
79 | }
|
---|
80 | if (currentNode.phrase.Length <= maxLen)
|
---|
81 | {
|
---|
82 | string simulatedPhrase;
|
---|
83 | double quality = simulation.Simulate(currentNode, out simulatedPhrase);
|
---|
84 | OnSolutionEvaluated(simulatedPhrase, quality);
|
---|
85 |
|
---|
86 | Propagate(currentNode, quality);
|
---|
87 | }
|
---|
88 | }
|
---|
89 | }
|
---|
90 |
|
---|
91 | protected void ExpandTreeNode(TreeNode treeNode)
|
---|
92 | {
|
---|
93 | // create children on the first visit
|
---|
94 | if (treeNode.children == null)
|
---|
95 | {
|
---|
96 | treeNode.children = new List<TreeNode>();
|
---|
97 |
|
---|
98 | var phrase = new Sequence(treeNode.phrase);
|
---|
99 | // create subnodes for each nt-symbol in phrase
|
---|
100 | for (int i = 0; i < phrase.Length; i++)
|
---|
101 | {
|
---|
102 | char symbol = phrase[i];
|
---|
103 | if (grammar.IsNonTerminal(symbol))
|
---|
104 | {
|
---|
105 | // create subnode for each alternative of symbol
|
---|
106 | foreach (Sequence alternative in grammar.GetAlternatives(symbol))
|
---|
107 | {
|
---|
108 | Sequence newSequence = new Sequence(phrase);
|
---|
109 | newSequence.ReplaceAt(i, 1, alternative);
|
---|
110 | if (newSequence.Length <= maxLen)
|
---|
111 | {
|
---|
112 | TreeNode childNode = new TreeNode(treeNode, newSequence.ToString(),
|
---|
113 | behaviourPolicy.CreateActionInfo(), treeNode.level + 1);
|
---|
114 | treeNode.children.Add(childNode);
|
---|
115 | }
|
---|
116 | }
|
---|
117 | }
|
---|
118 | }
|
---|
119 | }
|
---|
120 | }
|
---|
121 |
|
---|
122 | protected void Reset()
|
---|
123 | {
|
---|
124 | StopRequested = false;
|
---|
125 | bestQuality = 0.0;
|
---|
126 | rootNode = new TreeNode(null, grammar.SentenceSymbol.ToString(), behaviourPolicy.CreateActionInfo(), 0);
|
---|
127 | }
|
---|
128 |
|
---|
129 | protected void Propagate(TreeNode node, double quality)
|
---|
130 | {
|
---|
131 | var currentNode = node;
|
---|
132 | do
|
---|
133 | {
|
---|
134 | currentNode.actionInfo.UpdateReward(quality);
|
---|
135 | currentNode = currentNode.parent;
|
---|
136 | } while (currentNode != null);
|
---|
137 | }
|
---|
138 |
|
---|
139 | private void GetTreeInfosRek(TreeInfos treeInfos, List<TreeNode> children)
|
---|
140 | {
|
---|
141 | treeInfos.TotalNodes += children.Count;
|
---|
142 | foreach (TreeNode child in children)
|
---|
143 | {
|
---|
144 | if (child.children != null)
|
---|
145 | {
|
---|
146 | treeInfos.ExpandedNodes++;
|
---|
147 | if (treeInfos.DeepestLevel <= child.level)
|
---|
148 | {
|
---|
149 | treeInfos.DeepestLevel = child.level + 1;
|
---|
150 | }
|
---|
151 | GetTreeInfosRek(treeInfos, child.children);
|
---|
152 | }
|
---|
153 | else
|
---|
154 | {
|
---|
155 | if (grammar.IsTerminal(child.phrase))
|
---|
156 | {
|
---|
157 | treeInfos.LeaveNodes++;
|
---|
158 | }
|
---|
159 | else
|
---|
160 | {
|
---|
161 | treeInfos.UnexpandedNodes++;
|
---|
162 | }
|
---|
163 | }
|
---|
164 | }
|
---|
165 | }
|
---|
166 |
|
---|
167 | public TreeInfos GetTreeInfos()
|
---|
168 | {
|
---|
169 | TreeInfos treeInfos = new TreeInfos();
|
---|
170 |
|
---|
171 | if (rootNode != null)
|
---|
172 | {
|
---|
173 | treeInfos.TotalNodes++;
|
---|
174 | if (rootNode.children != null)
|
---|
175 | {
|
---|
176 | treeInfos.ExpandedNodes++;
|
---|
177 | treeInfos.DeepestLevel = rootNode.level + 1;
|
---|
178 | GetTreeInfosRek(treeInfos, rootNode.children);
|
---|
179 | }
|
---|
180 | else
|
---|
181 | {
|
---|
182 | treeInfos.DeepestLevel = rootNode.level;
|
---|
183 | if (grammar.IsTerminal(rootNode.phrase))
|
---|
184 | {
|
---|
185 | treeInfos.LeaveNodes++;
|
---|
186 | }
|
---|
187 | else
|
---|
188 | {
|
---|
189 | treeInfos.UnexpandedNodes++;
|
---|
190 | }
|
---|
191 | }
|
---|
192 | }
|
---|
193 | return treeInfos;
|
---|
194 | }
|
---|
195 |
|
---|
196 | public byte[] GenerateSvg()
|
---|
197 | {
|
---|
198 | if (GetTreeInfos().TotalNodes < 1000)
|
---|
199 | {
|
---|
200 | IGetStartProcessQuery getStartProcessQuery = new GetStartProcessQuery();
|
---|
201 | IGetProcessStartInfoQuery getProcessStartInfoQuery = new GetProcessStartInfoQuery();
|
---|
202 | IRegisterLayoutPluginCommand registerLayoutPluginCommand =
|
---|
203 | new RegisterLayoutPluginCommand(getProcessStartInfoQuery, getStartProcessQuery);
|
---|
204 |
|
---|
205 | GraphGeneration wrapper = new GraphGeneration(getStartProcessQuery,
|
---|
206 | getProcessStartInfoQuery,
|
---|
207 | registerLayoutPluginCommand);
|
---|
208 | wrapper.GraphvizPath = @"../../../Graphviz2.38/bin";
|
---|
209 | StringBuilder dotFile = new StringBuilder("digraph {");
|
---|
210 | dotFile.AppendLine();
|
---|
211 | dotFile.AppendLine("splines=ortho;");
|
---|
212 | dotFile.AppendLine("concentrate=true;");
|
---|
213 | dotFile.AppendLine("ranksep=1.2;");
|
---|
214 |
|
---|
215 | List<TreeNode> toDoNodes = new List<TreeNode>();
|
---|
216 | if (rootNode != null)
|
---|
217 | {
|
---|
218 | toDoNodes.Add(rootNode);
|
---|
219 | // declare node
|
---|
220 | string hexColor = GetHexNodeColor(Color.White, Color.OrangeRed, rootNode.actionInfo.Value);
|
---|
221 |
|
---|
222 | dotFile.AppendLine(
|
---|
223 | string.Format("{0} [label=\"{1}\\n{2:0.00}/{3}\", style=filled, fillcolor=\"{4}\"]",
|
---|
224 | rootNode.GetHashCode(),
|
---|
225 | rootNode.phrase, rootNode.actionInfo.Value, rootNode.actionInfo.Tries, hexColor));
|
---|
226 | }
|
---|
227 |
|
---|
228 | // to put nodes on the same level in graph
|
---|
229 | Dictionary<int, List<TreeNode>> levelMap = new Dictionary<int, List<TreeNode>>();
|
---|
230 |
|
---|
231 | List<TreeNode> sameLevelNodes;
|
---|
232 |
|
---|
233 | while (toDoNodes.Any())
|
---|
234 | {
|
---|
235 | TreeNode currentNode = toDoNodes[0];
|
---|
236 | toDoNodes.RemoveAt(0);
|
---|
237 | // put currentNode into levelMap
|
---|
238 | if (levelMap.TryGetValue(currentNode.level, out sameLevelNodes))
|
---|
239 | {
|
---|
240 | sameLevelNodes.Add(currentNode);
|
---|
241 | }
|
---|
242 | else
|
---|
243 | {
|
---|
244 | sameLevelNodes = new List<TreeNode>();
|
---|
245 | sameLevelNodes.Add(currentNode);
|
---|
246 | levelMap.Add(currentNode.level, sameLevelNodes);
|
---|
247 | }
|
---|
248 |
|
---|
249 | // draw line from current node to all its children
|
---|
250 | if (currentNode.children != null)
|
---|
251 | {
|
---|
252 | foreach (TreeNode childNode in currentNode.children)
|
---|
253 | {
|
---|
254 | toDoNodes.Add(childNode);
|
---|
255 | // declare node
|
---|
256 |
|
---|
257 | string hexColor = GetHexNodeColor(Color.White, Color.OrangeRed, childNode.actionInfo.Value);
|
---|
258 | dotFile.AppendLine(
|
---|
259 | string.Format("{0} [label=\"{1}\\n{2:0.00}/{3}\", style=filled, fillcolor=\"{4}\"]",
|
---|
260 | childNode.GetHashCode(),
|
---|
261 | childNode.phrase, childNode.actionInfo.Value, childNode.actionInfo.Tries, hexColor));
|
---|
262 | // add edge
|
---|
263 | dotFile.AppendLine(string.Format("{0} -> {1}", currentNode.GetHashCode(),
|
---|
264 | childNode.GetHashCode()));
|
---|
265 | }
|
---|
266 | }
|
---|
267 | }
|
---|
268 |
|
---|
269 | // set same level ranks..
|
---|
270 | foreach (KeyValuePair<int, List<TreeNode>> entry in levelMap)
|
---|
271 | {
|
---|
272 | dotFile.Append("{rank = same;");
|
---|
273 | foreach (TreeNode node in entry.Value)
|
---|
274 | {
|
---|
275 | dotFile.Append(string.Format(" {0};", node.GetHashCode()));
|
---|
276 | }
|
---|
277 | dotFile.AppendLine("}");
|
---|
278 | }
|
---|
279 |
|
---|
280 | dotFile.Append(" }");
|
---|
281 | byte[] output = wrapper.GenerateGraph(dotFile.ToString(), Enums.GraphReturnType.Svg);
|
---|
282 | return output;
|
---|
283 | }
|
---|
284 | return null;
|
---|
285 | }
|
---|
286 |
|
---|
287 | protected String HexConverter(Color c)
|
---|
288 | {
|
---|
289 | return "#" + c.R.ToString("X2") + c.G.ToString("X2") + c.B.ToString("X2");
|
---|
290 | }
|
---|
291 |
|
---|
292 | protected String GetHexNodeColor(Color weakColor, Color strongColor, double quality)
|
---|
293 | {
|
---|
294 | // convert quality to value between 0 and 1
|
---|
295 | double bestKnownQuality = problem.BestKnownQuality(this.maxLen);
|
---|
296 | double q = quality / bestKnownQuality;
|
---|
297 |
|
---|
298 | // calculate difference between colors
|
---|
299 | byte rDiff = (byte)Math.Abs(weakColor.R - strongColor.R);
|
---|
300 | byte bDiff = (byte)Math.Abs(weakColor.B - strongColor.B);
|
---|
301 | byte gDiff = (byte)Math.Abs(weakColor.G - strongColor.G);
|
---|
302 |
|
---|
303 | byte newR = weakColor.R > strongColor.R
|
---|
304 | ? Convert.ToByte(weakColor.R - Math.Round(rDiff * q))
|
---|
305 | : Convert.ToByte(weakColor.R + Math.Round(rDiff * q));
|
---|
306 |
|
---|
307 | byte newB = weakColor.B > strongColor.B
|
---|
308 | ? Convert.ToByte(weakColor.B - Math.Round(bDiff * q))
|
---|
309 | : Convert.ToByte(weakColor.B + Math.Round(bDiff * q));
|
---|
310 |
|
---|
311 | byte newG = weakColor.G > strongColor.G
|
---|
312 | ? Convert.ToByte(weakColor.G - Math.Round(gDiff * q))
|
---|
313 | : Convert.ToByte(weakColor.G + Math.Round(gDiff * q));
|
---|
314 |
|
---|
315 | return HexConverter(Color.FromArgb(newR, newG, newB));
|
---|
316 | }
|
---|
317 |
|
---|
318 | public void PauseContinue()
|
---|
319 | {
|
---|
320 | lock (pauseLock)
|
---|
321 | {
|
---|
322 | if (isPaused)
|
---|
323 | {
|
---|
324 | isPaused = false;
|
---|
325 | Monitor.Pulse(pauseLock);
|
---|
326 | }
|
---|
327 | else
|
---|
328 | {
|
---|
329 | isPaused = true;
|
---|
330 | }
|
---|
331 | }
|
---|
332 | }
|
---|
333 | }
|
---|
334 | }
|
---|