Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.MonteCarloTreeSearch/MonteCarloTreeSearch.cs @ 12547

Last change on this file since 12547 was 12503, checked in by aballeit, 10 years ago

#2283 added GUI and charts; fixed MCTS

File size: 6.7 KB
RevLine 
[12050]1using System;
2using System.Collections.Generic;
3using System.Linq;
4using HeuristicLab.Algorithms.Bandits;
[12098]5using HeuristicLab.Algorithms.GrammaticalOptimization;
6using HeuristicLab.Algorithms.MonteCarloTreeSearch.Base;
7using HeuristicLab.Algorithms.MonteCarloTreeSearch.Simulation;
8using HeuristicLab.Common;
[12050]9using HeuristicLab.Problems.GrammaticalOptimization;
10
[12098]11namespace HeuristicLab.Algorithms.MonteCarloTreeSearch
[12050]12{
13    public class MonteCarloTreeSearch : SolverBase
14    {
15        private readonly int maxLen;
16        private readonly IProblem problem;
[12098]17        private readonly IGrammar grammar;
[12050]18        private readonly Random random;
19        private readonly IBanditPolicy behaviourPolicy;
[12098]20        private readonly ISimulation simulation;
[12050]21        private TreeNode rootNode;
22
[12098]23        public MonteCarloTreeSearch(IProblem problem, int maxLen, Random random, IBanditPolicy behaviourPolicy, ISimulation simulationPolicy)
[12050]24        {
25            this.problem = problem;
[12098]26            this.grammar = problem.Grammar;
[12050]27            this.maxLen = maxLen;
28            this.random = random;
29            this.behaviourPolicy = behaviourPolicy;
[12098]30            this.simulation = simulationPolicy;
[12050]31        }
32
33        public bool StopRequested
34        {
35            get;
36            set;
37        }
38
39        public override void Run(int maxIterations)
40        {
41            Reset();
[12098]42            for (int i = 0; !StopRequested && i < maxIterations; i++)
[12050]43            {
[12098]44                TreeNode currentNode = rootNode;
45
46                while (!currentNode.IsLeaf())
[12050]47                {
[12098]48                    int currentActionIndex = behaviourPolicy.SelectAction(random,
49                        currentNode.GetChildActionInfos());
50                    currentNode = currentNode.children[currentActionIndex];
51                }
[12050]52
[12098]53                string phrase = currentNode.phrase;
54
[12503]55                if (!grammar.IsTerminal(phrase) && phrase.Length <= maxLen)
[12098]56                {
57                    ExpandTreeNode(currentNode);
58
59                    currentNode =
60                        currentNode.children[behaviourPolicy.SelectAction(random, currentNode.GetChildActionInfos())];
61                }
[12503]62                if (currentNode.phrase.Length <= maxLen)
63                {
64                    double quality = simulation.Simulate(currentNode);
65                    OnSolutionEvaluated(phrase, quality);
[12098]66
[12503]67                    Propagate(currentNode, quality);
68                }
[12050]69            }
70        }
71
[12098]72        private void ExpandTreeNode(TreeNode treeNode)
73        {
74            // create children on the first visit
75            if (treeNode.children == null)
76            {
77                treeNode.children = new List<TreeNode>();
78
79                var phrase = new Sequence(treeNode.phrase);
80                // create subnodes for each nt-symbol in phrase
81                for (int i = 0; i < phrase.Length; i++)
82                {
83                    char symbol = phrase[i];
84                    if (grammar.IsNonTerminal(symbol))
85                    {
86                        // create subnode for each alternative of symbol
87                        foreach (Sequence alternative in grammar.GetAlternatives(symbol))
88                        {
89                            Sequence newSequence = new Sequence(phrase);
90                            newSequence.ReplaceAt(i, 1, alternative);
91                            if (newSequence.Length <= maxLen)
92                            {
[12503]93                                TreeNode childNode = new TreeNode(treeNode, newSequence.ToString(), behaviourPolicy.CreateActionInfo());
[12098]94                                treeNode.children.Add(childNode);
95                            }
96                        }
97                    }
98                }
99            }
100        }
101
[12050]102        private void Reset()
103        {
104            StopRequested = false;
105            bestQuality = 0.0;
[12503]106            rootNode = new TreeNode(null, grammar.SentenceSymbol.ToString(), behaviourPolicy.CreateActionInfo());
[12050]107        }
108
[12098]109        private void Propagate(TreeNode node, double quality)
[12050]110        {
[12098]111            var currentNode = node;
112            do
113            {
114                currentNode.actionInfo.UpdateReward(quality);
115                currentNode = currentNode.parent;
116            } while (currentNode != null);
[12050]117        }
118
[12098]119        public void PrintStats()
[12050]120        {
[12098]121            //Console.WriteLine("depth: {0,5} tries: {1,5} best phrase {2,50} bestQ {3:F3}", maxSearchDepth, tries, bestPhrase, bestQuality);
122
123            //// use behaviour strategy to generate the currently prefered sentence
124            //var policy = behaviourPolicy;
125
126            //var n = rootNode;
127
128            //while (n != null)
129            //{
130            //    var phrase = n.phrase;
131            //    Console.ForegroundColor = ConsoleColor.White;
132            //    Console.WriteLine("{0,-30}", phrase);
133            //    var children = n.children;
134            //    if (children == null || !children.Any()) break;
135            //    var values = children.Select(ch => policy.GetValue(ch.phrase));
136            //    var maxValue = values.Max();
137            //    if (maxValue == 0) maxValue = 1.0;
138
139            //    // write phrases
140            //    foreach (var ch in children)
141            //    {
142            //        SetColorForValue(policy.GetValue(ch.phrase) / maxValue);
143            //        Console.Write(" {0,-4}", ch.phrase.Substring(Math.Max(0, ch.phrase.Length - 3), Math.Min(3, ch.phrase.Length)));
144            //    }
145            //    Console.WriteLine();
146
147            //    // write values
148            //    foreach (var ch in children)
149            //    {
150            //        SetColorForValue(policy.GetValue(ch.phrase) / maxValue);
151            //        Console.Write(" {0:F2}", policy.GetValue(ch.phrase) * 10.0);
152            //    }
153            //    Console.WriteLine();
154
155            //    // write tries
156            //    foreach (var ch in children)
157            //    {
158            //        SetColorForValue(policy.GetValue(ch.phrase) / maxValue);
159            //        Console.Write(" {0,4}", policy.GetTries(ch.phrase));
160            //    }
161            //    Console.WriteLine();
162            //    int selectedChildIdx;
163            //    if (!policy.TrySelect(random, phrase, children.Select(ch => ch.phrase), out selectedChildIdx))
164            //    {
165            //        break;
166            //    }
167            //    n = n.children[selectedChildIdx];
168            //}
169
170            //Console.ForegroundColor = ConsoleColor.White;
171            //Console.WriteLine("-------------------");
[12050]172        }
173
[12098]174        private void SetColorForValue(double v)
[12050]175        {
[12098]176            Console.ForegroundColor = ConsoleEx.ColorForValue(v);
[12050]177        }
178    }
179}
Note: See TracBrowser for help on using the repository browser.