[12050] | 1 | using System;
|
---|
| 2 | using System.Collections.Generic;
|
---|
[12815] | 3 | using System.Diagnostics;
|
---|
[12762] | 4 | using System.Drawing;
|
---|
[12050] | 5 | using System.Linq;
|
---|
[12762] | 6 | using System.Net.Mail;
|
---|
| 7 | using System.Text;
|
---|
| 8 | using System.Threading;
|
---|
| 9 | using GraphVizWrapper;
|
---|
| 10 | using GraphVizWrapper.Commands;
|
---|
| 11 | using GraphVizWrapper.Queries;
|
---|
[12050] | 12 | using HeuristicLab.Algorithms.Bandits;
|
---|
[12098] | 13 | using HeuristicLab.Algorithms.GrammaticalOptimization;
|
---|
| 14 | using HeuristicLab.Algorithms.MonteCarloTreeSearch.Base;
|
---|
| 15 | using HeuristicLab.Algorithms.MonteCarloTreeSearch.Simulation;
|
---|
| 16 | using HeuristicLab.Common;
|
---|
[12050] | 17 | using HeuristicLab.Problems.GrammaticalOptimization;
|
---|
| 18 |
|
---|
[12098] | 19 | namespace HeuristicLab.Algorithms.MonteCarloTreeSearch
|
---|
[12050] | 20 | {
|
---|
| 21 | public class MonteCarloTreeSearch : SolverBase
|
---|
| 22 | {
|
---|
[12781] | 23 | protected readonly int maxLen;
|
---|
[12815] | 24 | protected readonly ISymbolicExpressionTreeProblem problem;
|
---|
[12781] | 25 | protected readonly IGrammar grammar;
|
---|
| 26 | protected readonly Random random;
|
---|
| 27 | protected readonly IBanditPolicy behaviourPolicy;
|
---|
| 28 | protected readonly ISimulation simulation;
|
---|
| 29 | protected TreeNode rootNode;
|
---|
[12815] | 30 | protected int goodSelections;
|
---|
[12050] | 31 |
|
---|
[12815] | 32 | public MonteCarloTreeSearch(ISymbolicExpressionTreeProblem problem, int maxLen, Random random, IBanditPolicy behaviourPolicy,
|
---|
[12762] | 33 | ISimulation simulationPolicy)
|
---|
[12050] | 34 | {
|
---|
| 35 | this.problem = problem;
|
---|
[12098] | 36 | this.grammar = problem.Grammar;
|
---|
[12050] | 37 | this.maxLen = maxLen;
|
---|
| 38 | this.random = random;
|
---|
| 39 | this.behaviourPolicy = behaviourPolicy;
|
---|
[12098] | 40 | this.simulation = simulationPolicy;
|
---|
[12815] | 41 |
|
---|
| 42 | this.problem.GenerateProblemSolutions(maxLen);
|
---|
[12050] | 43 | }
|
---|
| 44 |
|
---|
[12840] | 45 | public event Action<int, int> IterationFinished;
|
---|
[12815] | 46 |
|
---|
[12840] | 47 | protected virtual void OnIterationFinishedChanged(int value, int selections)
|
---|
[12815] | 48 | {
|
---|
[12840] | 49 | RaiseIterationFinishedChanged(value, selections);
|
---|
[12815] | 50 | }
|
---|
| 51 |
|
---|
[12840] | 52 | private void RaiseIterationFinishedChanged(int value, int selections)
|
---|
[12815] | 53 | {
|
---|
[12840] | 54 | var handler = IterationFinished;
|
---|
[12815] | 55 | if (handler != null) handler(value, selections);
|
---|
| 56 | }
|
---|
| 57 |
|
---|
[12762] | 58 | public bool StopRequested { get; set; }
|
---|
[12829] | 59 |
|
---|
[12815] | 60 | public int GoodSelections { get { return this.goodSelections; } }
|
---|
| 61 |
|
---|
| 62 | protected void CheckSelection(TreeNode node, int selections)
|
---|
| 63 | {
|
---|
| 64 | if (problem.IsParentOfProblemSolution(node.phrase, maxLen))
|
---|
| 65 | {
|
---|
| 66 | goodSelections++;
|
---|
| 67 | }
|
---|
| 68 | }
|
---|
| 69 |
|
---|
[12050] | 70 | public override void Run(int maxIterations)
|
---|
| 71 | {
|
---|
| 72 | Reset();
|
---|
[12815] | 73 | int selections = 0;
|
---|
[12829] | 74 | TreeNode currentNode;
|
---|
| 75 | string phrase;
|
---|
| 76 | string simulatedPhrase;
|
---|
| 77 | double quality;
|
---|
| 78 |
|
---|
[12098] | 79 | for (int i = 0; !StopRequested && i < maxIterations; i++)
|
---|
[12050] | 80 | {
|
---|
[12829] | 81 | currentNode = rootNode;
|
---|
[12098] | 82 |
|
---|
| 83 | while (!currentNode.IsLeaf())
|
---|
[12050] | 84 | {
|
---|
[12098] | 85 | int currentActionIndex = behaviourPolicy.SelectAction(random,
|
---|
| 86 | currentNode.GetChildActionInfos());
|
---|
| 87 | currentNode = currentNode.children[currentActionIndex];
|
---|
[12840] | 88 | selections++;
|
---|
| 89 | CheckSelection(currentNode, selections);
|
---|
[12098] | 90 | }
|
---|
[12050] | 91 |
|
---|
[12829] | 92 | phrase = currentNode.phrase;
|
---|
[12098] | 93 |
|
---|
[12781] | 94 | if (!grammar.IsTerminal(phrase) && phrase.Length <= maxLen)
|
---|
[12098] | 95 | {
|
---|
| 96 | ExpandTreeNode(currentNode);
|
---|
[12781] | 97 | currentNode =
|
---|
| 98 | currentNode.children[behaviourPolicy.SelectAction(random, currentNode.GetChildActionInfos())
|
---|
| 99 | ];
|
---|
[12840] | 100 | selections++;
|
---|
| 101 | CheckSelection(currentNode, selections);
|
---|
[12098] | 102 | }
|
---|
[12503] | 103 | if (currentNode.phrase.Length <= maxLen)
|
---|
| 104 | {
|
---|
[12829] | 105 | quality = simulation.Simulate(currentNode, out simulatedPhrase);
|
---|
[12762] | 106 | OnSolutionEvaluated(simulatedPhrase, quality);
|
---|
[12098] | 107 |
|
---|
[12840] | 108 | OnIterationFinishedChanged(goodSelections, selections);
|
---|
[12503] | 109 | Propagate(currentNode, quality);
|
---|
| 110 | }
|
---|
[12050] | 111 | }
|
---|
| 112 | }
|
---|
| 113 |
|
---|
[12781] | 114 | protected void ExpandTreeNode(TreeNode treeNode)
|
---|
[12762] | 115 | {
|
---|
[12098] | 116 | // create children on the first visit
|
---|
[12829] | 117 | Sequence newSequence;
|
---|
| 118 | TreeNode childNode;
|
---|
[12098] | 119 | if (treeNode.children == null)
|
---|
| 120 | {
|
---|
[12832] | 121 | List<TreeNode> newChildren = new List<TreeNode>();
|
---|
[12098] | 122 |
|
---|
| 123 | var phrase = new Sequence(treeNode.phrase);
|
---|
| 124 | // create subnodes for each nt-symbol in phrase
|
---|
| 125 | for (int i = 0; i < phrase.Length; i++)
|
---|
| 126 | {
|
---|
| 127 | char symbol = phrase[i];
|
---|
| 128 | if (grammar.IsNonTerminal(symbol))
|
---|
| 129 | {
|
---|
| 130 | // create subnode for each alternative of symbol
|
---|
| 131 | foreach (Sequence alternative in grammar.GetAlternatives(symbol))
|
---|
| 132 | {
|
---|
[12829] | 133 | newSequence = new Sequence(phrase);
|
---|
[12098] | 134 | newSequence.ReplaceAt(i, 1, alternative);
|
---|
| 135 | if (newSequence.Length <= maxLen)
|
---|
| 136 | {
|
---|
[12829] | 137 | childNode = new TreeNode(treeNode, newSequence.ToString(),
|
---|
[12832] | 138 | behaviourPolicy.CreateActionInfo(), (ushort) (treeNode.level + 1));
|
---|
| 139 | newChildren.Add(childNode);
|
---|
[12098] | 140 | }
|
---|
| 141 | }
|
---|
| 142 | }
|
---|
| 143 | }
|
---|
[12832] | 144 | treeNode.children = newChildren.ToArray();
|
---|
[12098] | 145 | }
|
---|
| 146 | }
|
---|
| 147 |
|
---|
[12781] | 148 | protected void Reset()
|
---|
[12050] | 149 | {
|
---|
[12815] | 150 | goodSelections = 0;
|
---|
[12050] | 151 | StopRequested = false;
|
---|
| 152 | bestQuality = 0.0;
|
---|
[12762] | 153 | rootNode = new TreeNode(null, grammar.SentenceSymbol.ToString(), behaviourPolicy.CreateActionInfo(), 0);
|
---|
[12050] | 154 | }
|
---|
| 155 |
|
---|
[12781] | 156 | protected void Propagate(TreeNode node, double quality)
|
---|
[12050] | 157 | {
|
---|
[12098] | 158 | var currentNode = node;
|
---|
| 159 | do
|
---|
| 160 | {
|
---|
| 161 | currentNode.actionInfo.UpdateReward(quality);
|
---|
| 162 | currentNode = currentNode.parent;
|
---|
| 163 | } while (currentNode != null);
|
---|
[12050] | 164 | }
|
---|
| 165 |
|
---|
[12832] | 166 | private void GetTreeInfosRek(TreeInfos treeInfos, TreeNode[] children)
|
---|
[12781] | 167 | {
|
---|
[12832] | 168 | treeInfos.TotalNodes += children.Length;
|
---|
[12781] | 169 | foreach (TreeNode child in children)
|
---|
| 170 | {
|
---|
| 171 | if (child.children != null)
|
---|
| 172 | {
|
---|
| 173 | treeInfos.ExpandedNodes++;
|
---|
| 174 | if (treeInfos.DeepestLevel <= child.level)
|
---|
| 175 | {
|
---|
| 176 | treeInfos.DeepestLevel = child.level + 1;
|
---|
| 177 | }
|
---|
| 178 | GetTreeInfosRek(treeInfos, child.children);
|
---|
| 179 | }
|
---|
| 180 | else
|
---|
| 181 | {
|
---|
| 182 | if (grammar.IsTerminal(child.phrase))
|
---|
| 183 | {
|
---|
| 184 | treeInfos.LeaveNodes++;
|
---|
| 185 | }
|
---|
| 186 | else
|
---|
| 187 | {
|
---|
| 188 | treeInfos.UnexpandedNodes++;
|
---|
| 189 | }
|
---|
| 190 | }
|
---|
| 191 | }
|
---|
| 192 | }
|
---|
| 193 |
|
---|
[12762] | 194 | public TreeInfos GetTreeInfos()
|
---|
[12050] | 195 | {
|
---|
[12781] | 196 | TreeInfos treeInfos = new TreeInfos();
|
---|
[12098] | 197 |
|
---|
[12762] | 198 | if (rootNode != null)
|
---|
| 199 | {
|
---|
[12781] | 200 | treeInfos.TotalNodes++;
|
---|
| 201 | if (rootNode.children != null)
|
---|
[12762] | 202 | {
|
---|
[12781] | 203 | treeInfos.ExpandedNodes++;
|
---|
| 204 | treeInfos.DeepestLevel = rootNode.level + 1;
|
---|
| 205 | GetTreeInfosRek(treeInfos, rootNode.children);
|
---|
[12762] | 206 | }
|
---|
| 207 | else
|
---|
| 208 | {
|
---|
[12781] | 209 | treeInfos.DeepestLevel = rootNode.level;
|
---|
| 210 | if (grammar.IsTerminal(rootNode.phrase))
|
---|
[12762] | 211 | {
|
---|
[12781] | 212 | treeInfos.LeaveNodes++;
|
---|
[12762] | 213 | }
|
---|
| 214 | else
|
---|
| 215 | {
|
---|
[12781] | 216 | treeInfos.UnexpandedNodes++;
|
---|
[12762] | 217 | }
|
---|
| 218 | }
|
---|
| 219 | }
|
---|
[12781] | 220 | return treeInfos;
|
---|
[12762] | 221 | }
|
---|
[12098] | 222 |
|
---|
[12762] | 223 | public byte[] GenerateSvg()
|
---|
| 224 | {
|
---|
[12781] | 225 | if (GetTreeInfos().TotalNodes < 1000)
|
---|
| 226 | {
|
---|
| 227 | IGetStartProcessQuery getStartProcessQuery = new GetStartProcessQuery();
|
---|
| 228 | IGetProcessStartInfoQuery getProcessStartInfoQuery = new GetProcessStartInfoQuery();
|
---|
| 229 | IRegisterLayoutPluginCommand registerLayoutPluginCommand =
|
---|
| 230 | new RegisterLayoutPluginCommand(getProcessStartInfoQuery, getStartProcessQuery);
|
---|
[12098] | 231 |
|
---|
[12781] | 232 | GraphGeneration wrapper = new GraphGeneration(getStartProcessQuery,
|
---|
| 233 | getProcessStartInfoQuery,
|
---|
| 234 | registerLayoutPluginCommand);
|
---|
| 235 | wrapper.GraphvizPath = @"../../../Graphviz2.38/bin";
|
---|
| 236 | StringBuilder dotFile = new StringBuilder("digraph {");
|
---|
| 237 | dotFile.AppendLine();
|
---|
| 238 | dotFile.AppendLine("splines=ortho;");
|
---|
| 239 | dotFile.AppendLine("concentrate=true;");
|
---|
| 240 | dotFile.AppendLine("ranksep=1.2;");
|
---|
[12098] | 241 |
|
---|
[12781] | 242 | List<TreeNode> toDoNodes = new List<TreeNode>();
|
---|
| 243 | if (rootNode != null)
|
---|
| 244 | {
|
---|
| 245 | toDoNodes.Add(rootNode);
|
---|
| 246 | // declare node
|
---|
| 247 | string hexColor = GetHexNodeColor(Color.White, Color.OrangeRed, rootNode.actionInfo.Value);
|
---|
[12098] | 248 |
|
---|
[12781] | 249 | dotFile.AppendLine(
|
---|
| 250 | string.Format("{0} [label=\"{1}\\n{2:0.00}/{3}\", style=filled, fillcolor=\"{4}\"]",
|
---|
| 251 | rootNode.GetHashCode(),
|
---|
| 252 | rootNode.phrase, rootNode.actionInfo.Value, rootNode.actionInfo.Tries, hexColor));
|
---|
| 253 | }
|
---|
[12098] | 254 |
|
---|
[12781] | 255 | // to put nodes on the same level in graph
|
---|
| 256 | Dictionary<int, List<TreeNode>> levelMap = new Dictionary<int, List<TreeNode>>();
|
---|
[12762] | 257 |
|
---|
[12781] | 258 | List<TreeNode> sameLevelNodes;
|
---|
[12762] | 259 |
|
---|
[12781] | 260 | while (toDoNodes.Any())
|
---|
[12762] | 261 | {
|
---|
[12781] | 262 | TreeNode currentNode = toDoNodes[0];
|
---|
| 263 | toDoNodes.RemoveAt(0);
|
---|
| 264 | // put currentNode into levelMap
|
---|
| 265 | if (levelMap.TryGetValue(currentNode.level, out sameLevelNodes))
|
---|
| 266 | {
|
---|
| 267 | sameLevelNodes.Add(currentNode);
|
---|
| 268 | }
|
---|
| 269 | else
|
---|
| 270 | {
|
---|
| 271 | sameLevelNodes = new List<TreeNode>();
|
---|
| 272 | sameLevelNodes.Add(currentNode);
|
---|
| 273 | levelMap.Add(currentNode.level, sameLevelNodes);
|
---|
| 274 | }
|
---|
[12762] | 275 |
|
---|
[12781] | 276 | // draw line from current node to all its children
|
---|
| 277 | if (currentNode.children != null)
|
---|
[12762] | 278 | {
|
---|
[12781] | 279 | foreach (TreeNode childNode in currentNode.children)
|
---|
| 280 | {
|
---|
| 281 | toDoNodes.Add(childNode);
|
---|
| 282 | // declare node
|
---|
[12762] | 283 |
|
---|
[12781] | 284 | string hexColor = GetHexNodeColor(Color.White, Color.OrangeRed, childNode.actionInfo.Value);
|
---|
| 285 | dotFile.AppendLine(
|
---|
| 286 | string.Format("{0} [label=\"{1}\\n{2:0.00}/{3}\", style=filled, fillcolor=\"{4}\"]",
|
---|
| 287 | childNode.GetHashCode(),
|
---|
| 288 | childNode.phrase, childNode.actionInfo.Value, childNode.actionInfo.Tries, hexColor));
|
---|
| 289 | // add edge
|
---|
| 290 | dotFile.AppendLine(string.Format("{0} -> {1}", currentNode.GetHashCode(),
|
---|
| 291 | childNode.GetHashCode()));
|
---|
| 292 | }
|
---|
[12762] | 293 | }
|
---|
| 294 | }
|
---|
| 295 |
|
---|
[12781] | 296 | // set same level ranks..
|
---|
| 297 | foreach (KeyValuePair<int, List<TreeNode>> entry in levelMap)
|
---|
[12762] | 298 | {
|
---|
[12781] | 299 | dotFile.Append("{rank = same;");
|
---|
| 300 | foreach (TreeNode node in entry.Value)
|
---|
| 301 | {
|
---|
| 302 | dotFile.Append(string.Format(" {0};", node.GetHashCode()));
|
---|
| 303 | }
|
---|
| 304 | dotFile.AppendLine("}");
|
---|
[12762] | 305 | }
|
---|
[12781] | 306 |
|
---|
| 307 | dotFile.Append(" }");
|
---|
| 308 | byte[] output = wrapper.GenerateGraph(dotFile.ToString(), Enums.GraphReturnType.Svg);
|
---|
| 309 | return output;
|
---|
[12762] | 310 | }
|
---|
[12781] | 311 | return null;
|
---|
[12050] | 312 | }
|
---|
| 313 |
|
---|
[12781] | 314 | protected String HexConverter(Color c)
|
---|
[12050] | 315 | {
|
---|
[12762] | 316 | return "#" + c.R.ToString("X2") + c.G.ToString("X2") + c.B.ToString("X2");
|
---|
[12050] | 317 | }
|
---|
[12762] | 318 |
|
---|
[12781] | 319 | protected String GetHexNodeColor(Color weakColor, Color strongColor, double quality)
|
---|
[12762] | 320 | {
|
---|
| 321 | // convert quality to value between 0 and 1
|
---|
| 322 | double bestKnownQuality = problem.BestKnownQuality(this.maxLen);
|
---|
| 323 | double q = quality / bestKnownQuality;
|
---|
| 324 |
|
---|
| 325 | // calculate difference between colors
|
---|
| 326 | byte rDiff = (byte)Math.Abs(weakColor.R - strongColor.R);
|
---|
| 327 | byte bDiff = (byte)Math.Abs(weakColor.B - strongColor.B);
|
---|
| 328 | byte gDiff = (byte)Math.Abs(weakColor.G - strongColor.G);
|
---|
| 329 |
|
---|
| 330 | byte newR = weakColor.R > strongColor.R
|
---|
| 331 | ? Convert.ToByte(weakColor.R - Math.Round(rDiff * q))
|
---|
| 332 | : Convert.ToByte(weakColor.R + Math.Round(rDiff * q));
|
---|
| 333 |
|
---|
| 334 | byte newB = weakColor.B > strongColor.B
|
---|
| 335 | ? Convert.ToByte(weakColor.B - Math.Round(bDiff * q))
|
---|
| 336 | : Convert.ToByte(weakColor.B + Math.Round(bDiff * q));
|
---|
| 337 |
|
---|
| 338 | byte newG = weakColor.G > strongColor.G
|
---|
| 339 | ? Convert.ToByte(weakColor.G - Math.Round(gDiff * q))
|
---|
| 340 | : Convert.ToByte(weakColor.G + Math.Round(gDiff * q));
|
---|
| 341 |
|
---|
| 342 | return HexConverter(Color.FromArgb(newR, newG, newB));
|
---|
| 343 | }
|
---|
[12832] | 344 |
|
---|
| 345 | public void FreeAll()
|
---|
| 346 | {
|
---|
| 347 | rootNode = null;
|
---|
| 348 | GC.Collect();
|
---|
| 349 | }
|
---|
[12050] | 350 | }
|
---|
| 351 | }
|
---|