Ignore:
Timestamp:
01/07/15 09:21:46 (5 years ago)
Author:
gkronber
Message:

#2283: refactoring and bug fixes

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCTPolicy.cs

    r11730 r11732  
    88namespace HeuristicLab.Algorithms.Bandits {
    99  /* Kocsis et al. Bandit based Monte-Carlo Planning */
    10   public class UCTPolicy : BanditPolicy {
    11     private readonly int[] tries;
    12     private readonly double[] sumReward;
    13     private int totalTries = 0;
     10  public class UCTPolicy : IPolicy {
    1411    private readonly double c;
    1512
    16     public UCTPolicy(int numActions, double c = 1.0)
    17       : base(numActions) {
    18       this.tries = new int[numActions];
    19       this.sumReward = new double[numActions];
     13    public UCTPolicy(double c = 1.0) {
    2014      this.c = c;
    2115    }
    2216
    23     public override int SelectAction() {
     17
     18    public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) {
     19      var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>().ToArray(); // TODO: performance
    2420      int bestAction = -1;
    2521      double bestQ = double.NegativeInfinity;
    26       foreach (var a in Actions) {
    27         if (tries[a] == 0) return a;
    28         var q = sumReward[a] / tries[a] + 2 * c * Math.Sqrt(Math.Log(totalTries) / tries[a]);
     22      int totalTries = myActionInfos.Where(a => !a.Disabled).Sum(a => a.Tries);
     23
     24      for (int a = 0; a < myActionInfos.Length; a++) {
     25        if (myActionInfos[a].Disabled) continue;
     26        if (myActionInfos[a].Tries == 0) return a;
     27        var q = myActionInfos[a].SumReward / myActionInfos[a].Tries + 2 * c * Math.Sqrt(Math.Log(totalTries) / myActionInfos[a].Tries);
    2928        if (q > bestQ) {
    3029          bestQ = q;
     
    3231        }
    3332      }
     33      Debug.Assert(bestAction > -1);
    3434      return bestAction;
    3535    }
    36     public override void UpdateReward(int action, double reward) {
    37       Debug.Assert(Actions.Contains(action));
    38       totalTries++;
    39       tries[action]++;
    40       sumReward[action] += reward;
     36
     37    public IPolicyActionInfo CreateActionInfo() {
     38      return new DefaultPolicyActionInfo();
    4139    }
    4240
    43     public override void DisableAction(int action) {
    44       base.DisableAction(action);
    45       totalTries -= tries[action];
    46       tries[action] = -1;
    47       sumReward[action] = 0;
    48     }
    49 
    50     public override void Reset() {
    51       base.Reset();
    52       totalTries = 0;
    53       Array.Clear(tries, 0, tries.Length);
    54       Array.Clear(sumReward, 0, sumReward.Length);
    55     }
    56     public override void PrintStats() {
    57       for (int i = 0; i < sumReward.Length; i++) {
    58         if (tries[i] >= 0) {
    59           Console.Write("{0,5:F2}", sumReward[i] / tries[i]);
    60         } else {
    61           Console.Write("{0,5}", "");
    62         }
    63       }
    64       Console.WriteLine();
    65     }
    6641    public override string ToString() {
    6742      return string.Format("UCTPolicy({0:F2})", c);
Note: See TracChangeset for help on using the changeset viewer.