source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.MonteCarloTreeSearch/MonteCarloTreeSearch.cs @ 12832

Last change on this file since 12832 was 12832, checked in by aballeit, 6 years ago

#2283 limit parallelism

File size: 13.6 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Drawing;
5using System.Linq;
6using System.Net.Mail;
7using System.Text;
8using System.Threading;
9using GraphVizWrapper;
10using GraphVizWrapper.Commands;
11using GraphVizWrapper.Queries;
12using HeuristicLab.Algorithms.Bandits;
13using HeuristicLab.Algorithms.GrammaticalOptimization;
14using HeuristicLab.Algorithms.MonteCarloTreeSearch.Base;
15using HeuristicLab.Algorithms.MonteCarloTreeSearch.Simulation;
16using HeuristicLab.Common;
17using HeuristicLab.Problems.GrammaticalOptimization;
18
19namespace HeuristicLab.Algorithms.MonteCarloTreeSearch
20{
21    public class MonteCarloTreeSearch : SolverBase
22    {
23        protected readonly int maxLen;
24        protected readonly ISymbolicExpressionTreeProblem problem;
25        protected readonly IGrammar grammar;
26        protected readonly Random random;
27        protected readonly IBanditPolicy behaviourPolicy;
28        protected readonly ISimulation simulation;
29        protected TreeNode rootNode;
30        protected int goodSelections;
31
32        public MonteCarloTreeSearch(ISymbolicExpressionTreeProblem problem, int maxLen, Random random, IBanditPolicy behaviourPolicy,
33            ISimulation simulationPolicy)
34        {
35            this.problem = problem;
36            this.grammar = problem.Grammar;
37            this.maxLen = maxLen;
38            this.random = random;
39            this.behaviourPolicy = behaviourPolicy;
40            this.simulation = simulationPolicy;
41
42            this.problem.GenerateProblemSolutions(maxLen);
43        }
44
45        public event Action<int, int> NodeSelected;
46
47        protected virtual void OnNodeSelectedChanged(int value, int selections)
48        {
49            RaiseNodeSelectedChanged(value, selections);
50        }
51
52        private void RaiseNodeSelectedChanged(int value, int selections)
53        {
54            var handler = NodeSelected;
55            if (handler != null) handler(value, selections);
56        }
57
58        public bool StopRequested { get; set; }
59       
60        public int GoodSelections { get { return this.goodSelections; } }
61
62        protected void CheckSelection(TreeNode node, int selections)
63        {
64            if (problem.IsParentOfProblemSolution(node.phrase, maxLen))
65            {
66                goodSelections++;
67            }
68
69            OnNodeSelectedChanged(goodSelections, selections);
70        }
71
72        public override void Run(int maxIterations)
73        {
74            Reset();
75            int selections = 0;
76            TreeNode currentNode;
77            string phrase;
78            string simulatedPhrase;
79            double quality;
80
81            for (int i = 0; !StopRequested && i < maxIterations; i++)
82            {
83                currentNode = rootNode;
84
85                while (!currentNode.IsLeaf())
86                {
87                    int currentActionIndex = behaviourPolicy.SelectAction(random,
88                        currentNode.GetChildActionInfos());
89                    currentNode = currentNode.children[currentActionIndex];
90                    //selections++;
91                    //CheckSelection(currentNode, selections);
92                }
93
94                phrase = currentNode.phrase;
95
96                if (!grammar.IsTerminal(phrase) && phrase.Length <= maxLen)
97                {
98                    ExpandTreeNode(currentNode);
99                    currentNode =
100                        currentNode.children[behaviourPolicy.SelectAction(random, currentNode.GetChildActionInfos())
101                            ];
102                    //selections++;
103                    //CheckSelection(currentNode, selections);
104                }
105                if (currentNode.phrase.Length <= maxLen)
106                {
107                    quality = simulation.Simulate(currentNode, out simulatedPhrase);
108                    OnSolutionEvaluated(simulatedPhrase, quality);
109
110                    Propagate(currentNode, quality);
111                }
112            }
113        }
114
115        protected void ExpandTreeNode(TreeNode treeNode)
116        {
117            // create children on the first visit
118            Sequence newSequence;
119            TreeNode childNode;
120            if (treeNode.children == null)
121            {
122                List<TreeNode> newChildren = new List<TreeNode>();
123
124                var phrase = new Sequence(treeNode.phrase);
125                // create subnodes for each nt-symbol in phrase
126                for (int i = 0; i < phrase.Length; i++)
127                {
128                    char symbol = phrase[i];
129                    if (grammar.IsNonTerminal(symbol))
130                    {
131                        // create subnode for each alternative of symbol
132                        foreach (Sequence alternative in grammar.GetAlternatives(symbol))
133                        {
134                            newSequence = new Sequence(phrase);
135                            newSequence.ReplaceAt(i, 1, alternative);
136                            if (newSequence.Length <= maxLen)
137                            {
138                                childNode = new TreeNode(treeNode, newSequence.ToString(),
139                                    behaviourPolicy.CreateActionInfo(), (ushort) (treeNode.level + 1));
140                                newChildren.Add(childNode);
141                            }
142                        }
143                    }
144                }
145                treeNode.children = newChildren.ToArray();
146            }
147        }
148
149        protected void Reset()
150        {
151            goodSelections = 0;
152            StopRequested = false;
153            bestQuality = 0.0;
154            rootNode = new TreeNode(null, grammar.SentenceSymbol.ToString(), behaviourPolicy.CreateActionInfo(), 0);
155        }
156
157        protected void Propagate(TreeNode node, double quality)
158        {
159            var currentNode = node;
160            do
161            {
162                currentNode.actionInfo.UpdateReward(quality);
163                currentNode = currentNode.parent;
164            } while (currentNode != null);
165        }
166
167        private void GetTreeInfosRek(TreeInfos treeInfos, TreeNode[] children)
168        {
169            treeInfos.TotalNodes += children.Length;
170            foreach (TreeNode child in children)
171            {
172                if (child.children != null)
173                {
174                    treeInfos.ExpandedNodes++;
175                    if (treeInfos.DeepestLevel <= child.level)
176                    {
177                        treeInfos.DeepestLevel = child.level + 1;
178                    }
179                    GetTreeInfosRek(treeInfos, child.children);
180                }
181                else
182                {
183                    if (grammar.IsTerminal(child.phrase))
184                    {
185                        treeInfos.LeaveNodes++;
186                    }
187                    else
188                    {
189                        treeInfos.UnexpandedNodes++;
190                    }
191                }
192            }
193        }
194
195        public TreeInfos GetTreeInfos()
196        {
197            TreeInfos treeInfos = new TreeInfos();
198
199            if (rootNode != null)
200            {
201                treeInfos.TotalNodes++;
202                if (rootNode.children != null)
203                {
204                    treeInfos.ExpandedNodes++;
205                    treeInfos.DeepestLevel = rootNode.level + 1;
206                    GetTreeInfosRek(treeInfos, rootNode.children);
207                }
208                else
209                {
210                    treeInfos.DeepestLevel = rootNode.level;
211                    if (grammar.IsTerminal(rootNode.phrase))
212                    {
213                        treeInfos.LeaveNodes++;
214                    }
215                    else
216                    {
217                        treeInfos.UnexpandedNodes++;
218                    }
219                }
220            }
221            return treeInfos;
222        }
223
224        public byte[] GenerateSvg()
225        {
226            if (GetTreeInfos().TotalNodes < 1000)
227            {
228                IGetStartProcessQuery getStartProcessQuery = new GetStartProcessQuery();
229                IGetProcessStartInfoQuery getProcessStartInfoQuery = new GetProcessStartInfoQuery();
230                IRegisterLayoutPluginCommand registerLayoutPluginCommand =
231                    new RegisterLayoutPluginCommand(getProcessStartInfoQuery, getStartProcessQuery);
232
233                GraphGeneration wrapper = new GraphGeneration(getStartProcessQuery,
234                    getProcessStartInfoQuery,
235                    registerLayoutPluginCommand);
236                wrapper.GraphvizPath = @"../../../Graphviz2.38/bin";
237                StringBuilder dotFile = new StringBuilder("digraph {");
238                dotFile.AppendLine();
239                dotFile.AppendLine("splines=ortho;");
240                dotFile.AppendLine("concentrate=true;");
241                dotFile.AppendLine("ranksep=1.2;");
242
243                List<TreeNode> toDoNodes = new List<TreeNode>();
244                if (rootNode != null)
245                {
246                    toDoNodes.Add(rootNode);
247                    // declare node
248                    string hexColor = GetHexNodeColor(Color.White, Color.OrangeRed, rootNode.actionInfo.Value);
249
250                    dotFile.AppendLine(
251                        string.Format("{0} [label=\"{1}\\n{2:0.00}/{3}\", style=filled, fillcolor=\"{4}\"]",
252                            rootNode.GetHashCode(),
253                            rootNode.phrase, rootNode.actionInfo.Value, rootNode.actionInfo.Tries, hexColor));
254                }
255
256                // to put nodes on the same level in graph
257                Dictionary<int, List<TreeNode>> levelMap = new Dictionary<int, List<TreeNode>>();
258
259                List<TreeNode> sameLevelNodes;
260
261                while (toDoNodes.Any())
262                {
263                    TreeNode currentNode = toDoNodes[0];
264                    toDoNodes.RemoveAt(0);
265                    // put currentNode into levelMap
266                    if (levelMap.TryGetValue(currentNode.level, out sameLevelNodes))
267                    {
268                        sameLevelNodes.Add(currentNode);
269                    }
270                    else
271                    {
272                        sameLevelNodes = new List<TreeNode>();
273                        sameLevelNodes.Add(currentNode);
274                        levelMap.Add(currentNode.level, sameLevelNodes);
275                    }
276
277                    // draw line from current node to all its children
278                    if (currentNode.children != null)
279                    {
280                        foreach (TreeNode childNode in currentNode.children)
281                        {
282                            toDoNodes.Add(childNode);
283                            // declare node
284
285                            string hexColor = GetHexNodeColor(Color.White, Color.OrangeRed, childNode.actionInfo.Value);
286                            dotFile.AppendLine(
287                                string.Format("{0} [label=\"{1}\\n{2:0.00}/{3}\", style=filled, fillcolor=\"{4}\"]",
288                                    childNode.GetHashCode(),
289                                    childNode.phrase, childNode.actionInfo.Value, childNode.actionInfo.Tries, hexColor));
290                            // add edge
291                            dotFile.AppendLine(string.Format("{0} -> {1}", currentNode.GetHashCode(),
292                                childNode.GetHashCode()));
293                        }
294                    }
295                }
296
297                // set same level ranks..
298                foreach (KeyValuePair<int, List<TreeNode>> entry in levelMap)
299                {
300                    dotFile.Append("{rank = same;");
301                    foreach (TreeNode node in entry.Value)
302                    {
303                        dotFile.Append(string.Format(" {0};", node.GetHashCode()));
304                    }
305                    dotFile.AppendLine("}");
306                }
307
308                dotFile.Append(" }");
309                byte[] output = wrapper.GenerateGraph(dotFile.ToString(), Enums.GraphReturnType.Svg);
310                return output;
311            }
312            return null;
313        }
314
315        protected String HexConverter(Color c)
316        {
317            return "#" + c.R.ToString("X2") + c.G.ToString("X2") + c.B.ToString("X2");
318        }
319
320        protected String GetHexNodeColor(Color weakColor, Color strongColor, double quality)
321        {
322            // convert quality to value between 0 and 1
323            double bestKnownQuality = problem.BestKnownQuality(this.maxLen);
324            double q = quality / bestKnownQuality;
325
326            // calculate difference between colors
327            byte rDiff = (byte)Math.Abs(weakColor.R - strongColor.R);
328            byte bDiff = (byte)Math.Abs(weakColor.B - strongColor.B);
329            byte gDiff = (byte)Math.Abs(weakColor.G - strongColor.G);
330
331            byte newR = weakColor.R > strongColor.R
332                ? Convert.ToByte(weakColor.R - Math.Round(rDiff * q))
333                : Convert.ToByte(weakColor.R + Math.Round(rDiff * q));
334
335            byte newB = weakColor.B > strongColor.B
336                ? Convert.ToByte(weakColor.B - Math.Round(bDiff * q))
337                : Convert.ToByte(weakColor.B + Math.Round(bDiff * q));
338
339            byte newG = weakColor.G > strongColor.G
340                ? Convert.ToByte(weakColor.G - Math.Round(gDiff * q))
341                : Convert.ToByte(weakColor.G + Math.Round(gDiff * q));
342
343            return HexConverter(Color.FromArgb(newR, newG, newB));
344        }
345
346        public void FreeAll()
347        {
348            rootNode = null;
349            GC.Collect();
350        }
351    }
352}
Note: See TracBrowser for help on using the repository browser.