Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.MonteCarloTreeSearch/MonteCarloTreeSearch.cs @ 12747

Last change on this file since 12747 was 12503, checked in by aballeit, 10 years ago

#2283 added GUI and charts; fixed MCTS

File size: 6.7 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using HeuristicLab.Algorithms.Bandits;
5using HeuristicLab.Algorithms.GrammaticalOptimization;
6using HeuristicLab.Algorithms.MonteCarloTreeSearch.Base;
7using HeuristicLab.Algorithms.MonteCarloTreeSearch.Simulation;
8using HeuristicLab.Common;
9using HeuristicLab.Problems.GrammaticalOptimization;
10
11namespace HeuristicLab.Algorithms.MonteCarloTreeSearch
12{
13    public class MonteCarloTreeSearch : SolverBase
14    {
15        private readonly int maxLen;
16        private readonly IProblem problem;
17        private readonly IGrammar grammar;
18        private readonly Random random;
19        private readonly IBanditPolicy behaviourPolicy;
20        private readonly ISimulation simulation;
21        private TreeNode rootNode;
22
23        public MonteCarloTreeSearch(IProblem problem, int maxLen, Random random, IBanditPolicy behaviourPolicy, ISimulation simulationPolicy)
24        {
25            this.problem = problem;
26            this.grammar = problem.Grammar;
27            this.maxLen = maxLen;
28            this.random = random;
29            this.behaviourPolicy = behaviourPolicy;
30            this.simulation = simulationPolicy;
31        }
32
33        public bool StopRequested
34        {
35            get;
36            set;
37        }
38
39        public override void Run(int maxIterations)
40        {
41            Reset();
42            for (int i = 0; !StopRequested && i < maxIterations; i++)
43            {
44                TreeNode currentNode = rootNode;
45
46                while (!currentNode.IsLeaf())
47                {
48                    int currentActionIndex = behaviourPolicy.SelectAction(random,
49                        currentNode.GetChildActionInfos());
50                    currentNode = currentNode.children[currentActionIndex];
51                }
52
53                string phrase = currentNode.phrase;
54
55                if (!grammar.IsTerminal(phrase) && phrase.Length <= maxLen)
56                {
57                    ExpandTreeNode(currentNode);
58
59                    currentNode =
60                        currentNode.children[behaviourPolicy.SelectAction(random, currentNode.GetChildActionInfos())];
61                }
62                if (currentNode.phrase.Length <= maxLen)
63                {
64                    double quality = simulation.Simulate(currentNode);
65                    OnSolutionEvaluated(phrase, quality);
66
67                    Propagate(currentNode, quality);
68                }
69            }
70        }
71
72        private void ExpandTreeNode(TreeNode treeNode)
73        {
74            // create children on the first visit
75            if (treeNode.children == null)
76            {
77                treeNode.children = new List<TreeNode>();
78
79                var phrase = new Sequence(treeNode.phrase);
80                // create subnodes for each nt-symbol in phrase
81                for (int i = 0; i < phrase.Length; i++)
82                {
83                    char symbol = phrase[i];
84                    if (grammar.IsNonTerminal(symbol))
85                    {
86                        // create subnode for each alternative of symbol
87                        foreach (Sequence alternative in grammar.GetAlternatives(symbol))
88                        {
89                            Sequence newSequence = new Sequence(phrase);
90                            newSequence.ReplaceAt(i, 1, alternative);
91                            if (newSequence.Length <= maxLen)
92                            {
93                                TreeNode childNode = new TreeNode(treeNode, newSequence.ToString(), behaviourPolicy.CreateActionInfo());
94                                treeNode.children.Add(childNode);
95                            }
96                        }
97                    }
98                }
99            }
100        }
101
102        private void Reset()
103        {
104            StopRequested = false;
105            bestQuality = 0.0;
106            rootNode = new TreeNode(null, grammar.SentenceSymbol.ToString(), behaviourPolicy.CreateActionInfo());
107        }
108
109        private void Propagate(TreeNode node, double quality)
110        {
111            var currentNode = node;
112            do
113            {
114                currentNode.actionInfo.UpdateReward(quality);
115                currentNode = currentNode.parent;
116            } while (currentNode != null);
117        }
118
119        public void PrintStats()
120        {
121            //Console.WriteLine("depth: {0,5} tries: {1,5} best phrase {2,50} bestQ {3:F3}", maxSearchDepth, tries, bestPhrase, bestQuality);
122
123            //// use behaviour strategy to generate the currently prefered sentence
124            //var policy = behaviourPolicy;
125
126            //var n = rootNode;
127
128            //while (n != null)
129            //{
130            //    var phrase = n.phrase;
131            //    Console.ForegroundColor = ConsoleColor.White;
132            //    Console.WriteLine("{0,-30}", phrase);
133            //    var children = n.children;
134            //    if (children == null || !children.Any()) break;
135            //    var values = children.Select(ch => policy.GetValue(ch.phrase));
136            //    var maxValue = values.Max();
137            //    if (maxValue == 0) maxValue = 1.0;
138
139            //    // write phrases
140            //    foreach (var ch in children)
141            //    {
142            //        SetColorForValue(policy.GetValue(ch.phrase) / maxValue);
143            //        Console.Write(" {0,-4}", ch.phrase.Substring(Math.Max(0, ch.phrase.Length - 3), Math.Min(3, ch.phrase.Length)));
144            //    }
145            //    Console.WriteLine();
146
147            //    // write values
148            //    foreach (var ch in children)
149            //    {
150            //        SetColorForValue(policy.GetValue(ch.phrase) / maxValue);
151            //        Console.Write(" {0:F2}", policy.GetValue(ch.phrase) * 10.0);
152            //    }
153            //    Console.WriteLine();
154
155            //    // write tries
156            //    foreach (var ch in children)
157            //    {
158            //        SetColorForValue(policy.GetValue(ch.phrase) / maxValue);
159            //        Console.Write(" {0,4}", policy.GetTries(ch.phrase));
160            //    }
161            //    Console.WriteLine();
162            //    int selectedChildIdx;
163            //    if (!policy.TrySelect(random, phrase, children.Select(ch => ch.phrase), out selectedChildIdx))
164            //    {
165            //        break;
166            //    }
167            //    n = n.children[selectedChildIdx];
168            //}
169
170            //Console.ForegroundColor = ConsoleColor.White;
171            //Console.WriteLine("-------------------");
172        }
173
174        private void SetColorForValue(double v)
175        {
176            Console.ForegroundColor = ConsoleEx.ColorForValue(v);
177        }
178    }
179}
Note: See TracBrowser for help on using the repository browser.