1 | using System;
|
---|
2 | using System.Collections.Generic;
|
---|
3 | using System.Diagnostics;
|
---|
4 | using System.Drawing;
|
---|
5 | using System.Linq;
|
---|
6 | using System.Net.Mail;
|
---|
7 | using System.Text;
|
---|
8 | using System.Threading;
|
---|
9 | using GraphVizWrapper;
|
---|
10 | using GraphVizWrapper.Commands;
|
---|
11 | using GraphVizWrapper.Queries;
|
---|
12 | using HeuristicLab.Algorithms.Bandits;
|
---|
13 | using HeuristicLab.Algorithms.GrammaticalOptimization;
|
---|
14 | using HeuristicLab.Algorithms.MonteCarloTreeSearch.Base;
|
---|
15 | using HeuristicLab.Algorithms.MonteCarloTreeSearch.Simulation;
|
---|
16 | using HeuristicLab.Common;
|
---|
17 | using HeuristicLab.Problems.GrammaticalOptimization;
|
---|
18 |
|
---|
19 | namespace HeuristicLab.Algorithms.MonteCarloTreeSearch
|
---|
20 | {
|
---|
21 | public class MonteCarloTreeSearch : SolverBase
|
---|
22 | {
|
---|
23 | protected readonly int maxLen;
|
---|
24 | protected readonly ISymbolicExpressionTreeProblem problem;
|
---|
25 | protected readonly IGrammar grammar;
|
---|
26 | protected readonly Random random;
|
---|
27 | protected readonly IBanditPolicy behaviourPolicy;
|
---|
28 | protected readonly ISimulation simulation;
|
---|
29 | protected TreeNode rootNode;
|
---|
30 | protected int goodSelections;
|
---|
31 |
|
---|
32 | public MonteCarloTreeSearch(ISymbolicExpressionTreeProblem problem, int maxLen, Random random, IBanditPolicy behaviourPolicy,
|
---|
33 | ISimulation simulationPolicy)
|
---|
34 | {
|
---|
35 | this.problem = problem;
|
---|
36 | this.grammar = problem.Grammar;
|
---|
37 | this.maxLen = maxLen;
|
---|
38 | this.random = random;
|
---|
39 | this.behaviourPolicy = behaviourPolicy;
|
---|
40 | this.simulation = simulationPolicy;
|
---|
41 |
|
---|
42 | this.problem.GenerateProblemSolutions(maxLen);
|
---|
43 | }
|
---|
44 |
|
---|
45 | public event Action<int, int> NodeSelected;
|
---|
46 |
|
---|
47 | protected virtual void OnNodeSelectedChanged(int value, int selections)
|
---|
48 | {
|
---|
49 | RaiseNodeSelectedChanged(value, selections);
|
---|
50 | }
|
---|
51 |
|
---|
52 | private void RaiseNodeSelectedChanged(int value, int selections)
|
---|
53 | {
|
---|
54 | var handler = NodeSelected;
|
---|
55 | if (handler != null) handler(value, selections);
|
---|
56 | }
|
---|
57 |
|
---|
58 | public bool StopRequested { get; set; }
|
---|
59 |
|
---|
60 | public int GoodSelections { get { return this.goodSelections; } }
|
---|
61 |
|
---|
62 | protected void CheckSelection(TreeNode node, int selections)
|
---|
63 | {
|
---|
64 | if (problem.IsParentOfProblemSolution(node.phrase, maxLen))
|
---|
65 | {
|
---|
66 | goodSelections++;
|
---|
67 | }
|
---|
68 |
|
---|
69 | OnNodeSelectedChanged(goodSelections, selections);
|
---|
70 | }
|
---|
71 |
|
---|
72 | public override void Run(int maxIterations)
|
---|
73 | {
|
---|
74 | Reset();
|
---|
75 | int selections = 0;
|
---|
76 | TreeNode currentNode;
|
---|
77 | string phrase;
|
---|
78 | string simulatedPhrase;
|
---|
79 | double quality;
|
---|
80 |
|
---|
81 | for (int i = 0; !StopRequested && i < maxIterations; i++)
|
---|
82 | {
|
---|
83 | currentNode = rootNode;
|
---|
84 |
|
---|
85 | while (!currentNode.IsLeaf())
|
---|
86 | {
|
---|
87 | int currentActionIndex = behaviourPolicy.SelectAction(random,
|
---|
88 | currentNode.GetChildActionInfos());
|
---|
89 | currentNode = currentNode.children[currentActionIndex];
|
---|
90 | //selections++;
|
---|
91 | //CheckSelection(currentNode, selections);
|
---|
92 | }
|
---|
93 |
|
---|
94 | phrase = currentNode.phrase;
|
---|
95 |
|
---|
96 | if (!grammar.IsTerminal(phrase) && phrase.Length <= maxLen)
|
---|
97 | {
|
---|
98 | ExpandTreeNode(currentNode);
|
---|
99 | currentNode =
|
---|
100 | currentNode.children[behaviourPolicy.SelectAction(random, currentNode.GetChildActionInfos())
|
---|
101 | ];
|
---|
102 | //selections++;
|
---|
103 | //CheckSelection(currentNode, selections);
|
---|
104 | }
|
---|
105 | if (currentNode.phrase.Length <= maxLen)
|
---|
106 | {
|
---|
107 | quality = simulation.Simulate(currentNode, out simulatedPhrase);
|
---|
108 | OnSolutionEvaluated(simulatedPhrase, quality);
|
---|
109 |
|
---|
110 | Propagate(currentNode, quality);
|
---|
111 | }
|
---|
112 | }
|
---|
113 | }
|
---|
114 |
|
---|
115 | protected void ExpandTreeNode(TreeNode treeNode)
|
---|
116 | {
|
---|
117 | // create children on the first visit
|
---|
118 | Sequence newSequence;
|
---|
119 | TreeNode childNode;
|
---|
120 | if (treeNode.children == null)
|
---|
121 | {
|
---|
122 | List<TreeNode> newChildren = new List<TreeNode>();
|
---|
123 |
|
---|
124 | var phrase = new Sequence(treeNode.phrase);
|
---|
125 | // create subnodes for each nt-symbol in phrase
|
---|
126 | for (int i = 0; i < phrase.Length; i++)
|
---|
127 | {
|
---|
128 | char symbol = phrase[i];
|
---|
129 | if (grammar.IsNonTerminal(symbol))
|
---|
130 | {
|
---|
131 | // create subnode for each alternative of symbol
|
---|
132 | foreach (Sequence alternative in grammar.GetAlternatives(symbol))
|
---|
133 | {
|
---|
134 | newSequence = new Sequence(phrase);
|
---|
135 | newSequence.ReplaceAt(i, 1, alternative);
|
---|
136 | if (newSequence.Length <= maxLen)
|
---|
137 | {
|
---|
138 | childNode = new TreeNode(treeNode, newSequence.ToString(),
|
---|
139 | behaviourPolicy.CreateActionInfo(), (ushort) (treeNode.level + 1));
|
---|
140 | newChildren.Add(childNode);
|
---|
141 | }
|
---|
142 | }
|
---|
143 | }
|
---|
144 | }
|
---|
145 | treeNode.children = newChildren.ToArray();
|
---|
146 | }
|
---|
147 | }
|
---|
148 |
|
---|
149 | protected void Reset()
|
---|
150 | {
|
---|
151 | goodSelections = 0;
|
---|
152 | StopRequested = false;
|
---|
153 | bestQuality = 0.0;
|
---|
154 | rootNode = new TreeNode(null, grammar.SentenceSymbol.ToString(), behaviourPolicy.CreateActionInfo(), 0);
|
---|
155 | }
|
---|
156 |
|
---|
157 | protected void Propagate(TreeNode node, double quality)
|
---|
158 | {
|
---|
159 | var currentNode = node;
|
---|
160 | do
|
---|
161 | {
|
---|
162 | currentNode.actionInfo.UpdateReward(quality);
|
---|
163 | currentNode = currentNode.parent;
|
---|
164 | } while (currentNode != null);
|
---|
165 | }
|
---|
166 |
|
---|
167 | private void GetTreeInfosRek(TreeInfos treeInfos, TreeNode[] children)
|
---|
168 | {
|
---|
169 | treeInfos.TotalNodes += children.Length;
|
---|
170 | foreach (TreeNode child in children)
|
---|
171 | {
|
---|
172 | if (child.children != null)
|
---|
173 | {
|
---|
174 | treeInfos.ExpandedNodes++;
|
---|
175 | if (treeInfos.DeepestLevel <= child.level)
|
---|
176 | {
|
---|
177 | treeInfos.DeepestLevel = child.level + 1;
|
---|
178 | }
|
---|
179 | GetTreeInfosRek(treeInfos, child.children);
|
---|
180 | }
|
---|
181 | else
|
---|
182 | {
|
---|
183 | if (grammar.IsTerminal(child.phrase))
|
---|
184 | {
|
---|
185 | treeInfos.LeaveNodes++;
|
---|
186 | }
|
---|
187 | else
|
---|
188 | {
|
---|
189 | treeInfos.UnexpandedNodes++;
|
---|
190 | }
|
---|
191 | }
|
---|
192 | }
|
---|
193 | }
|
---|
194 |
|
---|
195 | public TreeInfos GetTreeInfos()
|
---|
196 | {
|
---|
197 | TreeInfos treeInfos = new TreeInfos();
|
---|
198 |
|
---|
199 | if (rootNode != null)
|
---|
200 | {
|
---|
201 | treeInfos.TotalNodes++;
|
---|
202 | if (rootNode.children != null)
|
---|
203 | {
|
---|
204 | treeInfos.ExpandedNodes++;
|
---|
205 | treeInfos.DeepestLevel = rootNode.level + 1;
|
---|
206 | GetTreeInfosRek(treeInfos, rootNode.children);
|
---|
207 | }
|
---|
208 | else
|
---|
209 | {
|
---|
210 | treeInfos.DeepestLevel = rootNode.level;
|
---|
211 | if (grammar.IsTerminal(rootNode.phrase))
|
---|
212 | {
|
---|
213 | treeInfos.LeaveNodes++;
|
---|
214 | }
|
---|
215 | else
|
---|
216 | {
|
---|
217 | treeInfos.UnexpandedNodes++;
|
---|
218 | }
|
---|
219 | }
|
---|
220 | }
|
---|
221 | return treeInfos;
|
---|
222 | }
|
---|
223 |
|
---|
224 | public byte[] GenerateSvg()
|
---|
225 | {
|
---|
226 | if (GetTreeInfos().TotalNodes < 1000)
|
---|
227 | {
|
---|
228 | IGetStartProcessQuery getStartProcessQuery = new GetStartProcessQuery();
|
---|
229 | IGetProcessStartInfoQuery getProcessStartInfoQuery = new GetProcessStartInfoQuery();
|
---|
230 | IRegisterLayoutPluginCommand registerLayoutPluginCommand =
|
---|
231 | new RegisterLayoutPluginCommand(getProcessStartInfoQuery, getStartProcessQuery);
|
---|
232 |
|
---|
233 | GraphGeneration wrapper = new GraphGeneration(getStartProcessQuery,
|
---|
234 | getProcessStartInfoQuery,
|
---|
235 | registerLayoutPluginCommand);
|
---|
236 | wrapper.GraphvizPath = @"../../../Graphviz2.38/bin";
|
---|
237 | StringBuilder dotFile = new StringBuilder("digraph {");
|
---|
238 | dotFile.AppendLine();
|
---|
239 | dotFile.AppendLine("splines=ortho;");
|
---|
240 | dotFile.AppendLine("concentrate=true;");
|
---|
241 | dotFile.AppendLine("ranksep=1.2;");
|
---|
242 |
|
---|
243 | List<TreeNode> toDoNodes = new List<TreeNode>();
|
---|
244 | if (rootNode != null)
|
---|
245 | {
|
---|
246 | toDoNodes.Add(rootNode);
|
---|
247 | // declare node
|
---|
248 | string hexColor = GetHexNodeColor(Color.White, Color.OrangeRed, rootNode.actionInfo.Value);
|
---|
249 |
|
---|
250 | dotFile.AppendLine(
|
---|
251 | string.Format("{0} [label=\"{1}\\n{2:0.00}/{3}\", style=filled, fillcolor=\"{4}\"]",
|
---|
252 | rootNode.GetHashCode(),
|
---|
253 | rootNode.phrase, rootNode.actionInfo.Value, rootNode.actionInfo.Tries, hexColor));
|
---|
254 | }
|
---|
255 |
|
---|
256 | // to put nodes on the same level in graph
|
---|
257 | Dictionary<int, List<TreeNode>> levelMap = new Dictionary<int, List<TreeNode>>();
|
---|
258 |
|
---|
259 | List<TreeNode> sameLevelNodes;
|
---|
260 |
|
---|
261 | while (toDoNodes.Any())
|
---|
262 | {
|
---|
263 | TreeNode currentNode = toDoNodes[0];
|
---|
264 | toDoNodes.RemoveAt(0);
|
---|
265 | // put currentNode into levelMap
|
---|
266 | if (levelMap.TryGetValue(currentNode.level, out sameLevelNodes))
|
---|
267 | {
|
---|
268 | sameLevelNodes.Add(currentNode);
|
---|
269 | }
|
---|
270 | else
|
---|
271 | {
|
---|
272 | sameLevelNodes = new List<TreeNode>();
|
---|
273 | sameLevelNodes.Add(currentNode);
|
---|
274 | levelMap.Add(currentNode.level, sameLevelNodes);
|
---|
275 | }
|
---|
276 |
|
---|
277 | // draw line from current node to all its children
|
---|
278 | if (currentNode.children != null)
|
---|
279 | {
|
---|
280 | foreach (TreeNode childNode in currentNode.children)
|
---|
281 | {
|
---|
282 | toDoNodes.Add(childNode);
|
---|
283 | // declare node
|
---|
284 |
|
---|
285 | string hexColor = GetHexNodeColor(Color.White, Color.OrangeRed, childNode.actionInfo.Value);
|
---|
286 | dotFile.AppendLine(
|
---|
287 | string.Format("{0} [label=\"{1}\\n{2:0.00}/{3}\", style=filled, fillcolor=\"{4}\"]",
|
---|
288 | childNode.GetHashCode(),
|
---|
289 | childNode.phrase, childNode.actionInfo.Value, childNode.actionInfo.Tries, hexColor));
|
---|
290 | // add edge
|
---|
291 | dotFile.AppendLine(string.Format("{0} -> {1}", currentNode.GetHashCode(),
|
---|
292 | childNode.GetHashCode()));
|
---|
293 | }
|
---|
294 | }
|
---|
295 | }
|
---|
296 |
|
---|
297 | // set same level ranks..
|
---|
298 | foreach (KeyValuePair<int, List<TreeNode>> entry in levelMap)
|
---|
299 | {
|
---|
300 | dotFile.Append("{rank = same;");
|
---|
301 | foreach (TreeNode node in entry.Value)
|
---|
302 | {
|
---|
303 | dotFile.Append(string.Format(" {0};", node.GetHashCode()));
|
---|
304 | }
|
---|
305 | dotFile.AppendLine("}");
|
---|
306 | }
|
---|
307 |
|
---|
308 | dotFile.Append(" }");
|
---|
309 | byte[] output = wrapper.GenerateGraph(dotFile.ToString(), Enums.GraphReturnType.Svg);
|
---|
310 | return output;
|
---|
311 | }
|
---|
312 | return null;
|
---|
313 | }
|
---|
314 |
|
---|
315 | protected String HexConverter(Color c)
|
---|
316 | {
|
---|
317 | return "#" + c.R.ToString("X2") + c.G.ToString("X2") + c.B.ToString("X2");
|
---|
318 | }
|
---|
319 |
|
---|
320 | protected String GetHexNodeColor(Color weakColor, Color strongColor, double quality)
|
---|
321 | {
|
---|
322 | // convert quality to value between 0 and 1
|
---|
323 | double bestKnownQuality = problem.BestKnownQuality(this.maxLen);
|
---|
324 | double q = quality / bestKnownQuality;
|
---|
325 |
|
---|
326 | // calculate difference between colors
|
---|
327 | byte rDiff = (byte)Math.Abs(weakColor.R - strongColor.R);
|
---|
328 | byte bDiff = (byte)Math.Abs(weakColor.B - strongColor.B);
|
---|
329 | byte gDiff = (byte)Math.Abs(weakColor.G - strongColor.G);
|
---|
330 |
|
---|
331 | byte newR = weakColor.R > strongColor.R
|
---|
332 | ? Convert.ToByte(weakColor.R - Math.Round(rDiff * q))
|
---|
333 | : Convert.ToByte(weakColor.R + Math.Round(rDiff * q));
|
---|
334 |
|
---|
335 | byte newB = weakColor.B > strongColor.B
|
---|
336 | ? Convert.ToByte(weakColor.B - Math.Round(bDiff * q))
|
---|
337 | : Convert.ToByte(weakColor.B + Math.Round(bDiff * q));
|
---|
338 |
|
---|
339 | byte newG = weakColor.G > strongColor.G
|
---|
340 | ? Convert.ToByte(weakColor.G - Math.Round(gDiff * q))
|
---|
341 | : Convert.ToByte(weakColor.G + Math.Round(gDiff * q));
|
---|
342 |
|
---|
343 | return HexConverter(Color.FromArgb(newR, newG, newB));
|
---|
344 | }
|
---|
345 |
|
---|
346 | public void FreeAll()
|
---|
347 | {
|
---|
348 | rootNode = null;
|
---|
349 | GC.Collect();
|
---|
350 | }
|
---|
351 | }
|
---|
352 | }
|
---|