Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
04/08/15 10:09:47 (9 years ago)
Author:
gkronber
Message:

#2283: worked on Q-Learning for poly10

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Problems.GrammaticalOptimization-gkr/HeuristicLab.Algorithms.GrammaticalOptimization/SequentialDecisionPolicies/GenericPolicy.cs

    r12291 r12294  
    1111  // resampling is not prevented
    1212  public sealed class GenericPolicy : IGrammarPolicy {
    13     private Dictionary<string, IBanditPolicyActionInfo> stateInfo; // stores the necessary information for bandit policies for each state
     13    private Dictionary<string, double> Q; // stores the necessary information for bandit policies for each state
     14    private Dictionary<string, int> T; // tries;
     15    private Dictionary<string, List<string>> followStates;
    1416    private readonly IProblem problem;
    15     private readonly IBanditPolicy banditPolicy;
    1617    private readonly HashSet<string> done; // contains all visited chains
    1718
    18     public GenericPolicy(IProblem problem, IBanditPolicy banditPolicy) {
     19    public GenericPolicy(IProblem problem) {
    1920      this.problem = problem;
    20       this.banditPolicy = banditPolicy;
    21       this.stateInfo = new Dictionary<string, IBanditPolicyActionInfo>();
     21      this.Q = new Dictionary<string, double>();
     22      this.T = new Dictionary<string, int>();
     23      this.followStates = new Dictionary<string, List<string>>();
    2224      this.done = new HashSet<string>();
    2325    }
    2426
    25     private IBanditPolicyActionInfo[] activeAfterStates; // don't allocate each time
     27    private double[] activeAfterStates; // don't allocate each time
    2628    private int[] actionIndexMap; // don't allocate each time
    2729
     
    3739
    3840      if (activeAfterStates == null || activeAfterStates.Length < afterStates.Count()) {
    39         activeAfterStates = new IBanditPolicyActionInfo[afterStates.Count()];
     41        activeAfterStates = new double[afterStates.Count()];
    4042        actionIndexMap = new int[afterStates.Count()];
     43      }
     44      if (!followStates.ContainsKey(curState)) {
     45        followStates[curState] = new List<string>(afterStates);
    4146      }
    4247      var idx = 0; int originalIdx = 0;
    4348      foreach (var afterState in afterStates) {
    4449        if (!Done(afterState)) {
    45           activeAfterStates[idx] = GetStateInfo(afterState);
     50          activeAfterStates[idx] = GetValue(afterState);
    4651          actionIndexMap[idx] = originalIdx;
    4752          idx++;
     
    5055      }
    5156
    52       selectedStateIdx = actionIndexMap[banditPolicy.SelectAction(random, activeAfterStates.Take(idx))];
     57      //var eps = Math.Max(500.0 / (GetTries(curState) + 1), 0.01);
     58      //var eps = 10.0 / Math.Sqrt(GetTries(curState) + 1);
     59      var eps = 0.2;
     60      selectedStateIdx = actionIndexMap[SelectEpsGreedy(random, activeAfterStates.Take(idx), eps)];
    5361
    5462      return true;
    5563    }
    5664
     65    private int SelectBoltzmann(Random random, IEnumerable<double> qs, double beta = 10) {
     66      // select best
    5767
     68      // try any of the untries actions randomly
     69      // for RoyalSequence it is much better to select the actions in the order of occurrence (all terminal alternatives first)
     70      //if (myActionInfos.Any(aInfo => !aInfo.Disabled && aInfo.Tries == 0)) {
     71      //  return myActionInfos
     72      //  .Select((aInfo, idx) => new { aInfo, idx })
     73      //  .Where(p => !p.aInfo.Disabled)
     74      //  .Where(p => p.aInfo.Tries == 0)
     75      //  .SelectRandom(random).idx;
     76      //}
    5877
    59     private IBanditPolicyActionInfo GetStateInfo(string state) {
    60       var s = CalcState(state);
    61       IBanditPolicyActionInfo info;
    62       if (!stateInfo.TryGetValue(s, out info)) {
    63         info = banditPolicy.CreateActionInfo();
    64         stateInfo[s] = info;
    65       }
    66       return info;
     78      var w = from q in qs
     79              select Math.Exp(beta * q);
     80
     81      var bestAction = Enumerable.Range(0, qs.Count()).SampleProportional(random, w);
     82      Debug.Assert(bestAction >= 0);
     83      return bestAction;
    6784    }
    6885
    69     public void UpdateReward(IEnumerable<string> stateTrajectory, double reward) {
    70       foreach (var state in stateTrajectory.Reverse()) {
    71         GetStateInfo(state).UpdateReward(reward);
     86    private int SelectEpsGreedy(Random random, IEnumerable<double> qs, double eps = 0.2) {
     87      if (random.NextDouble() >= eps) { // eps == 0 should be equivalent to pure exploitation, eps == 1 is pure exploration
     88        // select best
     89        var bestActions = new List<int>();
     90        double bestQ = double.NegativeInfinity;
    7291
    73         // actually only the last state can be terminal
    74         if (problem.Grammar.IsTerminal(state)) {
    75           MarkAsDone(state);
     92        int aIdx = -1;
     93        foreach (var q in qs) {
     94          aIdx++;
     95
     96          if (q > bestQ) {
     97            bestActions.Clear();
     98            bestActions.Add(aIdx);
     99            bestQ = q;
     100          } else if (q.IsAlmost(bestQ)) {
     101            bestActions.Add(aIdx);
     102          }
    76103        }
     104        Debug.Assert(bestActions.Any());
     105        return bestActions.SelectRandom(random);
     106      } else {
     107        // select random
     108        return SelectRandom(random, qs);
    77109      }
    78110    }
    79111
     112    private int SelectRandom(Random random, IEnumerable<double> qs) {
     113      return qs
     114         .Select((aInfo, idx) => Tuple.Create(aInfo, idx))
     115         .SelectRandom(random).Item2;
     116    }
     117
     118
     119    public void UpdateReward(IEnumerable<string> chainTrajectory, double reward) {
     120      const double gamma = 0.95;
     121      const double minAlpha = 0.01;
     122      var reverseChains = chainTrajectory.Reverse();
     123      var terminalChain = reverseChains.First();
     124
     125      var terminalState = CalcState(terminalChain);
     126      T[terminalState] = GetTries(terminalChain) + 1;
     127      double alpha = Math.Max(1.0 / GetTries(terminalChain), minAlpha);
     128      Q[terminalState] = (1 - alpha) * GetValue(terminalChain) + alpha * reward;
     129
     130      foreach (var chain in reverseChains.Skip(1)) {
     131
     132        var maxNextQ = followStates[chain]
     133          //.Where(s=>!Done(s))
     134          .Select(GetValue).Max();
     135        T[CalcState(chain)] = GetTries(chain) + 1;
     136
     137        alpha = Math.Max(1.0 / GetTries(chain), minAlpha);
     138        Q[CalcState(chain)] = (1 - alpha) * GetValue(chain) + gamma * alpha * maxNextQ; // direct contribution is zero
     139      }
     140      if (problem.Grammar.IsTerminal(terminalChain)) MarkAsDone(terminalChain);
     141    }
     142
    80143    public void Reset() {
    81       stateInfo.Clear();
     144      Q.Clear();
    82145      done.Clear();
     146      followStates.Clear();
    83147    }
    84148
     
    95159    public int GetTries(string state) {
    96160      var s = CalcState(state);
    97       if (stateInfo.ContainsKey(s)) return stateInfo[s].Tries;
     161      if (T.ContainsKey(s)) return T[s];
    98162      else return 0;
    99163    }
    100164
    101     public double GetValue(string state) {
    102       var s = CalcState(state);
    103       if (stateInfo.ContainsKey(s)) return stateInfo[s].Value;
     165    public double GetValue(string chain) {
     166      var s = CalcState(chain);
     167      if (Q.ContainsKey(s)) return Q[s];
    104168      else return 0.0; // TODO: check alternatives
    105169    }
     
    111175      return f.First().Id;
    112176    }
     177
     178    public void PrintStats() {
     179      Console.WriteLine(Q.Values.Max());
     180      var topTries = Q.Keys.OrderByDescending(key => T[key]).Take(50);
     181      var topQs = Q.Keys.Where(key=>key.Contains(",")).OrderByDescending(key => Q[key]).Take(50);
     182      foreach (var t in topTries.Zip(topQs, Tuple.Create)) {
     183        var id1 = t.Item1;
     184        var id2 = t.Item2;
     185        Console.WriteLine("{0,30} {1,6} {2:N4} {3,30} {4,6} {5:N4}", id1, T[id1], Q[id1], id2, T[id2], Q[id2]);
     186      }
     187
     188    }
    113189  }
    114190}
Note: See TracChangeset for help on using the changeset viewer.