Changeset 13657


Ignore:
Timestamp:
03/07/16 12:50:15 (4 years ago)
Author:
gkronber
Message:

#2581: update quality estimate in parent nodes when a branch is completely explored. added ucbtuned selection

Location:
trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/MctsSymbolicRegressionStatic.cs

    r13652 r13657  
    262262        alglib.minlmstate state;
    263263        alglib.minlmreport rep = null;
    264         alglib.minlmcreatevj(y.Length, optConsts, out state);       
     264        alglib.minlmcreatevj(y.Length, optConsts, out state);
    265265        alglib.minlmsetcond(state, 0.0, epsF, 0.0, nIters);
    266266        //alglib.minlmsetgradientcheck(state, 0.000001);
     
    333333      double c = mctsState.c;
    334334      double q = 0;
     335      double deltaQ = 0;
     336      double deltaSqrQ = 0;
     337      int deltaVisits = 0;
    335338      bool success = false;
    336339      do {
    337340        automaton.Reset();
    338         success = TryTreeSearchRec(rand, tree, c, automaton, eval, bestChildrenBuf, out q);
     341        success = TryTreeSearchRec(rand, tree, c, automaton, eval, bestChildrenBuf, out q, out deltaQ, out deltaSqrQ, out deltaVisits);
    339342        mctsState.totalRollouts++;
    340343      } while (!success && !tree.done);
     
    347350    // see ConstraintHandler.cs for more info
    348351    private static bool TryTreeSearchRec(IRandom rand, Tree tree, double c, Automaton automaton, Func<byte[], int, double> eval, List<Tree> bestChildrenBuf,
    349       out double q) {
     352      out double q, // quality of the expression
     353      out double deltaQ, out double deltaSqrQ, out int deltaVisits // the updates for total quality and number of visits (can be negative if branches have been fully explored)
     354      ) {
    350355      Tree selectedChild = null;
    351356      Contract.Assert(tree.state == automaton.CurrentState);
     
    360365          automaton.GetCode(out code, out nParams);
    361366          q = eval(code, nParams);
    362           tree.visits++;
     367          tree.visits += 1;
    363368          tree.sumQuality += q;
     369          tree.sumSqrQuality += q * q;
     370          deltaQ = q;
     371          deltaVisits = 1;
     372          deltaSqrQ = q * q;
    364373          return true; // we reached a final state
    365374        } else {
     
    371380            // stuck in a dead end (no final state and no allowed follow states)
    372381            q = 0;
     382            deltaQ = 0;
     383            deltaSqrQ = 0.0;
     384            deltaVisits = 0;
    373385            tree.done = true;
    374386            tree.children = null;
     
    380392            tree.children[i] = new Tree() { children = null, done = false, state = possibleFollowStates[i], visits = 0 };
    381393
    382           selectedChild = SelectFinalOrRandom(automaton, tree, rand);
     394          selectedChild = nFs > 1 ? SelectFinalOrRandom(automaton, tree, rand) : tree.children[0];
    383395        }
    384396      } else {
    385397        // tree.children != null
    386398        // UCT selection within tree
    387         selectedChild = SelectUct(tree, rand, c, bestChildrenBuf);
     399        selectedChild = tree.children.Length > 1 ? SelectUctTuned(tree, rand, c, bestChildrenBuf) : tree.children[0];
    388400      }
    389401      // make selected step and recurse
    390402      automaton.Goto(selectedChild.state);
    391       var success = TryTreeSearchRec(rand, selectedChild, c, automaton, eval, bestChildrenBuf, out q);
     403      var success = TryTreeSearchRec(rand, selectedChild, c, automaton, eval, bestChildrenBuf,
     404        out q, out deltaQ, out deltaSqrQ, out deltaVisits);
    392405      if (success) {
    393406        // only update if successful
    394         tree.sumQuality += q;
    395         tree.visits++;
    396       }
    397 
    398       // tree.done = tree.children.All(ch => ch.done);
    399       tree.done = true; for (int i = 0; i < tree.children.Length && tree.done; i++) tree.done = tree.children[i].done;
    400       if (tree.done) {
     407        tree.sumQuality += deltaQ;
     408        tree.sumSqrQuality += deltaSqrQ;
     409        tree.visits += deltaVisits;
     410      }
     411
     412      if (tree.children.All(ch => ch.done)) {
     413        tree.done = true;
     414        // update parent nodes to remove information from this branch
     415        if (tree.children.Length > 1) {
     416          deltaQ = -(tree.sumQuality - deltaQ);
     417          deltaSqrQ = -(tree.sumSqrQuality - deltaSqrQ);
     418          deltaVisits = -(tree.visits - deltaVisits);
     419        }
    401420        tree.children = null; // cut off the sub-branch if it has been fully explored
    402         // TODO: update all qualities and visits to remove the information gained from this whole branch
    403421      }
    404422      return success;
     
    437455    }
    438456
     457    private static Tree SelectUctTuned(Tree tree, IRandom rand, double c, List<Tree> bestChildrenBuf) {
     458      // determine total tries of still active children
     459      int totalTries = 0;
     460      bestChildrenBuf.Clear();
     461      for (int i = 0; i < tree.children.Length; i++) {
     462        var ch = tree.children[i];
     463        if (ch.done) continue;
     464        if (ch.visits == 0) bestChildrenBuf.Add(ch);
     465        else totalTries += tree.children[i].visits;
     466      }
     467      // if there are unvisited children select a random child
     468      if (bestChildrenBuf.Any()) {
     469        return bestChildrenBuf[rand.Next(bestChildrenBuf.Count)];
     470      }
     471      Contract.Assert(totalTries > 0); // the tree is not done yet so there is at least on child that is not done
     472      double logTotalTries = Math.Log(totalTries);
     473      var bestQ = double.NegativeInfinity;
     474      for (int i = 0; i < tree.children.Length; i++) {
     475        var ch = tree.children[i];
     476        if (ch.done) continue;
     477        var varianceBound = ch.QualityVariance + Math.Sqrt(2.0 * logTotalTries / ch.visits);
     478        if (varianceBound > 0.25) varianceBound = 0.25;
     479        var childQ = ch.AverageQuality + c * Math.Sqrt(logTotalTries / ch.visits * varianceBound);
     480        if (childQ > bestQ) {
     481          bestChildrenBuf.Clear();
     482          bestChildrenBuf.Add(ch);
     483          bestQ = childQ;
     484        } else if (childQ >= bestQ) {
     485          bestChildrenBuf.Add(ch);
     486        }
     487      }
     488      return bestChildrenBuf[rand.Next(bestChildrenBuf.Count)];
     489    }
     490
    439491    private static Tree SelectFinalOrRandom(Automaton automaton, Tree tree, IRandom rand) {
    440492      // if one of the new children leads to a final state then go there
  • trunk/sources/HeuristicLab.Algorithms.DataAnalysis/3.4/MctsSymbolicRegression/Tree.cs

    r13650 r13657  
    2626    public int visits;
    2727    public double sumQuality;
     28    public double sumSqrQuality; // for variance
    2829    public double AverageQuality { get { return sumQuality / (double)visits; } }
     30    public double QualityVariance { get { return sumSqrQuality / (double)visits - AverageQuality * AverageQuality; } }
    2931    public bool done;
    3032    public Tree[] children;
Note: See TracChangeset for help on using the changeset viewer.