Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.MonteCarloTreeSearch/MonteCarloTreeSearch.cs @ 12827

Last change on this file since 12827 was 12827, checked in by aballeit, 9 years ago

#2283 added TreeInfos to MCTS excel export

File size: 14.1 KB
RevLine 
[12050]1using System;
2using System.Collections.Generic;
[12815]3using System.Diagnostics;
[12762]4using System.Drawing;
[12050]5using System.Linq;
[12762]6using System.Net.Mail;
7using System.Text;
8using System.Threading;
9using GraphVizWrapper;
10using GraphVizWrapper.Commands;
11using GraphVizWrapper.Queries;
[12050]12using HeuristicLab.Algorithms.Bandits;
[12098]13using HeuristicLab.Algorithms.GrammaticalOptimization;
14using HeuristicLab.Algorithms.MonteCarloTreeSearch.Base;
15using HeuristicLab.Algorithms.MonteCarloTreeSearch.Simulation;
16using HeuristicLab.Common;
[12050]17using HeuristicLab.Problems.GrammaticalOptimization;
18
[12098]19namespace HeuristicLab.Algorithms.MonteCarloTreeSearch
[12050]20{
21    public class MonteCarloTreeSearch : SolverBase
22    {
[12781]23        protected readonly int maxLen;
[12815]24        protected readonly ISymbolicExpressionTreeProblem problem;
[12781]25        protected readonly IGrammar grammar;
26        protected readonly Random random;
27        protected readonly IBanditPolicy behaviourPolicy;
28        protected readonly ISimulation simulation;
29        protected TreeNode rootNode;
30        protected bool isPaused = false;
31        protected object pauseLock = new object();
[12815]32        protected int goodSelections;
[12050]33
[12815]34        public MonteCarloTreeSearch(ISymbolicExpressionTreeProblem problem, int maxLen, Random random, IBanditPolicy behaviourPolicy,
[12762]35            ISimulation simulationPolicy)
[12050]36        {
37            this.problem = problem;
[12098]38            this.grammar = problem.Grammar;
[12050]39            this.maxLen = maxLen;
40            this.random = random;
41            this.behaviourPolicy = behaviourPolicy;
[12098]42            this.simulation = simulationPolicy;
[12815]43
44            this.problem.GenerateProblemSolutions(maxLen);
[12050]45        }
46
[12815]47        public event Action<int, int> NodeSelected;
48
49        protected virtual void OnNodeSelectedChanged(int value, int selections)
50        {
51            RaiseNodeSelectedChanged(value, selections);
52        }
53
54        private void RaiseNodeSelectedChanged(int value, int selections)
55        {
56            var handler = NodeSelected;
57            if (handler != null) handler(value, selections);
58        }
59
[12762]60        public bool StopRequested { get; set; }
61
62        public bool IsPaused
[12050]63        {
[12762]64            get { return this.isPaused; }
[12050]65        }
66
[12815]67        public int GoodSelections { get { return this.goodSelections; } }
68
69        protected void CheckSelection(TreeNode node, int selections)
70        {
71            if (problem.IsParentOfProblemSolution(node.phrase, maxLen))
72            {
73                goodSelections++;
74            }
75
76            OnNodeSelectedChanged(goodSelections, selections);
77        }
78
[12050]79        public override void Run(int maxIterations)
80        {
81            Reset();
[12815]82            int selections = 0;
[12098]83            for (int i = 0; !StopRequested && i < maxIterations; i++)
[12050]84            {
[12762]85                lock (pauseLock)
86                {
87                    if (isPaused)
88                    {
89                        Monitor.Wait(pauseLock);
90                    }
91                }
[12098]92                TreeNode currentNode = rootNode;
93
94                while (!currentNode.IsLeaf())
[12050]95                {
[12098]96                    int currentActionIndex = behaviourPolicy.SelectAction(random,
97                        currentNode.GetChildActionInfos());
98                    currentNode = currentNode.children[currentActionIndex];
[12827]99                    //selections++;
100                    //CheckSelection(currentNode, selections);
[12098]101                }
[12050]102
[12098]103                string phrase = currentNode.phrase;
104
[12781]105                if (!grammar.IsTerminal(phrase) && phrase.Length <= maxLen)
[12098]106                {
107                    ExpandTreeNode(currentNode);
[12781]108                    currentNode =
109                        currentNode.children[behaviourPolicy.SelectAction(random, currentNode.GetChildActionInfos())
110                            ];
[12827]111                    //selections++;
112                    //CheckSelection(currentNode, selections);
[12098]113                }
[12503]114                if (currentNode.phrase.Length <= maxLen)
115                {
[12762]116                    string simulatedPhrase;
117                    double quality = simulation.Simulate(currentNode, out simulatedPhrase);
118                    OnSolutionEvaluated(simulatedPhrase, quality);
[12098]119
[12503]120                    Propagate(currentNode, quality);
121                }
[12050]122            }
123        }
124
[12781]125        protected void ExpandTreeNode(TreeNode treeNode)
[12762]126        {
[12098]127            // create children on the first visit
128            if (treeNode.children == null)
129            {
130                treeNode.children = new List<TreeNode>();
131
132                var phrase = new Sequence(treeNode.phrase);
133                // create subnodes for each nt-symbol in phrase
134                for (int i = 0; i < phrase.Length; i++)
135                {
136                    char symbol = phrase[i];
137                    if (grammar.IsNonTerminal(symbol))
138                    {
139                        // create subnode for each alternative of symbol
140                        foreach (Sequence alternative in grammar.GetAlternatives(symbol))
141                        {
142                            Sequence newSequence = new Sequence(phrase);
143                            newSequence.ReplaceAt(i, 1, alternative);
144                            if (newSequence.Length <= maxLen)
145                            {
[12762]146                                TreeNode childNode = new TreeNode(treeNode, newSequence.ToString(),
147                                    behaviourPolicy.CreateActionInfo(), treeNode.level + 1);
[12098]148                                treeNode.children.Add(childNode);
149                            }
150                        }
151                    }
152                }
153            }
154        }
155
[12781]156        protected void Reset()
[12050]157        {
[12815]158            goodSelections = 0;
[12050]159            StopRequested = false;
160            bestQuality = 0.0;
[12762]161            rootNode = new TreeNode(null, grammar.SentenceSymbol.ToString(), behaviourPolicy.CreateActionInfo(), 0);
[12050]162        }
163
[12781]164        protected void Propagate(TreeNode node, double quality)
[12050]165        {
[12098]166            var currentNode = node;
167            do
168            {
169                currentNode.actionInfo.UpdateReward(quality);
170                currentNode = currentNode.parent;
171            } while (currentNode != null);
[12050]172        }
173
[12781]174        private void GetTreeInfosRek(TreeInfos treeInfos, List<TreeNode> children)
175        {
176            treeInfos.TotalNodes += children.Count;
177            foreach (TreeNode child in children)
178            {
179                if (child.children != null)
180                {
181                    treeInfos.ExpandedNodes++;
182                    if (treeInfos.DeepestLevel <= child.level)
183                    {
184                        treeInfos.DeepestLevel = child.level + 1;
185                    }
186                    GetTreeInfosRek(treeInfos, child.children);
187                }
188                else
189                {
190                    if (grammar.IsTerminal(child.phrase))
191                    {
192                        treeInfos.LeaveNodes++;
193                    }
194                    else
195                    {
196                        treeInfos.UnexpandedNodes++;
197                    }
198                }
199            }
200        }
201
[12762]202        public TreeInfos GetTreeInfos()
[12050]203        {
[12781]204            TreeInfos treeInfos = new TreeInfos();
[12098]205
[12762]206            if (rootNode != null)
207            {
[12781]208                treeInfos.TotalNodes++;
209                if (rootNode.children != null)
[12762]210                {
[12781]211                    treeInfos.ExpandedNodes++;
212                    treeInfos.DeepestLevel = rootNode.level + 1;
213                    GetTreeInfosRek(treeInfos, rootNode.children);
[12762]214                }
215                else
216                {
[12781]217                    treeInfos.DeepestLevel = rootNode.level;
218                    if (grammar.IsTerminal(rootNode.phrase))
[12762]219                    {
[12781]220                        treeInfos.LeaveNodes++;
[12762]221                    }
222                    else
223                    {
[12781]224                        treeInfos.UnexpandedNodes++;
[12762]225                    }
226                }
227            }
[12781]228            return treeInfos;
[12762]229        }
[12098]230
[12762]231        public byte[] GenerateSvg()
232        {
[12781]233            if (GetTreeInfos().TotalNodes < 1000)
234            {
235                IGetStartProcessQuery getStartProcessQuery = new GetStartProcessQuery();
236                IGetProcessStartInfoQuery getProcessStartInfoQuery = new GetProcessStartInfoQuery();
237                IRegisterLayoutPluginCommand registerLayoutPluginCommand =
238                    new RegisterLayoutPluginCommand(getProcessStartInfoQuery, getStartProcessQuery);
[12098]239
[12781]240                GraphGeneration wrapper = new GraphGeneration(getStartProcessQuery,
241                    getProcessStartInfoQuery,
242                    registerLayoutPluginCommand);
243                wrapper.GraphvizPath = @"../../../Graphviz2.38/bin";
244                StringBuilder dotFile = new StringBuilder("digraph {");
245                dotFile.AppendLine();
246                dotFile.AppendLine("splines=ortho;");
247                dotFile.AppendLine("concentrate=true;");
248                dotFile.AppendLine("ranksep=1.2;");
[12098]249
[12781]250                List<TreeNode> toDoNodes = new List<TreeNode>();
251                if (rootNode != null)
252                {
253                    toDoNodes.Add(rootNode);
254                    // declare node
255                    string hexColor = GetHexNodeColor(Color.White, Color.OrangeRed, rootNode.actionInfo.Value);
[12098]256
[12781]257                    dotFile.AppendLine(
258                        string.Format("{0} [label=\"{1}\\n{2:0.00}/{3}\", style=filled, fillcolor=\"{4}\"]",
259                            rootNode.GetHashCode(),
260                            rootNode.phrase, rootNode.actionInfo.Value, rootNode.actionInfo.Tries, hexColor));
261                }
[12098]262
[12781]263                // to put nodes on the same level in graph
264                Dictionary<int, List<TreeNode>> levelMap = new Dictionary<int, List<TreeNode>>();
[12762]265
[12781]266                List<TreeNode> sameLevelNodes;
[12762]267
[12781]268                while (toDoNodes.Any())
[12762]269                {
[12781]270                    TreeNode currentNode = toDoNodes[0];
271                    toDoNodes.RemoveAt(0);
272                    // put currentNode into levelMap
273                    if (levelMap.TryGetValue(currentNode.level, out sameLevelNodes))
274                    {
275                        sameLevelNodes.Add(currentNode);
276                    }
277                    else
278                    {
279                        sameLevelNodes = new List<TreeNode>();
280                        sameLevelNodes.Add(currentNode);
281                        levelMap.Add(currentNode.level, sameLevelNodes);
282                    }
[12762]283
[12781]284                    // draw line from current node to all its children
285                    if (currentNode.children != null)
[12762]286                    {
[12781]287                        foreach (TreeNode childNode in currentNode.children)
288                        {
289                            toDoNodes.Add(childNode);
290                            // declare node
[12762]291
[12781]292                            string hexColor = GetHexNodeColor(Color.White, Color.OrangeRed, childNode.actionInfo.Value);
293                            dotFile.AppendLine(
294                                string.Format("{0} [label=\"{1}\\n{2:0.00}/{3}\", style=filled, fillcolor=\"{4}\"]",
295                                    childNode.GetHashCode(),
296                                    childNode.phrase, childNode.actionInfo.Value, childNode.actionInfo.Tries, hexColor));
297                            // add edge
298                            dotFile.AppendLine(string.Format("{0} -> {1}", currentNode.GetHashCode(),
299                                childNode.GetHashCode()));
300                        }
[12762]301                    }
302                }
303
[12781]304                // set same level ranks..
305                foreach (KeyValuePair<int, List<TreeNode>> entry in levelMap)
[12762]306                {
[12781]307                    dotFile.Append("{rank = same;");
308                    foreach (TreeNode node in entry.Value)
309                    {
310                        dotFile.Append(string.Format(" {0};", node.GetHashCode()));
311                    }
312                    dotFile.AppendLine("}");
[12762]313                }
[12781]314
315                dotFile.Append(" }");
316                byte[] output = wrapper.GenerateGraph(dotFile.ToString(), Enums.GraphReturnType.Svg);
317                return output;
[12762]318            }
[12781]319            return null;
[12050]320        }
321
[12781]322        protected String HexConverter(Color c)
[12050]323        {
[12762]324            return "#" + c.R.ToString("X2") + c.G.ToString("X2") + c.B.ToString("X2");
[12050]325        }
[12762]326
[12781]327        protected String GetHexNodeColor(Color weakColor, Color strongColor, double quality)
[12762]328        {
329            // convert quality to value between 0 and 1
330            double bestKnownQuality = problem.BestKnownQuality(this.maxLen);
331            double q = quality / bestKnownQuality;
332
333            // calculate difference between colors
334            byte rDiff = (byte)Math.Abs(weakColor.R - strongColor.R);
335            byte bDiff = (byte)Math.Abs(weakColor.B - strongColor.B);
336            byte gDiff = (byte)Math.Abs(weakColor.G - strongColor.G);
337
338            byte newR = weakColor.R > strongColor.R
339                ? Convert.ToByte(weakColor.R - Math.Round(rDiff * q))
340                : Convert.ToByte(weakColor.R + Math.Round(rDiff * q));
341
342            byte newB = weakColor.B > strongColor.B
343                ? Convert.ToByte(weakColor.B - Math.Round(bDiff * q))
344                : Convert.ToByte(weakColor.B + Math.Round(bDiff * q));
345
346            byte newG = weakColor.G > strongColor.G
347                ? Convert.ToByte(weakColor.G - Math.Round(gDiff * q))
348                : Convert.ToByte(weakColor.G + Math.Round(gDiff * q));
349
350            return HexConverter(Color.FromArgb(newR, newG, newB));
351        }
352
353        public void PauseContinue()
354        {
355            lock (pauseLock)
356            {
357                if (isPaused)
358                {
359                    isPaused = false;
360                    Monitor.Pulse(pauseLock);
361                }
362                else
363                {
364                    isPaused = true;
365                }
366            }
367        }
[12050]368    }
369}
Note: See TracBrowser for help on using the repository browser.