Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.MonteCarloTreeSearch/MonteCarloTreeSearch.cs @ 12098

Last change on this file since 12098 was 12098, checked in by aballeit, 9 years ago

#2283: implemented MCTS

File size: 6.5 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using HeuristicLab.Algorithms.Bandits;
5using HeuristicLab.Algorithms.GrammaticalOptimization;
6using HeuristicLab.Algorithms.MonteCarloTreeSearch.Base;
7using HeuristicLab.Algorithms.MonteCarloTreeSearch.Simulation;
8using HeuristicLab.Common;
9using HeuristicLab.Problems.GrammaticalOptimization;
10
11namespace HeuristicLab.Algorithms.MonteCarloTreeSearch
12{
13    public class MonteCarloTreeSearch : SolverBase
14    {
15        private readonly int maxLen;
16        private readonly IProblem problem;
17        private readonly IGrammar grammar;
18        private readonly Random random;
19        private readonly IBanditPolicy behaviourPolicy;
20        private readonly ISimulation simulation;
21        private TreeNode rootNode;
22
23        public MonteCarloTreeSearch(IProblem problem, int maxLen, Random random, IBanditPolicy behaviourPolicy, ISimulation simulationPolicy)
24        {
25            this.problem = problem;
26            this.grammar = problem.Grammar;
27            this.maxLen = maxLen;
28            this.random = random;
29            this.behaviourPolicy = behaviourPolicy;
30            this.simulation = simulationPolicy;
31        }
32
33        public bool StopRequested
34        {
35            get;
36            set;
37        }
38
39        public override void Run(int maxIterations)
40        {
41            Reset();
42            for (int i = 0; !StopRequested && i < maxIterations; i++)
43            {
44                TreeNode currentNode = rootNode;
45
46                while (!currentNode.IsLeaf())
47                {
48                    int currentActionIndex = behaviourPolicy.SelectAction(random,
49                        currentNode.GetChildActionInfos());
50                    currentNode = currentNode.children[currentActionIndex];
51                }
52
53                string phrase = currentNode.phrase;
54
55                if (!grammar.IsTerminal(phrase))
56                {
57                    ExpandTreeNode(currentNode);
58
59                    currentNode =
60                        currentNode.children[behaviourPolicy.SelectAction(random, currentNode.GetChildActionInfos())];
61                }
62                double quality = simulation.Simulate(currentNode);
63                OnSolutionEvaluated(phrase, quality);
64
65                Propagate(currentNode, quality);
66            }
67        }
68
69        private void ExpandTreeNode(TreeNode treeNode)
70        {
71            // create children on the first visit
72            if (treeNode.children == null)
73            {
74                treeNode.children = new List<TreeNode>();
75
76                var phrase = new Sequence(treeNode.phrase);
77                // create subnodes for each nt-symbol in phrase
78                for (int i = 0; i < phrase.Length; i++)
79                {
80                    char symbol = phrase[i];
81                    if (grammar.IsNonTerminal(symbol))
82                    {
83                        // create subnode for each alternative of symbol
84                        foreach (Sequence alternative in grammar.GetAlternatives(symbol))
85                        {
86                            Sequence newSequence = new Sequence(phrase);
87                            newSequence.ReplaceAt(i, 1, alternative);
88                            if (newSequence.Length <= maxLen)
89                            {
90                                TreeNode childNode = new TreeNode(treeNode, newSequence.ToString());
91                                treeNode.children.Add(childNode);
92                            }
93                        }
94                    }
95                }
96            }
97        }
98
99        private void Reset()
100        {
101            StopRequested = false;
102            bestQuality = 0.0;
103            rootNode = new TreeNode(null, grammar.SentenceSymbol.ToString());
104        }
105
106        private void Propagate(TreeNode node, double quality)
107        {
108            var currentNode = node;
109            do
110            {
111                currentNode.actionInfo.UpdateReward(quality);
112                currentNode = currentNode.parent;
113            } while (currentNode != null);
114        }
115
116        public void PrintStats()
117        {
118            //Console.WriteLine("depth: {0,5} tries: {1,5} best phrase {2,50} bestQ {3:F3}", maxSearchDepth, tries, bestPhrase, bestQuality);
119
120            //// use behaviour strategy to generate the currently prefered sentence
121            //var policy = behaviourPolicy;
122
123            //var n = rootNode;
124
125            //while (n != null)
126            //{
127            //    var phrase = n.phrase;
128            //    Console.ForegroundColor = ConsoleColor.White;
129            //    Console.WriteLine("{0,-30}", phrase);
130            //    var children = n.children;
131            //    if (children == null || !children.Any()) break;
132            //    var values = children.Select(ch => policy.GetValue(ch.phrase));
133            //    var maxValue = values.Max();
134            //    if (maxValue == 0) maxValue = 1.0;
135
136            //    // write phrases
137            //    foreach (var ch in children)
138            //    {
139            //        SetColorForValue(policy.GetValue(ch.phrase) / maxValue);
140            //        Console.Write(" {0,-4}", ch.phrase.Substring(Math.Max(0, ch.phrase.Length - 3), Math.Min(3, ch.phrase.Length)));
141            //    }
142            //    Console.WriteLine();
143
144            //    // write values
145            //    foreach (var ch in children)
146            //    {
147            //        SetColorForValue(policy.GetValue(ch.phrase) / maxValue);
148            //        Console.Write(" {0:F2}", policy.GetValue(ch.phrase) * 10.0);
149            //    }
150            //    Console.WriteLine();
151
152            //    // write tries
153            //    foreach (var ch in children)
154            //    {
155            //        SetColorForValue(policy.GetValue(ch.phrase) / maxValue);
156            //        Console.Write(" {0,4}", policy.GetTries(ch.phrase));
157            //    }
158            //    Console.WriteLine();
159            //    int selectedChildIdx;
160            //    if (!policy.TrySelect(random, phrase, children.Select(ch => ch.phrase), out selectedChildIdx))
161            //    {
162            //        break;
163            //    }
164            //    n = n.children[selectedChildIdx];
165            //}
166
167            //Console.ForegroundColor = ConsoleColor.White;
168            //Console.WriteLine("-------------------");
169        }
170
171        private void SetColorForValue(double v)
172        {
173            Console.ForegroundColor = ConsoleEx.ColorForValue(v);
174        }
175    }
176}
Note: See TracBrowser for help on using the repository browser.