Changeset 10427


Ignore:
Timestamp:
01/29/14 19:30:53 (7 years ago)
Author:
gkronber
Message:

#2026 integrated max depth into MCTS solver

Location:
branches/HeuristicLab.Problems.GPDL/CodeGenerator
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Problems.GPDL/CodeGenerator/MonteCarloTreeSearchCodeGen.cs

    r10426 r10427  
    3434    private readonly ?IDENT?Problem problem;
    3535    private readonly Random random;
     36    private readonly ?IDENT?RandomSearchSolver randomSearch;
     37
    3638    private SearchTreeNode searchTree = new SearchTreeNode();
    3739   
     
    4446    }
    4547
     48    private  const int RANDOM_TRIES = 1;
     49
    4650    private void SampleTree(SearchTreeNode searchTree, Stack<Tuple<Tree, int, int, int>> extensionPoints) {
    47       const int RANDOM_TRIES = 1000;
    48       if(extensionPoints.Count == 0) {
    49         searchTree.done = true;
    50         return; // nothing to do
    51       }
    5251      var extensionPoint = extensionPoints.Pop();
    5352      Tree parent = extensionPoint.Item1;
     
    5655      int maxDepth = extensionPoint.Item4;
    5756      Debug.Assert(maxDepth >= 1);
     57      Debug.Assert(Grammar.minDepth[state] <= maxDepth);
    5858      Tree t = null;
    5959      if(searchTree.tries < RANDOM_TRIES || Grammar.subtreeCount[state] == 0) {
    60         t = SampleTreeRandom(state);
     60        int steps = 0; int curDepth = this.maxDepth - maxDepth; int depth = this.maxDepth - maxDepth;
     61        t = randomSearch.SampleTree(state, maxDepth, ref steps, ref curDepth, ref depth);
    6162        if(Grammar.subtreeCount[state] == 0) {
    6263          // when we produced a terminal continue filling up all other empty points
    6364          Debug.Assert(searchTree.children == null || searchTree.children.Length == 1);
    64           if(searchTree.children == null)
    65             searchTree.children = new SearchTreeNode[] { new SearchTreeNode() } ;
    66           SampleTree(searchTree.children[0], extensionPoints);
    67           if(searchTree.children[0].done) searchTree.done = true;
     65          if(extensionPoints.Count == 0) {
     66            searchTree.done = true;
     67          } else {
     68            if(searchTree.children == null)
     69              searchTree.children = new SearchTreeNode[] { new SearchTreeNode() } ;
     70            SampleTree(searchTree.children[0], extensionPoints);
     71            if(searchTree.children[0].done) searchTree.done = true;
     72          }
    6873        } else {
    6974          // fill up all remaining slots randomly
     
    7277            var pState = p.Item2;
    7378            var pIdx = p.Item3;
    74             pParent.subtrees[pIdx] = SampleTreeRandom(pState);
    75           }
    76         }
    77       } else {
    78         if(Grammar.subtreeCount[state] == 1) {
    79           if(searchTree.children == null) {
    80             int nChildren = Grammar.transition[state].Length;
    81             searchTree.children = new SearchTreeNode[nChildren];
    82           }
    83           Debug.Assert(searchTree.children.Length == Grammar.transition[state].Length);
    84           Debug.Assert(searchTree.tries - RANDOM_TRIES == searchTree.children.Where(c=>c!=null).Sum(c=>c.tries));
    85           var altIdx = SelectAlternative(searchTree);
    86           t = new Tree(altIdx, new Tree[1]);
    87           extensionPoints.Push(Tuple.Create(t, Grammar.transition[state][altIdx], 0, maxDepth - 1));
    88           SampleTree(searchTree.children[altIdx], extensionPoints);
    89         } else {
    90           // multiple subtrees
    91           var subtrees = new Tree[Grammar.subtreeCount[state]];
    92           t = new Tree(-1, subtrees);
    93           for(int i = subtrees.Length - 1; i >= 0; i--) {
    94             extensionPoints.Push(Tuple.Create(t, Grammar.transition[state][i], i, maxDepth - 1));
    95           }
    96           SampleTree(searchTree, extensionPoints);         
    97         }
     79            var pMaxDepth = p.Item4;
     80            curDepth = this.maxDepth - pMaxDepth;
     81            depth = curDepth;
     82            pParent.subtrees[pIdx] = randomSearch.SampleTree(pState, pMaxDepth, ref steps, ref curDepth, ref depth);
     83          }
     84        }
     85      } else if(Grammar.subtreeCount[state] == 1) {
     86         if(searchTree.children == null) {
     87           int nChildren = Grammar.transition[state].Length;
     88           searchTree.children = new SearchTreeNode[nChildren];
     89         }
     90         Debug.Assert(searchTree.children.Length == Grammar.transition[state].Length);
     91         Debug.Assert(searchTree.tries - RANDOM_TRIES == searchTree.children.Where(c=>c!=null).Sum(c=>c.tries));
     92         var altIdx = SelectAlternative(searchTree, state, maxDepth);
     93         t = new Tree(altIdx, new Tree[1]);
     94         extensionPoints.Push(Tuple.Create(t, Grammar.transition[state][altIdx], 0, maxDepth - 1));
     95         SampleTree(searchTree.children[altIdx], extensionPoints);
     96         searchTree.done = (from idx in Enumerable.Range(0, searchTree.children.Length)
     97                            where Grammar.minDepth[Grammar.transition[state][idx]] <= maxDepth - 1
     98                            select searchTree.children[idx]).All(c=>c != null && c.done);
     99      } else {
     100        // multiple subtrees
     101        var subtrees = new Tree[Grammar.subtreeCount[state]];
     102        t = new Tree(-1, subtrees);
     103        for(int i = subtrees.Length - 1; i >= 0; i--) {
     104          extensionPoints.Push(Tuple.Create(t, Grammar.transition[state][i], i, maxDepth - 1));
     105        }
     106        SampleTree(searchTree, extensionPoints);   
    98107      }     
    99108      Debug.Assert(parent.subtrees[subtreeIdx] == null);
     
    101110    }
    102111
    103     private int SelectAlternative(SearchTreeNode searchTree) {
     112    private int SelectAlternative(SearchTreeNode searchTree, int state, int maxDepth) {
    104113      // any alternative not yet explored?
    105       var altIdx = Array.FindIndex(searchTree.children, (e) => e == null);
     114      var altIndexes = searchTree.children
     115                     .Select((e,i) => new {Elem = e, Idx = i})
     116                     .Where(p => p.Elem == null && Grammar.minDepth[Grammar.transition[state][p.Idx]] <= maxDepth)
     117                     .Select(p => p.Idx);
     118      int altIdx = altIndexes.Any()?altIndexes.First() : -1;
    106119      if(altIdx >= 0) {
    107120        searchTree.children[altIdx] = new SearchTreeNode();
    108121        return altIdx;
    109122      } else {
    110         altIdx = Array.FindIndex(searchTree.children, (e) => !e.done && e.tries < 1000);
     123        altIndexes = searchTree.children
     124                       .Select((e,i) => new {Elem = e, Idx = i})
     125                       .Where(p => p.Elem != null && !p.Elem.done && p.Elem.tries < RANDOM_TRIES && Grammar.minDepth[Grammar.transition[state][p.Idx]] <= maxDepth)
     126                       .Select(p => p.Idx);
     127        altIdx = altIndexes.Any()?altIndexes.First() : -1;
    111128        if(altIdx >= 0) return altIdx;
    112129        // select the least sampled alternative
    113         //altIdx = 0;
    114         //int minSamples = searchTree.children[altIdx].tries;
    115         //for(int idx = 1; idx < searchTree.children.Length; idx++) {
    116         //  if(!searchTree.children[idx].done && searchTree.children[idx].tries < minSamples) {
     130        //altIdx = -1;
     131        //int minSamples = int.MaxValue;
     132        //for(int idx = 0; idx < searchTree.children.Length; idx++) {
     133        //  if(searchTree.children[idx] == null) continue;
     134        //  if(!searchTree.children[idx].done && Grammar.minDepth[Grammar.transition[state][idx]] <= maxDepth && searchTree.children[idx].tries < minSamples) {
    117135        //    minSamples = searchTree.children[idx].tries;
    118136        //    altIdx = idx;
     
    120138        //}
    121139        // select the alternative with the largest average
    122         altIdx = 0;
    123         double bestAverage = UCB(searchTree, searchTree.children[altIdx]);
    124         for(int idx = 1; idx < searchTree.children.Length; idx++) {
    125           if (!searchTree.children[idx].done && UCB(searchTree, searchTree.children[idx]) > UCB(searchTree, searchTree.children[altIdx])) {
     140        altIdx = -1;
     141        double best = double.NegativeInfinity;
     142        for(int idx = 0; idx < searchTree.children.Length; idx++) {
     143          if(searchTree.children[idx] == null) continue;
     144          if (!searchTree.children[idx].done && Grammar.minDepth[Grammar.transition[state][idx]] <= maxDepth  && UCB(searchTree, searchTree.children[idx]) > best) {
    126145            altIdx = idx;
    127           }
    128         }
    129 
    130         searchTree.done = searchTree.children.All(c=>c.done);
     146            best = UCB(searchTree, searchTree.children[idx]);
     147          }
     148        }
     149        Debug.Assert(altIdx > -1);
    131150        return altIdx;
    132151      }
     
    136155      Debug.Assert(parent.tries >= n.tries);
    137156      Debug.Assert(n.tries > 0);
    138       return n.sumQuality / n.tries + Math.Sqrt((10 * Math.Log(parent.tries)) / n.tries ); // constant is dependent fitness function values
     157      return n.sumQuality / n.tries + Math.Sqrt((40 * Math.Log(parent.tries)) / n.tries ); // constant is dependent fitness function values
    139158    }
    140159
     
    174193    }
    175194
    176     // same as in random search solver (could reuse random search)
    177     private Tree SampleTreeRandom(int state) {
    178       return SampleTreeRandom(state, 5);
    179     }
    180     private Tree SampleTreeRandom(int state, int maxDepth) {
    181       Tree t = null;
    182 
    183       // terminals
    184       if(Grammar.subtreeCount[state] == 0) {
    185         t = CreateTerminalNode(state, random, problem);
    186       } else {
    187         // if the symbol has alternatives then we must choose one randomly (only one sub-tree in this case)
    188         if(Grammar.subtreeCount[state] == 1) {
    189           var targetStates = Grammar.transition[state];
    190           var altIdx = SampleAlternative(random, state, maxDepth);
    191           var alternative = SampleTreeRandom(targetStates[altIdx], maxDepth - 1);
    192           t = new Tree(altIdx, new Tree[] { alternative });
    193         } else {
    194           // if the symbol contains only one sequence we must use create sub-trees for each symbol in the sequence
    195           Tree[] subtrees = new Tree[Grammar.subtreeCount[state]];
    196           for(int i = 0; i < Grammar.subtreeCount[state]; i++) {
    197             subtrees[i] = SampleTreeRandom(Grammar.transition[state][i], maxDepth - 1);
    198           }
    199           t = new Tree(-1, subtrees); // alternative index is ignored
    200         }
    201       }
    202       return t;
    203     }
    204 
    205     private static Tree CreateTerminalNode(int state, Random random, ?IDENT?Problem problem) {
    206       switch(state) {
    207         ?CREATETERMINALNODECODE?
    208         default: { throw new ArgumentException(""Unknown state index"" + state); }
    209       }
    210     }
    211 
    212     private int SampleAlternative(Random random, int state, int maxDepth) {
    213       switch(state) {
    214 
    215 ?SAMPLEALTERNATIVECODE?
    216 
    217         default: throw new InvalidOperationException();
    218       }
    219     }
    220 
    221195    public ?IDENT?MonteCarloTreeSearchSolver(?IDENT?Problem problem, string[] args) {
    222       if(args.Length > 0 ) throw new ArgumentException(""Arguments ""+args.Aggregate("""", (str, s) => str+ "" "" + s)+"" are not supported "");
     196      this.randomSearch = new ?IDENT?RandomSearchSolver(problem, args);
     197      if(args.Length > 0 ) ParseArguments(args);
    223198      this.problem = problem;
    224199      this.random = new Random();
    225200    }
     201    private void ParseArguments(string[] args) {
     202      var maxDepthRegex = new Regex(@""--maxDepth=(?<d>.+)"");
     203
     204      var helpRegex = new Regex(@""--help|/\?"");
     205     
     206      foreach(var arg in args) {
     207        var maxDepthMatch = maxDepthRegex.Match(arg);
     208        var helpMatch = helpRegex.Match(arg);
     209        if(helpMatch.Success) {
     210          PrintUsage(); Environment.Exit(0);
     211        } else if(maxDepthMatch.Success) {
     212           maxDepth = int.Parse(maxDepthMatch.Groups[""d""].Captures[0].Value, System.Globalization.CultureInfo.InvariantCulture);
     213           if(maxDepth < 1 || maxDepth > 100) throw new ArgumentException(""max depth must lie in range [1 ... 100]"");
     214        } else {
     215           Console.WriteLine(""Unknown switch {0}"", arg); PrintUsage(); Environment.Exit(0);
     216        }
     217      }
     218    }
     219    private void PrintUsage() {
     220      Console.WriteLine(""Find a solution using Monte-Carlo tree search."");
     221      Console.WriteLine();
     222      Console.WriteLine(""Parameters:"");
     223      Console.WriteLine(""\t--maxDepth=<depth>\tSets the maximal depth of sampled trees [Default: 20]"");
     224    }
     225
     226
    226227
    227228    public void Start() {
  • branches/HeuristicLab.Problems.GPDL/CodeGenerator/ProblemCodeGen.cs

    r10426 r10427  
    2727    public static void Main(string[] args) {
    2828      var problem = new ?IDENT?Problem();
    29       var solver = new ?IDENT?RandomSearchSolver(problem, args);
     29      var solver = new ?IDENT?MonteCarloTreeSearchSolver(problem, args);
    3030      solver.Start();
    3131    }
  • branches/HeuristicLab.Problems.GPDL/CodeGenerator/RandomSearchCodeGen.cs

    r10426 r10427  
    1919    private readonly Random random;
    2020
    21     private Tree SampleTree(int maxDepth, out int steps, out int depth) {
     21    public Tree SampleTree(int maxDepth, out int steps, out int depth) {
    2222      steps = 0;
    2323      depth = 0;
     
    2626    }
    2727
    28     private Tree SampleTree(int state, int maxDepth, ref int steps, ref int curDepth, ref int depth) {
     28    public Tree SampleTree(int state, int maxDepth, ref int steps, ref int curDepth, ref int depth) {
    2929      curDepth += 1;
    3030      Debug.Assert(maxDepth > 0);
Note: See TracChangeset for help on using the changeset viewer.