Changeset 11730 for branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/MctsSampler.cs
- Timestamp:
- 01/02/15 16:08:21 (10 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/MctsSampler.cs
r11727 r11730 10 10 public class MctsSampler { 11 11 private class TreeNode { 12 public string ident; 12 13 public int randomTries; 14 public int policyTries; 13 15 public IPolicy policy; 14 16 public TreeNode[] children; 15 17 public bool done = false; 16 18 19 public TreeNode(string id) { 20 this.ident = id; 21 } 22 17 23 public override string ToString() { 18 return string.Format("Node( random-tries: {0}, done: {1}, policy: {2})", randomTries, done, policy);24 return string.Format("Node({0} tries: {1}, done: {2}, policy: {3})", ident, randomTries + policyTries, done, policy); 19 25 } 20 26 } 27 21 28 22 29 public event Action<string, double> FoundNewBestSolution; … … 27 34 private readonly Random random; 28 35 private readonly int randomTries; 29 private readonly Func< int, IPolicy> policyFactory;36 private readonly Func<Random, int, IPolicy> policyFactory; 30 37 31 38 private List<Tuple<TreeNode, int>> updateChain; 32 39 private TreeNode rootNode; 33 40 41 public int treeDepth; 42 public int treeSize; 43 34 44 public MctsSampler(IProblem problem, int maxLen, Random random) : 35 this(problem, maxLen, random, 10, ( numActions) => new EpsGreedyPolicy(random, numActions, 0.1)) {45 this(problem, maxLen, random, 10, (rand, numActions) => new EpsGreedyPolicy(rand, numActions, 0.1)) { 36 46 37 47 } 38 48 39 public MctsSampler(IProblem problem, int maxLen, Random random, int randomTries, Func< int, IPolicy> policyFactory) {49 public MctsSampler(IProblem problem, int maxLen, Random random, int randomTries, Func<Random, int, IPolicy> policyFactory) { 40 50 this.maxLen = maxLen; 41 51 this.problem = problem; … … 47 57 public void Run(int maxIterations) { 48 58 double bestQuality = double.MinValue; 49 InitPolicies( );59 InitPolicies(problem.Grammar); 50 60 for (int i = 0; !rootNode.done && i < maxIterations; i++) { 51 var sentence = SampleSentence(problem.Grammar) ;61 var sentence = SampleSentence(problem.Grammar).ToString(); 52 62 var quality = problem.Evaluate(sentence) / problem.GetBestKnownQuality(maxLen); 53 63 Debug.Assert(quality >= 0 && quality <= 1.0); … … 61 71 } 62 72 } 73 74 // clean up 75 InitPolicies(problem.Grammar); GC.Collect(); 63 76 } 64 77 65 private void InitPolicies() { 66 this.updateChain = new List<Tuple<TreeNode, int>>(); 67 rootNode = new TreeNode(); 78 public void PrintStats() { 79 var n = rootNode; 80 Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}", treeDepth, treeSize, rootNode.policyTries + rootNode.randomTries); 81 while (n.policy != null) { 82 Console.WriteLine(); 83 Console.WriteLine("{0,5}->{1,-50}", n.ident, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.ident)))); 84 Console.WriteLine("{0,5} {1,-50}", string.Empty, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.randomTries + ch.policyTries)))); 85 //n.policy.PrintStats(); 86 n = n.children.OrderByDescending(c => c.policyTries).First(); 87 } 88 Console.ReadLine(); 68 89 } 69 90 70 private string SampleSentence(IGrammar grammar) { 71 updateChain.Clear(); 72 return CompleteSentence(grammar, grammar.SentenceSymbol.ToString()); 91 private void InitPolicies(IGrammar grammar) { 92 this.updateChain = new List<Tuple<TreeNode, int>>(); 93 94 rootNode = new TreeNode(grammar.SentenceSymbol.ToString()); 95 treeDepth = 0; 96 treeSize = 0; 73 97 } 74 98 75 public string CompleteSentence(IGrammar g, string phrase) { 99 private Sequence SampleSentence(IGrammar grammar) { 100 updateChain.Clear(); 101 var startPhrase = new Sequence(grammar.SentenceSymbol); 102 return CompleteSentence(grammar, startPhrase); 103 } 104 105 private Sequence CompleteSentence(IGrammar g, Sequence phrase) { 76 106 if (phrase.Length > maxLen) throw new ArgumentException(); 77 107 if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException(); 78 108 TreeNode n = rootNode; 79 bool done = phrase. All(g.IsTerminal); // terminal phrase means we are done109 bool done = phrase.IsTerminal; 80 110 int selectedAltIdx = -1; 111 var curDepth = 0; 81 112 while (!done) { 82 int ntIdx; char nt; 83 Grammar.FindFirstNonTerminal(g, phrase, out nt, out ntIdx); 113 char nt = phrase.FirstNonTerminal; 84 114 85 115 int maxLenOfReplacement = maxLen - (phrase.Length - 1); // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2 … … 90 120 if (n.randomTries < randomTries) { 91 121 n.randomTries++; 122 123 treeDepth = Math.Max(treeDepth, curDepth); 124 92 125 return g.CompleteSentenceRandomly(random, phrase, maxLen); 93 126 } else if (n.randomTries == randomTries && n.policy == null) { 94 n.policy = policyFactory(alts.Count()); 95 n.children = alts.Select(_ => new TreeNode()).ToArray(); // create a new node for each alternative 127 n.policy = policyFactory(random, alts.Count()); 128 //n.children = alts.Select(alt => new TreeNode(alt.ToString())).ToArray(); // create a new node for each alternative 129 n.children = alts.Select(alt => new TreeNode(string.Empty)).ToArray(); // create a new node for each alternative 130 131 treeSize += n.children.Length; 96 132 } 97 133 n.policyTries++; 98 134 // => select using bandit policy 99 135 selectedAltIdx = n.policy.SelectAction(); 100 string selectedAlt = alts.ElementAt(selectedAltIdx); 136 Sequence selectedAlt = alts.ElementAt(selectedAltIdx); 137 101 138 // replace nt with alt 102 phrase = phrase.Remove(ntIdx, 1); 103 phrase = phrase.Insert(ntIdx, selectedAlt); 139 phrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, selectedAlt); 104 140 105 141 updateChain.Add(Tuple.Create(n, selectedAltIdx)); 106 142 107 done = phrase.All(g.IsTerminal); // terminal phrase means we are done 143 curDepth++; 144 145 done = phrase.IsTerminal; 108 146 if (!done) { 109 147 // prepare for next iteration … … 116 154 n.children[selectedAltIdx].done = true; 117 155 156 treeDepth = Math.Max(treeDepth, curDepth); 118 157 return phrase; 119 158 } … … 127 166 var policy = node.policy; 128 167 var action = e.Item2; 168 //policy.UpdateReward(action, reward / updateChain.Count); 129 169 policy.UpdateReward(action, reward); 130 170
Note: See TracChangeset
for help on using the changeset viewer.