Changeset 12098


Ignore:
Timestamp:
02/27/15 21:52:10 (3 years ago)
Author:
aballeit
Message:

#2283: implemented MCTS

Location:
branches/HeuristicLab.Problems.GrammaticalOptimization
Files:
2 added
3 deleted
5 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.MonteCarloTreeSearch/Base/TreeNode.cs

    r12050 r12098  
    55using System.Threading.Tasks;
    66using HeuristicLab.Algorithms.Bandits;
     7using HeuristicLab.Algorithms.Bandits.BanditPolicies;
    78using HeuristicLab.Problems.GrammaticalOptimization;
    89
    9 namespace HeuristicLab.Algorithms.MonteCarloTreeSearch
     10namespace HeuristicLab.Algorithms.MonteCarloTreeSearch.Base
    1011{
    1112    public class TreeNode
     
    1516        public List<TreeNode> children;
    1617        public IBanditPolicyActionInfo actionInfo;
    17         public bool expandable;
    18         public List<int> unvisitedNonTerminals;
    1918
    20         public TreeNode(TreeNode parent, string phrase, bool expandable, List<int> unvisitedNonTerminals)
     19        public TreeNode(TreeNode parent, string phrase)
    2120        {
     21            this.parent = parent;
    2222            this.phrase = phrase;
    23             this.expandable = expandable;
    24             this.unvisitedNonTerminals = unvisitedNonTerminals;
     23            actionInfo = new DefaultPolicyActionInfo();
     24        }
     25        public bool IsLeaf()
     26        {
     27            return children == null || !children.Any();
     28        }
     29
     30        internal IEnumerable<IBanditPolicyActionInfo> GetChildActionInfos()
     31        {
     32            return children.Select(n => n.actionInfo);
    2533        }
    2634    }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.MonteCarloTreeSearch/HeuristicLab.Algorithms.MonteCarloTreeSearch.csproj

    r12050 r12098  
    4040  </ItemGroup>
    4141  <ItemGroup>
     42    <Compile Include="Simulation\ISimulation.cs" />
    4243    <Compile Include="MonteCarloTreeSearch.cs" />
    4344    <Compile Include="Properties\AssemblyInfo.cs" />
    44     <Compile Include="TreeNode.cs" />
     45    <Compile Include="Base\TreeNode.cs" />
     46    <Compile Include="Simulation\RandomSimulation.cs" />
    4547  </ItemGroup>
    4648  <ItemGroup>
     
    5254      <Project>{eea07488-1a51-412a-a52c-53b754a628b3}</Project>
    5355      <Name>HeuristicLab.Algorithms.GrammaticalOptimization</Name>
     56    </ProjectReference>
     57    <ProjectReference Include="..\HeuristicLab.Common\HeuristicLab.Common.csproj">
     58      <Project>{3a2fbbcb-f9df-4970-87f3-f13337d941ad}</Project>
     59      <Name>HeuristicLab.Common</Name>
    5460    </ProjectReference>
    5561    <ProjectReference Include="..\HeuristicLab.Problems.GrammaticalOptimization\HeuristicLab.Problems.GrammaticalOptimization.csproj">
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.MonteCarloTreeSearch/MonteCarloTreeSearch.cs

    r12050 r12098  
    22using System.Collections.Generic;
    33using System.Linq;
    4 using System.Resources;
    5 using System.Text;
    6 using System.Threading.Tasks;
    74using HeuristicLab.Algorithms.Bandits;
    8 using HeuristicLab.Algorithms.Bandits.GrammarPolicies;
    9 using HeuristicLab.Algorithms.MonteCarloTreeSearch;
    10 using HeuristicLab.Algorithms.MonteCarloTreeSearch.Expansion;
     5using HeuristicLab.Algorithms.GrammaticalOptimization;
     6using HeuristicLab.Algorithms.MonteCarloTreeSearch.Base;
     7using HeuristicLab.Algorithms.MonteCarloTreeSearch.Simulation;
     8using HeuristicLab.Common;
    119using HeuristicLab.Problems.GrammaticalOptimization;
    1210
    13 namespace HeuristicLab.Algorithms.GrammaticalOptimization.Solvers
     11namespace HeuristicLab.Algorithms.MonteCarloTreeSearch
    1412{
    1513    public class MonteCarloTreeSearch : SolverBase
     
    1715        private readonly int maxLen;
    1816        private readonly IProblem problem;
     17        private readonly IGrammar grammar;
    1918        private readonly Random random;
    2019        private readonly IBanditPolicy behaviourPolicy;
    21         private readonly IExpansionPolicy expansionPolicy;
    22         private readonly ISimulationPolicy simulationPolicy;
     20        private readonly ISimulation simulation;
    2321        private TreeNode rootNode;
    24         private List<IBanditPolicyActionInfo> actions;
    25         private List<TreeNode> nodes;
    2622
    27         public MonteCarloTreeSearch(IProblem problem, int maxLen, Random random, IBanditPolicy behaviourPolicy,
    28             IExpansionPolicy expansionPolicy, ISimulationPolicy simulationPolicy)
     23        public MonteCarloTreeSearch(IProblem problem, int maxLen, Random random, IBanditPolicy behaviourPolicy, ISimulation simulationPolicy)
    2924        {
    3025            this.problem = problem;
     26            this.grammar = problem.Grammar;
    3127            this.maxLen = maxLen;
    3228            this.random = random;
    3329            this.behaviourPolicy = behaviourPolicy;
    34             this.expansionPolicy = expansionPolicy;
    35             this.simulationPolicy = simulationPolicy;
     30            this.simulation = simulationPolicy;
    3631        }
    3732
     
    4540        {
    4641            Reset();
    47             for (int i = 0; !StopRequested && !Done() && i < maxIterations; i++)
     42            for (int i = 0; !StopRequested && i < maxIterations; i++)
    4843            {
    49                 // select by behaviour policy
    50                 TreeNode currentNode;
    51                 do
     44                TreeNode currentNode = rootNode;
     45
     46                while (!currentNode.IsLeaf())
    5247                {
    53                     int currentActionIndex = behaviourPolicy.SelectAction(random, actions);
    54                     currentNode = nodes[currentActionIndex];
    55                 } while (!Expandable(currentNode));
     48                    int currentActionIndex = behaviourPolicy.SelectAction(random,
     49                        currentNode.GetChildActionInfos());
     50                    currentNode = currentNode.children[currentActionIndex];
     51                }
    5652
    57                 // expand tree
    58                 currentNode = expansionPolicy.ExpandTreeNode(currentNode);
    59                 // simulate
    60                 double reward = simulationPolicy.Simulate(currentNode);
    61                 // propagate/reward
    62                 Propagate(currentNode, reward);
     53                string phrase = currentNode.phrase;
     54
     55                if (!grammar.IsTerminal(phrase))
     56                {
     57                    ExpandTreeNode(currentNode);
     58
     59                    currentNode =
     60                        currentNode.children[behaviourPolicy.SelectAction(random, currentNode.GetChildActionInfos())];
     61                }
     62                double quality = simulation.Simulate(currentNode);
     63                OnSolutionEvaluated(phrase, quality);
     64
     65                Propagate(currentNode, quality);
     66            }
     67        }
     68
     69        private void ExpandTreeNode(TreeNode treeNode)
     70        {
     71            // create children on the first visit
     72            if (treeNode.children == null)
     73            {
     74                treeNode.children = new List<TreeNode>();
     75
     76                var phrase = new Sequence(treeNode.phrase);
     77                // create subnodes for each nt-symbol in phrase
     78                for (int i = 0; i < phrase.Length; i++)
     79                {
     80                    char symbol = phrase[i];
     81                    if (grammar.IsNonTerminal(symbol))
     82                    {
     83                        // create subnode for each alternative of symbol
     84                        foreach (Sequence alternative in grammar.GetAlternatives(symbol))
     85                        {
     86                            Sequence newSequence = new Sequence(phrase);
     87                            newSequence.ReplaceAt(i, 1, alternative);
     88                            if (newSequence.Length <= maxLen)
     89                            {
     90                                TreeNode childNode = new TreeNode(treeNode, newSequence.ToString());
     91                                treeNode.children.Add(childNode);
     92                            }
     93                        }
     94                    }
     95                }
    6396            }
    6497        }
     
    68101            StopRequested = false;
    69102            bestQuality = 0.0;
    70             rootNode = new TreeNode(null, problem.Grammar.SentenceSymbol.ToString(), true, new List<int>() { 0 });
     103            rootNode = new TreeNode(null, grammar.SentenceSymbol.ToString());
    71104        }
    72105
    73         private bool Done()
    74         {
    75             return !rootNode.expandable;
    76         }
    77 
    78         private bool Expandable(TreeNode node)
    79         {
    80             return !problem.Grammar.IsTerminal(node.phrase);
    81         }
    82 
    83         private void Propagate(TreeNode node, double reward)
     106        private void Propagate(TreeNode node, double quality)
    84107        {
    85108            var currentNode = node;
    86109            do
    87110            {
    88                 currentNode.actionInfo.UpdateReward(reward);
    89                 currentNode = node.parent;
     111                currentNode.actionInfo.UpdateReward(quality);
     112                currentNode = currentNode.parent;
    90113            } while (currentNode != null);
     114        }
     115
     116        public void PrintStats()
     117        {
     118            //Console.WriteLine("depth: {0,5} tries: {1,5} best phrase {2,50} bestQ {3:F3}", maxSearchDepth, tries, bestPhrase, bestQuality);
     119
     120            //// use behaviour strategy to generate the currently prefered sentence
     121            //var policy = behaviourPolicy;
     122
     123            //var n = rootNode;
     124
     125            //while (n != null)
     126            //{
     127            //    var phrase = n.phrase;
     128            //    Console.ForegroundColor = ConsoleColor.White;
     129            //    Console.WriteLine("{0,-30}", phrase);
     130            //    var children = n.children;
     131            //    if (children == null || !children.Any()) break;
     132            //    var values = children.Select(ch => policy.GetValue(ch.phrase));
     133            //    var maxValue = values.Max();
     134            //    if (maxValue == 0) maxValue = 1.0;
     135
     136            //    // write phrases
     137            //    foreach (var ch in children)
     138            //    {
     139            //        SetColorForValue(policy.GetValue(ch.phrase) / maxValue);
     140            //        Console.Write(" {0,-4}", ch.phrase.Substring(Math.Max(0, ch.phrase.Length - 3), Math.Min(3, ch.phrase.Length)));
     141            //    }
     142            //    Console.WriteLine();
     143
     144            //    // write values
     145            //    foreach (var ch in children)
     146            //    {
     147            //        SetColorForValue(policy.GetValue(ch.phrase) / maxValue);
     148            //        Console.Write(" {0:F2}", policy.GetValue(ch.phrase) * 10.0);
     149            //    }
     150            //    Console.WriteLine();
     151
     152            //    // write tries
     153            //    foreach (var ch in children)
     154            //    {
     155            //        SetColorForValue(policy.GetValue(ch.phrase) / maxValue);
     156            //        Console.Write(" {0,4}", policy.GetTries(ch.phrase));
     157            //    }
     158            //    Console.WriteLine();
     159            //    int selectedChildIdx;
     160            //    if (!policy.TrySelect(random, phrase, children.Select(ch => ch.phrase), out selectedChildIdx))
     161            //    {
     162            //        break;
     163            //    }
     164            //    n = n.children[selectedChildIdx];
     165            //}
     166
     167            //Console.ForegroundColor = ConsoleColor.White;
     168            //Console.WriteLine("-------------------");
     169        }
     170
     171        private void SetColorForValue(double v)
     172        {
     173            Console.ForegroundColor = ConsoleEx.ColorForValue(v);
    91174        }
    92175    }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/Main/Main.csproj

    r11981 r12098  
    5252      <Name>HeuristicLab.Algorithms.GrammaticalOptimization</Name>
    5353    </ProjectReference>
     54    <ProjectReference Include="..\HeuristicLab.Algorithms.MonteCarloTreeSearch\HeuristicLab.Algorithms.MonteCarloTreeSearch.csproj">
     55      <Project>{2c115235-8fa9-4f7f-b3a0-a0144f8a35ca}</Project>
     56      <Name>HeuristicLab.Algorithms.MonteCarloTreeSearch</Name>
     57    </ProjectReference>
    5458    <ProjectReference Include="..\HeuristicLab.Problems.GrammaticalOptimization\HeuristicLab.Problems.GrammaticalOptimization.csproj">
    5559      <Project>{cb9dccf6-667e-4a13-b82d-dbd6b45a045e}</Project>
  • branches/HeuristicLab.Problems.GrammaticalOptimization/Main/Program.cs

    r12050 r12098  
    44using HeuristicLab.Algorithms.Bandits.BanditPolicies;
    55using HeuristicLab.Algorithms.GrammaticalOptimization;
     6using HeuristicLab.Algorithms.MonteCarloTreeSearch;
     7using HeuristicLab.Algorithms.MonteCarloTreeSearch.Simulation;
    68using HeuristicLab.Problems.GrammaticalOptimization;
    79
     
    1820
    1921
    20 namespace Main {
    21   class Program {
    22     static void Main(string[] args) {
    23       CultureInfo.DefaultThreadCurrentCulture = CultureInfo.InvariantCulture;
     22namespace Main
     23{
     24    class Program
     25    {
     26        static void Main(string[] args)
     27        {
     28            CultureInfo.DefaultThreadCurrentCulture = CultureInfo.InvariantCulture;
    2429
    25       RunDemo();
    26     }
     30            RunDemo();
     31        }
    2732
    2833
    29     private static void RunDemo() {
     34        private static void RunDemo()
     35        {
    3036
    3137
    32       int maxIterations = 100000;
    33       int iterations = 0;
     38            int maxIterations = 100000;
     39            int iterations = 0;
    3440
    35       var globalStatistics = new SentenceSetStatistics();
    36       var random = new Random();
     41            var globalStatistics = new SentenceSetStatistics();
     42            var random = new Random();
    3743
    38       //var problem = new SymbolicRegressionPoly10Problem();
    39       //var problem = new SantaFeAntProblem();             
    40       var problem = new RoyalPairProblem();
    41       //var problem = new EvenParityProblem();
    42       var alg = new SequentialSearch(problem, 23, random, 0,
    43        new HeuristicLab.Algorithms.Bandits.GrammarPolicies.GenericGrammarPolicy(problem, new UCB1TunedPolicy()));
     44            //var problem = new SymbolicRegressionPoly10Problem();
     45            //var problem = new SantaFeAntProblem();             
     46            var problem = new RoyalPairProblem();
     47            //var problem = new EvenParityProblem();
     48            //var alg = new SequentialSearch(problem, 23, random, 0,
     49            // new HeuristicLab.Algorithms.Bandits.GrammarPolicies.GenericGrammarPolicy(problem, new UCB1TunedPolicy()));
     50            var alg = new MonteCarloTreeSearch(problem, 23, random, new UCB1Policy(), new RandomSimulation(problem, random, 23));
    4451
    4552
    46       alg.FoundNewBestSolution += (sentence, quality) => {
    47         //Console.WriteLine("{0}", globalStatistics);
    48       };
     53            alg.FoundNewBestSolution += (sentence, quality) =>
     54            {
     55                //Console.WriteLine("{0}", globalStatistics);
     56            };
    4957
    50       alg.SolutionEvaluated += (sentence, quality) => {
    51         iterations++;
    52         globalStatistics.AddSentence(sentence, quality);
     58            alg.SolutionEvaluated += (sentence, quality) =>
     59            {
     60                iterations++;
     61                globalStatistics.AddSentence(sentence, quality);
    5362
    54         // comment this if you don't want to see solver statistics
    55         if (iterations % 100 == 0) {
    56           if (iterations % 10000 == 0) Console.Clear();
    57           Console.SetCursorPosition(0, 0);
    58           alg.PrintStats();
     63                // comment this if you don't want to see solver statistics
     64                if (iterations % 100 == 0)
     65                {
     66                    if (iterations % 10000 == 0) Console.Clear();
     67                    Console.SetCursorPosition(0, 0);
     68                    alg.PrintStats();
     69                }
     70
     71                // uncomment this if you want to collect statistics of the generated sentences
     72                // if (iterations % 1000 == 0) {
     73                //   Console.WriteLine("{0}", globalStatistics);
     74                // }
     75            };
     76
     77            var sw = new Stopwatch();
     78            sw.Start();
     79            alg.Run(maxIterations);
     80            sw.Stop();
     81
     82            Console.Clear();
     83            alg.PrintStats();
     84            Console.WriteLine(globalStatistics);
     85            Console.WriteLine("{0:F2} sec {1,10:F1} sols/sec {2,10:F1} ns/sol",
     86              sw.Elapsed.TotalSeconds,
     87              maxIterations / (double)sw.Elapsed.TotalSeconds,
     88              (double)sw.ElapsedMilliseconds * 1000 / maxIterations);
    5989        }
    60 
    61         // uncomment this if you want to collect statistics of the generated sentences
    62         // if (iterations % 1000 == 0) {
    63         //   Console.WriteLine("{0}", globalStatistics);
    64         // }
    65       };
    66 
    67       var sw = new Stopwatch();
    68       sw.Start();
    69       alg.Run(maxIterations);
    70       sw.Stop();
    71 
    72       Console.Clear();
    73       alg.PrintStats();
    74       Console.WriteLine(globalStatistics);
    75       Console.WriteLine("{0:F2} sec {1,10:F1} sols/sec {2,10:F1} ns/sol",
    76         sw.Elapsed.TotalSeconds,
    77         maxIterations / (double)sw.Elapsed.TotalSeconds,
    78         (double)sw.ElapsedMilliseconds * 1000 / maxIterations);
    7990    }
    80   }
    8191}
Note: See TracChangeset for help on using the changeset viewer.