Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.MonteCarloTreeSearch/MonteCarloTreeSearch.cs @ 12803

Last change on this file since 12803 was 12781, checked in by aballeit, 9 years ago

#2283 stable GUI; ThreadPool for runs; improved TreeAnalysis

File size: 12.9 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Drawing;
4using System.Linq;
5using System.Net.Mail;
6using System.Text;
7using System.Threading;
8using GraphVizWrapper;
9using GraphVizWrapper.Commands;
10using GraphVizWrapper.Queries;
11using HeuristicLab.Algorithms.Bandits;
12using HeuristicLab.Algorithms.GrammaticalOptimization;
13using HeuristicLab.Algorithms.MonteCarloTreeSearch.Base;
14using HeuristicLab.Algorithms.MonteCarloTreeSearch.Simulation;
15using HeuristicLab.Common;
16using HeuristicLab.Problems.GrammaticalOptimization;
17
18namespace HeuristicLab.Algorithms.MonteCarloTreeSearch
19{
20    public class MonteCarloTreeSearch : SolverBase
21    {
22        protected readonly int maxLen;
23        protected readonly IProblem problem;
24        protected readonly IGrammar grammar;
25        protected readonly Random random;
26        protected readonly IBanditPolicy behaviourPolicy;
27        protected readonly ISimulation simulation;
28        protected TreeNode rootNode;
29        protected bool isPaused = false;
30        protected object pauseLock = new object();
31
32        public MonteCarloTreeSearch(IProblem problem, int maxLen, Random random, IBanditPolicy behaviourPolicy,
33            ISimulation simulationPolicy)
34        {
35            this.problem = problem;
36            this.grammar = problem.Grammar;
37            this.maxLen = maxLen;
38            this.random = random;
39            this.behaviourPolicy = behaviourPolicy;
40            this.simulation = simulationPolicy;
41        }
42
43        public bool StopRequested { get; set; }
44
45        public bool IsPaused
46        {
47            get { return this.isPaused; }
48        }
49
50        public override void Run(int maxIterations)
51        {
52            Reset();
53            for (int i = 0; !StopRequested && i < maxIterations; i++)
54            {
55                lock (pauseLock)
56                {
57                    if (isPaused)
58                    {
59                        Monitor.Wait(pauseLock);
60                    }
61                }
62                TreeNode currentNode = rootNode;
63
64                while (!currentNode.IsLeaf())
65                {
66                    int currentActionIndex = behaviourPolicy.SelectAction(random,
67                        currentNode.GetChildActionInfos());
68                    currentNode = currentNode.children[currentActionIndex];
69                }
70
71                string phrase = currentNode.phrase;
72
73                if (!grammar.IsTerminal(phrase) && phrase.Length <= maxLen)
74                {
75                    ExpandTreeNode(currentNode);
76                    currentNode =
77                        currentNode.children[behaviourPolicy.SelectAction(random, currentNode.GetChildActionInfos())
78                            ];
79                }
80                if (currentNode.phrase.Length <= maxLen)
81                {
82                    string simulatedPhrase;
83                    double quality = simulation.Simulate(currentNode, out simulatedPhrase);
84                    OnSolutionEvaluated(simulatedPhrase, quality);
85
86                    Propagate(currentNode, quality);
87                }
88            }
89        }
90
91        protected void ExpandTreeNode(TreeNode treeNode)
92        {
93            // create children on the first visit
94            if (treeNode.children == null)
95            {
96                treeNode.children = new List<TreeNode>();
97
98                var phrase = new Sequence(treeNode.phrase);
99                // create subnodes for each nt-symbol in phrase
100                for (int i = 0; i < phrase.Length; i++)
101                {
102                    char symbol = phrase[i];
103                    if (grammar.IsNonTerminal(symbol))
104                    {
105                        // create subnode for each alternative of symbol
106                        foreach (Sequence alternative in grammar.GetAlternatives(symbol))
107                        {
108                            Sequence newSequence = new Sequence(phrase);
109                            newSequence.ReplaceAt(i, 1, alternative);
110                            if (newSequence.Length <= maxLen)
111                            {
112                                TreeNode childNode = new TreeNode(treeNode, newSequence.ToString(),
113                                    behaviourPolicy.CreateActionInfo(), treeNode.level + 1);
114                                treeNode.children.Add(childNode);
115                            }
116                        }
117                    }
118                }
119            }
120        }
121
122        protected void Reset()
123        {
124            StopRequested = false;
125            bestQuality = 0.0;
126            rootNode = new TreeNode(null, grammar.SentenceSymbol.ToString(), behaviourPolicy.CreateActionInfo(), 0);
127        }
128
129        protected void Propagate(TreeNode node, double quality)
130        {
131            var currentNode = node;
132            do
133            {
134                currentNode.actionInfo.UpdateReward(quality);
135                currentNode = currentNode.parent;
136            } while (currentNode != null);
137        }
138
139        private void GetTreeInfosRek(TreeInfos treeInfos, List<TreeNode> children)
140        {
141            treeInfos.TotalNodes += children.Count;
142            foreach (TreeNode child in children)
143            {
144                if (child.children != null)
145                {
146                    treeInfos.ExpandedNodes++;
147                    if (treeInfos.DeepestLevel <= child.level)
148                    {
149                        treeInfos.DeepestLevel = child.level + 1;
150                    }
151                    GetTreeInfosRek(treeInfos, child.children);
152                }
153                else
154                {
155                    if (grammar.IsTerminal(child.phrase))
156                    {
157                        treeInfos.LeaveNodes++;
158                    }
159                    else
160                    {
161                        treeInfos.UnexpandedNodes++;
162                    }
163                }
164            }
165        }
166
167        public TreeInfos GetTreeInfos()
168        {
169            TreeInfos treeInfos = new TreeInfos();
170
171            if (rootNode != null)
172            {
173                treeInfos.TotalNodes++;
174                if (rootNode.children != null)
175                {
176                    treeInfos.ExpandedNodes++;
177                    treeInfos.DeepestLevel = rootNode.level + 1;
178                    GetTreeInfosRek(treeInfos, rootNode.children);
179                }
180                else
181                {
182                    treeInfos.DeepestLevel = rootNode.level;
183                    if (grammar.IsTerminal(rootNode.phrase))
184                    {
185                        treeInfos.LeaveNodes++;
186                    }
187                    else
188                    {
189                        treeInfos.UnexpandedNodes++;
190                    }
191                }
192            }
193            return treeInfos;
194        }
195
196        public byte[] GenerateSvg()
197        {
198            if (GetTreeInfos().TotalNodes < 1000)
199            {
200                IGetStartProcessQuery getStartProcessQuery = new GetStartProcessQuery();
201                IGetProcessStartInfoQuery getProcessStartInfoQuery = new GetProcessStartInfoQuery();
202                IRegisterLayoutPluginCommand registerLayoutPluginCommand =
203                    new RegisterLayoutPluginCommand(getProcessStartInfoQuery, getStartProcessQuery);
204
205                GraphGeneration wrapper = new GraphGeneration(getStartProcessQuery,
206                    getProcessStartInfoQuery,
207                    registerLayoutPluginCommand);
208                wrapper.GraphvizPath = @"../../../Graphviz2.38/bin";
209                StringBuilder dotFile = new StringBuilder("digraph {");
210                dotFile.AppendLine();
211                dotFile.AppendLine("splines=ortho;");
212                dotFile.AppendLine("concentrate=true;");
213                dotFile.AppendLine("ranksep=1.2;");
214
215                List<TreeNode> toDoNodes = new List<TreeNode>();
216                if (rootNode != null)
217                {
218                    toDoNodes.Add(rootNode);
219                    // declare node
220                    string hexColor = GetHexNodeColor(Color.White, Color.OrangeRed, rootNode.actionInfo.Value);
221
222                    dotFile.AppendLine(
223                        string.Format("{0} [label=\"{1}\\n{2:0.00}/{3}\", style=filled, fillcolor=\"{4}\"]",
224                            rootNode.GetHashCode(),
225                            rootNode.phrase, rootNode.actionInfo.Value, rootNode.actionInfo.Tries, hexColor));
226                }
227
228                // to put nodes on the same level in graph
229                Dictionary<int, List<TreeNode>> levelMap = new Dictionary<int, List<TreeNode>>();
230
231                List<TreeNode> sameLevelNodes;
232
233                while (toDoNodes.Any())
234                {
235                    TreeNode currentNode = toDoNodes[0];
236                    toDoNodes.RemoveAt(0);
237                    // put currentNode into levelMap
238                    if (levelMap.TryGetValue(currentNode.level, out sameLevelNodes))
239                    {
240                        sameLevelNodes.Add(currentNode);
241                    }
242                    else
243                    {
244                        sameLevelNodes = new List<TreeNode>();
245                        sameLevelNodes.Add(currentNode);
246                        levelMap.Add(currentNode.level, sameLevelNodes);
247                    }
248
249                    // draw line from current node to all its children
250                    if (currentNode.children != null)
251                    {
252                        foreach (TreeNode childNode in currentNode.children)
253                        {
254                            toDoNodes.Add(childNode);
255                            // declare node
256
257                            string hexColor = GetHexNodeColor(Color.White, Color.OrangeRed, childNode.actionInfo.Value);
258                            dotFile.AppendLine(
259                                string.Format("{0} [label=\"{1}\\n{2:0.00}/{3}\", style=filled, fillcolor=\"{4}\"]",
260                                    childNode.GetHashCode(),
261                                    childNode.phrase, childNode.actionInfo.Value, childNode.actionInfo.Tries, hexColor));
262                            // add edge
263                            dotFile.AppendLine(string.Format("{0} -> {1}", currentNode.GetHashCode(),
264                                childNode.GetHashCode()));
265                        }
266                    }
267                }
268
269                // set same level ranks..
270                foreach (KeyValuePair<int, List<TreeNode>> entry in levelMap)
271                {
272                    dotFile.Append("{rank = same;");
273                    foreach (TreeNode node in entry.Value)
274                    {
275                        dotFile.Append(string.Format(" {0};", node.GetHashCode()));
276                    }
277                    dotFile.AppendLine("}");
278                }
279
280                dotFile.Append(" }");
281                byte[] output = wrapper.GenerateGraph(dotFile.ToString(), Enums.GraphReturnType.Svg);
282                return output;
283            }
284            return null;
285        }
286
287        protected String HexConverter(Color c)
288        {
289            return "#" + c.R.ToString("X2") + c.G.ToString("X2") + c.B.ToString("X2");
290        }
291
292        protected String GetHexNodeColor(Color weakColor, Color strongColor, double quality)
293        {
294            // convert quality to value between 0 and 1
295            double bestKnownQuality = problem.BestKnownQuality(this.maxLen);
296            double q = quality / bestKnownQuality;
297
298            // calculate difference between colors
299            byte rDiff = (byte)Math.Abs(weakColor.R - strongColor.R);
300            byte bDiff = (byte)Math.Abs(weakColor.B - strongColor.B);
301            byte gDiff = (byte)Math.Abs(weakColor.G - strongColor.G);
302
303            byte newR = weakColor.R > strongColor.R
304                ? Convert.ToByte(weakColor.R - Math.Round(rDiff * q))
305                : Convert.ToByte(weakColor.R + Math.Round(rDiff * q));
306
307            byte newB = weakColor.B > strongColor.B
308                ? Convert.ToByte(weakColor.B - Math.Round(bDiff * q))
309                : Convert.ToByte(weakColor.B + Math.Round(bDiff * q));
310
311            byte newG = weakColor.G > strongColor.G
312                ? Convert.ToByte(weakColor.G - Math.Round(gDiff * q))
313                : Convert.ToByte(weakColor.G + Math.Round(gDiff * q));
314
315            return HexConverter(Color.FromArgb(newR, newG, newB));
316        }
317
318        public void PauseContinue()
319        {
320            lock (pauseLock)
321            {
322                if (isPaused)
323                {
324                    isPaused = false;
325                    Monitor.Pulse(pauseLock);
326                }
327                else
328                {
329                    isPaused = true;
330                }
331            }
332        }
333    }
334}
Note: See TracBrowser for help on using the repository browser.