Changeset 11745


Ignore:
Timestamp:
01/10/15 14:06:29 (7 years ago)
Author:
gkronber
Message:

#2283: worked on contextual MCTS

Location:
branches/HeuristicLab.Problems.GrammaticalOptimization
Files:
1 added
7 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/UCB1Policy.cs

    r11742 r11745  
    1010  public class UCB1Policy : IBanditPolicy {
    1111    public int SelectAction(Random random, IEnumerable<IBanditPolicyActionInfo> actionInfos) {
    12       var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>().ToArray(); // TODO: performance
     12      var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>();
    1313      int bestAction = -1;
    1414      double bestQ = double.NegativeInfinity;
    1515      int totalTries = myActionInfos.Where(a => !a.Disabled).Sum(a => a.Tries);
    1616
    17       for (int a = 0; a < myActionInfos.Length; a++) {
    18         if (myActionInfos[a].Disabled) continue;
    19         if (myActionInfos[a].Tries == 0) return a;
    20         var q = myActionInfos[a].SumReward / myActionInfos[a].Tries + Math.Sqrt((2 * Math.Log(totalTries)) / myActionInfos[a].Tries);
     17      int aIdx = -1;
     18      foreach (var aInfo in myActionInfos) {
     19        aIdx++;
     20        if (aInfo.Disabled) continue;
     21        if (aInfo.Tries == 0) return aIdx;
     22        var q = aInfo.SumReward / aInfo.Tries + Math.Sqrt((2 * Math.Log(totalTries)) / aInfo.Tries);
    2123        if (q > bestQ) {
    2224          bestQ = q;
    23           bestAction = a;
     25          bestAction = aIdx;
    2426        }
    2527      }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/MctsContextualSampler.cs

    r11742 r11745  
    55using System.Text;
    66using HeuristicLab.Algorithms.Bandits;
     7using HeuristicLab.Common;
    78using HeuristicLab.Problems.GrammaticalOptimization;
    89
     
    1011  public class MctsContextualSampler {
    1112    private class TreeNode {
     13      public string ident;
     14      public ReadonlySequence alt;
    1215      public int randomTries;
    13       public int policyTries;
     16      public int tries;
    1417      public TreeNode[] children;
    15       public readonly ReadonlySequence phrase;
    16       public readonly ReadonlySequence alt;
    17 
    18       // phrase represents the phrase of the state and alt represents how the phrase has been reached from the parent state
    19       public TreeNode(ReadonlySequence phrase, ReadonlySequence alt) {
    20         this.phrase = phrase;
     18      public bool done = false;
     19
     20      public TreeNode(string id, ReadonlySequence alt) {
     21        this.ident = id;
    2122        this.alt = alt;
    2223      }
    2324
    2425      public override string ToString() {
    25         return string.Format("Node({0} tries: {1})", phrase, randomTries + policyTries);
     26        return string.Format("Node({0} tries: {1}, done: {2})", ident, tries, done);
    2627      }
    2728    }
     
    3536    private readonly Random random;
    3637    private readonly int randomTries;
    37     private readonly IGrammarPolicy policy;
    38 
    39     private List<Tuple<ReadonlySequence, ReadonlySequence, ReadonlySequence>> updateChain;
     38
     39    private List<Tuple<TreeNode, TreeNode>> updateChain;
    4040    private TreeNode rootNode;
    4141
    4242    public int treeDepth;
    4343    public int treeSize;
    44 
    45     // public MctsSampler(IProblem problem, int maxLen, Random random) :
    46     //   this(problem, maxLen, random, 10, (rand, numActions) => new EpsGreedyPolicy(rand, numActions, 0.1)) {
    47     //
    48     // }
    49 
    50     public MctsContextualSampler(IProblem problem, int maxLen, Random random, int randomTries, IGrammarPolicy policy) {
     44    private double bestQuality;
     45
     46    public MctsContextualSampler(IProblem problem, int maxLen, Random random, int randomTries) {
    5147      this.maxLen = maxLen;
    5248      this.problem = problem;
    5349      this.random = random;
    5450      this.randomTries = randomTries;
    55       this.policy = policy;
     51      this.v = new Dictionary<string, double>(1000000);
     52      this.tries = new Dictionary<string, int>(1000000);
    5653    }
    5754
    5855    public void Run(int maxIterations) {
    59       double bestQuality = double.MinValue;
     56      bestQuality = double.MinValue;
    6057      InitPolicies(problem.Grammar);
    61       for (int i = 0; !policy.Done(rootNode.phrase) && i < maxIterations; i++) {
     58      for (int i = 0; !rootNode.done && i < maxIterations; i++) {
    6259        var sentence = SampleSentence(problem.Grammar).ToString();
    6360        var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen);
     
    7976    public void PrintStats() {
    8077      var n = rootNode;
    81       Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}", treeDepth, treeSize, rootNode.policyTries + rootNode.randomTries);
     78      Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}, rootQ {3:F3}, bestQ {4:F3}", treeDepth, treeSize, n.tries, V(n), bestQuality);
    8279      while (n.children != null) {
     80        Console.WriteLine("{0}", n.ident);
     81        double maxVForRow = n.children.Select(ch => V(ch)).Max();
     82        if (maxVForRow == 0) maxVForRow = 1.0;
     83
     84        for (int i = 0; i < n.children.Length; i++) {
     85          var ch = n.children[i];
     86          Console.ForegroundColor = ConsoleEx.ColorForValue(V(ch) / maxVForRow);
     87          Console.Write("{0,5}", ch.alt);
     88        }
    8389        Console.WriteLine();
    84         Console.WriteLine("{0,5}->{1,-50}", n.alt, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.alt))));
    85         Console.WriteLine("{0,5}  {1,-50}", string.Empty, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.randomTries + ch.policyTries))));
     90        for (int i = 0; i < n.children.Length; i++) {
     91          var ch = n.children[i];
     92          Console.ForegroundColor = ConsoleEx.ColorForValue(V(ch) / maxVForRow);
     93          Console.Write("{0,5:F2}", V(ch) * 10);
     94        }
     95        Console.WriteLine();
     96        for (int i = 0; i < n.children.Length; i++) {
     97          var ch = n.children[i];
     98          Console.ForegroundColor = ConsoleEx.ColorForValue(V(ch) / maxVForRow);
     99          Console.Write("{0,5}", ch.done ? "X" : ch.tries.ToString());
     100        }
     101        Console.ForegroundColor = ConsoleColor.White;
     102        Console.WriteLine();
    86103        //n.policy.PrintStats();
    87         n = n.children.OrderByDescending(c => c.policyTries).First();
    88       }
    89       Console.ReadLine();
    90     }
     104        n = n.children.Where(ch => !ch.done).OrderByDescending(c => V(c)).First();
     105      }
     106    }
     107
    91108
    92109    private void InitPolicies(IGrammar grammar) {
    93       this.updateChain = new List<Tuple<ReadonlySequence, ReadonlySequence, ReadonlySequence>>();
    94 
    95       rootNode = new TreeNode(new ReadonlySequence(grammar.SentenceSymbol), new ReadonlySequence("$"));
     110      this.updateChain = new List<Tuple<TreeNode, TreeNode>>();
     111      this.v.Clear();
     112      this.tries.Clear();
     113
     114      rootNode = new TreeNode(grammar.SentenceSymbol.ToString(), new ReadonlySequence("$"));
    96115      treeDepth = 0;
    97116      treeSize = 0;
     
    100119    private Sequence SampleSentence(IGrammar grammar) {
    101120      updateChain.Clear();
    102       var startPhrase = new Sequence(rootNode.phrase);
     121      //var startPhrase = new Sequence("a*b+c*d+e*f+E");
     122      var startPhrase = new Sequence(grammar.SentenceSymbol);
    103123      return CompleteSentence(grammar, startPhrase);
    104124    }
     
    109129      TreeNode parent = null;
    110130      TreeNode n = rootNode;
    111       bool done = false;
    112131      var curDepth = 0;
    113       while (!done) {
    114         if (parent != null)
    115           updateChain.Add(Tuple.Create(parent.phrase, n.alt, n.phrase));
     132      while (!phrase.IsTerminal) {
     133        updateChain.Add(Tuple.Create(n, parent));
    116134
    117135        if (n.randomTries < randomTries) {
     
    128146
    129147          if (n.randomTries == randomTries && n.children == null) {
     148            // create a new node for each alternative
    130149            n.children = new TreeNode[alts.Count()];
    131             int cIdx = 0;
     150            var i = 0;
    132151            foreach (var alt in alts) {
    133152              var newPhrase = new Sequence(phrase);
    134               newPhrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, alt);
    135               n.children[cIdx++] = new TreeNode(new ReadonlySequence(newPhrase), new ReadonlySequence(alt));
     153              newPhrase.ReplaceAt(newPhrase.FirstNonTerminalIndex, 1, alt);
     154              if (!newPhrase.IsTerminal) newPhrase = newPhrase.Subsequence(0, newPhrase.FirstNonTerminalIndex + 1);
     155              n.children[i++] = new TreeNode(newPhrase.ToString(), new ReadonlySequence(alt));
    136156            }
    137157            treeSize += n.children.Length;
    138158          }
    139 
    140           n.policyTries++;
    141           // => select using bandit policy
    142           ReadonlySequence selectedAlt = policy.SelectAction(random, n.phrase, n.children.Select(c => c.alt));
     159          // => select using eps-greedy
     160          int selectedAltIdx = SelectEpsGreedy(random, n.children);
     161
     162          //int selectedAltIdx = SelectActionUCB1(random, n.children);
     163          Sequence selectedAlt = alts.ElementAt(selectedAltIdx);
    143164
    144165          // replace nt with alt
     
    147168          curDepth++;
    148169
    149           done = phrase.IsTerminal;
    150 
    151170          // prepare for next iteration
    152171          parent = n;
    153           n = n.children.Single(ch => ch.alt == selectedAlt); // TODO: perf
     172          n = n.children[selectedAltIdx];
    154173        }
    155174      } // while
    156175
    157       n.policyTries++;
    158       updateChain.Add(Tuple.Create(parent.phrase, n.alt, n.phrase));
     176      updateChain.Add(Tuple.Create(n, parent));
     177
     178      // the last node is a leaf node (sentence is done), so we never need to visit this node again
     179      n.done = true;
    159180
    160181
     
    168189
    169190      foreach (var e in updateChain) {
    170         var state = e.Item1;
    171         var action = e.Item2;
    172         var newState = e.Item3;
    173         policy.UpdateReward(state, action, reward, newState);
    174         //policy.UpdateReward(action, reward / updateChain.Count);
    175       }
    176     }
     191        var node = e.Item1;
     192        var parent = e.Item2;
     193        node.tries++;
     194        if (node.children != null && node.children.All(c => c.done)) {
     195          node.done = true;
     196        }
     197        UpdateV(node, reward);
     198
     199        // the reward for the parent is either the just recieved reward or the value of the best action so far
     200        double value = 0.0;
     201        if (parent != null) {
     202          var doneChilds = parent.children.Where(ch => ch.done);
     203          if (doneChilds.Any()) value = doneChilds.Select(ch => V(ch)).Max();
     204        }
     205        //if (value > reward) reward = value;
     206      }
     207    }
     208
     209    private Dictionary<string, double> v;
     210    private Dictionary<string, int> tries;
     211
     212    private void UpdateV(TreeNode n, double reward) {
     213      var canonicalStr = problem.CanonicalRepresentation(n.ident);
     214      //var canonicalStr = n.ident;
     215      double stateV;
     216
     217      if (!v.TryGetValue(canonicalStr, out  stateV)) {
     218        v.Add(canonicalStr, reward);
     219        tries.Add(canonicalStr, 1);
     220      } else {
     221        v[canonicalStr] = stateV + 0.005 * (reward - stateV);
     222        //v[canonicalStr] = stateV + (1.0 / tries[canonicalStr]) * (reward - stateV);
     223        tries[canonicalStr]++;
     224      }
     225    }
     226
     227    private double V(TreeNode n) {
     228      var canonicalStr = problem.CanonicalRepresentation(n.ident);
     229      //var canonicalStr = n.ident;
     230      double stateV;
     231
     232      if (!v.TryGetValue(canonicalStr, out  stateV)) {
     233        return 0.0;
     234      } else {
     235        return stateV;
     236      }
     237    }
     238
     239    private int SelectEpsGreedy(Random random, TreeNode[] children) {
     240      if (random.NextDouble() < 0.2) {
     241
     242        return children.Select((ch, i) => Tuple.Create(ch, i)).Where(p => !p.Item1.done).SelectRandom(random).Item2;
     243      } else {
     244        var bestQ = double.NegativeInfinity;
     245        var bestChildIdx = new List<int>();
     246        for (int i = 0; i < children.Length; i++) {
     247          if (children[i].done) continue;
     248          // if (children[i].tries == 0) return i;
     249          var q = V(children[i]);
     250          if (q > bestQ) {
     251            bestQ = q;
     252            bestChildIdx.Clear();
     253            bestChildIdx.Add(i);
     254          } else if (q == bestQ) {
     255            bestChildIdx.Add(i);
     256          }
     257        }
     258        Debug.Assert(bestChildIdx.Any());
     259        return bestChildIdx.SelectRandom(random);
     260      }
     261    }
     262    private int SelectActionUCB1(Random random, TreeNode[] children) {
     263      int bestAction = -1;
     264      double bestQ = double.NegativeInfinity;
     265      int totalTries = children.Sum(ch => ch.tries);
     266
     267      for (int a = 0; a < children.Length; a++) {
     268        var ch = children[a];
     269        if (ch.done) continue;
     270        if (ch.tries == 0) return a;
     271        var q = V(ch) + Math.Sqrt((2 * Math.Log(totalTries)) / ch.tries);
     272        if (q > bestQ) {
     273          bestQ = q;
     274          bestAction = a;
     275        }
     276      }
     277      Debug.Assert(bestAction > -1);
     278      return bestAction;
     279    }
     280
     281
    177282
    178283    private void RaiseSolutionEvaluated(string sentence, double quality) {
     
    184289      if (handler != null) handler(sentence, quality);
    185290    }
     291
     292
    186293  }
    187294}
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/MctsSampler.cs

    r11744 r11745  
    4141    public int treeSize;
    4242    private double bestQuality;
    43 
    44     // public MctsSampler(IProblem problem, int maxLen, Random random) :
    45     //   this(problem, maxLen, random, 10, (rand, numActions) => new EpsGreedyPolicy(rand, numActions, 0.1)) {
    46     //
    47     // }
    4843
    4944    public MctsSampler(IProblem problem, int maxLen, Random random, int randomTries, IBanditPolicy policy) {
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Common/HeuristicLab.Common.csproj

    r11727 r11745  
    3333    <Reference Include="System" />
    3434    <Reference Include="System.Core" />
     35    <Reference Include="System.Drawing" />
    3536    <Reference Include="System.Xml.Linq" />
    3637    <Reference Include="System.Data.DataSetExtensions" />
     
    4041  </ItemGroup>
    4142  <ItemGroup>
     43    <Compile Include="ConsoleEx.cs" />
    4244    <Compile Include="Extensions.cs" />
    4345    <Compile Include="Properties\AssemblyInfo.cs" />
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization.Test/TestBanditPolicies.cs

    r11742 r11745  
    146146      var randSeed = 31415;
    147147      var nArms = 20;
    148       Console.WriteLine("Threshold Ascent (20)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(20, 0.01));
    149       Console.WriteLine("Threshold Ascent (100)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(100, 0.01));
    150       Console.WriteLine("Threshold Ascent (500)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(500, 0.01));
    151       Console.WriteLine("Threshold Ascent (1000)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(1000, 0.01));
    152       Console.WriteLine("Thompson (Gaussian orig)"); TestPolicyGaussianMixture(randSeed, nArms, new GaussianThompsonSamplingPolicy(true));
    153       Console.WriteLine("Thompson (Gaussian new)"); TestPolicyGaussianMixture(randSeed, nArms, new GaussianThompsonSamplingPolicy());
    154       Console.WriteLine("Generic Thompson (Gaussian fixed variance)"); TestPolicyGaussianMixture(randSeed, nArms, new GenericThompsonSamplingPolicy(new GaussianModel(0.5, 1, 0.1)));
    155       Console.WriteLine("Generic Thompson (Gaussian unknown variance)"); TestPolicyGaussianMixture(randSeed, nArms, new GenericThompsonSamplingPolicy(new GaussianModel(0.5, 1, 1, 1)));
     148
     149      Console.WriteLine("Generic Thompson (Gaussian Mixture)"); TestPolicyGaussianMixture(randSeed, nArms, new GenericThompsonSamplingPolicy(new GaussianMixtureModel()));
     150      // Console.WriteLine("Threshold Ascent (20)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(20, 0.01));
     151      // Console.WriteLine("Threshold Ascent (100)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(100, 0.01));
     152      // Console.WriteLine("Threshold Ascent (500)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(500, 0.01));
     153      // Console.WriteLine("Threshold Ascent (1000)"); TestPolicyGaussianMixture(randSeed, nArms, new ThresholdAscentPolicy(1000, 0.01));
     154      // Console.WriteLine("Thompson (Gaussian orig)"); TestPolicyGaussianMixture(randSeed, nArms, new GaussianThompsonSamplingPolicy(true));
     155      // Console.WriteLine("Thompson (Gaussian new)"); TestPolicyGaussianMixture(randSeed, nArms, new GaussianThompsonSamplingPolicy());
     156      // Console.WriteLine("Generic Thompson (Gaussian fixed variance)"); TestPolicyGaussianMixture(randSeed, nArms, new GenericThompsonSamplingPolicy(new GaussianModel(0.5, 1, 0.1)));
     157      // Console.WriteLine("Generic Thompson (Gaussian unknown variance)"); TestPolicyGaussianMixture(randSeed, nArms, new GenericThompsonSamplingPolicy(new GaussianModel(0.5, 1, 1, 1)));
    156158
    157159      /*
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization/SymbolicRegressionPoly10Problem.cs

    r11742 r11745  
    7272
    7373    // right now only + and * is supported
    74     public string CanonicalRepresentation(string terminalPhrase) {
    75       var terms = terminalPhrase.Split('+');
    76       return string.Join("+", terms.Select(term => string.Join("", term.Replace("*", "").OrderBy(ch => ch)))
    77         .OrderBy(term => term));
     74    public string CanonicalRepresentation(string phrase) {
     75      var terms = phrase.Split('+').Select(t => t.Replace("*", ""));
     76      var terminalTerms = terms.Where(t => t.All(ch => grammar.IsTerminal(ch)));
     77      var nonTerminalTerms = terms.Where(t => t.Any(ch => grammar.IsNonTerminal(ch)));
     78
     79      return string.Join("+", terminalTerms.Select(term => CanonicalTerm(term)).OrderBy(term => term).Concat(nonTerminalTerms.Select(term => CanonicalTerm(term))));
     80    }
     81
     82    private string CanonicalTerm(string term) {
     83      return string.Join("", term.OrderByDescending(ch => (byte)ch));
    7884    }
    7985  }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/Main/Program.cs

    r11744 r11745  
    169169      // good results e.g. with       var alg = new MctsSampler(problem, 17, random, 1, (rand, numActions) => new ThresholdAscentPolicy(numActions, 500, 0.01));
    170170      // GaussianModelWithUnknownVariance (and Q= 0.99-quantil) also works well for Ant
    171       //var problem = new SantaFeAntProblem(); 
     171      //var problem = new SantaFeAntProblem();
    172172      //var problem = new SymbolicRegressionProblem("Tower");
    173173      //var problem = new PalindromeProblem();
     
    175175      //var problem = new RoyalPairProblem();
    176176      //var problem = new EvenParityProblem();
    177       var alg = new MctsSampler(problem, 25, random, 0, new GenericThompsonSamplingPolicy(new LogitNormalModel()));
    178       //var alg = new TemporalDifferenceTreeSearchSampler(problem, 23, random, 0, new RandomPolicy());
     177      //var alg = new MctsSampler(problem, 23, random, 0, new GaussianThompsonSamplingPolicy(true));
     178      var alg = new MctsContextualSampler(problem, 23, random, 0);
     179      //var alg = new TemporalDifferenceTreeSearchSampler(problem, 17, random, 10, new EpsGreedyPolicy(0.1));
    179180      //var alg = new ExhaustiveBreadthFirstSearch(problem, 17);
    180181      //var alg = new AlternativesContextSampler(problem, random, 17, 4, (rand, numActions) => new RandomPolicy(rand, numActions));
     
    187188        bestQuality = quality;
    188189        bestSentence = sentence;
    189         Console.WriteLine("{0,4} {1,7} {2}", alg.treeDepth, alg.treeSize, globalStatistics);
     190        //Console.WriteLine("{0,4} {1,7} {2}", alg.treeDepth, alg.treeSize, globalStatistics);
     191        //Console.ReadLine();
    190192      };
    191193      alg.SolutionEvaluated += (sentence, quality) => {
     
    193195        globalStatistics.AddSentence(sentence, quality);
    194196        if (iterations % 100 == 0) {
    195           Console.Clear();
     197          //if (iterations % 1000 == 0) Console.Clear();
     198          Console.SetCursorPosition(0, 0);
    196199          alg.PrintStats();
    197200        }
     201
    198202        if (iterations % 10000 == 0) {
    199203          //Console.WriteLine("{0,10} {1,10:F5} {2,10:F5} {3}", iterations, bestQuality, quality, sentence);
    200204          //Console.WriteLine("{0,4} {1,7} {2}", alg.treeDepth, alg.treeSize, globalStatistics);
    201           Console.WriteLine("{0,4} {1,7} {2}", alg.treeDepth, alg.treeSize, globalStatistics);
     205          //Console.WriteLine("{0,4} {1,7} {2}", alg.treeDepth, alg.treeSize, globalStatistics);
    202206        }
    203207      };
Note: See TracChangeset for help on using the changeset viewer.