using System; using System.Collections.Generic; using System.Diagnostics; using System.Drawing; using System.Linq; using System.Net.Mail; using System.Text; using System.Threading; using GraphVizWrapper; using GraphVizWrapper.Commands; using GraphVizWrapper.Queries; using HeuristicLab.Algorithms.Bandits; using HeuristicLab.Algorithms.GrammaticalOptimization; using HeuristicLab.Algorithms.MonteCarloTreeSearch.Base; using HeuristicLab.Algorithms.MonteCarloTreeSearch.Simulation; using HeuristicLab.Common; using HeuristicLab.Problems.GrammaticalOptimization; namespace HeuristicLab.Algorithms.MonteCarloTreeSearch { public class MonteCarloTreeSearch : SolverBase { protected readonly int maxLen; protected readonly ISymbolicExpressionTreeProblem problem; protected readonly IGrammar grammar; protected readonly Random random; protected readonly IBanditPolicy behaviourPolicy; protected readonly ISimulation simulation; protected TreeNode rootNode; protected int goodSelections; public MonteCarloTreeSearch(ISymbolicExpressionTreeProblem problem, int maxLen, Random random, IBanditPolicy behaviourPolicy, ISimulation simulationPolicy) { this.problem = problem; this.grammar = problem.Grammar; this.maxLen = maxLen; this.random = random; this.behaviourPolicy = behaviourPolicy; this.simulation = simulationPolicy; this.problem.GenerateProblemSolutions(maxLen); } public event Action IterationFinished; protected virtual void OnIterationFinishedChanged(int value, int selections) { RaiseIterationFinishedChanged(value, selections); } private void RaiseIterationFinishedChanged(int value, int selections) { var handler = IterationFinished; if (handler != null) handler(value, selections); } public bool StopRequested { get; set; } public int GoodSelections { get { return this.goodSelections; } } protected void CheckSelection(TreeNode node, int selections) { if (problem.IsParentOfProblemSolution(node.phrase, maxLen)) { goodSelections++; } } public override void Run(int maxIterations) { Reset(); int selections = 0; TreeNode currentNode; string phrase; string simulatedPhrase; double quality; for (int i = 0; !StopRequested && i < maxIterations; i++) { currentNode = rootNode; while (!currentNode.IsLeaf()) { int currentActionIndex = behaviourPolicy.SelectAction(random, currentNode.GetChildActionInfos()); currentNode = currentNode.children[currentActionIndex]; selections++; CheckSelection(currentNode, selections); } phrase = currentNode.phrase; if (!grammar.IsTerminal(phrase) && phrase.Length <= maxLen) { ExpandTreeNode(currentNode); currentNode = currentNode.children[behaviourPolicy.SelectAction(random, currentNode.GetChildActionInfos()) ]; selections++; CheckSelection(currentNode, selections); } if (currentNode.phrase.Length <= maxLen) { quality = simulation.Simulate(currentNode, out simulatedPhrase); OnSolutionEvaluated(simulatedPhrase, quality); OnIterationFinishedChanged(goodSelections, selections); Propagate(currentNode, quality); } } } protected void ExpandTreeNode(TreeNode treeNode) { // create children on the first visit Sequence newSequence; TreeNode childNode; if (treeNode.children == null) { List newChildren = new List(); var phrase = new Sequence(treeNode.phrase); // create subnodes for each nt-symbol in phrase for (int i = 0; i < phrase.Length; i++) { char symbol = phrase[i]; if (grammar.IsNonTerminal(symbol)) { // create subnode for each alternative of symbol foreach (Sequence alternative in grammar.GetAlternatives(symbol)) { newSequence = new Sequence(phrase); newSequence.ReplaceAt(i, 1, alternative); if (newSequence.Length <= maxLen) { childNode = new TreeNode(treeNode, newSequence.ToString(), behaviourPolicy.CreateActionInfo(), (ushort) (treeNode.level + 1)); newChildren.Add(childNode); } } } } treeNode.children = newChildren.ToArray(); } } protected void Reset() { goodSelections = 0; StopRequested = false; bestQuality = 0.0; rootNode = new TreeNode(null, grammar.SentenceSymbol.ToString(), behaviourPolicy.CreateActionInfo(), 0); } protected void Propagate(TreeNode node, double quality) { var currentNode = node; do { currentNode.actionInfo.UpdateReward(quality); currentNode = currentNode.parent; } while (currentNode != null); } private void GetTreeInfosRek(TreeInfos treeInfos, TreeNode[] children) { treeInfos.TotalNodes += children.Length; foreach (TreeNode child in children) { if (child.children != null) { treeInfos.ExpandedNodes++; if (treeInfos.DeepestLevel <= child.level) { treeInfos.DeepestLevel = child.level + 1; } GetTreeInfosRek(treeInfos, child.children); } else { if (grammar.IsTerminal(child.phrase)) { treeInfos.LeaveNodes++; } else { treeInfos.UnexpandedNodes++; } } } } public TreeInfos GetTreeInfos() { TreeInfos treeInfos = new TreeInfos(); if (rootNode != null) { treeInfos.TotalNodes++; if (rootNode.children != null) { treeInfos.ExpandedNodes++; treeInfos.DeepestLevel = rootNode.level + 1; GetTreeInfosRek(treeInfos, rootNode.children); } else { treeInfos.DeepestLevel = rootNode.level; if (grammar.IsTerminal(rootNode.phrase)) { treeInfos.LeaveNodes++; } else { treeInfos.UnexpandedNodes++; } } } return treeInfos; } public byte[] GenerateSvg() { if (GetTreeInfos().TotalNodes < 1000) { IGetStartProcessQuery getStartProcessQuery = new GetStartProcessQuery(); IGetProcessStartInfoQuery getProcessStartInfoQuery = new GetProcessStartInfoQuery(); IRegisterLayoutPluginCommand registerLayoutPluginCommand = new RegisterLayoutPluginCommand(getProcessStartInfoQuery, getStartProcessQuery); GraphGeneration wrapper = new GraphGeneration(getStartProcessQuery, getProcessStartInfoQuery, registerLayoutPluginCommand); wrapper.GraphvizPath = @"../../../Graphviz2.38/bin"; StringBuilder dotFile = new StringBuilder("digraph {"); dotFile.AppendLine(); dotFile.AppendLine("splines=ortho;"); dotFile.AppendLine("concentrate=true;"); dotFile.AppendLine("ranksep=1.2;"); List toDoNodes = new List(); if (rootNode != null) { toDoNodes.Add(rootNode); // declare node string hexColor = GetHexNodeColor(Color.White, Color.OrangeRed, rootNode.actionInfo.Value); dotFile.AppendLine( string.Format("{0} [label=\"{1}\\n{2:0.00}/{3}\", style=filled, fillcolor=\"{4}\"]", rootNode.GetHashCode(), rootNode.phrase, rootNode.actionInfo.Value, rootNode.actionInfo.Tries, hexColor)); } // to put nodes on the same level in graph Dictionary> levelMap = new Dictionary>(); List sameLevelNodes; while (toDoNodes.Any()) { TreeNode currentNode = toDoNodes[0]; toDoNodes.RemoveAt(0); // put currentNode into levelMap if (levelMap.TryGetValue(currentNode.level, out sameLevelNodes)) { sameLevelNodes.Add(currentNode); } else { sameLevelNodes = new List(); sameLevelNodes.Add(currentNode); levelMap.Add(currentNode.level, sameLevelNodes); } // draw line from current node to all its children if (currentNode.children != null) { foreach (TreeNode childNode in currentNode.children) { toDoNodes.Add(childNode); // declare node string hexColor = GetHexNodeColor(Color.White, Color.OrangeRed, childNode.actionInfo.Value); dotFile.AppendLine( string.Format("{0} [label=\"{1}\\n{2:0.00}/{3}\", style=filled, fillcolor=\"{4}\"]", childNode.GetHashCode(), childNode.phrase, childNode.actionInfo.Value, childNode.actionInfo.Tries, hexColor)); // add edge dotFile.AppendLine(string.Format("{0} -> {1}", currentNode.GetHashCode(), childNode.GetHashCode())); } } } // set same level ranks.. foreach (KeyValuePair> entry in levelMap) { dotFile.Append("{rank = same;"); foreach (TreeNode node in entry.Value) { dotFile.Append(string.Format(" {0};", node.GetHashCode())); } dotFile.AppendLine("}"); } dotFile.Append(" }"); byte[] output = wrapper.GenerateGraph(dotFile.ToString(), Enums.GraphReturnType.Svg); return output; } return null; } protected String HexConverter(Color c) { return "#" + c.R.ToString("X2") + c.G.ToString("X2") + c.B.ToString("X2"); } protected String GetHexNodeColor(Color weakColor, Color strongColor, double quality) { // convert quality to value between 0 and 1 double bestKnownQuality = problem.BestKnownQuality(this.maxLen); double q = quality / bestKnownQuality; // calculate difference between colors byte rDiff = (byte)Math.Abs(weakColor.R - strongColor.R); byte bDiff = (byte)Math.Abs(weakColor.B - strongColor.B); byte gDiff = (byte)Math.Abs(weakColor.G - strongColor.G); byte newR = weakColor.R > strongColor.R ? Convert.ToByte(weakColor.R - Math.Round(rDiff * q)) : Convert.ToByte(weakColor.R + Math.Round(rDiff * q)); byte newB = weakColor.B > strongColor.B ? Convert.ToByte(weakColor.B - Math.Round(bDiff * q)) : Convert.ToByte(weakColor.B + Math.Round(bDiff * q)); byte newG = weakColor.G > strongColor.G ? Convert.ToByte(weakColor.G - Math.Round(gDiff * q)) : Convert.ToByte(weakColor.G + Math.Round(gDiff * q)); return HexConverter(Color.FromArgb(newR, newG, newB)); } public void FreeAll() { rootNode = null; GC.Collect(); } } }