Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.MonteCarloTreeSearch/MonteCarloTreeSearch.cs @ 13825

Last change on this file since 13825 was 13492, checked in by aballeit, 9 years ago

#2283 UCT parameter c

File size: 13.9 KB
RevLine 
[12050]1using System;
2using System.Collections.Generic;
[12815]3using System.Diagnostics;
[12762]4using System.Drawing;
[12050]5using System.Linq;
[12762]6using System.Net.Mail;
7using System.Text;
8using System.Threading;
9using GraphVizWrapper;
10using GraphVizWrapper.Commands;
11using GraphVizWrapper.Queries;
[12050]12using HeuristicLab.Algorithms.Bandits;
[12098]13using HeuristicLab.Algorithms.GrammaticalOptimization;
14using HeuristicLab.Algorithms.MonteCarloTreeSearch.Base;
15using HeuristicLab.Algorithms.MonteCarloTreeSearch.Simulation;
16using HeuristicLab.Common;
[12050]17using HeuristicLab.Problems.GrammaticalOptimization;
18
[12098]19namespace HeuristicLab.Algorithms.MonteCarloTreeSearch
[12050]20{
21    public class MonteCarloTreeSearch : SolverBase
22    {
[12781]23        protected readonly int maxLen;
[12815]24        protected readonly ISymbolicExpressionTreeProblem problem;
[12781]25        protected readonly IGrammar grammar;
26        protected readonly Random random;
27        protected readonly IBanditPolicy behaviourPolicy;
28        protected readonly ISimulation simulation;
29        protected TreeNode rootNode;
[12815]30        protected int goodSelections;
[12050]31
[12815]32        public MonteCarloTreeSearch(ISymbolicExpressionTreeProblem problem, int maxLen, Random random, IBanditPolicy behaviourPolicy,
[12762]33            ISimulation simulationPolicy)
[12050]34        {
35            this.problem = problem;
[12098]36            this.grammar = problem.Grammar;
[12050]37            this.maxLen = maxLen;
38            this.random = random;
39            this.behaviourPolicy = behaviourPolicy;
[12098]40            this.simulation = simulationPolicy;
[12815]41
42            this.problem.GenerateProblemSolutions(maxLen);
[12050]43        }
44
[12840]45        public event Action<int, int> IterationFinished;
[12815]46
[12840]47        protected virtual void OnIterationFinishedChanged(int value, int selections)
[12815]48        {
[12840]49            RaiseIterationFinishedChanged(value, selections);
[12815]50        }
51
[12840]52        private void RaiseIterationFinishedChanged(int value, int selections)
[12815]53        {
[12840]54            var handler = IterationFinished;
[12815]55            if (handler != null) handler(value, selections);
56        }
57
[12762]58        public bool StopRequested { get; set; }
[12829]59       
[12815]60        public int GoodSelections { get { return this.goodSelections; } }
61
62        protected void CheckSelection(TreeNode node, int selections)
63        {
64            if (problem.IsParentOfProblemSolution(node.phrase, maxLen))
65            {
66                goodSelections++;
67            }
68        }
69
[12050]70        public override void Run(int maxIterations)
71        {
72            Reset();
[12815]73            int selections = 0;
[12829]74            TreeNode currentNode;
75            string phrase;
76            string simulatedPhrase;
77            double quality;
78
[12098]79            for (int i = 0; !StopRequested && i < maxIterations; i++)
[12050]80            {
[12829]81                currentNode = rootNode;
[12098]82
83                while (!currentNode.IsLeaf())
[12050]84                {
[12098]85                    int currentActionIndex = behaviourPolicy.SelectAction(random,
86                        currentNode.GetChildActionInfos());
87                    currentNode = currentNode.children[currentActionIndex];
[13492]88                    //selections++;
89                    //CheckSelection(currentNode, selections);
[12098]90                }
[12050]91
[12829]92                phrase = currentNode.phrase;
[12098]93
[12781]94                if (!grammar.IsTerminal(phrase) && phrase.Length <= maxLen)
[12098]95                {
96                    ExpandTreeNode(currentNode);
[12781]97                    currentNode =
98                        currentNode.children[behaviourPolicy.SelectAction(random, currentNode.GetChildActionInfos())
99                            ];
[13492]100                    //selections++;
101                    //CheckSelection(currentNode, selections);
[12098]102                }
[12503]103                if (currentNode.phrase.Length <= maxLen)
104                {
[12829]105                    quality = simulation.Simulate(currentNode, out simulatedPhrase);
[12762]106                    OnSolutionEvaluated(simulatedPhrase, quality);
[12098]107
[12840]108                    OnIterationFinishedChanged(goodSelections, selections);
[12503]109                    Propagate(currentNode, quality);
110                }
[12050]111            }
112        }
113
[12781]114        protected void ExpandTreeNode(TreeNode treeNode)
[12762]115        {
[12098]116            // create children on the first visit
[12829]117            Sequence newSequence;
118            TreeNode childNode;
[12098]119            if (treeNode.children == null)
120            {
[12832]121                List<TreeNode> newChildren = new List<TreeNode>();
[12098]122
123                var phrase = new Sequence(treeNode.phrase);
124                // create subnodes for each nt-symbol in phrase
125                for (int i = 0; i < phrase.Length; i++)
126                {
127                    char symbol = phrase[i];
128                    if (grammar.IsNonTerminal(symbol))
129                    {
130                        // create subnode for each alternative of symbol
131                        foreach (Sequence alternative in grammar.GetAlternatives(symbol))
132                        {
[12829]133                            newSequence = new Sequence(phrase);
[12098]134                            newSequence.ReplaceAt(i, 1, alternative);
135                            if (newSequence.Length <= maxLen)
136                            {
[12829]137                                childNode = new TreeNode(treeNode, newSequence.ToString(),
[12832]138                                    behaviourPolicy.CreateActionInfo(), (ushort) (treeNode.level + 1));
139                                newChildren.Add(childNode);
[12098]140                            }
141                        }
142                    }
143                }
[12832]144                treeNode.children = newChildren.ToArray();
[12098]145            }
146        }
147
[12781]148        protected void Reset()
[12050]149        {
[12815]150            goodSelections = 0;
[12050]151            StopRequested = false;
152            bestQuality = 0.0;
[12762]153            rootNode = new TreeNode(null, grammar.SentenceSymbol.ToString(), behaviourPolicy.CreateActionInfo(), 0);
[12050]154        }
155
[12781]156        protected void Propagate(TreeNode node, double quality)
[12050]157        {
[12098]158            var currentNode = node;
159            do
160            {
161                currentNode.actionInfo.UpdateReward(quality);
162                currentNode = currentNode.parent;
163            } while (currentNode != null);
[12050]164        }
165
[12832]166        private void GetTreeInfosRek(TreeInfos treeInfos, TreeNode[] children)
[12781]167        {
[12832]168            treeInfos.TotalNodes += children.Length;
[12781]169            foreach (TreeNode child in children)
170            {
171                if (child.children != null)
172                {
173                    treeInfos.ExpandedNodes++;
174                    if (treeInfos.DeepestLevel <= child.level)
175                    {
176                        treeInfos.DeepestLevel = child.level + 1;
177                    }
178                    GetTreeInfosRek(treeInfos, child.children);
179                }
180                else
181                {
182                    if (grammar.IsTerminal(child.phrase))
183                    {
184                        treeInfos.LeaveNodes++;
185                    }
186                    else
187                    {
188                        treeInfos.UnexpandedNodes++;
189                    }
190                }
191            }
192        }
193
[12762]194        public TreeInfos GetTreeInfos()
[12050]195        {
[12781]196            TreeInfos treeInfos = new TreeInfos();
[12098]197
[12762]198            if (rootNode != null)
199            {
[12781]200                treeInfos.TotalNodes++;
201                if (rootNode.children != null)
[12762]202                {
[12781]203                    treeInfos.ExpandedNodes++;
204                    treeInfos.DeepestLevel = rootNode.level + 1;
205                    GetTreeInfosRek(treeInfos, rootNode.children);
[12762]206                }
207                else
208                {
[12781]209                    treeInfos.DeepestLevel = rootNode.level;
210                    if (grammar.IsTerminal(rootNode.phrase))
[12762]211                    {
[12781]212                        treeInfos.LeaveNodes++;
[12762]213                    }
214                    else
215                    {
[12781]216                        treeInfos.UnexpandedNodes++;
[12762]217                    }
218                }
219            }
[12781]220            return treeInfos;
[12762]221        }
[12098]222
[12762]223        public byte[] GenerateSvg()
224        {
[13492]225            if (GetTreeInfos().TotalNodes < 6000)
[12781]226            {
227                IGetStartProcessQuery getStartProcessQuery = new GetStartProcessQuery();
228                IGetProcessStartInfoQuery getProcessStartInfoQuery = new GetProcessStartInfoQuery();
229                IRegisterLayoutPluginCommand registerLayoutPluginCommand =
230                    new RegisterLayoutPluginCommand(getProcessStartInfoQuery, getStartProcessQuery);
[12098]231
[12781]232                GraphGeneration wrapper = new GraphGeneration(getStartProcessQuery,
233                    getProcessStartInfoQuery,
234                    registerLayoutPluginCommand);
[13492]235                wrapper.GraphvizPath = @"../../../../Graphviz2.38/bin";
[12781]236                StringBuilder dotFile = new StringBuilder("digraph {");
237                dotFile.AppendLine();
238                dotFile.AppendLine("splines=ortho;");
239                dotFile.AppendLine("concentrate=true;");
240                dotFile.AppendLine("ranksep=1.2;");
[12098]241
[12781]242                List<TreeNode> toDoNodes = new List<TreeNode>();
243                if (rootNode != null)
244                {
245                    toDoNodes.Add(rootNode);
246                    // declare node
247                    string hexColor = GetHexNodeColor(Color.White, Color.OrangeRed, rootNode.actionInfo.Value);
[12098]248
[12781]249                    dotFile.AppendLine(
250                        string.Format("{0} [label=\"{1}\\n{2:0.00}/{3}\", style=filled, fillcolor=\"{4}\"]",
251                            rootNode.GetHashCode(),
252                            rootNode.phrase, rootNode.actionInfo.Value, rootNode.actionInfo.Tries, hexColor));
253                }
[12098]254
[12781]255                // to put nodes on the same level in graph
256                Dictionary<int, List<TreeNode>> levelMap = new Dictionary<int, List<TreeNode>>();
[12762]257
[12781]258                List<TreeNode> sameLevelNodes;
[12762]259
[12781]260                while (toDoNodes.Any())
[12762]261                {
[12781]262                    TreeNode currentNode = toDoNodes[0];
263                    toDoNodes.RemoveAt(0);
264                    // put currentNode into levelMap
265                    if (levelMap.TryGetValue(currentNode.level, out sameLevelNodes))
266                    {
267                        sameLevelNodes.Add(currentNode);
268                    }
269                    else
270                    {
271                        sameLevelNodes = new List<TreeNode>();
272                        sameLevelNodes.Add(currentNode);
273                        levelMap.Add(currentNode.level, sameLevelNodes);
274                    }
[12762]275
[12781]276                    // draw line from current node to all its children
277                    if (currentNode.children != null)
[12762]278                    {
[12781]279                        foreach (TreeNode childNode in currentNode.children)
280                        {
[13492]281                            if (childNode.actionInfo.Tries > 1)
282                            {
283                                toDoNodes.Add(childNode);
284                                // declare node
[12762]285
[13492]286                                string hexColor = GetHexNodeColor(Color.White, Color.OrangeRed,
287                                    childNode.actionInfo.Value);
288                                dotFile.AppendLine(
289                                    string.Format("{0} [label=\"{1}\\n{2:0.00}/{3}\", style=filled, fillcolor=\"{4}\"]",
290                                        childNode.GetHashCode(),
291                                        childNode.phrase, childNode.actionInfo.Value, childNode.actionInfo.Tries,
292                                        hexColor));
293                                // add edge
294                                dotFile.AppendLine(string.Format("{0} -> {1}", currentNode.GetHashCode(),
295                                    childNode.GetHashCode()));
296                            }
[12781]297                        }
[12762]298                    }
299                }
300
[12781]301                // set same level ranks..
302                foreach (KeyValuePair<int, List<TreeNode>> entry in levelMap)
[12762]303                {
[12781]304                    dotFile.Append("{rank = same;");
305                    foreach (TreeNode node in entry.Value)
306                    {
307                        dotFile.Append(string.Format(" {0};", node.GetHashCode()));
308                    }
309                    dotFile.AppendLine("}");
[12762]310                }
[12781]311
312                dotFile.Append(" }");
313                byte[] output = wrapper.GenerateGraph(dotFile.ToString(), Enums.GraphReturnType.Svg);
314                return output;
[12762]315            }
[12781]316            return null;
[12050]317        }
318
[12781]319        protected String HexConverter(Color c)
[12050]320        {
[12762]321            return "#" + c.R.ToString("X2") + c.G.ToString("X2") + c.B.ToString("X2");
[12050]322        }
[12762]323
[12781]324        protected String GetHexNodeColor(Color weakColor, Color strongColor, double quality)
[12762]325        {
326            // convert quality to value between 0 and 1
327            double bestKnownQuality = problem.BestKnownQuality(this.maxLen);
328            double q = quality / bestKnownQuality;
329
330            // calculate difference between colors
331            byte rDiff = (byte)Math.Abs(weakColor.R - strongColor.R);
332            byte bDiff = (byte)Math.Abs(weakColor.B - strongColor.B);
333            byte gDiff = (byte)Math.Abs(weakColor.G - strongColor.G);
334
335            byte newR = weakColor.R > strongColor.R
336                ? Convert.ToByte(weakColor.R - Math.Round(rDiff * q))
337                : Convert.ToByte(weakColor.R + Math.Round(rDiff * q));
338
339            byte newB = weakColor.B > strongColor.B
340                ? Convert.ToByte(weakColor.B - Math.Round(bDiff * q))
341                : Convert.ToByte(weakColor.B + Math.Round(bDiff * q));
342
343            byte newG = weakColor.G > strongColor.G
344                ? Convert.ToByte(weakColor.G - Math.Round(gDiff * q))
345                : Convert.ToByte(weakColor.G + Math.Round(gDiff * q));
346
347            return HexConverter(Color.FromArgb(newR, newG, newB));
348        }
[12832]349
350        public void FreeAll()
351        {
352            rootNode = null;
353            GC.Collect();
354        }
[12050]355    }
356}
Note: See TracBrowser for help on using the repository browser.