using System; using System.Collections.Generic; using System.Linq; using HeuristicLab.Algorithms.Bandits; using HeuristicLab.Algorithms.GrammaticalOptimization; using HeuristicLab.Algorithms.MonteCarloTreeSearch.Base; using HeuristicLab.Algorithms.MonteCarloTreeSearch.Simulation; using HeuristicLab.Common; using HeuristicLab.Problems.GrammaticalOptimization; namespace HeuristicLab.Algorithms.MonteCarloTreeSearch { public class MonteCarloTreeSearch : SolverBase { private readonly int maxLen; private readonly IProblem problem; private readonly IGrammar grammar; private readonly Random random; private readonly IBanditPolicy behaviourPolicy; private readonly ISimulation simulation; private TreeNode rootNode; public MonteCarloTreeSearch(IProblem problem, int maxLen, Random random, IBanditPolicy behaviourPolicy, ISimulation simulationPolicy) { this.problem = problem; this.grammar = problem.Grammar; this.maxLen = maxLen; this.random = random; this.behaviourPolicy = behaviourPolicy; this.simulation = simulationPolicy; } public bool StopRequested { get; set; } public override void Run(int maxIterations) { Reset(); for (int i = 0; !StopRequested && i < maxIterations; i++) { TreeNode currentNode = rootNode; while (!currentNode.IsLeaf()) { int currentActionIndex = behaviourPolicy.SelectAction(random, currentNode.GetChildActionInfos()); currentNode = currentNode.children[currentActionIndex]; } string phrase = currentNode.phrase; if (!grammar.IsTerminal(phrase) && phrase.Length <= maxLen) { ExpandTreeNode(currentNode); currentNode = currentNode.children[behaviourPolicy.SelectAction(random, currentNode.GetChildActionInfos())]; } if (currentNode.phrase.Length <= maxLen) { double quality = simulation.Simulate(currentNode); OnSolutionEvaluated(phrase, quality); Propagate(currentNode, quality); } } } private void ExpandTreeNode(TreeNode treeNode) { // create children on the first visit if (treeNode.children == null) { treeNode.children = new List(); var phrase = new Sequence(treeNode.phrase); // create subnodes for each nt-symbol in phrase for (int i = 0; i < phrase.Length; i++) { char symbol = phrase[i]; if (grammar.IsNonTerminal(symbol)) { // create subnode for each alternative of symbol foreach (Sequence alternative in grammar.GetAlternatives(symbol)) { Sequence newSequence = new Sequence(phrase); newSequence.ReplaceAt(i, 1, alternative); if (newSequence.Length <= maxLen) { TreeNode childNode = new TreeNode(treeNode, newSequence.ToString(), behaviourPolicy.CreateActionInfo()); treeNode.children.Add(childNode); } } } } } } private void Reset() { StopRequested = false; bestQuality = 0.0; rootNode = new TreeNode(null, grammar.SentenceSymbol.ToString(), behaviourPolicy.CreateActionInfo()); } private void Propagate(TreeNode node, double quality) { var currentNode = node; do { currentNode.actionInfo.UpdateReward(quality); currentNode = currentNode.parent; } while (currentNode != null); } public void PrintStats() { //Console.WriteLine("depth: {0,5} tries: {1,5} best phrase {2,50} bestQ {3:F3}", maxSearchDepth, tries, bestPhrase, bestQuality); //// use behaviour strategy to generate the currently prefered sentence //var policy = behaviourPolicy; //var n = rootNode; //while (n != null) //{ // var phrase = n.phrase; // Console.ForegroundColor = ConsoleColor.White; // Console.WriteLine("{0,-30}", phrase); // var children = n.children; // if (children == null || !children.Any()) break; // var values = children.Select(ch => policy.GetValue(ch.phrase)); // var maxValue = values.Max(); // if (maxValue == 0) maxValue = 1.0; // // write phrases // foreach (var ch in children) // { // SetColorForValue(policy.GetValue(ch.phrase) / maxValue); // Console.Write(" {0,-4}", ch.phrase.Substring(Math.Max(0, ch.phrase.Length - 3), Math.Min(3, ch.phrase.Length))); // } // Console.WriteLine(); // // write values // foreach (var ch in children) // { // SetColorForValue(policy.GetValue(ch.phrase) / maxValue); // Console.Write(" {0:F2}", policy.GetValue(ch.phrase) * 10.0); // } // Console.WriteLine(); // // write tries // foreach (var ch in children) // { // SetColorForValue(policy.GetValue(ch.phrase) / maxValue); // Console.Write(" {0,4}", policy.GetTries(ch.phrase)); // } // Console.WriteLine(); // int selectedChildIdx; // if (!policy.TrySelect(random, phrase, children.Select(ch => ch.phrase), out selectedChildIdx)) // { // break; // } // n = n.children[selectedChildIdx]; //} //Console.ForegroundColor = ConsoleColor.White; //Console.WriteLine("-------------------"); } private void SetColorForValue(double v) { Console.ForegroundColor = ConsoleEx.ColorForValue(v); } } }