Changeset 11747 for branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization
- Timestamp:
- 01/12/15 21:23:01 (10 years ago)
- Location:
- branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization
- Files:
-
- 4 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization.csproj
r11744 r11747 45 45 <Compile Include="AlternativesSampler.cs" /> 46 46 <Compile Include="AlternativesContextSampler.cs" /> 47 <Compile Include="MctsQLearningSampler.cs" /> 47 48 <Compile Include="TemporalDifferenceTreeSearchSampler.cs" /> 48 49 <Compile Include="ExhaustiveRandomFirstSearch.cs" /> -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/MctsContextualSampler.cs
r11745 r11747 15 15 public int randomTries; 16 16 public int tries; 17 public List<TreeNode> parents; 17 18 public TreeNode[] children; 18 19 public bool done = false; … … 21 22 this.ident = id; 22 23 this.alt = alt; 24 this.parents = new List<TreeNode>(); 23 25 } 24 26 … … 28 30 } 29 31 32 private Dictionary<string, TreeNode> treeNodes; 33 private TreeNode GetTreeNode(string id, ReadonlySequence alt) { 34 TreeNode n; 35 var canonicalId = problem.CanonicalRepresentation(id); 36 if (!treeNodes.TryGetValue(canonicalId, out n)) { 37 n = new TreeNode(canonicalId, alt); 38 tries.TryGetValue(canonicalId, out n.tries); 39 treeNodes[canonicalId] = n; 40 } 41 return n; 42 } 30 43 31 44 public event Action<string, double> FoundNewBestSolution; … … 51 64 this.v = new Dictionary<string, double>(1000000); 52 65 this.tries = new Dictionary<string, int>(1000000); 66 treeNodes = new Dictionary<string, TreeNode>(); 53 67 } 54 68 … … 57 71 InitPolicies(problem.Grammar); 58 72 for (int i = 0; !rootNode.done && i < maxIterations; i++) { 59 var sentence = SampleSentence(problem.Grammar).ToString(); 60 var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen); 61 Debug.Assert(quality >= 0 && quality <= 1.0); 62 DistributeReward(quality); 63 64 RaiseSolutionEvaluated(sentence, quality); 65 66 if (quality > bestQuality) { 67 bestQuality = quality; 68 RaiseFoundNewBestSolution(sentence, quality); 73 bool success; 74 var sentence = SampleSentence(problem.Grammar, out success).ToString(); 75 if (success) { 76 var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen); 77 Debug.Assert(quality >= 0 && quality <= 1.0); 78 DistributeReward(quality); 79 80 RaiseSolutionEvaluated(sentence, quality); 81 82 if (quality > bestQuality) { 83 bestQuality = quality; 84 RaiseFoundNewBestSolution(sentence, quality); 85 } 69 86 } 70 87 } … … 78 95 Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}, rootQ {3:F3}, bestQ {4:F3}", treeDepth, treeSize, n.tries, V(n), bestQuality); 79 96 while (n.children != null) { 80 Console.WriteLine("{0 }", n.ident);81 double maxVForRow = n.children.Select(ch => V(ch)).Max();97 Console.WriteLine("{0,-30}", n.ident); 98 double maxVForRow = n.children.Select(ch => Math.Min(1.0, Math.Max(0.0, V(ch)))).Max(); 82 99 if (maxVForRow == 0) maxVForRow = 1.0; 83 100 84 101 for (int i = 0; i < n.children.Length; i++) { 85 102 var ch = n.children[i]; 86 Console.ForegroundColor = ConsoleEx.ColorForValue( V(ch) / maxVForRow);103 Console.ForegroundColor = ConsoleEx.ColorForValue(Math.Min(1.0, V(ch)) / maxVForRow); 87 104 Console.Write("{0,5}", ch.alt); 88 105 } … … 90 107 for (int i = 0; i < n.children.Length; i++) { 91 108 var ch = n.children[i]; 92 Console.ForegroundColor = ConsoleEx.ColorForValue( V(ch) / maxVForRow);93 Console.Write("{0,5:F2}", V(ch) * 10);109 Console.ForegroundColor = ConsoleEx.ColorForValue(Math.Min(1.0, V(ch)) / maxVForRow); 110 Console.Write("{0,5:F2}", Math.Min(1.0, V(ch)) * 10); 94 111 } 95 112 Console.WriteLine(); 96 113 for (int i = 0; i < n.children.Length; i++) { 97 114 var ch = n.children[i]; 98 Console.ForegroundColor = ConsoleEx.ColorForValue( V(ch) / maxVForRow);115 Console.ForegroundColor = ConsoleEx.ColorForValue(Math.Min(1.0, V(ch)) / maxVForRow); 99 116 Console.Write("{0,5}", ch.done ? "X" : ch.tries.ToString()); 100 117 } … … 102 119 Console.WriteLine(); 103 120 //n.policy.PrintStats(); 104 n = n.children.Where(ch => !ch.done).OrderByDescending(c => V(c)).First();121 n = n.children.Where(ch => !ch.done).OrderByDescending(c => c.tries).First(); 105 122 } 106 123 } … … 112 129 this.tries.Clear(); 113 130 114 rootNode = newTreeNode(grammar.SentenceSymbol.ToString(), new ReadonlySequence("$"));131 rootNode = GetTreeNode(grammar.SentenceSymbol.ToString(), new ReadonlySequence("$")); 115 132 treeDepth = 0; 116 133 treeSize = 0; 117 134 } 118 135 119 private Sequence SampleSentence(IGrammar grammar ) {136 private Sequence SampleSentence(IGrammar grammar, out bool success) { 120 137 updateChain.Clear(); 121 138 //var startPhrase = new Sequence("a*b+c*d+e*f+E"); 122 139 var startPhrase = new Sequence(grammar.SentenceSymbol); 123 return CompleteSentence(grammar, startPhrase );124 } 125 126 private Sequence CompleteSentence(IGrammar g, Sequence phrase ) {140 return CompleteSentence(grammar, startPhrase, out success); 141 } 142 143 private Sequence CompleteSentence(IGrammar g, Sequence phrase, out bool success) { 127 144 if (phrase.Length > maxLen) throw new ArgumentException(); 128 145 if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException(); … … 136 153 n.randomTries++; 137 154 treeDepth = Math.Max(treeDepth, curDepth); 155 success = true; 138 156 return g.CompleteSentenceRandomly(random, phrase, maxLen); 139 157 } else { … … 153 171 newPhrase.ReplaceAt(newPhrase.FirstNonTerminalIndex, 1, alt); 154 172 if (!newPhrase.IsTerminal) newPhrase = newPhrase.Subsequence(0, newPhrase.FirstNonTerminalIndex + 1); 155 n.children[i++] = new TreeNode(newPhrase.ToString(), new ReadonlySequence(alt)); 173 var treeNode = GetTreeNode(newPhrase.ToString(), new ReadonlySequence(alt)); 174 treeNode.parents.Add(n); 175 n.children[i++] = treeNode; 156 176 } 157 177 treeSize += n.children.Length; 178 UpdateDone(n); 179 180 // it could happend that we already finished all variations starting from the branch 181 // stop 182 if (n.done) { 183 success = false; 184 return phrase; 185 } 158 186 } 187 //int selectedAltIdx = SelectRandom(random, n.children); 188 159 189 // => select using eps-greedy 160 190 int selectedAltIdx = SelectEpsGreedy(random, n.children); … … 167 197 168 198 curDepth++; 199 169 200 170 201 // prepare for next iteration 171 202 parent = n; 172 203 n = n.children[selectedAltIdx]; 204 //UpdateTD(parent, n, 0.0); 173 205 } 174 206 } // while … … 181 213 182 214 treeDepth = Math.Max(treeDepth, curDepth); 215 success = true; 183 216 return phrase; 184 217 } 185 218 219 220 //private void UpdateTD(TreeNode parent, TreeNode child, double reward) { 221 // double alpha = 1.0; 222 // var vParent = V(parent); 223 // var vChild = V(child); 224 // if (double.IsInfinity(vParent)) vParent = 0.0; 225 // if (double.IsInfinity(vChild)) vChild = 0.0; 226 // UpdateV(parent, (alpha * (reward + vChild - vParent))); 227 //} 228 186 229 private void DistributeReward(double reward) { 230 187 231 // iterate in reverse order (bottom up) 188 updateChain.Reverse(); 189 232 //updateChain.Reverse(); 233 UpdateDone(updateChain.Last().Item1); 234 //UpdateTD(updateChain.Last().Item2, updateChain.Last().Item1, reward); 235 //return; 236 237 BackPropReward(updateChain.Last().Item1, reward); 238 /* 190 239 foreach (var e in updateChain) { 191 240 var node = e.Item1; 192 var parent = e.Item2;241 //var parent = e.Item2; 193 242 node.tries++; 194 if (node.children != null && node.children.All(c => c.done)) {195 node.done = true;196 }243 //if (node.children != null && node.children.All(c => c.done)) { 244 // node.done = true; 245 //} 197 246 UpdateV(node, reward); 198 247 199 248 // the reward for the parent is either the just recieved reward or the value of the best action so far 200 double value = 0.0;201 if (parent != null) {202 var doneChilds = parent.children.Where(ch => ch.done);203 if (doneChilds.Any()) value = doneChilds.Select(ch => V(ch)).Max();204 }249 //double value = 0.0; 250 //if (parent != null) { 251 // var doneChilds = parent.children.Where(ch => ch.done); 252 // if (doneChilds.Any()) value = doneChilds.Select(ch => V(ch)).Max(); 253 //} 205 254 //if (value > reward) reward = value; 206 } 207 } 255 }*/ 256 } 257 258 private void BackPropReward(TreeNode n, double reward) { 259 n.tries++; 260 UpdateV(n, reward); 261 foreach (var p in n.parents) BackPropReward(p, reward); 262 } 263 264 private void UpdateDone(TreeNode n) { 265 if (!n.done && n.children != null && n.children.All(c => c.done)) n.done = true; 266 if (n.done) foreach (var p in n.parents) UpdateDone(p); 267 } 268 208 269 209 270 private Dictionary<string, double> v; … … 219 280 tries.Add(canonicalStr, 1); 220 281 } else { 221 v[canonicalStr] = stateV + 0.005 * (reward - stateV);222 //v[canonicalStr] = stateV + (1.0 / tries[canonicalStr]) * (reward - stateV);282 //v[canonicalStr] = stateV + 0.005 * (reward - stateV); 283 v[canonicalStr] = stateV + (1.0 / tries[canonicalStr]) * (reward - stateV); 223 284 tries[canonicalStr]++; 224 285 } … … 229 290 //var canonicalStr = n.ident; 230 291 double stateV; 231 292 if (!tries.ContainsKey(canonicalStr)) return double.PositiveInfinity; 232 293 if (!v.TryGetValue(canonicalStr, out stateV)) { 233 294 return 0.0; … … 237 298 } 238 299 300 private int SelectRandom(Random random, TreeNode[] children) { 301 return children.Select((ch, i) => Tuple.Create(ch, i)).Where(p => !p.Item1.done).SelectRandom(random).Item2; 302 } 303 239 304 private int SelectEpsGreedy(Random random, TreeNode[] children) { 240 305 if (random.NextDouble() < 0.2) { 241 242 return children.Select((ch, i) => Tuple.Create(ch, i)).Where(p => !p.Item1.done).SelectRandom(random).Item2; 306 return SelectRandom(random, children); 243 307 } else { 244 308 var bestQ = double.NegativeInfinity; -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/MctsSampler.cs
r11745 r11747 5 5 using System.Text; 6 6 using HeuristicLab.Algorithms.Bandits; 7 using HeuristicLab.Common; 7 8 using HeuristicLab.Problems.GrammaticalOptimization; 8 9 … … 13 14 public int randomTries; 14 15 public IBanditPolicyActionInfo actionInfo; 16 public TreeNode parent; 15 17 public TreeNode[] children; 16 18 public bool done = false; 17 19 18 public TreeNode(string id ) {20 public TreeNode(string id, TreeNode parent) { 19 21 this.ident = id; 22 this.parent = parent; 20 23 } 21 24 … … 35 38 private readonly IBanditPolicy policy; 36 39 37 private List<TreeNode> updateChain;40 private TreeNode lastNode; // the bottom node in one episode 38 41 private TreeNode rootNode; 39 42 … … 75 78 Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}, rootQ {3:F3}, bestQ {4:F3}", treeDepth, treeSize, n.actionInfo.Tries, n.actionInfo.Value, bestQuality); 76 79 while (n.children != null) { 80 Console.WriteLine("{0,-30}", n.ident); 81 double maxVForRow = n.children.Select(ch => ch.actionInfo.Value).Max(); 82 if (maxVForRow == 0) maxVForRow = 1.0; 83 84 for (int i = 0; i < n.children.Length; i++) { 85 var ch = n.children[i]; 86 SetColorForChild(ch, maxVForRow); 87 Console.Write("{0,5}", ch.ident); 88 } 77 89 Console.WriteLine(); 78 Console.WriteLine("{0,5}->{1,-50}", n.ident, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.ident)))); 79 Console.WriteLine("{0,5} {1,-50}", string.Empty, string.Join(" ", n.children.Select(ch => string.Format("{0,4:F2}", ch.actionInfo.Value * 10)))); 80 Console.WriteLine("{0,5} {1,-50}", string.Empty, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.done ? "X" : ch.actionInfo.Tries.ToString())))); 90 for (int i = 0; i < n.children.Length; i++) { 91 var ch = n.children[i]; 92 SetColorForChild(ch, maxVForRow); 93 Console.Write("{0,5:F2}", ch.actionInfo.Value * 10); 94 } 95 Console.WriteLine(); 96 for (int i = 0; i < n.children.Length; i++) { 97 var ch = n.children[i]; 98 SetColorForChild(ch, maxVForRow); 99 Console.Write("{0,5}", ch.done ? "X" : ch.actionInfo.Tries.ToString()); 100 } 101 Console.ForegroundColor = ConsoleColor.White; 102 Console.WriteLine(); 81 103 //n.policy.PrintStats(); 82 n = n.children.Where(ch => !ch.done).OrderByDescending(c => c.actionInfo.Value).First(); 83 } 104 //n = n.children.Where(ch => !ch.done).OrderByDescending(c => c.actionInfo.Value).First(); 105 n = n.children.Where(ch=>!ch.done).OrderByDescending(c => c.actionInfo.Value).First(); 106 } 107 Console.WriteLine("-----------------------"); 108 } 109 110 private void SetColorForChild(TreeNode ch, double maxVForRow) { 111 //if (ch.done) Console.ForegroundColor = ConsoleColor.White; 112 //else 113 Console.ForegroundColor = ConsoleEx.ColorForValue(ch.actionInfo.Value / maxVForRow); 84 114 } 85 115 86 116 private void InitPolicies(IGrammar grammar) { 87 this.updateChain = new List<TreeNode>(); 88 89 rootNode = new TreeNode(grammar.SentenceSymbol.ToString() );117 118 119 rootNode = new TreeNode(grammar.SentenceSymbol.ToString(), null); 90 120 rootNode.actionInfo = policy.CreateActionInfo(); 91 121 treeDepth = 0; … … 94 124 95 125 private Sequence SampleSentence(IGrammar grammar) { 96 updateChain.Clear();126 lastNode = null; 97 127 var startPhrase = new Sequence(grammar.SentenceSymbol); 128 //var startPhrase = new Sequence("a*b+c*d+e*f+E"); 129 98 130 return CompleteSentence(grammar, startPhrase); 99 131 } … … 105 137 var curDepth = 0; 106 138 while (!phrase.IsTerminal) { 107 updateChain.Add(n);108 139 109 140 if (n.randomTries < randomTries) { 110 141 n.randomTries++; 111 142 treeDepth = Math.Max(treeDepth, curDepth); 143 lastNode = n; 112 144 return g.CompleteSentenceRandomly(random, phrase, maxLen); 113 145 } else { … … 120 152 121 153 if (n.randomTries == randomTries && n.children == null) { 122 n.children = alts.Select(alt => new TreeNode(alt.ToString() )).ToArray(); // create a new node for each alternative154 n.children = alts.Select(alt => new TreeNode(alt.ToString(), n)).ToArray(); // create a new node for each alternative 123 155 foreach (var ch in n.children) ch.actionInfo = policy.CreateActionInfo(); 124 156 treeSize += n.children.Length; … … 138 170 } // while 139 171 140 updateChain.Add(n);172 lastNode = n; 141 173 142 174 … … 150 182 private void DistributeReward(double reward) { 151 183 // iterate in reverse order (bottom up) 152 updateChain.Reverse(); 153 154 foreach (var e in updateChain) { 155 var node = e; 156 if (node.done) node.actionInfo.Disable(); 184 185 var node = lastNode; 186 while (node != null) { 187 if (node.done) node.actionInfo.Disable(reward); 157 188 if (node.children != null && node.children.All(c => c.done)) { 158 189 node.done = true; 159 node.actionInfo.Disable(); 190 var bestActionValue = node.children.Select(c => c.actionInfo.Value).Max(); 191 node.actionInfo.Disable(bestActionValue); 160 192 } 161 193 if (!node.done) { 162 194 node.actionInfo.UpdateReward(reward); 163 195 } 196 node = node.parent; 164 197 } 165 198 } -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/TemporalDifferenceTreeSearchSampler.cs
r11744 r11747 36 36 private readonly Random random; 37 37 private readonly int randomTries; 38 private readonly IBanditPolicy policy;39 38 40 39 private List<TreeNode> updateChain; … … 46 45 47 46 48 public TemporalDifferenceTreeSearchSampler(IProblem problem, int maxLen, Random random, int randomTries , IBanditPolicy policy) {47 public TemporalDifferenceTreeSearchSampler(IProblem problem, int maxLen, Random random, int randomTries) { 49 48 this.maxLen = maxLen; 50 49 this.problem = problem; 51 50 this.random = random; 52 51 this.randomTries = randomTries; 53 this.policy = policy;54 52 } 55 53 … … 78 76 Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}, rootQ {3:F3}, bestQ {4:F3}", treeDepth, treeSize, n.tries, n.q, bestQuality); 79 77 while (n.children != null) { 78 Console.WriteLine("{0,-30}", n.ident); 79 double maxVForRow = n.children.Select(ch => ch.q).Max(); 80 if (maxVForRow == 0) maxVForRow = 1.0; 81 82 for (int i = 0; i < n.children.Length; i++) { 83 var ch = n.children[i]; 84 Console.ForegroundColor = ConsoleEx.ColorForValue(ch.q / maxVForRow); 85 Console.Write("{0,5}", ch.ident); 86 } 80 87 Console.WriteLine(); 81 Console.WriteLine("{0,5}->{1,-50}", n.ident, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.ident)))); 82 Console.WriteLine("{0,5} {1,-50}", string.Empty, string.Join(" ", n.children.Select(ch => string.Format("{0,4:F2}", ch.q * 10)))); 83 Console.WriteLine("{0,5} {1,-50}", string.Empty, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.done ? "X" : ch.tries.ToString())))); 88 for (int i = 0; i < n.children.Length; i++) { 89 var ch = n.children[i]; 90 Console.ForegroundColor = ConsoleEx.ColorForValue(ch.q / maxVForRow); 91 Console.Write("{0,5:F2}", ch.q * 10); 92 } 93 Console.WriteLine(); 94 for (int i = 0; i < n.children.Length; i++) { 95 var ch = n.children[i]; 96 Console.ForegroundColor = ConsoleEx.ColorForValue(ch.q / maxVForRow); 97 Console.Write("{0,5}", ch.done ? "X" : ch.tries.ToString()); 98 } 99 Console.ForegroundColor = ConsoleColor.White; 100 Console.WriteLine(); 84 101 //n.policy.PrintStats(); 85 102 n = n.children.Where(ch => !ch.done).OrderByDescending(c => c.q).First(); 86 103 } 87 //Console.ReadLine();88 104 } 89 105 … … 127 143 } 128 144 // => select using bandit policy 129 int selectedAltIdx = Select Action(random, n.children);145 int selectedAltIdx = SelectEpsGreedy(random, n.children); 130 146 Sequence selectedAlt = alts.ElementAt(selectedAltIdx); 131 147 … … 152 168 153 169 // eps-greedy 154 private int Select Action(Random random, TreeNode[] children) {170 private int SelectEpsGreedy(Random random, TreeNode[] children) { 155 171 if (random.NextDouble() < 0.1) { 156 172 … … 158 174 } else { 159 175 var bestQ = double.NegativeInfinity; 160 var bestChildIdx = -1;176 var bestChildIdx = new List<int>(); 161 177 for (int i = 0; i < children.Length; i++) { 162 178 if (children[i].done) continue; 163 if (children[i].tries == 0) return i; 164 if (children[i].q > bestQ) { 165 bestQ = children[i].q; 166 bestChildIdx = i; 179 // if (children[i].tries == 0) return i; 180 var q = children[i].q; 181 if (q > bestQ) { 182 bestQ = q; 183 bestChildIdx.Clear(); 184 bestChildIdx.Add(i); 185 } else if (q == bestQ) { 186 bestChildIdx.Add(i); 167 187 } 168 188 } 169 Debug.Assert(bestChildIdx > -1);170 return bestChildIdx ;189 Debug.Assert(bestChildIdx.Any()); 190 return bestChildIdx.SelectRandom(random); 171 191 } 172 192 } 173 193 174 194 private void DistributeReward(double reward) { 175 const double alpha = 0.1;176 const double gamma = 1;177 // iterate in reverse order (bottom up)178 195 updateChain.Reverse(); 179 var nextQ = 0.0; 180 foreach (var e in updateChain) { 181 var node = e; 182 node.tries++; 196 foreach (var node in updateChain) { 183 197 if (node.children != null && node.children.All(c => c.done)) { 184 198 node.done = true; 185 199 } 186 // reward is recieved only for the last action 187 if (e == updateChain.First()) { 188 node.q = node.q + alpha * (reward + gamma * nextQ - node.q); 189 nextQ = node.q; 190 } else { 191 node.q = node.q + alpha * (0 + gamma * nextQ - node.q); 192 nextQ = node.q; 193 } 194 } 200 } 201 updateChain.Reverse(); 202 203 //const double alpha = 0.1; 204 const double gamma = 1; 205 double alpha; 206 foreach (var p in updateChain.Zip(updateChain.Skip(1), Tuple.Create)) { 207 var parent = p.Item1; 208 var child = p.Item2; 209 parent.tries++; 210 alpha = 1.0 / parent.tries; 211 //alpha = 0.01; 212 parent.q = parent.q + alpha * (0 + gamma * child.q - parent.q); 213 } 214 // reward is recieved only for the last action 215 var n = updateChain.Last(); 216 n.tries++; 217 alpha = 1.0 / n.tries; 218 //alpha = 0.1; 219 n.q = n.q + alpha * reward; 195 220 } 196 221
Note: See TracChangeset
for help on using the changeset viewer.