[12050] | 1 | using System;
|
---|
| 2 | using System.Collections.Generic;
|
---|
[12762] | 3 | using System.Drawing;
|
---|
[12050] | 4 | using System.Linq;
|
---|
[12762] | 5 | using System.Net.Mail;
|
---|
| 6 | using System.Text;
|
---|
| 7 | using System.Threading;
|
---|
| 8 | using GraphVizWrapper;
|
---|
| 9 | using GraphVizWrapper.Commands;
|
---|
| 10 | using GraphVizWrapper.Queries;
|
---|
[12050] | 11 | using HeuristicLab.Algorithms.Bandits;
|
---|
[12098] | 12 | using HeuristicLab.Algorithms.GrammaticalOptimization;
|
---|
| 13 | using HeuristicLab.Algorithms.MonteCarloTreeSearch.Base;
|
---|
| 14 | using HeuristicLab.Algorithms.MonteCarloTreeSearch.Simulation;
|
---|
| 15 | using HeuristicLab.Common;
|
---|
[12050] | 16 | using HeuristicLab.Problems.GrammaticalOptimization;
|
---|
| 17 |
|
---|
[12098] | 18 | namespace HeuristicLab.Algorithms.MonteCarloTreeSearch
|
---|
[12050] | 19 | {
|
---|
| 20 | public class MonteCarloTreeSearch : SolverBase
|
---|
| 21 | {
|
---|
[12781] | 22 | protected readonly int maxLen;
|
---|
| 23 | protected readonly IProblem problem;
|
---|
| 24 | protected readonly IGrammar grammar;
|
---|
| 25 | protected readonly Random random;
|
---|
| 26 | protected readonly IBanditPolicy behaviourPolicy;
|
---|
| 27 | protected readonly ISimulation simulation;
|
---|
| 28 | protected TreeNode rootNode;
|
---|
| 29 | protected bool isPaused = false;
|
---|
| 30 | protected object pauseLock = new object();
|
---|
[12050] | 31 |
|
---|
[12762] | 32 | public MonteCarloTreeSearch(IProblem problem, int maxLen, Random random, IBanditPolicy behaviourPolicy,
|
---|
| 33 | ISimulation simulationPolicy)
|
---|
[12050] | 34 | {
|
---|
| 35 | this.problem = problem;
|
---|
[12098] | 36 | this.grammar = problem.Grammar;
|
---|
[12050] | 37 | this.maxLen = maxLen;
|
---|
| 38 | this.random = random;
|
---|
| 39 | this.behaviourPolicy = behaviourPolicy;
|
---|
[12098] | 40 | this.simulation = simulationPolicy;
|
---|
[12050] | 41 | }
|
---|
| 42 |
|
---|
[12762] | 43 | public bool StopRequested { get; set; }
|
---|
| 44 |
|
---|
| 45 | public bool IsPaused
|
---|
[12050] | 46 | {
|
---|
[12762] | 47 | get { return this.isPaused; }
|
---|
[12050] | 48 | }
|
---|
| 49 |
|
---|
| 50 | public override void Run(int maxIterations)
|
---|
| 51 | {
|
---|
| 52 | Reset();
|
---|
[12098] | 53 | for (int i = 0; !StopRequested && i < maxIterations; i++)
|
---|
[12050] | 54 | {
|
---|
[12762] | 55 | lock (pauseLock)
|
---|
| 56 | {
|
---|
| 57 | if (isPaused)
|
---|
| 58 | {
|
---|
| 59 | Monitor.Wait(pauseLock);
|
---|
| 60 | }
|
---|
| 61 | }
|
---|
[12098] | 62 | TreeNode currentNode = rootNode;
|
---|
| 63 |
|
---|
| 64 | while (!currentNode.IsLeaf())
|
---|
[12050] | 65 | {
|
---|
[12098] | 66 | int currentActionIndex = behaviourPolicy.SelectAction(random,
|
---|
| 67 | currentNode.GetChildActionInfos());
|
---|
| 68 | currentNode = currentNode.children[currentActionIndex];
|
---|
| 69 | }
|
---|
[12050] | 70 |
|
---|
[12098] | 71 | string phrase = currentNode.phrase;
|
---|
| 72 |
|
---|
[12781] | 73 | if (!grammar.IsTerminal(phrase) && phrase.Length <= maxLen)
|
---|
[12098] | 74 | {
|
---|
| 75 | ExpandTreeNode(currentNode);
|
---|
[12781] | 76 | currentNode =
|
---|
| 77 | currentNode.children[behaviourPolicy.SelectAction(random, currentNode.GetChildActionInfos())
|
---|
| 78 | ];
|
---|
[12098] | 79 | }
|
---|
[12503] | 80 | if (currentNode.phrase.Length <= maxLen)
|
---|
| 81 | {
|
---|
[12762] | 82 | string simulatedPhrase;
|
---|
| 83 | double quality = simulation.Simulate(currentNode, out simulatedPhrase);
|
---|
| 84 | OnSolutionEvaluated(simulatedPhrase, quality);
|
---|
[12098] | 85 |
|
---|
[12503] | 86 | Propagate(currentNode, quality);
|
---|
| 87 | }
|
---|
[12050] | 88 | }
|
---|
| 89 | }
|
---|
| 90 |
|
---|
[12781] | 91 | protected void ExpandTreeNode(TreeNode treeNode)
|
---|
[12762] | 92 | {
|
---|
[12098] | 93 | // create children on the first visit
|
---|
| 94 | if (treeNode.children == null)
|
---|
| 95 | {
|
---|
| 96 | treeNode.children = new List<TreeNode>();
|
---|
| 97 |
|
---|
| 98 | var phrase = new Sequence(treeNode.phrase);
|
---|
| 99 | // create subnodes for each nt-symbol in phrase
|
---|
| 100 | for (int i = 0; i < phrase.Length; i++)
|
---|
| 101 | {
|
---|
| 102 | char symbol = phrase[i];
|
---|
| 103 | if (grammar.IsNonTerminal(symbol))
|
---|
| 104 | {
|
---|
| 105 | // create subnode for each alternative of symbol
|
---|
| 106 | foreach (Sequence alternative in grammar.GetAlternatives(symbol))
|
---|
| 107 | {
|
---|
| 108 | Sequence newSequence = new Sequence(phrase);
|
---|
| 109 | newSequence.ReplaceAt(i, 1, alternative);
|
---|
| 110 | if (newSequence.Length <= maxLen)
|
---|
| 111 | {
|
---|
[12762] | 112 | TreeNode childNode = new TreeNode(treeNode, newSequence.ToString(),
|
---|
| 113 | behaviourPolicy.CreateActionInfo(), treeNode.level + 1);
|
---|
[12098] | 114 | treeNode.children.Add(childNode);
|
---|
| 115 | }
|
---|
| 116 | }
|
---|
| 117 | }
|
---|
| 118 | }
|
---|
| 119 | }
|
---|
| 120 | }
|
---|
| 121 |
|
---|
[12781] | 122 | protected void Reset()
|
---|
[12050] | 123 | {
|
---|
| 124 | StopRequested = false;
|
---|
| 125 | bestQuality = 0.0;
|
---|
[12762] | 126 | rootNode = new TreeNode(null, grammar.SentenceSymbol.ToString(), behaviourPolicy.CreateActionInfo(), 0);
|
---|
[12050] | 127 | }
|
---|
| 128 |
|
---|
[12781] | 129 | protected void Propagate(TreeNode node, double quality)
|
---|
[12050] | 130 | {
|
---|
[12098] | 131 | var currentNode = node;
|
---|
| 132 | do
|
---|
| 133 | {
|
---|
| 134 | currentNode.actionInfo.UpdateReward(quality);
|
---|
| 135 | currentNode = currentNode.parent;
|
---|
| 136 | } while (currentNode != null);
|
---|
[12050] | 137 | }
|
---|
| 138 |
|
---|
[12781] | 139 | private void GetTreeInfosRek(TreeInfos treeInfos, List<TreeNode> children)
|
---|
| 140 | {
|
---|
| 141 | treeInfos.TotalNodes += children.Count;
|
---|
| 142 | foreach (TreeNode child in children)
|
---|
| 143 | {
|
---|
| 144 | if (child.children != null)
|
---|
| 145 | {
|
---|
| 146 | treeInfos.ExpandedNodes++;
|
---|
| 147 | if (treeInfos.DeepestLevel <= child.level)
|
---|
| 148 | {
|
---|
| 149 | treeInfos.DeepestLevel = child.level + 1;
|
---|
| 150 | }
|
---|
| 151 | GetTreeInfosRek(treeInfos, child.children);
|
---|
| 152 | }
|
---|
| 153 | else
|
---|
| 154 | {
|
---|
| 155 | if (grammar.IsTerminal(child.phrase))
|
---|
| 156 | {
|
---|
| 157 | treeInfos.LeaveNodes++;
|
---|
| 158 | }
|
---|
| 159 | else
|
---|
| 160 | {
|
---|
| 161 | treeInfos.UnexpandedNodes++;
|
---|
| 162 | }
|
---|
| 163 | }
|
---|
| 164 | }
|
---|
| 165 | }
|
---|
| 166 |
|
---|
[12762] | 167 | public TreeInfos GetTreeInfos()
|
---|
[12050] | 168 | {
|
---|
[12781] | 169 | TreeInfos treeInfos = new TreeInfos();
|
---|
[12098] | 170 |
|
---|
[12762] | 171 | if (rootNode != null)
|
---|
| 172 | {
|
---|
[12781] | 173 | treeInfos.TotalNodes++;
|
---|
| 174 | if (rootNode.children != null)
|
---|
[12762] | 175 | {
|
---|
[12781] | 176 | treeInfos.ExpandedNodes++;
|
---|
| 177 | treeInfos.DeepestLevel = rootNode.level + 1;
|
---|
| 178 | GetTreeInfosRek(treeInfos, rootNode.children);
|
---|
[12762] | 179 | }
|
---|
| 180 | else
|
---|
| 181 | {
|
---|
[12781] | 182 | treeInfos.DeepestLevel = rootNode.level;
|
---|
| 183 | if (grammar.IsTerminal(rootNode.phrase))
|
---|
[12762] | 184 | {
|
---|
[12781] | 185 | treeInfos.LeaveNodes++;
|
---|
[12762] | 186 | }
|
---|
| 187 | else
|
---|
| 188 | {
|
---|
[12781] | 189 | treeInfos.UnexpandedNodes++;
|
---|
[12762] | 190 | }
|
---|
| 191 | }
|
---|
| 192 | }
|
---|
[12781] | 193 | return treeInfos;
|
---|
[12762] | 194 | }
|
---|
[12098] | 195 |
|
---|
[12762] | 196 | public byte[] GenerateSvg()
|
---|
| 197 | {
|
---|
[12781] | 198 | if (GetTreeInfos().TotalNodes < 1000)
|
---|
| 199 | {
|
---|
| 200 | IGetStartProcessQuery getStartProcessQuery = new GetStartProcessQuery();
|
---|
| 201 | IGetProcessStartInfoQuery getProcessStartInfoQuery = new GetProcessStartInfoQuery();
|
---|
| 202 | IRegisterLayoutPluginCommand registerLayoutPluginCommand =
|
---|
| 203 | new RegisterLayoutPluginCommand(getProcessStartInfoQuery, getStartProcessQuery);
|
---|
[12098] | 204 |
|
---|
[12781] | 205 | GraphGeneration wrapper = new GraphGeneration(getStartProcessQuery,
|
---|
| 206 | getProcessStartInfoQuery,
|
---|
| 207 | registerLayoutPluginCommand);
|
---|
| 208 | wrapper.GraphvizPath = @"../../../Graphviz2.38/bin";
|
---|
| 209 | StringBuilder dotFile = new StringBuilder("digraph {");
|
---|
| 210 | dotFile.AppendLine();
|
---|
| 211 | dotFile.AppendLine("splines=ortho;");
|
---|
| 212 | dotFile.AppendLine("concentrate=true;");
|
---|
| 213 | dotFile.AppendLine("ranksep=1.2;");
|
---|
[12098] | 214 |
|
---|
[12781] | 215 | List<TreeNode> toDoNodes = new List<TreeNode>();
|
---|
| 216 | if (rootNode != null)
|
---|
| 217 | {
|
---|
| 218 | toDoNodes.Add(rootNode);
|
---|
| 219 | // declare node
|
---|
| 220 | string hexColor = GetHexNodeColor(Color.White, Color.OrangeRed, rootNode.actionInfo.Value);
|
---|
[12098] | 221 |
|
---|
[12781] | 222 | dotFile.AppendLine(
|
---|
| 223 | string.Format("{0} [label=\"{1}\\n{2:0.00}/{3}\", style=filled, fillcolor=\"{4}\"]",
|
---|
| 224 | rootNode.GetHashCode(),
|
---|
| 225 | rootNode.phrase, rootNode.actionInfo.Value, rootNode.actionInfo.Tries, hexColor));
|
---|
| 226 | }
|
---|
[12098] | 227 |
|
---|
[12781] | 228 | // to put nodes on the same level in graph
|
---|
| 229 | Dictionary<int, List<TreeNode>> levelMap = new Dictionary<int, List<TreeNode>>();
|
---|
[12762] | 230 |
|
---|
[12781] | 231 | List<TreeNode> sameLevelNodes;
|
---|
[12762] | 232 |
|
---|
[12781] | 233 | while (toDoNodes.Any())
|
---|
[12762] | 234 | {
|
---|
[12781] | 235 | TreeNode currentNode = toDoNodes[0];
|
---|
| 236 | toDoNodes.RemoveAt(0);
|
---|
| 237 | // put currentNode into levelMap
|
---|
| 238 | if (levelMap.TryGetValue(currentNode.level, out sameLevelNodes))
|
---|
| 239 | {
|
---|
| 240 | sameLevelNodes.Add(currentNode);
|
---|
| 241 | }
|
---|
| 242 | else
|
---|
| 243 | {
|
---|
| 244 | sameLevelNodes = new List<TreeNode>();
|
---|
| 245 | sameLevelNodes.Add(currentNode);
|
---|
| 246 | levelMap.Add(currentNode.level, sameLevelNodes);
|
---|
| 247 | }
|
---|
[12762] | 248 |
|
---|
[12781] | 249 | // draw line from current node to all its children
|
---|
| 250 | if (currentNode.children != null)
|
---|
[12762] | 251 | {
|
---|
[12781] | 252 | foreach (TreeNode childNode in currentNode.children)
|
---|
| 253 | {
|
---|
| 254 | toDoNodes.Add(childNode);
|
---|
| 255 | // declare node
|
---|
[12762] | 256 |
|
---|
[12781] | 257 | string hexColor = GetHexNodeColor(Color.White, Color.OrangeRed, childNode.actionInfo.Value);
|
---|
| 258 | dotFile.AppendLine(
|
---|
| 259 | string.Format("{0} [label=\"{1}\\n{2:0.00}/{3}\", style=filled, fillcolor=\"{4}\"]",
|
---|
| 260 | childNode.GetHashCode(),
|
---|
| 261 | childNode.phrase, childNode.actionInfo.Value, childNode.actionInfo.Tries, hexColor));
|
---|
| 262 | // add edge
|
---|
| 263 | dotFile.AppendLine(string.Format("{0} -> {1}", currentNode.GetHashCode(),
|
---|
| 264 | childNode.GetHashCode()));
|
---|
| 265 | }
|
---|
[12762] | 266 | }
|
---|
| 267 | }
|
---|
| 268 |
|
---|
[12781] | 269 | // set same level ranks..
|
---|
| 270 | foreach (KeyValuePair<int, List<TreeNode>> entry in levelMap)
|
---|
[12762] | 271 | {
|
---|
[12781] | 272 | dotFile.Append("{rank = same;");
|
---|
| 273 | foreach (TreeNode node in entry.Value)
|
---|
| 274 | {
|
---|
| 275 | dotFile.Append(string.Format(" {0};", node.GetHashCode()));
|
---|
| 276 | }
|
---|
| 277 | dotFile.AppendLine("}");
|
---|
[12762] | 278 | }
|
---|
[12781] | 279 |
|
---|
| 280 | dotFile.Append(" }");
|
---|
| 281 | byte[] output = wrapper.GenerateGraph(dotFile.ToString(), Enums.GraphReturnType.Svg);
|
---|
| 282 | return output;
|
---|
[12762] | 283 | }
|
---|
[12781] | 284 | return null;
|
---|
[12050] | 285 | }
|
---|
| 286 |
|
---|
[12781] | 287 | protected String HexConverter(Color c)
|
---|
[12050] | 288 | {
|
---|
[12762] | 289 | return "#" + c.R.ToString("X2") + c.G.ToString("X2") + c.B.ToString("X2");
|
---|
[12050] | 290 | }
|
---|
[12762] | 291 |
|
---|
[12781] | 292 | protected String GetHexNodeColor(Color weakColor, Color strongColor, double quality)
|
---|
[12762] | 293 | {
|
---|
| 294 | // convert quality to value between 0 and 1
|
---|
| 295 | double bestKnownQuality = problem.BestKnownQuality(this.maxLen);
|
---|
| 296 | double q = quality / bestKnownQuality;
|
---|
| 297 |
|
---|
| 298 | // calculate difference between colors
|
---|
| 299 | byte rDiff = (byte)Math.Abs(weakColor.R - strongColor.R);
|
---|
| 300 | byte bDiff = (byte)Math.Abs(weakColor.B - strongColor.B);
|
---|
| 301 | byte gDiff = (byte)Math.Abs(weakColor.G - strongColor.G);
|
---|
| 302 |
|
---|
| 303 | byte newR = weakColor.R > strongColor.R
|
---|
| 304 | ? Convert.ToByte(weakColor.R - Math.Round(rDiff * q))
|
---|
| 305 | : Convert.ToByte(weakColor.R + Math.Round(rDiff * q));
|
---|
| 306 |
|
---|
| 307 | byte newB = weakColor.B > strongColor.B
|
---|
| 308 | ? Convert.ToByte(weakColor.B - Math.Round(bDiff * q))
|
---|
| 309 | : Convert.ToByte(weakColor.B + Math.Round(bDiff * q));
|
---|
| 310 |
|
---|
| 311 | byte newG = weakColor.G > strongColor.G
|
---|
| 312 | ? Convert.ToByte(weakColor.G - Math.Round(gDiff * q))
|
---|
| 313 | : Convert.ToByte(weakColor.G + Math.Round(gDiff * q));
|
---|
| 314 |
|
---|
| 315 | return HexConverter(Color.FromArgb(newR, newG, newB));
|
---|
| 316 | }
|
---|
| 317 |
|
---|
| 318 | public void PauseContinue()
|
---|
| 319 | {
|
---|
| 320 | lock (pauseLock)
|
---|
| 321 | {
|
---|
| 322 | if (isPaused)
|
---|
| 323 | {
|
---|
| 324 | isPaused = false;
|
---|
| 325 | Monitor.Pulse(pauseLock);
|
---|
| 326 | }
|
---|
| 327 | else
|
---|
| 328 | {
|
---|
| 329 | isPaused = true;
|
---|
| 330 | }
|
---|
| 331 | }
|
---|
| 332 | }
|
---|
[12050] | 333 | }
|
---|
| 334 | }
|
---|