Ignore:
Timestamp:
01/07/15 09:21:46 (5 years ago)
Author:
gkronber
Message:

#2283: refactoring and bug fixes

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/EpsGreedyPolicy.cs

    r11730 r11732  
    77
    88namespace HeuristicLab.Algorithms.Bandits {
    9   public class EpsGreedyPolicy : BanditPolicy {
    10     private readonly Random random;
     9  public class EpsGreedyPolicy : IPolicy {
    1110    private readonly double eps;
    12     private readonly int[] tries;
    13     private readonly double[] sumReward;
    1411    private readonly RandomPolicy randomPolicy;
    1512
    16     public EpsGreedyPolicy(Random random, int numActions, double eps)
    17       : base(numActions) {
    18       this.random = random;
     13    public EpsGreedyPolicy(double eps) {
    1914      this.eps = eps;
    20       this.randomPolicy = new RandomPolicy(random, numActions);
    21       this.tries = new int[numActions];
    22       this.sumReward = new double[numActions];
     15      this.randomPolicy = new RandomPolicy();
    2316    }
    24 
    25     public override int SelectAction() {
    26       Debug.Assert(Actions.Any());
     17    public int SelectAction(Random random, IEnumerable<IPolicyActionInfo> actionInfos) {
     18      Debug.Assert(actionInfos.Any());
    2719      if (random.NextDouble() > eps) {
    2820        // select best
    29         var bestQ = double.NegativeInfinity;
     21        var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>();
    3022        int bestAction = -1;
    31         foreach (var a in Actions) {
    32           if (tries[a] == 0) return a;
    33           var q = sumReward[a] / tries[a];
    34           if (bestQ < q) {
     23        double bestQ = double.NegativeInfinity;
     24        int aIdx = -1;
     25        foreach (var aInfo in myActionInfos) {
     26
     27          aIdx++;
     28          if (aInfo.Disabled) continue;
     29          if (aInfo.Tries == 0) return aIdx;
     30
     31
     32          var avgReward = aInfo.SumReward / aInfo.Tries;         
     33          //var q = avgReward;
     34          var q = aInfo.MaxReward;
     35          if (q > bestQ) {
    3536            bestQ = q;
    36             bestAction = a;
     37            bestAction = aIdx;
    3738          }
    3839        }
     
    4142      } else {
    4243        // select random
    43         return randomPolicy.SelectAction();
     44        return randomPolicy.SelectAction(random, actionInfos);
    4445      }
    4546    }
    46     public override void UpdateReward(int action, double reward) {
    47       Debug.Assert(Actions.Contains(action));
    4847
    49       randomPolicy.UpdateReward(action, reward); // does nothing
    50       tries[action]++;
    51       sumReward[action] += reward;
     48    public IPolicyActionInfo CreateActionInfo() {
     49      return new DefaultPolicyActionInfo();
    5250    }
    5351
    54     public override void DisableAction(int action) {
    55       base.DisableAction(action);
    56       randomPolicy.DisableAction(action);
    57       sumReward[action] = 0;
    58       tries[action] = -1;
    59     }
    6052
    61     public override void Reset() {
    62       base.Reset();
    63       randomPolicy.Reset();
    64       Array.Clear(tries, 0, tries.Length);
    65       Array.Clear(sumReward, 0, sumReward.Length);
    66     }
    67     public override void PrintStats() {
    68       for (int i = 0; i < sumReward.Length; i++) {
    69         if (tries[i] >= 0) {
    70           Console.Write(" {0,5:F2} {1}", sumReward[i] / tries[i], tries[i]);
    71         } else {
    72           Console.Write("-", "");
    73         }
    74       }
    75       Console.WriteLine();
    76     }
    7753    public override string ToString() {
    7854      return string.Format("EpsGreedyPolicy({0:F2})", eps);
Note: See TracChangeset for help on using the changeset viewer.