Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
01/20/15 20:25:00 (10 years ago)
Author:
gkronber
Message:

#2283: separated value-states from done-states in GenericGrammarPolicy and removed disabling of actions from bandit policies

Location:
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits
Files:
2 added
17 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/ActiveLearningPolicy.cs

    r11792 r11806  
    1111    public int SelectAction(Random random, IEnumerable<IBanditPolicyActionInfo> actionInfos) {
    1212      var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>();
    13       int totalTries = myActionInfos.Where(a => !a.Disabled).Sum(a => a.Tries);
     13      int totalTries = myActionInfos.Sum(a => a.Tries);
    1414      const double delta = 0.1;
    15       int k = myActionInfos.Where(a => !a.Disabled).Count();
     15      int k = myActionInfos.Count();
    1616      var bestActions = new List<int>();
    1717      var us = new List<double>();
     
    2020      foreach (var aInfo in myActionInfos) {
    2121        aIdx++;
    22         if (aInfo.Disabled) continue;
    2322        double q;
    2423        double u;
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/BoltzmannExplorationPolicy.cs

    r11799 r11806  
    3737
    3838      var w = from aInfo in myActionInfos
    39               select aInfo.Disabled
    40                 ? 0.0
    41                 : Math.Exp(beta * valueFunction(aInfo));
     39              select Math.Exp(beta * valueFunction(aInfo));
    4240
    4341      var bestAction = Enumerable.Range(0, myActionInfos.Count()).SampleProportional(random, w);
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/ChernoffIntervalEstimationPolicy.cs

    r11792 r11806  
    2121      // select best
    2222      var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>();
    23       int k = myActionInfos.Count(a => !a.Disabled);
    24       int totalTries = myActionInfos.Where(a => !a.Disabled).Sum(a => a.Tries);
     23      int k = myActionInfos.Count();
     24      int totalTries = myActionInfos.Sum(a => a.Tries);
    2525      double bestQ = double.NegativeInfinity;
    2626      var bestActions = new List<int>();
     
    2828      foreach (var aInfo in myActionInfos) {
    2929        aIdx++;
    30         if (aInfo.Disabled) continue;
    3130        double q;
    3231        if (aInfo.Tries == 0) {
     
    4645          bestActions.Clear();
    4746          bestActions.Add(aIdx);
    48         } else if (q == bestQ) {
     47        } else if (q.IsAlmost(bestQ)) {
    4948          bestActions.Add(aIdx);
    5049        }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/DefaultPolicyActionInfo.cs

    r11747 r11806  
    99  // stores information that is relevant for most of the policies
    1010  public class DefaultPolicyActionInfo : IBanditPolicyActionInfo {
    11     private double knownValue;
    12     public bool Disabled { get { return Tries == -1; } }
    1311    public double SumReward { get; private set; }
    1412    public int Tries { get; private set; }
     
    1614    public double Value {
    1715      get {
    18         if (Disabled) return knownValue;
    19         else
    2016          return Tries > 0 ? SumReward / Tries : 0.0;
    2117      }
     
    2622
    2723    public void UpdateReward(double reward) {
    28       Debug.Assert(!Disabled);
    29 
    3024      Tries++;
    3125      SumReward += reward;
    3226      MaxReward = Math.Max(MaxReward, reward);
    3327    }
    34     public void Disable(double reward) {
    35       this.Tries = -1;
    36       this.SumReward = 0.0;
    37       this.knownValue = reward;
    38     }
     28
    3929    public void Reset() {
    4030      SumReward = 0.0;
    4131      Tries = 0;
    4232      MaxReward = 0.0;
    43       knownValue = 0.0;
    44     }
    45     public void PrintStats() {
    46       Console.WriteLine("avg reward {0,5:F2} disabled {1}", SumReward / Tries, Disabled);
    4733    }
    4834
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/EpsGreedyPolicy.cs

    r11793 r11806  
    3535        foreach (var aInfo in myActionInfos) {
    3636          aIdx++;
    37           if (aInfo.Disabled) continue;
    3837
    3938          var q = valueFunction(aInfo);
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/GenericThompsonSamplingPolicy.cs

    r11799 r11806  
    2222      foreach (var aInfo in myActionInfos) {
    2323        aIdx++;
    24         if (aInfo.Disabled) continue;
    25         //if (aInfo.Tries == 0) return aIdx;
    2624        var q = aInfo.SampleExpectedReward(random);
    2725        if (q > bestQ) {
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/MeanAndVariancePolicyActionInfo.cs

    r11747 r11806  
    3939      estimator.Reset();
    4040    }
    41 
    42     public void PrintStats() {
    43       Console.WriteLine("avg reward {0,5:F2} disabled {1}", AvgReward, Disabled);
    44     }
    4541  }
    4642}
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/ModelPolicyActionInfo.cs

    r11747 r11806  
    1010  public class ModelPolicyActionInfo : IBanditPolicyActionInfo {
    1111    private readonly IModel model;
    12     private double knownValue;
    13     public bool Disabled { get { return Tries == -1; } }
    1412    public double Value {
    1513      get {
    16         if (Disabled) return knownValue;
    17         else
    18           return model.SampleExpectedReward(new Random());
     14        return model.SampleExpectedReward(new Random());
    1915      }
    2016    }
     
    2622
    2723    public void UpdateReward(double reward) {
    28       Debug.Assert(!Disabled);
    2924      Tries++;
    3025      model.Update(reward);
     
    3530    }
    3631
    37     public void Disable(double reward) {
    38       this.Tries = -1;
    39       this.knownValue = reward;
    40     }
    41 
    4232    public void Reset() {
    4333      Tries = 0;
    44       knownValue = 0.0;
    4534      model.Reset();
    4635    }
    4736
    48     public void PrintStats() {
    49       model.PrintStats();
    50     }
    51 
    5237    public override string ToString() {
    53       return string.Format("disabled {0} model {1}", Disabled, model);
     38      return string.Format("model {1}", model);
    5439    }
    5540  }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/RandomPolicy.cs

    r11742 r11806  
    1717      return actionInfos
    1818        .Select((aInfo, idx) => Tuple.Create(aInfo, idx))
    19         .Where(p => !p.Item1.Disabled)
    2019        .SelectRandom(random).Item2;
    2120    }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/ThresholdAscentPolicy.cs

    r11792 r11806  
    2828      public int Tries { get; private set; }
    2929      public int thresholdBin = 1;
    30       private double knownValue;
    3130
    3231      public double Value {
    3332        get {
    34           if (Disabled) return knownValue;
    3533          if (Tries == 0.0) return 0.0;
    3634          return rewardHistogram[thresholdBin] / (double)Tries;
    3735        }
    3836      }
    39 
    40       public bool Disabled { get { return Tries == -1; } }
    4137
    4238      public void UpdateReward(double reward) {
     
    4642      }
    4743
    48       public void Disable(double reward) {
    49         this.knownValue = reward;
    50         Tries = -1;
    51       }
    52 
    5344      public void Reset() {
    5445        Tries = 0;
    5546        thresholdBin = 1;
    56         this.knownValue = 0.0;
    5747        Array.Clear(rewardHistogram, 0, rewardHistogram.Length);
    58       }
    59 
    60       public void PrintStats() {
    61         if (Tries >= 0) {
    62           Console.Write("{0,6}", Tries);
    63         } else {
    64           Console.Write("{0,6}", "");
    65         }
    6648      }
    6749
     
    10183      var bestActions = new List<int>();
    10284      double bestQ = double.NegativeInfinity;
    103       int k = myActionInfos.Count(a => !a.Disabled);
    104       var totalTries = myActionInfos.Where(a => !a.Disabled).Sum(a => a.Tries);
     85      int k = myActionInfos.Count();
     86      var totalTries = myActionInfos.Sum(a => a.Tries);
    10587      int aIdx = -1;
    10688      foreach (var aInfo in myActionInfos) {
    10789        aIdx++;
    108         if (aInfo.Disabled) continue;
    10990        double q;
    11091        if (aInfo.Tries == 0) {
     
    11899          bestActions.Clear();
    119100          bestActions.Add(aIdx);
    120         } else if (q == bestQ) {
     101        } else if (q.IsAlmost(bestQ)) {
    121102          bestActions.Add(aIdx);
    122103        }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/UCB1Policy.cs

    r11747 r11806  
    1313      var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>();
    1414      double bestQ = double.NegativeInfinity;
    15       int totalTries = myActionInfos.Where(a => !a.Disabled).Sum(a => a.Tries);
     15      int totalTries = myActionInfos.Sum(a => a.Tries);
    1616
    1717      var bestActions = new List<int>();
     
    1919      foreach (var aInfo in myActionInfos) {
    2020        aIdx++;
    21         if (aInfo.Disabled) continue;
    2221        double q;
    2322        if (aInfo.Tries == 0) {
     
    3130          bestActions.Clear();
    3231          bestActions.Add(aIdx);
    33         } else if (q == bestQ) {
     32        } else if (q.IsAlmost(bestQ)) {
    3433          bestActions.Add(aIdx);
    3534        }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/UCB1TunedPolicy.cs

    r11792 r11806  
    3737          bestActions.Clear();
    3838          bestActions.Add(aIdx);
    39         } else if (q == bestQ) {
     39        } else if (q.IsAlmost(bestQ)) {
    4040          bestActions.Add(aIdx);
    4141        }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/UCBNormalPolicy.cs

    r11792 r11806  
    3333          bestActions.Clear();
    3434          bestActions.Add(aIdx);
    35         } else if (q == bestQ) {
     35        } else if (q.IsAlmost(bestQ)) {
    3636          bestActions.Add(aIdx);
    3737        }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/UCTPolicy.cs

    r11747 r11806  
    2121      int bestAction = -1;
    2222      double bestQ = double.NegativeInfinity;
    23       int totalTries = myActionInfos.Where(a => !a.Disabled).Sum(a => a.Tries);
     23      int totalTries = myActionInfos.Sum(a => a.Tries);
    2424
    2525      int aIdx = -1;
     
    2727      foreach (var aInfo in myActionInfos) {
    2828        aIdx++;
    29         if (aInfo.Disabled) continue;
    3029        double q;
    3130        if (aInfo.Tries == 0) {
     
    3837          bestQ = q;
    3938          bestActions.Add(aIdx);
    40         }
    41         if (q == bestQ) {
     39        } else if (q.IsAlmost(bestQ)) {
    4240          bestActions.Add(aIdx);
    4341        }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/GrammarPolicies/GenericGrammarPolicy.cs

    r11799 r11806  
    11using System;
    22using System.Collections.Generic;
     3using System.Diagnostics;
    34using System.Linq;
    45using System.Text;
     
    910namespace HeuristicLab.Algorithms.Bandits.GrammarPolicies {
    1011  // this represents grammar policies that use one of the available bandit policies for state selection
    11   public class GenericGrammarPolicy : IGrammarPolicy {
    12     protected Dictionary<string, IBanditPolicyActionInfo> stateInfo; // stores the necessary information for bandit policies for each state
    13     private readonly bool useCanonicalState;
     12  // any bandit policy can be used to select actions for states
     13  // a separate datastructure is used to store visited states and to prevent revisiting of states
     14  public sealed class GenericGrammarPolicy : IGrammarPolicy {
     15    private Dictionary<string, IBanditPolicyActionInfo> stateInfo; // stores the necessary information for bandit policies for each state (=canonical phrase)
     16    private HashSet<string> done;
     17    private readonly bool useCanonicalPhrases;
    1418    private readonly IProblem problem;
    1519    private readonly IBanditPolicy banditPolicy;
    1620
    17     public GenericGrammarPolicy(IProblem problem, IBanditPolicy banditPolicy, bool useCanonicalState = false) {
    18       this.useCanonicalState = useCanonicalState;
     21    public GenericGrammarPolicy(IProblem problem, IBanditPolicy banditPolicy, bool useCanonicalPhrases = false) {
     22      this.useCanonicalPhrases = useCanonicalPhrases;
    1923      this.problem = problem;
    2024      this.banditPolicy = banditPolicy;
    2125      this.stateInfo = new Dictionary<string, IBanditPolicyActionInfo>();
     26      this.done = new HashSet<string>();
    2227    }
     28
     29    private IBanditPolicyActionInfo[] activeAfterStates; // don't allocate each time
     30    private int[] actionIndexMap; // don't allocate each time
    2331
    2432    public bool TrySelect(Random random, string curState, IEnumerable<string> afterStates, out int selectedStateIdx) {
    2533      // fail if all states are done (corresponding state infos are disabled)
    26       if (afterStates.All(s => GetStateInfo(s).Disabled)) {
     34      if (afterStates.All(s => Done(s))) {
    2735        // fail because all follow states have already been visited => also disable the current state (if we can be sure that it has been fully explored)
     36        MarkAsDone(curState);
    2837
    29         GetStateInfo(curState).Disable(afterStates.Select(afterState => GetStateInfo(afterState).Value).Max());
    3038        selectedStateIdx = -1;
    3139        return false;
    3240      }
    3341
    34       selectedStateIdx = banditPolicy.SelectAction(random, afterStates.Select(s => GetStateInfo(s)));
     42      // determine active actions (not done yet) and create an array to map the selected index back to original actions
     43      if (activeAfterStates == null || activeAfterStates.Length < afterStates.Count()) {
     44        activeAfterStates = new IBanditPolicyActionInfo[afterStates.Count()];
     45        actionIndexMap = new int[afterStates.Count()];
     46      }
     47      var idx = 0; int originalIdx = 0;
     48      foreach (var afterState in afterStates) {
     49        if (!Done(afterState)) {
     50          activeAfterStates[idx] = GetStateInfo(afterState);
     51          actionIndexMap[idx] = originalIdx;
     52          idx++;
     53        }
     54        originalIdx++;
     55      }
     56
     57      selectedStateIdx = actionIndexMap[banditPolicy.SelectAction(random, activeAfterStates.Take(idx))];
    3558
    3659      return true;
    3760    }
     61
     62
    3863
    3964    private IBanditPolicyActionInfo GetStateInfo(string state) {
     
    4772    }
    4873
    49     public virtual void UpdateReward(IEnumerable<string> stateTrajectory, double reward) {
     74    public void UpdateReward(IEnumerable<string> stateTrajectory, double reward) {
    5075      foreach (var state in stateTrajectory) {
    5176        GetStateInfo(state).UpdateReward(reward);
     
    5378        // only the last state can be terminal
    5479        if (problem.Grammar.IsTerminal(state)) {
    55           GetStateInfo(state).Disable(reward);
     80          MarkAsDone(state);
    5681        }
    5782      }
    5883    }
    5984
    60     public virtual void Reset() {
     85
     86    public void Reset() {
    6187      stateInfo.Clear();
     88      done.Clear();
    6289    }
    6390
     
    74101    }
    75102
    76     protected string CanonicalState(string state) {
    77       if (useCanonicalState) {
     103    // the canonical states for the value function (banditInfos) and the done set must be distinguished
     104    // sequences of different length could have the same canonical representation and can have the same value (banditInfo)
     105    // however, if the canonical representation of a state is shorter than we must not mark the canonical state as done when all possible derivations from the initial state have been explored
     106    // eg. in the ant problem the canonical representation for ...lllA is ...rA
     107    // even though all possible derivations (of limited length) of lllA have been visited we must not mark the state rA as done
     108    private void MarkAsDone(string state) {
     109      var s = CanonicalState(state);
     110      // when the lengths of the canonical string and the original string are the same we also disable the actions
     111      // always disable terminals
     112      Debug.Assert(s.Length <= state.Length);
     113      if (s.Length == state.Length || problem.Grammar.IsTerminal(state)) {
     114        Debug.Assert(!done.Contains(s));
     115        done.Add(s);
     116      } else {
     117        // for non-terminals where the canonical string is shorter than the original string we can only disable the canonical representation for all states in the same level
     118        Debug.Assert(!done.Contains(s + state.Length));
     119        done.Add(s + state.Length); // encode the original length of the state, states in the same level of the tree are treated as equivalent
     120      }
     121    }
     122
     123    // symmetric to MarkDone
     124    private bool Done(string state) {
     125      var s = CanonicalState(state);
     126      if (s.Length == state.Length || problem.Grammar.IsTerminal(state)) {
     127        return done.Contains(s);
     128      } else {
     129        // it is not necessary to visit states if the canonical representation has already been fully explored
     130        if (done.Contains(s)) return true;
     131        if (done.Contains(s + state.Length)) return true;
     132        for (int i = 1; i < state.Length; i++) {
     133          if (done.Contains(s + i)) return true;
     134        }
     135        return false;
     136      }
     137    }
     138
     139    private string CanonicalState(string state) {
     140      if (useCanonicalPhrases) {
    78141        return problem.CanonicalRepresentation(state);
    79142      } else
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/HeuristicLab.Algorithms.Bandits.csproj

    r11793 r11806  
    4949    <Compile Include="BanditPolicies\ChernoffIntervalEstimationPolicy.cs" />
    5050    <Compile Include="BanditPolicies\ActiveLearningPolicy.cs" />
     51    <Compile Include="BanditPolicies\ModifiedUCTPolicy.cs" />
    5152    <Compile Include="BanditPolicies\DefaultPolicyActionInfo.cs" />
    5253    <Compile Include="BanditPolicies\EpsGreedyPolicy.cs" />
     
    6667    <Compile Include="Bandits\IBandit.cs" />
    6768    <Compile Include="Bandits\TruncatedNormalBandit.cs" />
     69    <Compile Include="GrammarPolicies\GenericTDPolicy.cs" />
    6870    <Compile Include="GrammarPolicies\GenericGrammarPolicy.cs">
    6971      <SubType>Code</SubType>
     
    7274      <SubType>Code</SubType>
    7375    </Compile>
    74     <Compile Include="GrammarPolicies\TDPolicy.cs" />
    7576    <Compile Include="GrammarPolicies\GrammarPolicy.cs" />
    7677    <Compile Include="GrammarPolicies\IGrammarPolicy.cs" />
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/IBanditPolicyActionInfo.cs

    r11770 r11806  
    11namespace HeuristicLab.Algorithms.Bandits {
    22  public interface IBanditPolicyActionInfo {
    3     bool Disabled { get; }
     3    //bool Disabled { get; }
    44    double Value { get; }
    55    int Tries { get; }
    66    void UpdateReward(double reward);
    7     void Disable(double reward);
     7    //void Disable(double reward);
    88    // reset causes the state of the action to be reinitialized (as after constructor-call)
    99    void Reset();
    10     void PrintStats();
    1110  }
    1211}
Note: See TracChangeset for help on using the changeset viewer.