Changeset 10427
- Timestamp:
- 01/29/14 19:30:53 (11 years ago)
- Location:
- branches/HeuristicLab.Problems.GPDL/CodeGenerator
- Files:
-
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/HeuristicLab.Problems.GPDL/CodeGenerator/MonteCarloTreeSearchCodeGen.cs
r10426 r10427 34 34 private readonly ?IDENT?Problem problem; 35 35 private readonly Random random; 36 private readonly ?IDENT?RandomSearchSolver randomSearch; 37 36 38 private SearchTreeNode searchTree = new SearchTreeNode(); 37 39 … … 44 46 } 45 47 48 private const int RANDOM_TRIES = 1; 49 46 50 private void SampleTree(SearchTreeNode searchTree, Stack<Tuple<Tree, int, int, int>> extensionPoints) { 47 const int RANDOM_TRIES = 1000;48 if(extensionPoints.Count == 0) {49 searchTree.done = true;50 return; // nothing to do51 }52 51 var extensionPoint = extensionPoints.Pop(); 53 52 Tree parent = extensionPoint.Item1; … … 56 55 int maxDepth = extensionPoint.Item4; 57 56 Debug.Assert(maxDepth >= 1); 57 Debug.Assert(Grammar.minDepth[state] <= maxDepth); 58 58 Tree t = null; 59 59 if(searchTree.tries < RANDOM_TRIES || Grammar.subtreeCount[state] == 0) { 60 t = SampleTreeRandom(state); 60 int steps = 0; int curDepth = this.maxDepth - maxDepth; int depth = this.maxDepth - maxDepth; 61 t = randomSearch.SampleTree(state, maxDepth, ref steps, ref curDepth, ref depth); 61 62 if(Grammar.subtreeCount[state] == 0) { 62 63 // when we produced a terminal continue filling up all other empty points 63 64 Debug.Assert(searchTree.children == null || searchTree.children.Length == 1); 64 if(searchTree.children == null) 65 searchTree.children = new SearchTreeNode[] { new SearchTreeNode() } ; 66 SampleTree(searchTree.children[0], extensionPoints); 67 if(searchTree.children[0].done) searchTree.done = true; 65 if(extensionPoints.Count == 0) { 66 searchTree.done = true; 67 } else { 68 if(searchTree.children == null) 69 searchTree.children = new SearchTreeNode[] { new SearchTreeNode() } ; 70 SampleTree(searchTree.children[0], extensionPoints); 71 if(searchTree.children[0].done) searchTree.done = true; 72 } 68 73 } else { 69 74 // fill up all remaining slots randomly … … 72 77 var pState = p.Item2; 73 78 var pIdx = p.Item3; 74 pParent.subtrees[pIdx] = SampleTreeRandom(pState); 75 } 76 } 77 } else { 78 if(Grammar.subtreeCount[state] == 1) { 79 if(searchTree.children == null) { 80 int nChildren = Grammar.transition[state].Length; 81 searchTree.children = new SearchTreeNode[nChildren]; 82 } 83 Debug.Assert(searchTree.children.Length == Grammar.transition[state].Length); 84 Debug.Assert(searchTree.tries - RANDOM_TRIES == searchTree.children.Where(c=>c!=null).Sum(c=>c.tries)); 85 var altIdx = SelectAlternative(searchTree); 86 t = new Tree(altIdx, new Tree[1]); 87 extensionPoints.Push(Tuple.Create(t, Grammar.transition[state][altIdx], 0, maxDepth - 1)); 88 SampleTree(searchTree.children[altIdx], extensionPoints); 89 } else { 90 // multiple subtrees 91 var subtrees = new Tree[Grammar.subtreeCount[state]]; 92 t = new Tree(-1, subtrees); 93 for(int i = subtrees.Length - 1; i >= 0; i--) { 94 extensionPoints.Push(Tuple.Create(t, Grammar.transition[state][i], i, maxDepth - 1)); 95 } 96 SampleTree(searchTree, extensionPoints); 97 } 79 var pMaxDepth = p.Item4; 80 curDepth = this.maxDepth - pMaxDepth; 81 depth = curDepth; 82 pParent.subtrees[pIdx] = randomSearch.SampleTree(pState, pMaxDepth, ref steps, ref curDepth, ref depth); 83 } 84 } 85 } else if(Grammar.subtreeCount[state] == 1) { 86 if(searchTree.children == null) { 87 int nChildren = Grammar.transition[state].Length; 88 searchTree.children = new SearchTreeNode[nChildren]; 89 } 90 Debug.Assert(searchTree.children.Length == Grammar.transition[state].Length); 91 Debug.Assert(searchTree.tries - RANDOM_TRIES == searchTree.children.Where(c=>c!=null).Sum(c=>c.tries)); 92 var altIdx = SelectAlternative(searchTree, state, maxDepth); 93 t = new Tree(altIdx, new Tree[1]); 94 extensionPoints.Push(Tuple.Create(t, Grammar.transition[state][altIdx], 0, maxDepth - 1)); 95 SampleTree(searchTree.children[altIdx], extensionPoints); 96 searchTree.done = (from idx in Enumerable.Range(0, searchTree.children.Length) 97 where Grammar.minDepth[Grammar.transition[state][idx]] <= maxDepth - 1 98 select searchTree.children[idx]).All(c=>c != null && c.done); 99 } else { 100 // multiple subtrees 101 var subtrees = new Tree[Grammar.subtreeCount[state]]; 102 t = new Tree(-1, subtrees); 103 for(int i = subtrees.Length - 1; i >= 0; i--) { 104 extensionPoints.Push(Tuple.Create(t, Grammar.transition[state][i], i, maxDepth - 1)); 105 } 106 SampleTree(searchTree, extensionPoints); 98 107 } 99 108 Debug.Assert(parent.subtrees[subtreeIdx] == null); … … 101 110 } 102 111 103 private int SelectAlternative(SearchTreeNode searchTree ) {112 private int SelectAlternative(SearchTreeNode searchTree, int state, int maxDepth) { 104 113 // any alternative not yet explored? 105 var altIdx = Array.FindIndex(searchTree.children, (e) => e == null); 114 var altIndexes = searchTree.children 115 .Select((e,i) => new {Elem = e, Idx = i}) 116 .Where(p => p.Elem == null && Grammar.minDepth[Grammar.transition[state][p.Idx]] <= maxDepth) 117 .Select(p => p.Idx); 118 int altIdx = altIndexes.Any()?altIndexes.First() : -1; 106 119 if(altIdx >= 0) { 107 120 searchTree.children[altIdx] = new SearchTreeNode(); 108 121 return altIdx; 109 122 } else { 110 altIdx = Array.FindIndex(searchTree.children, (e) => !e.done && e.tries < 1000); 123 altIndexes = searchTree.children 124 .Select((e,i) => new {Elem = e, Idx = i}) 125 .Where(p => p.Elem != null && !p.Elem.done && p.Elem.tries < RANDOM_TRIES && Grammar.minDepth[Grammar.transition[state][p.Idx]] <= maxDepth) 126 .Select(p => p.Idx); 127 altIdx = altIndexes.Any()?altIndexes.First() : -1; 111 128 if(altIdx >= 0) return altIdx; 112 129 // select the least sampled alternative 113 //altIdx = 0; 114 //int minSamples = searchTree.children[altIdx].tries; 115 //for(int idx = 1; idx < searchTree.children.Length; idx++) { 116 // if(!searchTree.children[idx].done && searchTree.children[idx].tries < minSamples) { 130 //altIdx = -1; 131 //int minSamples = int.MaxValue; 132 //for(int idx = 0; idx < searchTree.children.Length; idx++) { 133 // if(searchTree.children[idx] == null) continue; 134 // if(!searchTree.children[idx].done && Grammar.minDepth[Grammar.transition[state][idx]] <= maxDepth && searchTree.children[idx].tries < minSamples) { 117 135 // minSamples = searchTree.children[idx].tries; 118 136 // altIdx = idx; … … 120 138 //} 121 139 // select the alternative with the largest average 122 altIdx = 0; 123 double bestAverage = UCB(searchTree, searchTree.children[altIdx]); 124 for(int idx = 1; idx < searchTree.children.Length; idx++) { 125 if (!searchTree.children[idx].done && UCB(searchTree, searchTree.children[idx]) > UCB(searchTree, searchTree.children[altIdx])) { 140 altIdx = -1; 141 double best = double.NegativeInfinity; 142 for(int idx = 0; idx < searchTree.children.Length; idx++) { 143 if(searchTree.children[idx] == null) continue; 144 if (!searchTree.children[idx].done && Grammar.minDepth[Grammar.transition[state][idx]] <= maxDepth && UCB(searchTree, searchTree.children[idx]) > best) { 126 145 altIdx = idx; 127 }128 }129 130 searchTree.done = searchTree.children.All(c=>c.done);146 best = UCB(searchTree, searchTree.children[idx]); 147 } 148 } 149 Debug.Assert(altIdx > -1); 131 150 return altIdx; 132 151 } … … 136 155 Debug.Assert(parent.tries >= n.tries); 137 156 Debug.Assert(n.tries > 0); 138 return n.sumQuality / n.tries + Math.Sqrt(( 10 * Math.Log(parent.tries)) / n.tries ); // constant is dependent fitness function values157 return n.sumQuality / n.tries + Math.Sqrt((40 * Math.Log(parent.tries)) / n.tries ); // constant is dependent fitness function values 139 158 } 140 159 … … 174 193 } 175 194 176 // same as in random search solver (could reuse random search)177 private Tree SampleTreeRandom(int state) {178 return SampleTreeRandom(state, 5);179 }180 private Tree SampleTreeRandom(int state, int maxDepth) {181 Tree t = null;182 183 // terminals184 if(Grammar.subtreeCount[state] == 0) {185 t = CreateTerminalNode(state, random, problem);186 } else {187 // if the symbol has alternatives then we must choose one randomly (only one sub-tree in this case)188 if(Grammar.subtreeCount[state] == 1) {189 var targetStates = Grammar.transition[state];190 var altIdx = SampleAlternative(random, state, maxDepth);191 var alternative = SampleTreeRandom(targetStates[altIdx], maxDepth - 1);192 t = new Tree(altIdx, new Tree[] { alternative });193 } else {194 // if the symbol contains only one sequence we must use create sub-trees for each symbol in the sequence195 Tree[] subtrees = new Tree[Grammar.subtreeCount[state]];196 for(int i = 0; i < Grammar.subtreeCount[state]; i++) {197 subtrees[i] = SampleTreeRandom(Grammar.transition[state][i], maxDepth - 1);198 }199 t = new Tree(-1, subtrees); // alternative index is ignored200 }201 }202 return t;203 }204 205 private static Tree CreateTerminalNode(int state, Random random, ?IDENT?Problem problem) {206 switch(state) {207 ?CREATETERMINALNODECODE?208 default: { throw new ArgumentException(""Unknown state index"" + state); }209 }210 }211 212 private int SampleAlternative(Random random, int state, int maxDepth) {213 switch(state) {214 215 ?SAMPLEALTERNATIVECODE?216 217 default: throw new InvalidOperationException();218 }219 }220 221 195 public ?IDENT?MonteCarloTreeSearchSolver(?IDENT?Problem problem, string[] args) { 222 if(args.Length > 0 ) throw new ArgumentException(""Arguments ""+args.Aggregate("""", (str, s) => str+ "" "" + s)+"" are not supported ""); 196 this.randomSearch = new ?IDENT?RandomSearchSolver(problem, args); 197 if(args.Length > 0 ) ParseArguments(args); 223 198 this.problem = problem; 224 199 this.random = new Random(); 225 200 } 201 private void ParseArguments(string[] args) { 202 var maxDepthRegex = new Regex(@""--maxDepth=(?<d>.+)""); 203 204 var helpRegex = new Regex(@""--help|/\?""); 205 206 foreach(var arg in args) { 207 var maxDepthMatch = maxDepthRegex.Match(arg); 208 var helpMatch = helpRegex.Match(arg); 209 if(helpMatch.Success) { 210 PrintUsage(); Environment.Exit(0); 211 } else if(maxDepthMatch.Success) { 212 maxDepth = int.Parse(maxDepthMatch.Groups[""d""].Captures[0].Value, System.Globalization.CultureInfo.InvariantCulture); 213 if(maxDepth < 1 || maxDepth > 100) throw new ArgumentException(""max depth must lie in range [1 ... 100]""); 214 } else { 215 Console.WriteLine(""Unknown switch {0}"", arg); PrintUsage(); Environment.Exit(0); 216 } 217 } 218 } 219 private void PrintUsage() { 220 Console.WriteLine(""Find a solution using Monte-Carlo tree search.""); 221 Console.WriteLine(); 222 Console.WriteLine(""Parameters:""); 223 Console.WriteLine(""\t--maxDepth=<depth>\tSets the maximal depth of sampled trees [Default: 20]""); 224 } 225 226 226 227 227 228 public void Start() { -
branches/HeuristicLab.Problems.GPDL/CodeGenerator/ProblemCodeGen.cs
r10426 r10427 27 27 public static void Main(string[] args) { 28 28 var problem = new ?IDENT?Problem(); 29 var solver = new ?IDENT? RandomSearchSolver(problem, args);29 var solver = new ?IDENT?MonteCarloTreeSearchSolver(problem, args); 30 30 solver.Start(); 31 31 } -
branches/HeuristicLab.Problems.GPDL/CodeGenerator/RandomSearchCodeGen.cs
r10426 r10427 19 19 private readonly Random random; 20 20 21 p rivateTree SampleTree(int maxDepth, out int steps, out int depth) {21 public Tree SampleTree(int maxDepth, out int steps, out int depth) { 22 22 steps = 0; 23 23 depth = 0; … … 26 26 } 27 27 28 p rivateTree SampleTree(int state, int maxDepth, ref int steps, ref int curDepth, ref int depth) {28 public Tree SampleTree(int state, int maxDepth, ref int steps, ref int curDepth, ref int depth) { 29 29 curDepth += 1; 30 30 Debug.Assert(maxDepth > 0);
Note: See TracChangeset
for help on using the changeset viewer.