Changeset 11747 for branches/HeuristicLab.Problems.GrammaticalOptimization
- Timestamp:
- 01/12/15 21:23:01 (10 years ago)
- Location:
- branches/HeuristicLab.Problems.GrammaticalOptimization
- Files:
-
- 3 added
- 21 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/BernoulliPolicyActionInfo.cs
r11742 r11747 9 9 namespace HeuristicLab.Algorithms.Bandits.BanditPolicies { 10 10 public class BernoulliPolicyActionInfo : IBanditPolicyActionInfo { 11 private double knownValue; 11 12 public bool Disabled { get { return NumSuccess == -1; } } 12 13 public int NumSuccess { get; private set; } 13 14 public int NumFailure { get; private set; } 14 15 public int Tries { get { return NumSuccess + NumFailure; } } 15 public double Value { get { return NumSuccess / (double)(Tries); } } 16 public double Value { 17 get { 18 if (Disabled) return knownValue; 19 else 20 return NumSuccess / (double)(Tries); 21 } 22 } 16 23 public void UpdateReward(double reward) { 17 24 Debug.Assert(!Disabled); … … 22 29 else NumFailure++; 23 30 } 24 public void Disable( ) {31 public void Disable(double reward) { 25 32 this.NumSuccess = -1; 26 33 this.NumFailure = -1; 34 this.knownValue = reward; 27 35 } 28 36 public void Reset() { 29 37 NumSuccess = 0; 30 38 NumFailure = 0; 39 knownValue = 0.0; 31 40 } 32 41 public void PrintStats() { -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/BoltzmannExplorationPolicy.cs
r11742 r11747 13 13 private readonly Func<DefaultPolicyActionInfo, double> valueFunction; 14 14 15 public BoltzmannExplorationPolicy(double eps) : this(eps, DefaultPolicyActionInfo.AverageReward) { }15 public BoltzmannExplorationPolicy(double beta) : this(beta, DefaultPolicyActionInfo.AverageReward) { } 16 16 17 17 public BoltzmannExplorationPolicy(double beta, Func<DefaultPolicyActionInfo, double> valueFunction) { … … 25 25 // select best 26 26 var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>(); 27 Debug.Assert(myActionInfos.Any(a => !a.Disabled)); 27 28 // try any of the untries actions randomly 29 // for RoyalSequence it is much better to select the actions in the order of occurrence (all terminal alternatives first) 30 //if (myActionInfos.Any(aInfo => !aInfo.Disabled && aInfo.Tries == 0)) { 31 // return myActionInfos 32 // .Select((aInfo, idx) => new { aInfo, idx }) 33 // .Where(p => !p.aInfo.Disabled) 34 // .Where(p => p.aInfo.Tries == 0) 35 // .SelectRandom(random).idx; 36 //} 28 37 29 38 var w = from aInfo in myActionInfos -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/DefaultPolicyActionInfo.cs
r11742 r11747 9 9 // stores information that is relevant for most of the policies 10 10 public class DefaultPolicyActionInfo : IBanditPolicyActionInfo { 11 private double knownValue; 11 12 public bool Disabled { get { return Tries == -1; } } 12 13 public double SumReward { get; private set; } 13 14 public int Tries { get; private set; } 14 15 public double MaxReward { get; private set; } 15 public double Value { get { return SumReward / Tries; } } 16 public double Value { 17 get { 18 if (Disabled) return knownValue; 19 else 20 return Tries > 0 ? SumReward / Tries : 0.0; 21 } 22 } 16 23 public DefaultPolicyActionInfo() { 17 24 MaxReward = double.MinValue; … … 25 32 MaxReward = Math.Max(MaxReward, reward); 26 33 } 27 public void Disable( ) {34 public void Disable(double reward) { 28 35 this.Tries = -1; 29 36 this.SumReward = 0.0; 37 this.knownValue = reward; 30 38 } 31 39 public void Reset() { … … 33 41 Tries = 0; 34 42 MaxReward = 0.0; 43 knownValue = 0.0; 35 44 } 36 45 public void PrintStats() { -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/MeanAndVariancePolicyActionInfo.cs
r11742 r11747 11 11 public bool Disabled { get { return disabled; } } 12 12 private OnlineMeanAndVarianceEstimator estimator = new OnlineMeanAndVarianceEstimator(); 13 private double knownValue; 13 14 public int Tries { get { return estimator.N; } } 14 15 public double SumReward { get { return estimator.Sum; } } 15 16 public double AvgReward { get { return estimator.Avg; } } 16 17 public double RewardVariance { get { return estimator.Variance; } } 17 public double Value { get { return AvgReward; } } 18 public double Value { 19 get { 20 if (disabled) return knownValue; 21 else 22 return AvgReward; 23 } 24 } 18 25 19 26 public void UpdateReward(double reward) { … … 22 29 } 23 30 24 public void Disable( ) {31 public void Disable(double reward) { 25 32 disabled = true; 33 this.knownValue = reward; 26 34 } 27 35 28 36 public void Reset() { 29 37 disabled = false; 38 knownValue = 0.0; 30 39 estimator.Reset(); 31 40 } -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/ModelPolicyActionInfo.cs
r11744 r11747 10 10 public class ModelPolicyActionInfo : IBanditPolicyActionInfo { 11 11 private readonly IModel model; 12 private double knownValue; 12 13 public bool Disabled { get { return Tries == -1; } } 13 public double Value { get { return model.SampleExpectedReward(new Random()); } } 14 public double Value { 15 get { 16 if (Disabled) return knownValue; 17 else 18 return model.SampleExpectedReward(new Random()); 19 } 20 } 14 21 15 22 public int Tries { get; private set; } … … 28 35 } 29 36 30 public void Disable( ) {37 public void Disable(double reward) { 31 38 this.Tries = -1; 39 this.knownValue = reward; 32 40 } 33 41 34 42 public void Reset() { 35 43 Tries = 0; 44 knownValue = 0.0; 36 45 model.Reset(); 37 46 } -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/ThresholdAscentPolicy.cs
r11744 r11747 28 28 public int Tries { get; private set; } 29 29 public int thresholdBin = 1; 30 public double Value { get { return rewardHistogram[thresholdBin] / (double)Tries; } } 30 private double knownValue; 31 32 public double Value { 33 get { 34 if (Disabled) return knownValue; 35 if(Tries == 0.0) return 0.0; 36 return rewardHistogram[thresholdBin] / (double)Tries; 37 } 38 } 31 39 32 40 public bool Disabled { get { return Tries == -1; } } … … 38 46 } 39 47 40 public void Disable() { 48 public void Disable(double reward) { 49 this.knownValue = reward; 41 50 Tries = -1; 42 51 } … … 45 54 Tries = 0; 46 55 thresholdBin = 1; 56 this.knownValue = 0.0; 47 57 Array.Clear(rewardHistogram, 0, rewardHistogram.Length); 48 58 } -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/UCB1Policy.cs
r11745 r11747 5 5 using System.Text; 6 6 using System.Threading.Tasks; 7 using HeuristicLab.Common; 7 8 8 9 namespace HeuristicLab.Algorithms.Bandits.BanditPolicies { … … 11 12 public int SelectAction(Random random, IEnumerable<IBanditPolicyActionInfo> actionInfos) { 12 13 var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>(); 13 int bestAction = -1;14 14 double bestQ = double.NegativeInfinity; 15 15 int totalTries = myActionInfos.Where(a => !a.Disabled).Sum(a => a.Tries); 16 16 17 var bestActions = new List<int>(); 17 18 int aIdx = -1; 18 19 foreach (var aInfo in myActionInfos) { 19 20 aIdx++; 20 21 if (aInfo.Disabled) continue; 21 if (aInfo.Tries == 0) return aIdx; 22 var q = aInfo.SumReward / aInfo.Tries + Math.Sqrt((2 * Math.Log(totalTries)) / aInfo.Tries); 22 double q; 23 if (aInfo.Tries == 0) { 24 q = double.PositiveInfinity; 25 } else { 26 27 q = aInfo.SumReward / aInfo.Tries + 0.5 * Math.Sqrt((2 * Math.Log(totalTries)) / aInfo.Tries); 28 } 23 29 if (q > bestQ) { 24 30 bestQ = q; 25 bestAction = aIdx; 31 bestActions.Clear(); 32 bestActions.Add(aIdx); 33 } else if (q == bestQ) { 34 bestActions.Add(aIdx); 26 35 } 27 36 } 28 Debug.Assert(bestAction > -1);29 return bestAction ;37 Debug.Assert(bestActions.Any()); 38 return bestActions.SelectRandom(random); 30 39 } 31 40 -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/UCTPolicy.cs
r11742 r11747 5 5 using System.Text; 6 6 using System.Threading.Tasks; 7 using HeuristicLab.Common; 8 7 9 namespace HeuristicLab.Algorithms.Bandits.BanditPolicies { 8 10 /* Kocsis et al. Bandit based Monte-Carlo Planning */ … … 22 24 23 25 int aIdx = -1; 26 var bestActions = new List<int>(); 24 27 foreach (var aInfo in myActionInfos) { 25 28 aIdx++; 26 29 if (aInfo.Disabled) continue; 27 if (aInfo.Tries == 0) return aIdx; 28 var q = aInfo.SumReward / aInfo.Tries + 2.0 * c * Math.Sqrt(Math.Log(totalTries) / aInfo.Tries); 30 double q; 31 if (aInfo.Tries == 0) { 32 q = double.PositiveInfinity; 33 } else { 34 q = aInfo.SumReward / aInfo.Tries + 2.0 * c * Math.Sqrt(Math.Log(totalTries) / aInfo.Tries); 35 } 29 36 if (q > bestQ) { 37 bestActions.Clear(); 30 38 bestQ = q; 31 bestAction = aIdx;39 bestActions.Add(aIdx); 32 40 } 41 if (q == bestQ) { 42 bestActions.Add(aIdx); 43 } 44 33 45 } 34 Debug.Assert(bestAction > -1);35 return bestAction ;46 Debug.Assert(bestActions.Any()); 47 return bestActions.SelectRandom(random); 36 48 } 37 49 -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/HeuristicLab.Algorithms.Bandits.csproj
r11744 r11747 48 48 <Compile Include="BanditPolicies\BoltzmannExplorationPolicy.cs" /> 49 49 <Compile Include="BanditPolicies\ChernoffIntervalEstimationPolicy.cs" /> 50 <Compile Include="BanditPolicies\ActiveLearningPolicy.cs" /> 50 51 <Compile Include="BanditPolicies\DefaultPolicyActionInfo.cs" /> 51 52 <Compile Include="BanditPolicies\EpsGreedyPolicy.cs" /> -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/IBanditPolicyActionInfo.cs
r11742 r11747 11 11 int Tries { get; } 12 12 void UpdateReward(double reward); 13 void Disable( );13 void Disable(double reward); 14 14 // reset causes the state of the action to be reinitialized (as after constructor-call) 15 15 void Reset(); -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Models/GaussianMixtureModel.cs
r11744 r11747 9 9 namespace HeuristicLab.Algorithms.Bandits.Models { 10 10 public class GaussianMixtureModel : IModel { 11 private readonly double[] componentMeans; 12 private readonly double[] componentVars; 13 private readonly double[] componentProbs; 11 private double[] componentMeans; 12 private double[] componentVars; 13 private double[] componentProbs; 14 private readonly List<double> allRewards = new List<double>(); 14 15 15 16 private int numComponents; … … 17 18 public GaussianMixtureModel(int nComponents = 5) { 18 19 this.numComponents = nComponents; 19 this.componentProbs = new double[nComponents]; 20 this.componentMeans = new double[nComponents]; 21 this.componentVars = new double[nComponents]; 20 21 Reset(); 22 22 } 23 23 … … 29 29 30 30 public void Update(double reward) { 31 // see http://www.cs.toronto.edu/~mackay/itprnn/ps/302.320.pdf Algorithm 22.2 soft k-means 32 throw new NotImplementedException(); 31 allRewards.Add(reward); 32 throw new NotSupportedException("this does not yet work"); 33 if (allRewards.Count < 1000 && allRewards.Count % 10 == 0) { 34 // see http://www.cs.toronto.edu/~mackay/itprnn/ps/302.320.pdf Algorithm 22.2 soft k-means 35 Reset(); 36 for (int i = 0; i < 20; i++) { 37 var responsibilities = allRewards.Select(r => CalcResponsibility(r)).ToArray(); 38 39 40 var sumWeightedRewards = new double[numComponents]; 41 var sumResponsibilities = new double[numComponents]; 42 foreach (var p in allRewards.Zip(responsibilities, Tuple.Create)) { 43 for (int k = 0; k < numComponents; k++) { 44 sumWeightedRewards[k] += p.Item2[k] * p.Item1; 45 sumResponsibilities[k] += p.Item2[k]; 46 } 47 } 48 for (int k = 0; k < numComponents; k++) { 49 componentMeans[k] = sumWeightedRewards[k] / sumResponsibilities[k]; 50 } 51 52 sumWeightedRewards = new double[numComponents]; 53 foreach (var p in allRewards.Zip(responsibilities, Tuple.Create)) { 54 for (int k = 0; k < numComponents; k++) { 55 sumWeightedRewards[k] += p.Item2[k] * Math.Pow(p.Item1 - componentMeans[k], 2); 56 } 57 } 58 for (int k = 0; k < numComponents; k++) { 59 componentVars[k] = sumWeightedRewards[k] / sumResponsibilities[k]; 60 componentProbs[k] = sumResponsibilities[k] / sumResponsibilities.Sum(); 61 } 62 } 63 } 64 } 65 66 private double[] CalcResponsibility(double r) { 67 var res = new double[numComponents]; 68 for (int k = 0; k < numComponents; k++) { 69 componentVars[k] = Math.Max(componentVars[k], 0.001); 70 res[k] = componentProbs[k] * alglib.normaldistribution((r - componentMeans[k]) / Math.Sqrt(componentVars[k])); 71 res[k] = Math.Max(res[k], 0.0001); 72 } 73 var sum = res.Sum(); 74 for (int k = 0; k < numComponents; k++) { 75 res[k] /= sum; 76 } 77 return res; 33 78 } 34 79 … … 44 89 45 90 public void Reset() { 46 Array.Clear(componentMeans, 0, numComponents); 47 Array.Clear(componentVars, 0, numComponents); 48 Array.Clear(componentProbs, 0, numComponents); 91 var rand = new Random(); 92 this.componentProbs = Enumerable.Range(0, numComponents).Select((_) => rand.NextDouble()).ToArray(); 93 var sum = componentProbs.Sum(); 94 for (int i = 0; i < componentProbs.Length; i++) componentProbs[i] /= sum; 95 this.componentMeans = Enumerable.Range(0, numComponents).Select((_) => Rand.RandNormal(rand)).ToArray(); 96 this.componentVars = Enumerable.Range(0, numComponents).Select((_) => 0.01).ToArray(); 49 97 } 50 98 -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization.csproj
r11744 r11747 45 45 <Compile Include="AlternativesSampler.cs" /> 46 46 <Compile Include="AlternativesContextSampler.cs" /> 47 <Compile Include="MctsQLearningSampler.cs" /> 47 48 <Compile Include="TemporalDifferenceTreeSearchSampler.cs" /> 48 49 <Compile Include="ExhaustiveRandomFirstSearch.cs" /> -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/MctsContextualSampler.cs
r11745 r11747 15 15 public int randomTries; 16 16 public int tries; 17 public List<TreeNode> parents; 17 18 public TreeNode[] children; 18 19 public bool done = false; … … 21 22 this.ident = id; 22 23 this.alt = alt; 24 this.parents = new List<TreeNode>(); 23 25 } 24 26 … … 28 30 } 29 31 32 private Dictionary<string, TreeNode> treeNodes; 33 private TreeNode GetTreeNode(string id, ReadonlySequence alt) { 34 TreeNode n; 35 var canonicalId = problem.CanonicalRepresentation(id); 36 if (!treeNodes.TryGetValue(canonicalId, out n)) { 37 n = new TreeNode(canonicalId, alt); 38 tries.TryGetValue(canonicalId, out n.tries); 39 treeNodes[canonicalId] = n; 40 } 41 return n; 42 } 30 43 31 44 public event Action<string, double> FoundNewBestSolution; … … 51 64 this.v = new Dictionary<string, double>(1000000); 52 65 this.tries = new Dictionary<string, int>(1000000); 66 treeNodes = new Dictionary<string, TreeNode>(); 53 67 } 54 68 … … 57 71 InitPolicies(problem.Grammar); 58 72 for (int i = 0; !rootNode.done && i < maxIterations; i++) { 59 var sentence = SampleSentence(problem.Grammar).ToString(); 60 var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen); 61 Debug.Assert(quality >= 0 && quality <= 1.0); 62 DistributeReward(quality); 63 64 RaiseSolutionEvaluated(sentence, quality); 65 66 if (quality > bestQuality) { 67 bestQuality = quality; 68 RaiseFoundNewBestSolution(sentence, quality); 73 bool success; 74 var sentence = SampleSentence(problem.Grammar, out success).ToString(); 75 if (success) { 76 var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen); 77 Debug.Assert(quality >= 0 && quality <= 1.0); 78 DistributeReward(quality); 79 80 RaiseSolutionEvaluated(sentence, quality); 81 82 if (quality > bestQuality) { 83 bestQuality = quality; 84 RaiseFoundNewBestSolution(sentence, quality); 85 } 69 86 } 70 87 } … … 78 95 Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}, rootQ {3:F3}, bestQ {4:F3}", treeDepth, treeSize, n.tries, V(n), bestQuality); 79 96 while (n.children != null) { 80 Console.WriteLine("{0 }", n.ident);81 double maxVForRow = n.children.Select(ch => V(ch)).Max();97 Console.WriteLine("{0,-30}", n.ident); 98 double maxVForRow = n.children.Select(ch => Math.Min(1.0, Math.Max(0.0, V(ch)))).Max(); 82 99 if (maxVForRow == 0) maxVForRow = 1.0; 83 100 84 101 for (int i = 0; i < n.children.Length; i++) { 85 102 var ch = n.children[i]; 86 Console.ForegroundColor = ConsoleEx.ColorForValue( V(ch) / maxVForRow);103 Console.ForegroundColor = ConsoleEx.ColorForValue(Math.Min(1.0, V(ch)) / maxVForRow); 87 104 Console.Write("{0,5}", ch.alt); 88 105 } … … 90 107 for (int i = 0; i < n.children.Length; i++) { 91 108 var ch = n.children[i]; 92 Console.ForegroundColor = ConsoleEx.ColorForValue( V(ch) / maxVForRow);93 Console.Write("{0,5:F2}", V(ch) * 10);109 Console.ForegroundColor = ConsoleEx.ColorForValue(Math.Min(1.0, V(ch)) / maxVForRow); 110 Console.Write("{0,5:F2}", Math.Min(1.0, V(ch)) * 10); 94 111 } 95 112 Console.WriteLine(); 96 113 for (int i = 0; i < n.children.Length; i++) { 97 114 var ch = n.children[i]; 98 Console.ForegroundColor = ConsoleEx.ColorForValue( V(ch) / maxVForRow);115 Console.ForegroundColor = ConsoleEx.ColorForValue(Math.Min(1.0, V(ch)) / maxVForRow); 99 116 Console.Write("{0,5}", ch.done ? "X" : ch.tries.ToString()); 100 117 } … … 102 119 Console.WriteLine(); 103 120 //n.policy.PrintStats(); 104 n = n.children.Where(ch => !ch.done).OrderByDescending(c => V(c)).First();121 n = n.children.Where(ch => !ch.done).OrderByDescending(c => c.tries).First(); 105 122 } 106 123 } … … 112 129 this.tries.Clear(); 113 130 114 rootNode = newTreeNode(grammar.SentenceSymbol.ToString(), new ReadonlySequence("$"));131 rootNode = GetTreeNode(grammar.SentenceSymbol.ToString(), new ReadonlySequence("$")); 115 132 treeDepth = 0; 116 133 treeSize = 0; 117 134 } 118 135 119 private Sequence SampleSentence(IGrammar grammar ) {136 private Sequence SampleSentence(IGrammar grammar, out bool success) { 120 137 updateChain.Clear(); 121 138 //var startPhrase = new Sequence("a*b+c*d+e*f+E"); 122 139 var startPhrase = new Sequence(grammar.SentenceSymbol); 123 return CompleteSentence(grammar, startPhrase );124 } 125 126 private Sequence CompleteSentence(IGrammar g, Sequence phrase ) {140 return CompleteSentence(grammar, startPhrase, out success); 141 } 142 143 private Sequence CompleteSentence(IGrammar g, Sequence phrase, out bool success) { 127 144 if (phrase.Length > maxLen) throw new ArgumentException(); 128 145 if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException(); … … 136 153 n.randomTries++; 137 154 treeDepth = Math.Max(treeDepth, curDepth); 155 success = true; 138 156 return g.CompleteSentenceRandomly(random, phrase, maxLen); 139 157 } else { … … 153 171 newPhrase.ReplaceAt(newPhrase.FirstNonTerminalIndex, 1, alt); 154 172 if (!newPhrase.IsTerminal) newPhrase = newPhrase.Subsequence(0, newPhrase.FirstNonTerminalIndex + 1); 155 n.children[i++] = new TreeNode(newPhrase.ToString(), new ReadonlySequence(alt)); 173 var treeNode = GetTreeNode(newPhrase.ToString(), new ReadonlySequence(alt)); 174 treeNode.parents.Add(n); 175 n.children[i++] = treeNode; 156 176 } 157 177 treeSize += n.children.Length; 178 UpdateDone(n); 179 180 // it could happend that we already finished all variations starting from the branch 181 // stop 182 if (n.done) { 183 success = false; 184 return phrase; 185 } 158 186 } 187 //int selectedAltIdx = SelectRandom(random, n.children); 188 159 189 // => select using eps-greedy 160 190 int selectedAltIdx = SelectEpsGreedy(random, n.children); … … 167 197 168 198 curDepth++; 199 169 200 170 201 // prepare for next iteration 171 202 parent = n; 172 203 n = n.children[selectedAltIdx]; 204 //UpdateTD(parent, n, 0.0); 173 205 } 174 206 } // while … … 181 213 182 214 treeDepth = Math.Max(treeDepth, curDepth); 215 success = true; 183 216 return phrase; 184 217 } 185 218 219 220 //private void UpdateTD(TreeNode parent, TreeNode child, double reward) { 221 // double alpha = 1.0; 222 // var vParent = V(parent); 223 // var vChild = V(child); 224 // if (double.IsInfinity(vParent)) vParent = 0.0; 225 // if (double.IsInfinity(vChild)) vChild = 0.0; 226 // UpdateV(parent, (alpha * (reward + vChild - vParent))); 227 //} 228 186 229 private void DistributeReward(double reward) { 230 187 231 // iterate in reverse order (bottom up) 188 updateChain.Reverse(); 189 232 //updateChain.Reverse(); 233 UpdateDone(updateChain.Last().Item1); 234 //UpdateTD(updateChain.Last().Item2, updateChain.Last().Item1, reward); 235 //return; 236 237 BackPropReward(updateChain.Last().Item1, reward); 238 /* 190 239 foreach (var e in updateChain) { 191 240 var node = e.Item1; 192 var parent = e.Item2;241 //var parent = e.Item2; 193 242 node.tries++; 194 if (node.children != null && node.children.All(c => c.done)) {195 node.done = true;196 }243 //if (node.children != null && node.children.All(c => c.done)) { 244 // node.done = true; 245 //} 197 246 UpdateV(node, reward); 198 247 199 248 // the reward for the parent is either the just recieved reward or the value of the best action so far 200 double value = 0.0;201 if (parent != null) {202 var doneChilds = parent.children.Where(ch => ch.done);203 if (doneChilds.Any()) value = doneChilds.Select(ch => V(ch)).Max();204 }249 //double value = 0.0; 250 //if (parent != null) { 251 // var doneChilds = parent.children.Where(ch => ch.done); 252 // if (doneChilds.Any()) value = doneChilds.Select(ch => V(ch)).Max(); 253 //} 205 254 //if (value > reward) reward = value; 206 } 207 } 255 }*/ 256 } 257 258 private void BackPropReward(TreeNode n, double reward) { 259 n.tries++; 260 UpdateV(n, reward); 261 foreach (var p in n.parents) BackPropReward(p, reward); 262 } 263 264 private void UpdateDone(TreeNode n) { 265 if (!n.done && n.children != null && n.children.All(c => c.done)) n.done = true; 266 if (n.done) foreach (var p in n.parents) UpdateDone(p); 267 } 268 208 269 209 270 private Dictionary<string, double> v; … … 219 280 tries.Add(canonicalStr, 1); 220 281 } else { 221 v[canonicalStr] = stateV + 0.005 * (reward - stateV);222 //v[canonicalStr] = stateV + (1.0 / tries[canonicalStr]) * (reward - stateV);282 //v[canonicalStr] = stateV + 0.005 * (reward - stateV); 283 v[canonicalStr] = stateV + (1.0 / tries[canonicalStr]) * (reward - stateV); 223 284 tries[canonicalStr]++; 224 285 } … … 229 290 //var canonicalStr = n.ident; 230 291 double stateV; 231 292 if (!tries.ContainsKey(canonicalStr)) return double.PositiveInfinity; 232 293 if (!v.TryGetValue(canonicalStr, out stateV)) { 233 294 return 0.0; … … 237 298 } 238 299 300 private int SelectRandom(Random random, TreeNode[] children) { 301 return children.Select((ch, i) => Tuple.Create(ch, i)).Where(p => !p.Item1.done).SelectRandom(random).Item2; 302 } 303 239 304 private int SelectEpsGreedy(Random random, TreeNode[] children) { 240 305 if (random.NextDouble() < 0.2) { 241 242 return children.Select((ch, i) => Tuple.Create(ch, i)).Where(p => !p.Item1.done).SelectRandom(random).Item2; 306 return SelectRandom(random, children); 243 307 } else { 244 308 var bestQ = double.NegativeInfinity; -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/MctsSampler.cs
r11745 r11747 5 5 using System.Text; 6 6 using HeuristicLab.Algorithms.Bandits; 7 using HeuristicLab.Common; 7 8 using HeuristicLab.Problems.GrammaticalOptimization; 8 9 … … 13 14 public int randomTries; 14 15 public IBanditPolicyActionInfo actionInfo; 16 public TreeNode parent; 15 17 public TreeNode[] children; 16 18 public bool done = false; 17 19 18 public TreeNode(string id ) {20 public TreeNode(string id, TreeNode parent) { 19 21 this.ident = id; 22 this.parent = parent; 20 23 } 21 24 … … 35 38 private readonly IBanditPolicy policy; 36 39 37 private List<TreeNode> updateChain;40 private TreeNode lastNode; // the bottom node in one episode 38 41 private TreeNode rootNode; 39 42 … … 75 78 Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}, rootQ {3:F3}, bestQ {4:F3}", treeDepth, treeSize, n.actionInfo.Tries, n.actionInfo.Value, bestQuality); 76 79 while (n.children != null) { 80 Console.WriteLine("{0,-30}", n.ident); 81 double maxVForRow = n.children.Select(ch => ch.actionInfo.Value).Max(); 82 if (maxVForRow == 0) maxVForRow = 1.0; 83 84 for (int i = 0; i < n.children.Length; i++) { 85 var ch = n.children[i]; 86 SetColorForChild(ch, maxVForRow); 87 Console.Write("{0,5}", ch.ident); 88 } 77 89 Console.WriteLine(); 78 Console.WriteLine("{0,5}->{1,-50}", n.ident, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.ident)))); 79 Console.WriteLine("{0,5} {1,-50}", string.Empty, string.Join(" ", n.children.Select(ch => string.Format("{0,4:F2}", ch.actionInfo.Value * 10)))); 80 Console.WriteLine("{0,5} {1,-50}", string.Empty, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.done ? "X" : ch.actionInfo.Tries.ToString())))); 90 for (int i = 0; i < n.children.Length; i++) { 91 var ch = n.children[i]; 92 SetColorForChild(ch, maxVForRow); 93 Console.Write("{0,5:F2}", ch.actionInfo.Value * 10); 94 } 95 Console.WriteLine(); 96 for (int i = 0; i < n.children.Length; i++) { 97 var ch = n.children[i]; 98 SetColorForChild(ch, maxVForRow); 99 Console.Write("{0,5}", ch.done ? "X" : ch.actionInfo.Tries.ToString()); 100 } 101 Console.ForegroundColor = ConsoleColor.White; 102 Console.WriteLine(); 81 103 //n.policy.PrintStats(); 82 n = n.children.Where(ch => !ch.done).OrderByDescending(c => c.actionInfo.Value).First(); 83 } 104 //n = n.children.Where(ch => !ch.done).OrderByDescending(c => c.actionInfo.Value).First(); 105 n = n.children.Where(ch=>!ch.done).OrderByDescending(c => c.actionInfo.Value).First(); 106 } 107 Console.WriteLine("-----------------------"); 108 } 109 110 private void SetColorForChild(TreeNode ch, double maxVForRow) { 111 //if (ch.done) Console.ForegroundColor = ConsoleColor.White; 112 //else 113 Console.ForegroundColor = ConsoleEx.ColorForValue(ch.actionInfo.Value / maxVForRow); 84 114 } 85 115 86 116 private void InitPolicies(IGrammar grammar) { 87 this.updateChain = new List<TreeNode>(); 88 89 rootNode = new TreeNode(grammar.SentenceSymbol.ToString() );117 118 119 rootNode = new TreeNode(grammar.SentenceSymbol.ToString(), null); 90 120 rootNode.actionInfo = policy.CreateActionInfo(); 91 121 treeDepth = 0; … … 94 124 95 125 private Sequence SampleSentence(IGrammar grammar) { 96 updateChain.Clear();126 lastNode = null; 97 127 var startPhrase = new Sequence(grammar.SentenceSymbol); 128 //var startPhrase = new Sequence("a*b+c*d+e*f+E"); 129 98 130 return CompleteSentence(grammar, startPhrase); 99 131 } … … 105 137 var curDepth = 0; 106 138 while (!phrase.IsTerminal) { 107 updateChain.Add(n);108 139 109 140 if (n.randomTries < randomTries) { 110 141 n.randomTries++; 111 142 treeDepth = Math.Max(treeDepth, curDepth); 143 lastNode = n; 112 144 return g.CompleteSentenceRandomly(random, phrase, maxLen); 113 145 } else { … … 120 152 121 153 if (n.randomTries == randomTries && n.children == null) { 122 n.children = alts.Select(alt => new TreeNode(alt.ToString() )).ToArray(); // create a new node for each alternative154 n.children = alts.Select(alt => new TreeNode(alt.ToString(), n)).ToArray(); // create a new node for each alternative 123 155 foreach (var ch in n.children) ch.actionInfo = policy.CreateActionInfo(); 124 156 treeSize += n.children.Length; … … 138 170 } // while 139 171 140 updateChain.Add(n);172 lastNode = n; 141 173 142 174 … … 150 182 private void DistributeReward(double reward) { 151 183 // iterate in reverse order (bottom up) 152 updateChain.Reverse(); 153 154 foreach (var e in updateChain) { 155 var node = e; 156 if (node.done) node.actionInfo.Disable(); 184 185 var node = lastNode; 186 while (node != null) { 187 if (node.done) node.actionInfo.Disable(reward); 157 188 if (node.children != null && node.children.All(c => c.done)) { 158 189 node.done = true; 159 node.actionInfo.Disable(); 190 var bestActionValue = node.children.Select(c => c.actionInfo.Value).Max(); 191 node.actionInfo.Disable(bestActionValue); 160 192 } 161 193 if (!node.done) { 162 194 node.actionInfo.UpdateReward(reward); 163 195 } 196 node = node.parent; 164 197 } 165 198 } -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/TemporalDifferenceTreeSearchSampler.cs
r11744 r11747 36 36 private readonly Random random; 37 37 private readonly int randomTries; 38 private readonly IBanditPolicy policy;39 38 40 39 private List<TreeNode> updateChain; … … 46 45 47 46 48 public TemporalDifferenceTreeSearchSampler(IProblem problem, int maxLen, Random random, int randomTries , IBanditPolicy policy) {47 public TemporalDifferenceTreeSearchSampler(IProblem problem, int maxLen, Random random, int randomTries) { 49 48 this.maxLen = maxLen; 50 49 this.problem = problem; 51 50 this.random = random; 52 51 this.randomTries = randomTries; 53 this.policy = policy;54 52 } 55 53 … … 78 76 Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}, rootQ {3:F3}, bestQ {4:F3}", treeDepth, treeSize, n.tries, n.q, bestQuality); 79 77 while (n.children != null) { 78 Console.WriteLine("{0,-30}", n.ident); 79 double maxVForRow = n.children.Select(ch => ch.q).Max(); 80 if (maxVForRow == 0) maxVForRow = 1.0; 81 82 for (int i = 0; i < n.children.Length; i++) { 83 var ch = n.children[i]; 84 Console.ForegroundColor = ConsoleEx.ColorForValue(ch.q / maxVForRow); 85 Console.Write("{0,5}", ch.ident); 86 } 80 87 Console.WriteLine(); 81 Console.WriteLine("{0,5}->{1,-50}", n.ident, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.ident)))); 82 Console.WriteLine("{0,5} {1,-50}", string.Empty, string.Join(" ", n.children.Select(ch => string.Format("{0,4:F2}", ch.q * 10)))); 83 Console.WriteLine("{0,5} {1,-50}", string.Empty, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.done ? "X" : ch.tries.ToString())))); 88 for (int i = 0; i < n.children.Length; i++) { 89 var ch = n.children[i]; 90 Console.ForegroundColor = ConsoleEx.ColorForValue(ch.q / maxVForRow); 91 Console.Write("{0,5:F2}", ch.q * 10); 92 } 93 Console.WriteLine(); 94 for (int i = 0; i < n.children.Length; i++) { 95 var ch = n.children[i]; 96 Console.ForegroundColor = ConsoleEx.ColorForValue(ch.q / maxVForRow); 97 Console.Write("{0,5}", ch.done ? "X" : ch.tries.ToString()); 98 } 99 Console.ForegroundColor = ConsoleColor.White; 100 Console.WriteLine(); 84 101 //n.policy.PrintStats(); 85 102 n = n.children.Where(ch => !ch.done).OrderByDescending(c => c.q).First(); 86 103 } 87 //Console.ReadLine();88 104 } 89 105 … … 127 143 } 128 144 // => select using bandit policy 129 int selectedAltIdx = Select Action(random, n.children);145 int selectedAltIdx = SelectEpsGreedy(random, n.children); 130 146 Sequence selectedAlt = alts.ElementAt(selectedAltIdx); 131 147 … … 152 168 153 169 // eps-greedy 154 private int Select Action(Random random, TreeNode[] children) {170 private int SelectEpsGreedy(Random random, TreeNode[] children) { 155 171 if (random.NextDouble() < 0.1) { 156 172 … … 158 174 } else { 159 175 var bestQ = double.NegativeInfinity; 160 var bestChildIdx = -1;176 var bestChildIdx = new List<int>(); 161 177 for (int i = 0; i < children.Length; i++) { 162 178 if (children[i].done) continue; 163 if (children[i].tries == 0) return i; 164 if (children[i].q > bestQ) { 165 bestQ = children[i].q; 166 bestChildIdx = i; 179 // if (children[i].tries == 0) return i; 180 var q = children[i].q; 181 if (q > bestQ) { 182 bestQ = q; 183 bestChildIdx.Clear(); 184 bestChildIdx.Add(i); 185 } else if (q == bestQ) { 186 bestChildIdx.Add(i); 167 187 } 168 188 } 169 Debug.Assert(bestChildIdx > -1);170 return bestChildIdx ;189 Debug.Assert(bestChildIdx.Any()); 190 return bestChildIdx.SelectRandom(random); 171 191 } 172 192 } 173 193 174 194 private void DistributeReward(double reward) { 175 const double alpha = 0.1;176 const double gamma = 1;177 // iterate in reverse order (bottom up)178 195 updateChain.Reverse(); 179 var nextQ = 0.0; 180 foreach (var e in updateChain) { 181 var node = e; 182 node.tries++; 196 foreach (var node in updateChain) { 183 197 if (node.children != null && node.children.All(c => c.done)) { 184 198 node.done = true; 185 199 } 186 // reward is recieved only for the last action 187 if (e == updateChain.First()) { 188 node.q = node.q + alpha * (reward + gamma * nextQ - node.q); 189 nextQ = node.q; 190 } else { 191 node.q = node.q + alpha * (0 + gamma * nextQ - node.q); 192 nextQ = node.q; 193 } 194 } 200 } 201 updateChain.Reverse(); 202 203 //const double alpha = 0.1; 204 const double gamma = 1; 205 double alpha; 206 foreach (var p in updateChain.Zip(updateChain.Skip(1), Tuple.Create)) { 207 var parent = p.Item1; 208 var child = p.Item2; 209 parent.tries++; 210 alpha = 1.0 / parent.tries; 211 //alpha = 0.01; 212 parent.q = parent.q + alpha * (0 + gamma * child.q - parent.q); 213 } 214 // reward is recieved only for the last action 215 var n = updateChain.Last(); 216 n.tries++; 217 alpha = 1.0 / n.tries; 218 //alpha = 0.1; 219 n.q = n.q + alpha * reward; 195 220 } 196 221 -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization.SymbReg/SymbolicRegressionProblem.cs
r11742 r11747 75 75 // right now only + and * is supported 76 76 public string CanonicalRepresentation(string terminalPhrase) { 77 return terminalPhrase;78 //var terms = terminalPhrase.Split('+');79 //return string.Join("+", terms.Select(term => string.Join("", term.Replace("*", "").OrderBy(ch => ch)))80 //.OrderBy(term => term));77 //return terminalPhrase; 78 var terms = terminalPhrase.Split('+'); 79 return string.Join("+", terms.Select(term => string.Join("", term.Replace("*", "").OrderBy(ch => ch))) 80 .OrderBy(term => term)); 81 81 } 82 82 } -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization.Test/TestInstances.cs
r11730 r11747 256 256 Assert.AreEqual(0.116199534934045, p.Evaluate("c*f*j"), 1.0E-7); 257 257 258 Assert.AreEqual(0.824522210419616, p.Evaluate("a*b+c*d+e*f"), 1E-7); 259 258 260 259 261 Assert.AreEqual(1.0, p.Evaluate("a*b+c*d+e*f+a*g*i+c*f*j"), 1.0E-7); -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization.csproj
r11732 r11747 43 43 </ItemGroup> 44 44 <ItemGroup> 45 <Compile Include="RoyalPhraseSequenceProblem.cs" /> 46 <Compile Include="RoyalSequenceProblem.cs" /> 45 47 <Compile Include="ExpressionInterpreter.cs" /> 46 48 <Compile Include="Grammar.cs" /> -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization/SantaFeAntProblem.cs
r11742 r11747 99 99 100 100 public string CanonicalRepresentation(string terminalPhrase) { 101 return terminalPhrase.Replace("rl", "").Replace("lr", ""); 101 //return terminalPhrase; 102 string oldPhrase; 103 do { 104 oldPhrase = terminalPhrase; 105 terminalPhrase.Replace("ll", "rr").Replace("rl", "lr"); 106 } while (terminalPhrase != oldPhrase); 107 return terminalPhrase; 102 108 } 103 109 } -
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization/SymbolicRegressionPoly10Problem.cs
r11745 r11747 16 16 private const string grammarString = @" 17 17 G(E): 18 E -> a | b | c | d | e | f | g | h | i | j | a+E | b+E | c+E | d+E | e+E | f+E | g+E | h+E | i+E | j+E | a*E | b*E | c*E | d*E | e*E | f*E | g*E | h*E | i*E | j*E 18 E -> a | b | c | d | e | f | g | h | i | j | a+E | b+E | c+E | d+E | e+E | f+E | g+E | h+E | i+E | j+E | a*E | b*E | c*E | d*E | e*E | f*E | g*E | h*E | i*E | j*E 19 19 "; 20 20 -
branches/HeuristicLab.Problems.GrammaticalOptimization/Main/Program.cs
r11745 r11747 88 88 Tuple.Create((IProblem)new SymbolicRegressionPoly10Problem(), 23), 89 89 }) 90 foreach (var randomTries in new int[] { 1, 10, /* 5, 100 /*, 500, 1000 */}) {90 foreach (var randomTries in new int[] { 0, 1, 10, /* 5, 100 /*, 500, 1000 */}) { 91 91 foreach (var policy in policies) { 92 92 var myRandomTries = randomTries; … … 137 137 138 138 private static void RunDemo() { 139 // TODO: unify MCTS, TD and ContextMCTS Solvers (stateInfos) 139 140 // TODO: test with eps-greedy using max instead of average as value (seems to work well for symb-reg! explore further!) 140 141 // TODO: separate value function from policy … … 165 166 var random = new Random(); 166 167 167 var problem = new SymbolicRegressionPoly10Problem(); // good results e.g. 10 randomtries and EpsGreedyPolicy(0.2, (aInfo)=>aInfo.MaxReward) 168 var phraseLen = 1; 169 var sentenceLen = 25; 170 var numPhrases = sentenceLen / phraseLen; 171 var problem = new RoyalPhraseSequenceProblem(random, 10, numPhrases, phraseLen: 1, k: 1, correctReward: 1, incorrectReward: 0); 172 173 //var problem = new SymbolicRegressionPoly10Problem(); // good results e.g. 10 randomtries and EpsGreedyPolicy(0.2, (aInfo)=>aInfo.MaxReward) 168 174 // Ant 169 175 // good results e.g. with var alg = new MctsSampler(problem, 17, random, 1, (rand, numActions) => new ThresholdAscentPolicy(numActions, 500, 0.01)); … … 175 181 //var problem = new RoyalPairProblem(); 176 182 //var problem = new EvenParityProblem(); 177 //var alg = new MctsSampler(problem, 23, random, 0, new GaussianThompsonSamplingPolicy(true)); 178 var alg = new MctsContextualSampler(problem, 23, random, 0); 179 //var alg = new TemporalDifferenceTreeSearchSampler(problem, 17, random, 10, new EpsGreedyPolicy(0.1)); 180 //var alg = new ExhaustiveBreadthFirstSearch(problem, 17); 183 // symbreg length = 11 q = 0.824522210419616 184 var alg = new MctsSampler(problem, sentenceLen, random, 0, new BoltzmannExplorationPolicy(200)); 185 //var alg = new MctsQLearningSampler(problem, sentenceLen, random, 0, null); 186 //var alg = new MctsQLearningSampler(problem, 30, random, 0, new EpsGreedyPolicy(0.2)); 187 //var alg = new MctsContextualSampler(problem, 23, random, 0); // must visit each canonical solution only once 188 //var alg = new TemporalDifferenceTreeSearchSampler(problem, 30, random, 1); 189 //var alg = new ExhaustiveBreadthFirstSearch(problem, 7); 181 190 //var alg = new AlternativesContextSampler(problem, random, 17, 4, (rand, numActions) => new RandomPolicy(rand, numActions)); 182 191 //var alg = new ExhaustiveDepthFirstSearch(problem, 17); 183 192 // var alg = new AlternativesSampler(problem, 17); 184 193 // var alg = new RandomSearch(problem, random, 17); 185 // 194 //var alg = new ExhaustiveRandomFirstSearch(problem, random, 17); 186 195 187 196 alg.FoundNewBestSolution += (sentence, quality) => { … … 199 208 alg.PrintStats(); 200 209 } 210 //Console.WriteLine(sentence); 201 211 202 212 if (iterations % 10000 == 0) { 203 //Console.WriteLine("{0,10} {1,10:F5} {2,10:F5} {3}", iterations, bestQuality, quality, sentence);204 //Console.WriteLine("{0,4} {1,7} {2}", alg.treeDepth, alg.treeSize, globalStatistics);205 213 //Console.WriteLine("{0,4} {1,7} {2}", alg.treeDepth, alg.treeSize, globalStatistics); 206 214 }
Note: See TracChangeset
for help on using the changeset viewer.