Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization-gkr/HeuristicLab.Algorithms.GrammaticalOptimization/SequentialDecisionPolicies/GenericPolicy.cs @ 12295

Last change on this file since 12295 was 12295, checked in by gkronber, 9 years ago

#2283: force selection of untried alternatives

File size: 6.8 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using System.Threading.Tasks;
7using HeuristicLab.Common;
8using HeuristicLab.Problems.GrammaticalOptimization;
9
10namespace HeuristicLab.Algorithms.Bandits.GrammarPolicies {
11  // resampling is not prevented
12  public sealed class GenericPolicy : IGrammarPolicy {
13    private Dictionary<string, double> Q; // stores the necessary information for bandit policies for each state
14    private Dictionary<string, int> T; // tries;
15    private Dictionary<string, List<string>> followStates;
16    private readonly IProblem problem;
17    private readonly HashSet<string> done; // contains all visited chains
18
19    public GenericPolicy(IProblem problem) {
20      this.problem = problem;
21      this.Q = new Dictionary<string, double>();
22      this.T = new Dictionary<string, int>();
23      this.followStates = new Dictionary<string, List<string>>();
24      this.done = new HashSet<string>();
25    }
26
27    private double[] activeAfterStates; // don't allocate each time
28    private int[] actionIndexMap; // don't allocate each time
29
30    public bool TrySelect(Random random, string curState, IEnumerable<string> afterStates, out int selectedStateIdx) {
31      // fail if all states are done (corresponding state infos are disabled)
32      if (afterStates.All(s => Done(s))) {
33        // fail because all follow states have already been visited => also disable the current state (if we can be sure that it has been fully explored)
34        MarkAsDone(curState);
35
36        selectedStateIdx = -1;
37        return false;
38      }
39
40      if (activeAfterStates == null || activeAfterStates.Length < afterStates.Count()) {
41        activeAfterStates = new double[afterStates.Count()];
42        actionIndexMap = new int[afterStates.Count()];
43      }
44      if (!followStates.ContainsKey(curState)) {
45        followStates[curState] = new List<string>(afterStates);
46      }
47      var idx = 0; int originalIdx = 0;
48      foreach (var afterState in afterStates) {
49        if (!Done(afterState)) {
50          if (GetTries(afterState) == 0)
51            activeAfterStates[idx] = double.PositiveInfinity;
52          else
53            activeAfterStates[idx] = GetValue(afterState);
54          actionIndexMap[idx] = originalIdx;
55          idx++;
56        }
57        originalIdx++;
58      }
59
60      //var eps = Math.Max(500.0 / (GetTries(curState) + 1), 0.01);
61      //var eps = 10.0 / Math.Sqrt(GetTries(curState) + 1);
62      var eps = 0.2;
63      selectedStateIdx = actionIndexMap[SelectEpsGreedy(random, activeAfterStates.Take(idx), eps)];
64
65      return true;
66    }
67
68    private int SelectBoltzmann(Random random, IEnumerable<double> qs, double beta = 10) {
69      // select best
70
71      // try any of the untries actions randomly
72      // for RoyalSequence it is much better to select the actions in the order of occurrence (all terminal alternatives first)
73      //if (myActionInfos.Any(aInfo => !aInfo.Disabled && aInfo.Tries == 0)) {
74      //  return myActionInfos
75      //  .Select((aInfo, idx) => new { aInfo, idx })
76      //  .Where(p => !p.aInfo.Disabled)
77      //  .Where(p => p.aInfo.Tries == 0)
78      //  .SelectRandom(random).idx;
79      //}
80
81      var w = from q in qs
82              select Math.Exp(beta * q);
83
84      var bestAction = Enumerable.Range(0, qs.Count()).SampleProportional(random, w);
85      Debug.Assert(bestAction >= 0);
86      return bestAction;
87    }
88
89    private int SelectEpsGreedy(Random random, IEnumerable<double> qs, double eps = 0.2) {
90      if (random.NextDouble() >= eps) { // eps == 0 should be equivalent to pure exploitation, eps == 1 is pure exploration
91        // select best
92        var bestActions = new List<int>();
93        double bestQ = double.NegativeInfinity;
94
95        int aIdx = -1;
96        foreach (var q in qs) {
97          aIdx++;
98
99          if (q > bestQ) {
100            bestActions.Clear();
101            bestActions.Add(aIdx);
102            bestQ = q;
103          } else if (q.IsAlmost(bestQ)) {
104            bestActions.Add(aIdx);
105          }
106        }
107        Debug.Assert(bestActions.Any());
108        return bestActions.SelectRandom(random);
109      } else {
110        // select random
111        return SelectRandom(random, qs);
112      }
113    }
114
115    private int SelectRandom(Random random, IEnumerable<double> qs) {
116      return qs
117         .Select((aInfo, idx) => Tuple.Create(aInfo, idx))
118         .SelectRandom(random).Item2;
119    }
120
121
122    public void UpdateReward(IEnumerable<string> chainTrajectory, double reward) {
123      const double gamma = 0.95;
124      const double minAlpha = 0.01;
125      var reverseChains = chainTrajectory.Reverse();
126      var terminalChain = reverseChains.First();
127
128      var terminalState = CalcState(terminalChain);
129      T[terminalState] = GetTries(terminalChain) + 1;
130      double alpha = Math.Max(1.0 / GetTries(terminalChain), minAlpha);
131      Q[terminalState] = (1 - alpha) * GetValue(terminalChain) + alpha * reward;
132
133      foreach (var chain in reverseChains.Skip(1)) {
134
135        var maxNextQ = followStates[chain]
136          //.Where(s=>!Done(s))
137          .Select(GetValue).Max();
138        T[CalcState(chain)] = GetTries(chain) + 1;
139
140        alpha = Math.Max(1.0 / GetTries(chain), minAlpha);
141        Q[CalcState(chain)] = (1 - alpha) * GetValue(chain) + gamma * alpha * maxNextQ; // direct contribution is zero
142      }
143      if (problem.Grammar.IsTerminal(terminalChain)) MarkAsDone(terminalChain);
144    }
145
146    public void Reset() {
147      Q.Clear();
148      done.Clear();
149      followStates.Clear();
150    }
151
152
153    private bool Done(string chain) {
154      return done.Contains(chain);
155    }
156
157    private void MarkAsDone(string chain) {
158      done.Add(chain);
159    }
160
161
162    public int GetTries(string state) {
163      var s = CalcState(state);
164      if (T.ContainsKey(s)) return T[s];
165      else return 0;
166    }
167
168    public double GetValue(string chain) {
169      var s = CalcState(chain);
170      if (Q.ContainsKey(s)) return Q[s];
171      else return 0.0; // TODO: check alternatives
172    }
173
174    private string CalcState(string chain) {
175      var f = problem.GetFeatures(chain);
176      // this policy only works for problems that return exactly one feature (the 'state')
177      if (f.Skip(1).Any()) throw new ArgumentException();
178      return f.First().Id;
179    }
180
181    public void PrintStats() {
182      Console.WriteLine(Q.Values.Max());
183      var topTries = Q.Keys.OrderByDescending(key => T[key]).Take(50);
184      var topQs = Q.Keys.Where(key => key.Contains(",")).OrderByDescending(key => Q[key]).Take(50);
185      foreach (var t in topTries.Zip(topQs, Tuple.Create)) {
186        var id1 = t.Item1;
187        var id2 = t.Item2;
188        Console.WriteLine("{0,30} {1,6} {2:N4} {3,30} {4,6} {5:N4}", id1, T[id1], Q[id1], id2, T[id2], Q[id2]);
189      }
190
191    }
192  }
193}
Note: See TracBrowser for help on using the repository browser.