[12050] | 1 | using System;
|
---|
| 2 | using System.Collections.Generic;
|
---|
[12815] | 3 | using System.Diagnostics;
|
---|
[12762] | 4 | using System.Drawing;
|
---|
[12050] | 5 | using System.Linq;
|
---|
[12762] | 6 | using System.Net.Mail;
|
---|
| 7 | using System.Text;
|
---|
| 8 | using System.Threading;
|
---|
| 9 | using GraphVizWrapper;
|
---|
| 10 | using GraphVizWrapper.Commands;
|
---|
| 11 | using GraphVizWrapper.Queries;
|
---|
[12050] | 12 | using HeuristicLab.Algorithms.Bandits;
|
---|
[12098] | 13 | using HeuristicLab.Algorithms.GrammaticalOptimization;
|
---|
| 14 | using HeuristicLab.Algorithms.MonteCarloTreeSearch.Base;
|
---|
| 15 | using HeuristicLab.Algorithms.MonteCarloTreeSearch.Simulation;
|
---|
| 16 | using HeuristicLab.Common;
|
---|
[12050] | 17 | using HeuristicLab.Problems.GrammaticalOptimization;
|
---|
| 18 |
|
---|
[12098] | 19 | namespace HeuristicLab.Algorithms.MonteCarloTreeSearch
|
---|
[12050] | 20 | {
|
---|
| 21 | public class MonteCarloTreeSearch : SolverBase
|
---|
| 22 | {
|
---|
[12781] | 23 | protected readonly int maxLen;
|
---|
[12815] | 24 | protected readonly ISymbolicExpressionTreeProblem problem;
|
---|
[12781] | 25 | protected readonly IGrammar grammar;
|
---|
| 26 | protected readonly Random random;
|
---|
| 27 | protected readonly IBanditPolicy behaviourPolicy;
|
---|
| 28 | protected readonly ISimulation simulation;
|
---|
| 29 | protected TreeNode rootNode;
|
---|
| 30 | protected bool isPaused = false;
|
---|
| 31 | protected object pauseLock = new object();
|
---|
[12815] | 32 | protected int goodSelections;
|
---|
[12050] | 33 |
|
---|
[12815] | 34 | public MonteCarloTreeSearch(ISymbolicExpressionTreeProblem problem, int maxLen, Random random, IBanditPolicy behaviourPolicy,
|
---|
[12762] | 35 | ISimulation simulationPolicy)
|
---|
[12050] | 36 | {
|
---|
| 37 | this.problem = problem;
|
---|
[12098] | 38 | this.grammar = problem.Grammar;
|
---|
[12050] | 39 | this.maxLen = maxLen;
|
---|
| 40 | this.random = random;
|
---|
| 41 | this.behaviourPolicy = behaviourPolicy;
|
---|
[12098] | 42 | this.simulation = simulationPolicy;
|
---|
[12815] | 43 |
|
---|
| 44 | this.problem.GenerateProblemSolutions(maxLen);
|
---|
[12050] | 45 | }
|
---|
| 46 |
|
---|
[12815] | 47 | public event Action<int, int> NodeSelected;
|
---|
| 48 |
|
---|
| 49 | protected virtual void OnNodeSelectedChanged(int value, int selections)
|
---|
| 50 | {
|
---|
| 51 | RaiseNodeSelectedChanged(value, selections);
|
---|
| 52 | }
|
---|
| 53 |
|
---|
| 54 | private void RaiseNodeSelectedChanged(int value, int selections)
|
---|
| 55 | {
|
---|
| 56 | var handler = NodeSelected;
|
---|
| 57 | if (handler != null) handler(value, selections);
|
---|
| 58 | }
|
---|
| 59 |
|
---|
[12762] | 60 | public bool StopRequested { get; set; }
|
---|
| 61 |
|
---|
| 62 | public bool IsPaused
|
---|
[12050] | 63 | {
|
---|
[12762] | 64 | get { return this.isPaused; }
|
---|
[12050] | 65 | }
|
---|
| 66 |
|
---|
[12815] | 67 | public int GoodSelections { get { return this.goodSelections; } }
|
---|
| 68 |
|
---|
| 69 | protected void CheckSelection(TreeNode node, int selections)
|
---|
| 70 | {
|
---|
| 71 | if (problem.IsParentOfProblemSolution(node.phrase, maxLen))
|
---|
| 72 | {
|
---|
| 73 | goodSelections++;
|
---|
| 74 | }
|
---|
| 75 |
|
---|
| 76 | OnNodeSelectedChanged(goodSelections, selections);
|
---|
| 77 | }
|
---|
| 78 |
|
---|
[12050] | 79 | public override void Run(int maxIterations)
|
---|
| 80 | {
|
---|
| 81 | Reset();
|
---|
[12815] | 82 | int selections = 0;
|
---|
[12098] | 83 | for (int i = 0; !StopRequested && i < maxIterations; i++)
|
---|
[12050] | 84 | {
|
---|
[12762] | 85 | lock (pauseLock)
|
---|
| 86 | {
|
---|
| 87 | if (isPaused)
|
---|
| 88 | {
|
---|
| 89 | Monitor.Wait(pauseLock);
|
---|
| 90 | }
|
---|
| 91 | }
|
---|
[12098] | 92 | TreeNode currentNode = rootNode;
|
---|
| 93 |
|
---|
| 94 | while (!currentNode.IsLeaf())
|
---|
[12050] | 95 | {
|
---|
[12098] | 96 | int currentActionIndex = behaviourPolicy.SelectAction(random,
|
---|
| 97 | currentNode.GetChildActionInfos());
|
---|
| 98 | currentNode = currentNode.children[currentActionIndex];
|
---|
[12827] | 99 | //selections++;
|
---|
| 100 | //CheckSelection(currentNode, selections);
|
---|
[12098] | 101 | }
|
---|
[12050] | 102 |
|
---|
[12098] | 103 | string phrase = currentNode.phrase;
|
---|
| 104 |
|
---|
[12781] | 105 | if (!grammar.IsTerminal(phrase) && phrase.Length <= maxLen)
|
---|
[12098] | 106 | {
|
---|
| 107 | ExpandTreeNode(currentNode);
|
---|
[12781] | 108 | currentNode =
|
---|
| 109 | currentNode.children[behaviourPolicy.SelectAction(random, currentNode.GetChildActionInfos())
|
---|
| 110 | ];
|
---|
[12827] | 111 | //selections++;
|
---|
| 112 | //CheckSelection(currentNode, selections);
|
---|
[12098] | 113 | }
|
---|
[12503] | 114 | if (currentNode.phrase.Length <= maxLen)
|
---|
| 115 | {
|
---|
[12762] | 116 | string simulatedPhrase;
|
---|
| 117 | double quality = simulation.Simulate(currentNode, out simulatedPhrase);
|
---|
| 118 | OnSolutionEvaluated(simulatedPhrase, quality);
|
---|
[12098] | 119 |
|
---|
[12503] | 120 | Propagate(currentNode, quality);
|
---|
| 121 | }
|
---|
[12050] | 122 | }
|
---|
| 123 | }
|
---|
| 124 |
|
---|
[12781] | 125 | protected void ExpandTreeNode(TreeNode treeNode)
|
---|
[12762] | 126 | {
|
---|
[12098] | 127 | // create children on the first visit
|
---|
| 128 | if (treeNode.children == null)
|
---|
| 129 | {
|
---|
| 130 | treeNode.children = new List<TreeNode>();
|
---|
| 131 |
|
---|
| 132 | var phrase = new Sequence(treeNode.phrase);
|
---|
| 133 | // create subnodes for each nt-symbol in phrase
|
---|
| 134 | for (int i = 0; i < phrase.Length; i++)
|
---|
| 135 | {
|
---|
| 136 | char symbol = phrase[i];
|
---|
| 137 | if (grammar.IsNonTerminal(symbol))
|
---|
| 138 | {
|
---|
| 139 | // create subnode for each alternative of symbol
|
---|
| 140 | foreach (Sequence alternative in grammar.GetAlternatives(symbol))
|
---|
| 141 | {
|
---|
| 142 | Sequence newSequence = new Sequence(phrase);
|
---|
| 143 | newSequence.ReplaceAt(i, 1, alternative);
|
---|
| 144 | if (newSequence.Length <= maxLen)
|
---|
| 145 | {
|
---|
[12762] | 146 | TreeNode childNode = new TreeNode(treeNode, newSequence.ToString(),
|
---|
| 147 | behaviourPolicy.CreateActionInfo(), treeNode.level + 1);
|
---|
[12098] | 148 | treeNode.children.Add(childNode);
|
---|
| 149 | }
|
---|
| 150 | }
|
---|
| 151 | }
|
---|
| 152 | }
|
---|
| 153 | }
|
---|
| 154 | }
|
---|
| 155 |
|
---|
[12781] | 156 | protected void Reset()
|
---|
[12050] | 157 | {
|
---|
[12815] | 158 | goodSelections = 0;
|
---|
[12050] | 159 | StopRequested = false;
|
---|
| 160 | bestQuality = 0.0;
|
---|
[12762] | 161 | rootNode = new TreeNode(null, grammar.SentenceSymbol.ToString(), behaviourPolicy.CreateActionInfo(), 0);
|
---|
[12050] | 162 | }
|
---|
| 163 |
|
---|
[12781] | 164 | protected void Propagate(TreeNode node, double quality)
|
---|
[12050] | 165 | {
|
---|
[12098] | 166 | var currentNode = node;
|
---|
| 167 | do
|
---|
| 168 | {
|
---|
| 169 | currentNode.actionInfo.UpdateReward(quality);
|
---|
| 170 | currentNode = currentNode.parent;
|
---|
| 171 | } while (currentNode != null);
|
---|
[12050] | 172 | }
|
---|
| 173 |
|
---|
[12781] | 174 | private void GetTreeInfosRek(TreeInfos treeInfos, List<TreeNode> children)
|
---|
| 175 | {
|
---|
| 176 | treeInfos.TotalNodes += children.Count;
|
---|
| 177 | foreach (TreeNode child in children)
|
---|
| 178 | {
|
---|
| 179 | if (child.children != null)
|
---|
| 180 | {
|
---|
| 181 | treeInfos.ExpandedNodes++;
|
---|
| 182 | if (treeInfos.DeepestLevel <= child.level)
|
---|
| 183 | {
|
---|
| 184 | treeInfos.DeepestLevel = child.level + 1;
|
---|
| 185 | }
|
---|
| 186 | GetTreeInfosRek(treeInfos, child.children);
|
---|
| 187 | }
|
---|
| 188 | else
|
---|
| 189 | {
|
---|
| 190 | if (grammar.IsTerminal(child.phrase))
|
---|
| 191 | {
|
---|
| 192 | treeInfos.LeaveNodes++;
|
---|
| 193 | }
|
---|
| 194 | else
|
---|
| 195 | {
|
---|
| 196 | treeInfos.UnexpandedNodes++;
|
---|
| 197 | }
|
---|
| 198 | }
|
---|
| 199 | }
|
---|
| 200 | }
|
---|
| 201 |
|
---|
[12762] | 202 | public TreeInfos GetTreeInfos()
|
---|
[12050] | 203 | {
|
---|
[12781] | 204 | TreeInfos treeInfos = new TreeInfos();
|
---|
[12098] | 205 |
|
---|
[12762] | 206 | if (rootNode != null)
|
---|
| 207 | {
|
---|
[12781] | 208 | treeInfos.TotalNodes++;
|
---|
| 209 | if (rootNode.children != null)
|
---|
[12762] | 210 | {
|
---|
[12781] | 211 | treeInfos.ExpandedNodes++;
|
---|
| 212 | treeInfos.DeepestLevel = rootNode.level + 1;
|
---|
| 213 | GetTreeInfosRek(treeInfos, rootNode.children);
|
---|
[12762] | 214 | }
|
---|
| 215 | else
|
---|
| 216 | {
|
---|
[12781] | 217 | treeInfos.DeepestLevel = rootNode.level;
|
---|
| 218 | if (grammar.IsTerminal(rootNode.phrase))
|
---|
[12762] | 219 | {
|
---|
[12781] | 220 | treeInfos.LeaveNodes++;
|
---|
[12762] | 221 | }
|
---|
| 222 | else
|
---|
| 223 | {
|
---|
[12781] | 224 | treeInfos.UnexpandedNodes++;
|
---|
[12762] | 225 | }
|
---|
| 226 | }
|
---|
| 227 | }
|
---|
[12781] | 228 | return treeInfos;
|
---|
[12762] | 229 | }
|
---|
[12098] | 230 |
|
---|
[12762] | 231 | public byte[] GenerateSvg()
|
---|
| 232 | {
|
---|
[12781] | 233 | if (GetTreeInfos().TotalNodes < 1000)
|
---|
| 234 | {
|
---|
| 235 | IGetStartProcessQuery getStartProcessQuery = new GetStartProcessQuery();
|
---|
| 236 | IGetProcessStartInfoQuery getProcessStartInfoQuery = new GetProcessStartInfoQuery();
|
---|
| 237 | IRegisterLayoutPluginCommand registerLayoutPluginCommand =
|
---|
| 238 | new RegisterLayoutPluginCommand(getProcessStartInfoQuery, getStartProcessQuery);
|
---|
[12098] | 239 |
|
---|
[12781] | 240 | GraphGeneration wrapper = new GraphGeneration(getStartProcessQuery,
|
---|
| 241 | getProcessStartInfoQuery,
|
---|
| 242 | registerLayoutPluginCommand);
|
---|
| 243 | wrapper.GraphvizPath = @"../../../Graphviz2.38/bin";
|
---|
| 244 | StringBuilder dotFile = new StringBuilder("digraph {");
|
---|
| 245 | dotFile.AppendLine();
|
---|
| 246 | dotFile.AppendLine("splines=ortho;");
|
---|
| 247 | dotFile.AppendLine("concentrate=true;");
|
---|
| 248 | dotFile.AppendLine("ranksep=1.2;");
|
---|
[12098] | 249 |
|
---|
[12781] | 250 | List<TreeNode> toDoNodes = new List<TreeNode>();
|
---|
| 251 | if (rootNode != null)
|
---|
| 252 | {
|
---|
| 253 | toDoNodes.Add(rootNode);
|
---|
| 254 | // declare node
|
---|
| 255 | string hexColor = GetHexNodeColor(Color.White, Color.OrangeRed, rootNode.actionInfo.Value);
|
---|
[12098] | 256 |
|
---|
[12781] | 257 | dotFile.AppendLine(
|
---|
| 258 | string.Format("{0} [label=\"{1}\\n{2:0.00}/{3}\", style=filled, fillcolor=\"{4}\"]",
|
---|
| 259 | rootNode.GetHashCode(),
|
---|
| 260 | rootNode.phrase, rootNode.actionInfo.Value, rootNode.actionInfo.Tries, hexColor));
|
---|
| 261 | }
|
---|
[12098] | 262 |
|
---|
[12781] | 263 | // to put nodes on the same level in graph
|
---|
| 264 | Dictionary<int, List<TreeNode>> levelMap = new Dictionary<int, List<TreeNode>>();
|
---|
[12762] | 265 |
|
---|
[12781] | 266 | List<TreeNode> sameLevelNodes;
|
---|
[12762] | 267 |
|
---|
[12781] | 268 | while (toDoNodes.Any())
|
---|
[12762] | 269 | {
|
---|
[12781] | 270 | TreeNode currentNode = toDoNodes[0];
|
---|
| 271 | toDoNodes.RemoveAt(0);
|
---|
| 272 | // put currentNode into levelMap
|
---|
| 273 | if (levelMap.TryGetValue(currentNode.level, out sameLevelNodes))
|
---|
| 274 | {
|
---|
| 275 | sameLevelNodes.Add(currentNode);
|
---|
| 276 | }
|
---|
| 277 | else
|
---|
| 278 | {
|
---|
| 279 | sameLevelNodes = new List<TreeNode>();
|
---|
| 280 | sameLevelNodes.Add(currentNode);
|
---|
| 281 | levelMap.Add(currentNode.level, sameLevelNodes);
|
---|
| 282 | }
|
---|
[12762] | 283 |
|
---|
[12781] | 284 | // draw line from current node to all its children
|
---|
| 285 | if (currentNode.children != null)
|
---|
[12762] | 286 | {
|
---|
[12781] | 287 | foreach (TreeNode childNode in currentNode.children)
|
---|
| 288 | {
|
---|
| 289 | toDoNodes.Add(childNode);
|
---|
| 290 | // declare node
|
---|
[12762] | 291 |
|
---|
[12781] | 292 | string hexColor = GetHexNodeColor(Color.White, Color.OrangeRed, childNode.actionInfo.Value);
|
---|
| 293 | dotFile.AppendLine(
|
---|
| 294 | string.Format("{0} [label=\"{1}\\n{2:0.00}/{3}\", style=filled, fillcolor=\"{4}\"]",
|
---|
| 295 | childNode.GetHashCode(),
|
---|
| 296 | childNode.phrase, childNode.actionInfo.Value, childNode.actionInfo.Tries, hexColor));
|
---|
| 297 | // add edge
|
---|
| 298 | dotFile.AppendLine(string.Format("{0} -> {1}", currentNode.GetHashCode(),
|
---|
| 299 | childNode.GetHashCode()));
|
---|
| 300 | }
|
---|
[12762] | 301 | }
|
---|
| 302 | }
|
---|
| 303 |
|
---|
[12781] | 304 | // set same level ranks..
|
---|
| 305 | foreach (KeyValuePair<int, List<TreeNode>> entry in levelMap)
|
---|
[12762] | 306 | {
|
---|
[12781] | 307 | dotFile.Append("{rank = same;");
|
---|
| 308 | foreach (TreeNode node in entry.Value)
|
---|
| 309 | {
|
---|
| 310 | dotFile.Append(string.Format(" {0};", node.GetHashCode()));
|
---|
| 311 | }
|
---|
| 312 | dotFile.AppendLine("}");
|
---|
[12762] | 313 | }
|
---|
[12781] | 314 |
|
---|
| 315 | dotFile.Append(" }");
|
---|
| 316 | byte[] output = wrapper.GenerateGraph(dotFile.ToString(), Enums.GraphReturnType.Svg);
|
---|
| 317 | return output;
|
---|
[12762] | 318 | }
|
---|
[12781] | 319 | return null;
|
---|
[12050] | 320 | }
|
---|
| 321 |
|
---|
[12781] | 322 | protected String HexConverter(Color c)
|
---|
[12050] | 323 | {
|
---|
[12762] | 324 | return "#" + c.R.ToString("X2") + c.G.ToString("X2") + c.B.ToString("X2");
|
---|
[12050] | 325 | }
|
---|
[12762] | 326 |
|
---|
[12781] | 327 | protected String GetHexNodeColor(Color weakColor, Color strongColor, double quality)
|
---|
[12762] | 328 | {
|
---|
| 329 | // convert quality to value between 0 and 1
|
---|
| 330 | double bestKnownQuality = problem.BestKnownQuality(this.maxLen);
|
---|
| 331 | double q = quality / bestKnownQuality;
|
---|
| 332 |
|
---|
| 333 | // calculate difference between colors
|
---|
| 334 | byte rDiff = (byte)Math.Abs(weakColor.R - strongColor.R);
|
---|
| 335 | byte bDiff = (byte)Math.Abs(weakColor.B - strongColor.B);
|
---|
| 336 | byte gDiff = (byte)Math.Abs(weakColor.G - strongColor.G);
|
---|
| 337 |
|
---|
| 338 | byte newR = weakColor.R > strongColor.R
|
---|
| 339 | ? Convert.ToByte(weakColor.R - Math.Round(rDiff * q))
|
---|
| 340 | : Convert.ToByte(weakColor.R + Math.Round(rDiff * q));
|
---|
| 341 |
|
---|
| 342 | byte newB = weakColor.B > strongColor.B
|
---|
| 343 | ? Convert.ToByte(weakColor.B - Math.Round(bDiff * q))
|
---|
| 344 | : Convert.ToByte(weakColor.B + Math.Round(bDiff * q));
|
---|
| 345 |
|
---|
| 346 | byte newG = weakColor.G > strongColor.G
|
---|
| 347 | ? Convert.ToByte(weakColor.G - Math.Round(gDiff * q))
|
---|
| 348 | : Convert.ToByte(weakColor.G + Math.Round(gDiff * q));
|
---|
| 349 |
|
---|
| 350 | return HexConverter(Color.FromArgb(newR, newG, newB));
|
---|
| 351 | }
|
---|
| 352 |
|
---|
| 353 | public void PauseContinue()
|
---|
| 354 | {
|
---|
| 355 | lock (pauseLock)
|
---|
| 356 | {
|
---|
| 357 | if (isPaused)
|
---|
| 358 | {
|
---|
| 359 | isPaused = false;
|
---|
| 360 | Monitor.Pulse(pauseLock);
|
---|
| 361 | }
|
---|
| 362 | else
|
---|
| 363 | {
|
---|
| 364 | isPaused = true;
|
---|
| 365 | }
|
---|
| 366 | }
|
---|
| 367 | }
|
---|
[12050] | 368 | }
|
---|
| 369 | }
|
---|