Changeset 11745
- Timestamp:
- 01/10/15 14:06:29 (10 years ago)
- Location:
- branches/HeuristicLab.Problems.GrammaticalOptimization
- Files:
-
- 1 added
- 7 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/UCB1Policy.cs
r11742 r11745 10 10 public class UCB1Policy : IBanditPolicy { 11 11 public int SelectAction(Random random, IEnumerable<IBanditPolicyActionInfo> actionInfos) { 12 var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>() .ToArray(); // TODO: performance12 var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>(); 13 13 int bestAction = -1; 14 14 double bestQ = double.NegativeInfinity; 15 15 int totalTries = myActionInfos.Where(a => !a.Disabled).Sum(a => a.Tries); 16 16 17 for (int a = 0; a < myActionInfos.Length; a++) { 18 if (myActionInfos[a].Disabled) continue; 19 if (myActionInfos[a].Tries == 0) return a; 20 var q = myActionInfos[a].SumReward / myActionInfos[a].Tries + Math.Sqrt((2 * Math.Log(totalTries)) / myActionInfos[a].Tries); 17 int aIdx = -1; 18 foreach (var aInfo in myActionInfos) { 19 aIdx++; 20 if (aInfo.Disabled) continue; 21 if (aInfo.Tries == 0) return aIdx; 22 var q = aInfo.SumReward / aInfo.Tries + Math.Sqrt((2 * Math.Log(totalTries)) / aInfo.Tries); 21 23 if (q > bestQ) { 22 24 bestQ = q; 23 bestAction = a ;25 bestAction = aIdx; 24 26 } 25 27 } -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/MctsContextualSampler.cs
r11742 r11745 5 5 using System.Text; 6 6 using HeuristicLab.Algorithms.Bandits; 7 using HeuristicLab.Common; 7 8 using HeuristicLab.Problems.GrammaticalOptimization; 8 9 … … 10 11 public class MctsContextualSampler { 11 12 private class TreeNode { 13 public string ident; 14 public ReadonlySequence alt; 12 15 public int randomTries; 13 public int policyTries;16 public int tries; 14 17 public TreeNode[] children; 15 public readonly ReadonlySequence phrase; 16 public readonly ReadonlySequence alt; 17 18 // phrase represents the phrase of the state and alt represents how the phrase has been reached from the parent state 19 public TreeNode(ReadonlySequence phrase, ReadonlySequence alt) { 20 this.phrase = phrase; 18 public bool done = false; 19 20 public TreeNode(string id, ReadonlySequence alt) { 21 this.ident = id; 21 22 this.alt = alt; 22 23 } 23 24 24 25 public override string ToString() { 25 return string.Format("Node({0} tries: {1} )", phrase, randomTries + policyTries);26 return string.Format("Node({0} tries: {1}, done: {2})", ident, tries, done); 26 27 } 27 28 } … … 35 36 private readonly Random random; 36 37 private readonly int randomTries; 37 private readonly IGrammarPolicy policy; 38 39 private List<Tuple<ReadonlySequence, ReadonlySequence, ReadonlySequence>> updateChain; 38 39 private List<Tuple<TreeNode, TreeNode>> updateChain; 40 40 private TreeNode rootNode; 41 41 42 42 public int treeDepth; 43 43 public int treeSize; 44 45 // public MctsSampler(IProblem problem, int maxLen, Random random) : 46 // this(problem, maxLen, random, 10, (rand, numActions) => new EpsGreedyPolicy(rand, numActions, 0.1)) { 47 // 48 // } 49 50 public MctsContextualSampler(IProblem problem, int maxLen, Random random, int randomTries, IGrammarPolicy policy) { 44 private double bestQuality; 45 46 public MctsContextualSampler(IProblem problem, int maxLen, Random random, int randomTries) { 51 47 this.maxLen = maxLen; 52 48 this.problem = problem; 53 49 this.random = random; 54 50 this.randomTries = randomTries; 55 this.policy = policy; 51 this.v = new Dictionary<string, double>(1000000); 52 this.tries = new Dictionary<string, int>(1000000); 56 53 } 57 54 58 55 public void Run(int maxIterations) { 59 doublebestQuality = double.MinValue;56 bestQuality = double.MinValue; 60 57 InitPolicies(problem.Grammar); 61 for (int i = 0; ! policy.Done(rootNode.phrase)&& i < maxIterations; i++) {58 for (int i = 0; !rootNode.done && i < maxIterations; i++) { 62 59 var sentence = SampleSentence(problem.Grammar).ToString(); 63 60 var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen); … … 79 76 public void PrintStats() { 80 77 var n = rootNode; 81 Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10} ", treeDepth, treeSize, rootNode.policyTries + rootNode.randomTries);78 Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}, rootQ {3:F3}, bestQ {4:F3}", treeDepth, treeSize, n.tries, V(n), bestQuality); 82 79 while (n.children != null) { 80 Console.WriteLine("{0}", n.ident); 81 double maxVForRow = n.children.Select(ch => V(ch)).Max(); 82 if (maxVForRow == 0) maxVForRow = 1.0; 83 84 for (int i = 0; i < n.children.Length; i++) { 85 var ch = n.children[i]; 86 Console.ForegroundColor = ConsoleEx.ColorForValue(V(ch) / maxVForRow); 87 Console.Write("{0,5}", ch.alt); 88 } 83 89 Console.WriteLine(); 84 Console.WriteLine("{0,5}->{1,-50}", n.alt, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.alt)))); 85 Console.WriteLine("{0,5} {1,-50}", string.Empty, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.randomTries + ch.policyTries)))); 90 for (int i = 0; i < n.children.Length; i++) { 91 var ch = n.children[i]; 92 Console.ForegroundColor = ConsoleEx.ColorForValue(V(ch) / maxVForRow); 93 Console.Write("{0,5:F2}", V(ch) * 10); 94 } 95 Console.WriteLine(); 96 for (int i = 0; i < n.children.Length; i++) { 97 var ch = n.children[i]; 98 Console.ForegroundColor = ConsoleEx.ColorForValue(V(ch) / maxVForRow); 99 Console.Write("{0,5}", ch.done ? "X" : ch.tries.ToString()); 100 } 101 Console.ForegroundColor = ConsoleColor.White; 102 Console.WriteLine(); 86 103 //n.policy.PrintStats(); 87 n = n.children. OrderByDescending(c => c.policyTries).First();88 } 89 Console.ReadLine();90 } 104 n = n.children.Where(ch => !ch.done).OrderByDescending(c => V(c)).First(); 105 } 106 } 107 91 108 92 109 private void InitPolicies(IGrammar grammar) { 93 this.updateChain = new List<Tuple<ReadonlySequence, ReadonlySequence, ReadonlySequence>>(); 94 95 rootNode = new TreeNode(new ReadonlySequence(grammar.SentenceSymbol), new ReadonlySequence("$")); 110 this.updateChain = new List<Tuple<TreeNode, TreeNode>>(); 111 this.v.Clear(); 112 this.tries.Clear(); 113 114 rootNode = new TreeNode(grammar.SentenceSymbol.ToString(), new ReadonlySequence("$")); 96 115 treeDepth = 0; 97 116 treeSize = 0; … … 100 119 private Sequence SampleSentence(IGrammar grammar) { 101 120 updateChain.Clear(); 102 var startPhrase = new Sequence(rootNode.phrase); 121 //var startPhrase = new Sequence("a*b+c*d+e*f+E"); 122 var startPhrase = new Sequence(grammar.SentenceSymbol); 103 123 return CompleteSentence(grammar, startPhrase); 104 124 } … … 109 129 TreeNode parent = null; 110 130 TreeNode n = rootNode; 111 bool done = false;112 131 var curDepth = 0; 113 while (!done) { 114 if (parent != null) 115 updateChain.Add(Tuple.Create(parent.phrase, n.alt, n.phrase)); 132 while (!phrase.IsTerminal) { 133 updateChain.Add(Tuple.Create(n, parent)); 116 134 117 135 if (n.randomTries < randomTries) { … … 128 146 129 147 if (n.randomTries == randomTries && n.children == null) { 148 // create a new node for each alternative 130 149 n.children = new TreeNode[alts.Count()]; 131 int cIdx= 0;150 var i = 0; 132 151 foreach (var alt in alts) { 133 152 var newPhrase = new Sequence(phrase); 134 newPhrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, alt); 135 n.children[cIdx++] = new TreeNode(new ReadonlySequence(newPhrase), new ReadonlySequence(alt)); 153 newPhrase.ReplaceAt(newPhrase.FirstNonTerminalIndex, 1, alt); 154 if (!newPhrase.IsTerminal) newPhrase = newPhrase.Subsequence(0, newPhrase.FirstNonTerminalIndex + 1); 155 n.children[i++] = new TreeNode(newPhrase.ToString(), new ReadonlySequence(alt)); 136 156 } 137 157 treeSize += n.children.Length; 138 158 } 139 140 n.policyTries++; 141 // => select using bandit policy 142 ReadonlySequence selectedAlt = policy.SelectAction(random, n.phrase, n.children.Select(c => c.alt)); 159 // => select using eps-greedy 160 int selectedAltIdx = SelectEpsGreedy(random, n.children); 161 162 //int selectedAltIdx = SelectActionUCB1(random, n.children); 163 Sequence selectedAlt = alts.ElementAt(selectedAltIdx); 143 164 144 165 // replace nt with alt … … 147 168 curDepth++; 148 169 149 done = phrase.IsTerminal;150 151 170 // prepare for next iteration 152 171 parent = n; 153 n = n.children .Single(ch => ch.alt == selectedAlt); // TODO: perf172 n = n.children[selectedAltIdx]; 154 173 } 155 174 } // while 156 175 157 n.policyTries++; 158 updateChain.Add(Tuple.Create(parent.phrase, n.alt, n.phrase)); 176 updateChain.Add(Tuple.Create(n, parent)); 177 178 // the last node is a leaf node (sentence is done), so we never need to visit this node again 179 n.done = true; 159 180 160 181 … … 168 189 169 190 foreach (var e in updateChain) { 170 var state = e.Item1; 171 var action = e.Item2; 172 var newState = e.Item3; 173 policy.UpdateReward(state, action, reward, newState); 174 //policy.UpdateReward(action, reward / updateChain.Count); 175 } 176 } 191 var node = e.Item1; 192 var parent = e.Item2; 193 node.tries++; 194 if (node.children != null && node.children.All(c => c.done)) { 195 node.done = true; 196 } 197 UpdateV(node, reward); 198 199 // the reward for the parent is either the just recieved reward or the value of the best action so far 200 double value = 0.0; 201 if (parent != null) { 202 var doneChilds = parent.children.Where(ch => ch.done); 203 if (doneChilds.Any()) value = doneChilds.Select(ch => V(ch)).Max(); 204 } 205 //if (value > reward) reward = value; 206 } 207 } 208 209 private Dictionary<string, double> v; 210 private Dictionary<string, int> tries; 211 212 private void UpdateV(TreeNode n, double reward) { 213 var canonicalStr = problem.CanonicalRepresentation(n.ident); 214 //var canonicalStr = n.ident; 215 double stateV; 216 217 if (!v.TryGetValue(canonicalStr, out stateV)) { 218 v.Add(canonicalStr, reward); 219 tries.Add(canonicalStr, 1); 220 } else { 221 v[canonicalStr] = stateV + 0.005 * (reward - stateV); 222 //v[canonicalStr] = stateV + (1.0 / tries[canonicalStr]) * (reward - stateV); 223 tries[canonicalStr]++; 224 } 225 } 226 227 private double V(TreeNode n) { 228 var canonicalStr = problem.CanonicalRepresentation(n.ident); 229 //var canonicalStr = n.ident; 230 double stateV; 231 232 if (!v.TryGetValue(canonicalStr, out stateV)) { 233 return 0.0; 234 } else { 235 return stateV; 236 } 237 } 238 239 private int SelectEpsGreedy(Random random, TreeNode[] children) { 240 if (random.NextDouble() < 0.2) { 241 242 return children.Select((ch, i) => Tuple.Create(ch, i)).Where(p => !p.Item1.done).SelectRandom(random).Item2; 243 } else { 244 var bestQ = double.NegativeInfinity; 245 var bestChildIdx = new List<int>(); 246 for (int i = 0; i < children.Length; i++) { 247 if (children[i].done) continue; 248 // if (children[i].tries == 0) return i; 249 var q = V(children[i]); 250 if (q > bestQ) { 251 bestQ = q; 252 bestChildIdx.Clear(); 253 bestChildIdx.Add(i); 254 } else if (q == bestQ) { 255 bestChildIdx.Add(i); 256 } 257 } 258 Debug.Assert(bestChildIdx.Any()); 259 return bestChildIdx.SelectRandom(random); 260 } 261 } 262 private int SelectActionUCB1(Random random, TreeNode[] children) { 263 int bestAction = -1; 264 double bestQ = double.NegativeInfinity; 265 int totalTries = children.Sum(ch => ch.tries); 266 267 for (int a = 0; a < children.Length; a++) { 268 var ch = children[a]; 269 if (ch.done) continue; 270 if (ch.tries == 0) return a; 271 var q = V(ch) + Math.Sqrt((2 * Math.Log(totalTries)) / ch.tries); 272 if (q > bestQ) { 273 bestQ = q; 274 bestAction = a; 275 } 276 } 277 Debug.Assert(bestAction > -1); 278 return bestAction; 279 } 280 281 177 282 178 283 private void RaiseSolutionEvaluated(string sentence, double quality) { … … 184 289 if (handler != null) handler(sentence, quality); 185 290 } 291 292 186 293 } 187 294 } -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/MctsSampler.cs
r11744 r11745 41 41 public int treeSize; 42 42 private double bestQuality; 43 44 // public MctsSampler(IProblem problem, int maxLen, Random random) :45 // this(problem, maxLen, random, 10, (rand, numActions) => new EpsGreedyPolicy(rand, numActions, 0.1)) {46 //47 // }48 43 49 44 public MctsSampler(IProblem problem, int maxLen, Random random, int randomTries, IBanditPolicy policy) { -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Common/HeuristicLab.Common.csproj
r11727 r11745 33 33 <Reference Include="System" /> 34 34 <Reference Include="System.Core" /> 35 <Reference Include="System.Drawing" /> 35 36 <Reference Include="System.Xml.Linq" /> 36 37 <Reference Include="System.Data.DataSetExtensions" /> … … 40 41 </ItemGroup> 41 42 <ItemGroup> 43 <Compile Include="ConsoleEx.cs" /> 42 44 <Compile Include="Extensions.cs" /> 43 45 <Compile Include="Properties\AssemblyInfo.cs" /> -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization.Test/TestBanditPolicies.cs
r11742 r11745 146 146 var randSeed = 31415; 147 147 var nArms = 20; 148 Console.WriteLine("Threshold Ascent (20)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(20, 0.01)); 149 Console.WriteLine("Threshold Ascent (100)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(100, 0.01)); 150 Console.WriteLine("Threshold Ascent (500)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(500, 0.01)); 151 Console.WriteLine("Threshold Ascent (1000)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(1000, 0.01)); 152 Console.WriteLine("Thompson (Gaussian orig)"); TestPolicyGaussianMixture(randSeed, nArms, new GaussianThompsonSamplingPolicy(true)); 153 Console.WriteLine("Thompson (Gaussian new)"); TestPolicyGaussianMixture(randSeed, nArms, new GaussianThompsonSamplingPolicy()); 154 Console.WriteLine("Generic Thompson (Gaussian fixed variance)"); TestPolicyGaussianMixture(randSeed, nArms, new GenericThompsonSamplingPolicy(new GaussianModel(0.5, 1, 0.1))); 155 Console.WriteLine("Generic Thompson (Gaussian unknown variance)"); TestPolicyGaussianMixture(randSeed, nArms, new GenericThompsonSamplingPolicy(new GaussianModel(0.5, 1, 1, 1))); 148 149 Console.WriteLine("Generic Thompson (Gaussian Mixture)"); TestPolicyGaussianMixture(randSeed, nArms, new GenericThompsonSamplingPolicy(new GaussianMixtureModel())); 150 // Console.WriteLine("Threshold Ascent (20)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(20, 0.01)); 151 // Console.WriteLine("Threshold Ascent (100)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(100, 0.01)); 152 // Console.WriteLine("Threshold Ascent (500)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(500, 0.01)); 153 // Console.WriteLine("Threshold Ascent (1000)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(1000, 0.01)); 154 // Console.WriteLine("Thompson (Gaussian orig)"); TestPolicyGaussianMixture(randSeed, nArms, new GaussianThompsonSamplingPolicy(true)); 155 // Console.WriteLine("Thompson (Gaussian new)"); TestPolicyGaussianMixture(randSeed, nArms, new GaussianThompsonSamplingPolicy()); 156 // Console.WriteLine("Generic Thompson (Gaussian fixed variance)"); TestPolicyGaussianMixture(randSeed, nArms, new GenericThompsonSamplingPolicy(new GaussianModel(0.5, 1, 0.1))); 157 // Console.WriteLine("Generic Thompson (Gaussian unknown variance)"); TestPolicyGaussianMixture(randSeed, nArms, new GenericThompsonSamplingPolicy(new GaussianModel(0.5, 1, 1, 1))); 156 158 157 159 /* -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization/SymbolicRegressionPoly10Problem.cs
r11742 r11745 72 72 73 73 // right now only + and * is supported 74 public string CanonicalRepresentation(string terminalPhrase) { 75 var terms = terminalPhrase.Split('+'); 76 return string.Join("+", terms.Select(term => string.Join("", term.Replace("*", "").OrderBy(ch => ch))) 77 .OrderBy(term => term)); 74 public string CanonicalRepresentation(string phrase) { 75 var terms = phrase.Split('+').Select(t => t.Replace("*", "")); 76 var terminalTerms = terms.Where(t => t.All(ch => grammar.IsTerminal(ch))); 77 var nonTerminalTerms = terms.Where(t => t.Any(ch => grammar.IsNonTerminal(ch))); 78 79 return string.Join("+", terminalTerms.Select(term => CanonicalTerm(term)).OrderBy(term => term).Concat(nonTerminalTerms.Select(term => CanonicalTerm(term)))); 80 } 81 82 private string CanonicalTerm(string term) { 83 return string.Join("", term.OrderByDescending(ch => (byte)ch)); 78 84 } 79 85 } -
branches/HeuristicLab.Problems.GrammaticalOptimization/Main/Program.cs
r11744 r11745 169 169 // good results e.g. with var alg = new MctsSampler(problem, 17, random, 1, (rand, numActions) => new ThresholdAscentPolicy(numActions, 500, 0.01)); 170 170 // GaussianModelWithUnknownVariance (and Q= 0.99-quantil) also works well for Ant 171 //var problem = new SantaFeAntProblem(); 171 //var problem = new SantaFeAntProblem(); 172 172 //var problem = new SymbolicRegressionProblem("Tower"); 173 173 //var problem = new PalindromeProblem(); … … 175 175 //var problem = new RoyalPairProblem(); 176 176 //var problem = new EvenParityProblem(); 177 var alg = new MctsSampler(problem, 25, random, 0, new GenericThompsonSamplingPolicy(new LogitNormalModel())); 178 //var alg = new TemporalDifferenceTreeSearchSampler(problem, 23, random, 0, new RandomPolicy()); 177 //var alg = new MctsSampler(problem, 23, random, 0, new GaussianThompsonSamplingPolicy(true)); 178 var alg = new MctsContextualSampler(problem, 23, random, 0); 179 //var alg = new TemporalDifferenceTreeSearchSampler(problem, 17, random, 10, new EpsGreedyPolicy(0.1)); 179 180 //var alg = new ExhaustiveBreadthFirstSearch(problem, 17); 180 181 //var alg = new AlternativesContextSampler(problem, random, 17, 4, (rand, numActions) => new RandomPolicy(rand, numActions)); … … 187 188 bestQuality = quality; 188 189 bestSentence = sentence; 189 Console.WriteLine("{0,4} {1,7} {2}", alg.treeDepth, alg.treeSize, globalStatistics); 190 //Console.WriteLine("{0,4} {1,7} {2}", alg.treeDepth, alg.treeSize, globalStatistics); 191 //Console.ReadLine(); 190 192 }; 191 193 alg.SolutionEvaluated += (sentence, quality) => { … … 193 195 globalStatistics.AddSentence(sentence, quality); 194 196 if (iterations % 100 == 0) { 195 Console.Clear(); 197 //if (iterations % 1000 == 0) Console.Clear(); 198 Console.SetCursorPosition(0, 0); 196 199 alg.PrintStats(); 197 200 } 201 198 202 if (iterations % 10000 == 0) { 199 203 //Console.WriteLine("{0,10} {1,10:F5} {2,10:F5} {3}", iterations, bestQuality, quality, sentence); 200 204 //Console.WriteLine("{0,4} {1,7} {2}", alg.treeDepth, alg.treeSize, globalStatistics); 201 Console.WriteLine("{0,4} {1,7} {2}", alg.treeDepth, alg.treeSize, globalStatistics);205 //Console.WriteLine("{0,4} {1,7} {2}", alg.treeDepth, alg.treeSize, globalStatistics); 202 206 } 203 207 };
Note: See TracChangeset
for help on using the changeset viewer.