Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
01/28/14 19:23:46 (10 years ago)
Author:
gkronber
Message:

#2026 implemented prevention of resampling of known nodes.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Problems.GPDL/CodeGenerator/MonteCarloTreeSearchCodeGen.cs

    r10411 r10415  
    1515    public double sumQuality = 0.0;
    1616    public double bestQuality = double.NegativeInfinity;
    17     public bool ready;
     17    public bool done;
    1818    public SearchTreeNode[] children;
    19 
     19    public double[] Ucb {
     20      get {
     21        return (from c in children
     22                select ?IDENT?Solver.UCB(this, c)
     23               ).ToArray();
     24      }
     25    }
    2026    public SearchTreeNode() {
    2127    }
     
    3743
    3844    private void SampleTree(SearchTreeNode searchTree, Stack<Tuple<Tree, int, int>> extensionPoints) {
    39       const int RANDOM_TRIES = 100;
    40       if(extensionPoints.Count == 0) return; // nothing to do
     45      const int RANDOM_TRIES = 1000;
     46      if(extensionPoints.Count == 0) {
     47        searchTree.done = true;
     48        return; // nothing to do
     49      }
    4150      var extensionPoint = extensionPoints.Pop();
    4251      Tree parent = extensionPoint.Item1;
     
    4453      int subtreeIdx = extensionPoint.Item3;
    4554      Tree t = null;
    46      
    4755      if(searchTree.tries < RANDOM_TRIES || Grammar.subtreeCount[state] == 0) {
    4856        t = SampleTreeRandom(state);
     
    5361            searchTree.children = new SearchTreeNode[] { new SearchTreeNode() } ;
    5462          SampleTree(searchTree.children[0], extensionPoints);
     63          if(searchTree.children[0].done) searchTree.done = true;
    5564        } else {
    5665          // fill up all remaining slots randomly
     
    8190            extensionPoints.Push(Tuple.Create(t, Grammar.transition[state][i], i));
    8291          }
    83           SampleTree(searchTree, extensionPoints);
    84         }
    85       }
     92          SampleTree(searchTree, extensionPoints);         
     93        }
     94      }     
    8695      Debug.Assert(parent.subtrees[subtreeIdx] == null);
    8796      parent.subtrees[subtreeIdx] = t;
     
    95104        return altIdx;
    96105      } else {
     106        altIdx = Array.FindIndex(searchTree.children, (e) => !e.done && e.tries < 1000);
     107        if(altIdx >= 0) return altIdx;
    97108        // select the least sampled alternative
     109        //altIdx = 0;
     110        //int minSamples = searchTree.children[altIdx].tries;
     111        //for(int idx = 1; idx < searchTree.children.Length; idx++) {
     112        //  if(!searchTree.children[idx].done && searchTree.children[idx].tries < minSamples) {
     113        //    minSamples = searchTree.children[idx].tries;
     114        //    altIdx = idx;
     115        //  }
     116        //}
     117        // select the alternative with the largest average
    98118        altIdx = 0;
    99         int minSamples = searchTree.children[altIdx].tries;
     119        double bestAverage = UCB(searchTree, searchTree.children[altIdx]);
    100120        for(int idx = 1; idx < searchTree.children.Length; idx++) {
    101           if(searchTree.children[idx].tries < minSamples) {
    102             minSamples = searchTree.children[idx].tries;
     121          if (!searchTree.children[idx].done && UCB(searchTree, searchTree.children[idx]) > UCB(searchTree, searchTree.children[altIdx])) {
    103122            altIdx = idx;
    104123          }
    105124        }
    106         // select the alternative with the largest average
    107         // altIdx = 0;
    108         // double bestAverage = UCB(searchTree, searchTree.children[altIdx]);
    109         // for(int idx = 1; idx < searchTree.children.Length; idx++) {
    110         //   if (UCB(searchTree, searchTree.children[idx]) > UCB(searchTree, searchTree.children[altIdx])) {
    111         //     altIdx = idx;
    112         //   }
    113         // }
     125
     126        searchTree.done = searchTree.children.All(c=>c.done);
    114127        return altIdx;
    115128      }
    116129    }
    117130
    118     private double UCB(SearchTreeNode parent, SearchTreeNode n) {
     131    public static double UCB(SearchTreeNode parent, SearchTreeNode n) {
    119132      Debug.Assert(parent.tries >= n.tries);
    120133      Debug.Assert(n.tries > 0);
    121       return n.sumQuality / n.tries + Math.Sqrt((2 * Math.Log(parent.tries)) / n.tries ); // constant is dependent fitness function values
     134      return n.sumQuality / n.tries + Math.Sqrt((10 * Math.Log(parent.tries)) / n.tries ); // constant is dependent fitness function values
    122135    }
    123136
     
    144157        if(t.subtrees != null) {
    145158          Debug.Assert(t.subtrees.Length == 1);
    146           if(searchTree.children!=null) {
     159          if(searchTree.children != null) {
    147160            trees.Push(t.subtrees[0]);
    148161            UpdateSearchTree(searchTree.children[t.altIdx], trees, quality);
    149162          }
    150163        } else {
    151           if(searchTree.children!=null) {
     164          if(searchTree.children != null) {
    152165            Debug.Assert(searchTree.children.Length == 1);
    153166            UpdateSearchTree(searchTree.children[0], trees, quality);
     
    225238      var sw = new System.Diagnostics.Stopwatch();
    226239      sw.Start();
    227       while (true) {
     240      while (!searchTree.done) {
    228241
    229242        int steps, depth;
     
    297310            GenerateReturnStatement(terminalAltIndexes, sb);
    298311            sb.Append("} else {");
    299             GenerateReturnStatement(nonTerminalAltIndexes, sb);
     312            GenerateReturnStatement(nonTerminalAltIndexes.Concat(terminalAltIndexes), sb);
    300313            sb.Append("}").EndBlock();
    301314          } else {
     
    317330          foreach (var constr in terminal.Constraints) {
    318331            if (constr.Type == ConstraintNodeType.Set) {
    319               sb.Append("{").BeginBlock();
    320               sb.AppendFormat("var elements = problem.GetAllowed{0}_{1}().ToArray();", terminal.Ident, constr.Ident).AppendLine();
    321               sb.AppendFormat("t.{0} = elements[random.Next(elements.Length)]; ", constr.Ident).EndBlock();
    322               sb.AppendLine("}");
     332              throw new NotImplementedException("Support for terminal symbols with attributes is not yet implemented.");
     333              // sb.Append("{").BeginBlock();
     334              // sb.AppendFormat("var elements = problem.GetAllowed{0}_{1}().ToArray();", terminal.Ident, constr.Ident).AppendLine();
     335              // sb.AppendFormat("t.{0} = elements[random.Next(elements.Length)]; ", constr.Ident).EndBlock();
     336              // sb.AppendLine("}");
    323337            } else {
    324338              throw new NotSupportedException("The MTCS solver does not support RANGE constraints.");
Note: See TracChangeset for help on using the changeset viewer.