source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.MonteCarloTreeSearch/MonteCarloTreeSearch.cs @ 12827

Last change on this file since 12827 was 12827, checked in by aballeit, 6 years ago

#2283 added TreeInfos to MCTS excel export

File size: 14.1 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Drawing;
5using System.Linq;
6using System.Net.Mail;
7using System.Text;
8using System.Threading;
9using GraphVizWrapper;
10using GraphVizWrapper.Commands;
11using GraphVizWrapper.Queries;
12using HeuristicLab.Algorithms.Bandits;
13using HeuristicLab.Algorithms.GrammaticalOptimization;
14using HeuristicLab.Algorithms.MonteCarloTreeSearch.Base;
15using HeuristicLab.Algorithms.MonteCarloTreeSearch.Simulation;
16using HeuristicLab.Common;
17using HeuristicLab.Problems.GrammaticalOptimization;
18
19namespace HeuristicLab.Algorithms.MonteCarloTreeSearch
20{
21    public class MonteCarloTreeSearch : SolverBase
22    {
23        protected readonly int maxLen;
24        protected readonly ISymbolicExpressionTreeProblem problem;
25        protected readonly IGrammar grammar;
26        protected readonly Random random;
27        protected readonly IBanditPolicy behaviourPolicy;
28        protected readonly ISimulation simulation;
29        protected TreeNode rootNode;
30        protected bool isPaused = false;
31        protected object pauseLock = new object();
32        protected int goodSelections;
33
34        public MonteCarloTreeSearch(ISymbolicExpressionTreeProblem problem, int maxLen, Random random, IBanditPolicy behaviourPolicy,
35            ISimulation simulationPolicy)
36        {
37            this.problem = problem;
38            this.grammar = problem.Grammar;
39            this.maxLen = maxLen;
40            this.random = random;
41            this.behaviourPolicy = behaviourPolicy;
42            this.simulation = simulationPolicy;
43
44            this.problem.GenerateProblemSolutions(maxLen);
45        }
46
47        public event Action<int, int> NodeSelected;
48
49        protected virtual void OnNodeSelectedChanged(int value, int selections)
50        {
51            RaiseNodeSelectedChanged(value, selections);
52        }
53
54        private void RaiseNodeSelectedChanged(int value, int selections)
55        {
56            var handler = NodeSelected;
57            if (handler != null) handler(value, selections);
58        }
59
60        public bool StopRequested { get; set; }
61
62        public bool IsPaused
63        {
64            get { return this.isPaused; }
65        }
66
67        public int GoodSelections { get { return this.goodSelections; } }
68
69        protected void CheckSelection(TreeNode node, int selections)
70        {
71            if (problem.IsParentOfProblemSolution(node.phrase, maxLen))
72            {
73                goodSelections++;
74            }
75
76            OnNodeSelectedChanged(goodSelections, selections);
77        }
78
79        public override void Run(int maxIterations)
80        {
81            Reset();
82            int selections = 0;
83            for (int i = 0; !StopRequested && i < maxIterations; i++)
84            {
85                lock (pauseLock)
86                {
87                    if (isPaused)
88                    {
89                        Monitor.Wait(pauseLock);
90                    }
91                }
92                TreeNode currentNode = rootNode;
93
94                while (!currentNode.IsLeaf())
95                {
96                    int currentActionIndex = behaviourPolicy.SelectAction(random,
97                        currentNode.GetChildActionInfos());
98                    currentNode = currentNode.children[currentActionIndex];
99                    //selections++;
100                    //CheckSelection(currentNode, selections);
101                }
102
103                string phrase = currentNode.phrase;
104
105                if (!grammar.IsTerminal(phrase) && phrase.Length <= maxLen)
106                {
107                    ExpandTreeNode(currentNode);
108                    currentNode =
109                        currentNode.children[behaviourPolicy.SelectAction(random, currentNode.GetChildActionInfos())
110                            ];
111                    //selections++;
112                    //CheckSelection(currentNode, selections);
113                }
114                if (currentNode.phrase.Length <= maxLen)
115                {
116                    string simulatedPhrase;
117                    double quality = simulation.Simulate(currentNode, out simulatedPhrase);
118                    OnSolutionEvaluated(simulatedPhrase, quality);
119
120                    Propagate(currentNode, quality);
121                }
122            }
123        }
124
125        protected void ExpandTreeNode(TreeNode treeNode)
126        {
127            // create children on the first visit
128            if (treeNode.children == null)
129            {
130                treeNode.children = new List<TreeNode>();
131
132                var phrase = new Sequence(treeNode.phrase);
133                // create subnodes for each nt-symbol in phrase
134                for (int i = 0; i < phrase.Length; i++)
135                {
136                    char symbol = phrase[i];
137                    if (grammar.IsNonTerminal(symbol))
138                    {
139                        // create subnode for each alternative of symbol
140                        foreach (Sequence alternative in grammar.GetAlternatives(symbol))
141                        {
142                            Sequence newSequence = new Sequence(phrase);
143                            newSequence.ReplaceAt(i, 1, alternative);
144                            if (newSequence.Length <= maxLen)
145                            {
146                                TreeNode childNode = new TreeNode(treeNode, newSequence.ToString(),
147                                    behaviourPolicy.CreateActionInfo(), treeNode.level + 1);
148                                treeNode.children.Add(childNode);
149                            }
150                        }
151                    }
152                }
153            }
154        }
155
156        protected void Reset()
157        {
158            goodSelections = 0;
159            StopRequested = false;
160            bestQuality = 0.0;
161            rootNode = new TreeNode(null, grammar.SentenceSymbol.ToString(), behaviourPolicy.CreateActionInfo(), 0);
162        }
163
164        protected void Propagate(TreeNode node, double quality)
165        {
166            var currentNode = node;
167            do
168            {
169                currentNode.actionInfo.UpdateReward(quality);
170                currentNode = currentNode.parent;
171            } while (currentNode != null);
172        }
173
174        private void GetTreeInfosRek(TreeInfos treeInfos, List<TreeNode> children)
175        {
176            treeInfos.TotalNodes += children.Count;
177            foreach (TreeNode child in children)
178            {
179                if (child.children != null)
180                {
181                    treeInfos.ExpandedNodes++;
182                    if (treeInfos.DeepestLevel <= child.level)
183                    {
184                        treeInfos.DeepestLevel = child.level + 1;
185                    }
186                    GetTreeInfosRek(treeInfos, child.children);
187                }
188                else
189                {
190                    if (grammar.IsTerminal(child.phrase))
191                    {
192                        treeInfos.LeaveNodes++;
193                    }
194                    else
195                    {
196                        treeInfos.UnexpandedNodes++;
197                    }
198                }
199            }
200        }
201
202        public TreeInfos GetTreeInfos()
203        {
204            TreeInfos treeInfos = new TreeInfos();
205
206            if (rootNode != null)
207            {
208                treeInfos.TotalNodes++;
209                if (rootNode.children != null)
210                {
211                    treeInfos.ExpandedNodes++;
212                    treeInfos.DeepestLevel = rootNode.level + 1;
213                    GetTreeInfosRek(treeInfos, rootNode.children);
214                }
215                else
216                {
217                    treeInfos.DeepestLevel = rootNode.level;
218                    if (grammar.IsTerminal(rootNode.phrase))
219                    {
220                        treeInfos.LeaveNodes++;
221                    }
222                    else
223                    {
224                        treeInfos.UnexpandedNodes++;
225                    }
226                }
227            }
228            return treeInfos;
229        }
230
231        public byte[] GenerateSvg()
232        {
233            if (GetTreeInfos().TotalNodes < 1000)
234            {
235                IGetStartProcessQuery getStartProcessQuery = new GetStartProcessQuery();
236                IGetProcessStartInfoQuery getProcessStartInfoQuery = new GetProcessStartInfoQuery();
237                IRegisterLayoutPluginCommand registerLayoutPluginCommand =
238                    new RegisterLayoutPluginCommand(getProcessStartInfoQuery, getStartProcessQuery);
239
240                GraphGeneration wrapper = new GraphGeneration(getStartProcessQuery,
241                    getProcessStartInfoQuery,
242                    registerLayoutPluginCommand);
243                wrapper.GraphvizPath = @"../../../Graphviz2.38/bin";
244                StringBuilder dotFile = new StringBuilder("digraph {");
245                dotFile.AppendLine();
246                dotFile.AppendLine("splines=ortho;");
247                dotFile.AppendLine("concentrate=true;");
248                dotFile.AppendLine("ranksep=1.2;");
249
250                List<TreeNode> toDoNodes = new List<TreeNode>();
251                if (rootNode != null)
252                {
253                    toDoNodes.Add(rootNode);
254                    // declare node
255                    string hexColor = GetHexNodeColor(Color.White, Color.OrangeRed, rootNode.actionInfo.Value);
256
257                    dotFile.AppendLine(
258                        string.Format("{0} [label=\"{1}\\n{2:0.00}/{3}\", style=filled, fillcolor=\"{4}\"]",
259                            rootNode.GetHashCode(),
260                            rootNode.phrase, rootNode.actionInfo.Value, rootNode.actionInfo.Tries, hexColor));
261                }
262
263                // to put nodes on the same level in graph
264                Dictionary<int, List<TreeNode>> levelMap = new Dictionary<int, List<TreeNode>>();
265
266                List<TreeNode> sameLevelNodes;
267
268                while (toDoNodes.Any())
269                {
270                    TreeNode currentNode = toDoNodes[0];
271                    toDoNodes.RemoveAt(0);
272                    // put currentNode into levelMap
273                    if (levelMap.TryGetValue(currentNode.level, out sameLevelNodes))
274                    {
275                        sameLevelNodes.Add(currentNode);
276                    }
277                    else
278                    {
279                        sameLevelNodes = new List<TreeNode>();
280                        sameLevelNodes.Add(currentNode);
281                        levelMap.Add(currentNode.level, sameLevelNodes);
282                    }
283
284                    // draw line from current node to all its children
285                    if (currentNode.children != null)
286                    {
287                        foreach (TreeNode childNode in currentNode.children)
288                        {
289                            toDoNodes.Add(childNode);
290                            // declare node
291
292                            string hexColor = GetHexNodeColor(Color.White, Color.OrangeRed, childNode.actionInfo.Value);
293                            dotFile.AppendLine(
294                                string.Format("{0} [label=\"{1}\\n{2:0.00}/{3}\", style=filled, fillcolor=\"{4}\"]",
295                                    childNode.GetHashCode(),
296                                    childNode.phrase, childNode.actionInfo.Value, childNode.actionInfo.Tries, hexColor));
297                            // add edge
298                            dotFile.AppendLine(string.Format("{0} -> {1}", currentNode.GetHashCode(),
299                                childNode.GetHashCode()));
300                        }
301                    }
302                }
303
304                // set same level ranks..
305                foreach (KeyValuePair<int, List<TreeNode>> entry in levelMap)
306                {
307                    dotFile.Append("{rank = same;");
308                    foreach (TreeNode node in entry.Value)
309                    {
310                        dotFile.Append(string.Format(" {0};", node.GetHashCode()));
311                    }
312                    dotFile.AppendLine("}");
313                }
314
315                dotFile.Append(" }");
316                byte[] output = wrapper.GenerateGraph(dotFile.ToString(), Enums.GraphReturnType.Svg);
317                return output;
318            }
319            return null;
320        }
321
322        protected String HexConverter(Color c)
323        {
324            return "#" + c.R.ToString("X2") + c.G.ToString("X2") + c.B.ToString("X2");
325        }
326
327        protected String GetHexNodeColor(Color weakColor, Color strongColor, double quality)
328        {
329            // convert quality to value between 0 and 1
330            double bestKnownQuality = problem.BestKnownQuality(this.maxLen);
331            double q = quality / bestKnownQuality;
332
333            // calculate difference between colors
334            byte rDiff = (byte)Math.Abs(weakColor.R - strongColor.R);
335            byte bDiff = (byte)Math.Abs(weakColor.B - strongColor.B);
336            byte gDiff = (byte)Math.Abs(weakColor.G - strongColor.G);
337
338            byte newR = weakColor.R > strongColor.R
339                ? Convert.ToByte(weakColor.R - Math.Round(rDiff * q))
340                : Convert.ToByte(weakColor.R + Math.Round(rDiff * q));
341
342            byte newB = weakColor.B > strongColor.B
343                ? Convert.ToByte(weakColor.B - Math.Round(bDiff * q))
344                : Convert.ToByte(weakColor.B + Math.Round(bDiff * q));
345
346            byte newG = weakColor.G > strongColor.G
347                ? Convert.ToByte(weakColor.G - Math.Round(gDiff * q))
348                : Convert.ToByte(weakColor.G + Math.Round(gDiff * q));
349
350            return HexConverter(Color.FromArgb(newR, newG, newB));
351        }
352
353        public void PauseContinue()
354        {
355            lock (pauseLock)
356            {
357                if (isPaused)
358                {
359                    isPaused = false;
360                    Monitor.Pulse(pauseLock);
361                }
362                else
363                {
364                    isPaused = true;
365                }
366            }
367        }
368    }
369}
Note: See TracBrowser for help on using the repository browser.