Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization-gkr/HeuristicLab.Algorithms.GrammaticalOptimization/SequentialDecisionPolicies/GenericPolicy.cs @ 12298

Last change on this file since 12298 was 12298, checked in by gkronber, 9 years ago

#2283: experiments with q-learning

File size: 8.4 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using System.Threading.Tasks;
7using HeuristicLab.Common;
8using HeuristicLab.Problems.GrammaticalOptimization;
9
10namespace HeuristicLab.Algorithms.Bandits.GrammarPolicies {
11  // resampling is not prevented
12  public sealed class GenericPolicy : IGrammarPolicy {
13    private Dictionary<string, double> Q; // stores the necessary information for bandit policies for each state
14    private Dictionary<string, int> T; // tries;
15    private Dictionary<string, List<string>> followStates;
16    private readonly IProblem problem;
17    private readonly HashSet<string> done; // contains all visited chains
18
19    public GenericPolicy(IProblem problem) {
20      this.problem = problem;
21      this.Q = new Dictionary<string, double>();
22      this.T = new Dictionary<string, int>();
23      this.followStates = new Dictionary<string, List<string>>();
24      this.done = new HashSet<string>();
25    }
26
27    private double[] activeAfterStates; // don't allocate each time
28    private int[] actionIndexMap; // don't allocate each time
29
30    public bool TrySelect(Random random, string curState, IEnumerable<string> afterStates, out int selectedStateIdx) {
31      // fail if all states are done (corresponding state infos are disabled)
32      if (afterStates.All(s => Done(s))) {
33        // fail because all follow states have already been visited => also disable the current state (if we can be sure that it has been fully explored)
34        MarkAsDone(curState);
35
36        selectedStateIdx = -1;
37        return false;
38      }
39
40      if (activeAfterStates == null || activeAfterStates.Length < afterStates.Count()) {
41        activeAfterStates = new double[afterStates.Count()];
42        actionIndexMap = new int[afterStates.Count()];
43      }
44      if (!followStates.ContainsKey(curState)) {
45        followStates[curState] = new List<string>(afterStates);
46      }
47      var idx = 0; int originalIdx = 0;
48      foreach (var afterState in afterStates) {
49        if (!Done(afterState)) {
50          activeAfterStates[idx] = CalculateValue(afterState);
51          actionIndexMap[idx] = originalIdx;
52          idx++;
53        }
54        originalIdx++;
55      }
56
57
58      //var eps = Math.Max(500.0 / (GetTries(curState) + 1), 0.01);
59      //var eps = 10.0 / Math.Sqrt(GetTries(curState) + 1);
60      var eps = 0.01;
61      selectedStateIdx = actionIndexMap[SelectEpsGreedy(random, activeAfterStates.Take(idx), eps)];
62
63      UpdateValue(curState, afterStates);
64
65      return true;
66    }
67
68    private double CalculateValue(string chain) {
69      var features = problem.GetFeatures(chain);
70      var sum = 0.0;
71      foreach (var f in features) {
72        // if (GetTries(f.Id) == 0)
73        //   sum = 0.0;
74        // else
75        sum += GetValue(f.Id) * f.Value;
76      }
77      return sum;
78    }
79
80    private void UpdateValue(string curChain, IEnumerable<string> alternatives) {
81      const double gamma = 1;
82      const double alpha = 0.01;
83      var maxNextQ = alternatives
84          .Select(CalculateValue).Max();
85
86      var delta = gamma * maxNextQ - CalculateValue(curChain);
87
88      foreach (var f in problem.GetFeatures(curChain)) {
89
90        Q[f.Id] = GetValue(f.Id) + alpha * delta * f.Value;
91      }
92    }
93
94    private void UpdateLastValue(string terminalChain, double reward) {
95      const double alpha = 0.01;
96      var delta = reward - CalculateValue(terminalChain);
97      foreach (var f in problem.GetFeatures(terminalChain)) {
98        Q[f.Id] = GetValue(f.Id) + alpha * delta * f.Value;
99      }
100    }
101
102
103    private int SelectBoltzmann(Random random, IEnumerable<double> qs, double beta = 10) {
104      // select best
105
106      // try any of the untries actions randomly
107      // for RoyalSequence it is much better to select the actions in the order of occurrence (all terminal alternatives first)
108      //if (myActionInfos.Any(aInfo => !aInfo.Disabled && aInfo.Tries == 0)) {
109      //  return myActionInfos
110      //  .Select((aInfo, idx) => new { aInfo, idx })
111      //  .Where(p => !p.aInfo.Disabled)
112      //  .Where(p => p.aInfo.Tries == 0)
113      //  .SelectRandom(random).idx;
114      //}
115
116      var w = from q in qs
117              select Math.Exp(beta * q);
118
119      var bestAction = Enumerable.Range(0, qs.Count()).SampleProportional(random, w);
120      Debug.Assert(bestAction >= 0);
121      return bestAction;
122    }
123
124    private int SelectEpsGreedy(Random random, IEnumerable<double> qs, double eps = 0.2) {
125      if (random.NextDouble() >= eps) { // eps == 0 should be equivalent to pure exploitation, eps == 1 is pure exploration
126        // select best
127        var bestActions = new List<int>();
128        double bestQ = double.NegativeInfinity;
129
130        int aIdx = -1;
131        foreach (var q in qs) {
132          aIdx++;
133
134          if (q > bestQ) {
135            bestActions.Clear();
136            bestActions.Add(aIdx);
137            bestQ = q;
138          } else if (q.IsAlmost(bestQ)) {
139            bestActions.Add(aIdx);
140          }
141        }
142        Debug.Assert(bestActions.Any());
143        return bestActions.SelectRandom(random);
144      } else {
145        // select random
146        return SelectRandom(random, qs);
147      }
148    }
149
150    private int SelectRandom(Random random, IEnumerable<double> qs) {
151      return qs
152         .Select((aInfo, idx) => Tuple.Create(aInfo, idx))
153         .SelectRandom(random).Item2;
154    }
155
156
157    public void UpdateReward(IEnumerable<string> chainTrajectory, double reward) {
158      // // only updates the last chain because we already update values after each step
159      // var reverseChains = chainTrajectory.Reverse();
160      // var terminalChain = reverseChains.First();
161      //
162      // UpdateValue(terminalChain, reward);
163      //
164      // foreach (var chain in reverseChains.Skip(1)) {
165      //
166      //   var maxNextQ = followStates[chain]
167      //     //.Where(s=>!Done(s))
168      //     .Select(GetValue).Max();
169      //
170      //   UpdateValue(chain, maxNextQ);
171      // }
172      var terminalChain = chainTrajectory.Last();
173      UpdateLastValue(terminalChain, reward);
174      if (problem.Grammar.IsTerminal(terminalChain)) MarkAsDone(terminalChain);
175    }
176
177
178    public void Reset() {
179      Q.Clear();
180      T.Clear();
181      done.Clear();
182      followStates.Clear();
183    }
184
185
186    private bool Done(string chain) {
187      return done.Contains(chain);
188    }
189
190    private void MarkAsDone(string chain) {
191      done.Add(chain);
192    }
193
194
195    public int GetTries(string fId) {
196      if (T.ContainsKey(fId)) return T[fId];
197      else return 0;
198    }
199
200    public double GetValue(string fId) {
201      // var s = CalcState(chain);
202      if (Q.ContainsKey(fId)) return Q[fId];
203      else return 0.0; // TODO: check alternatives
204    }
205
206    // private string CalcState(string chain) {
207    //   var f = problem.GetFeatures(chain);
208    //   // this policy only works for problems that return exactly one feature (the 'state')
209    //   if (f.Skip(1).Any()) throw new ArgumentException();
210    //   return f.First().Id;
211    // }
212
213    public void PrintStats() {
214      Console.WriteLine(Q.Values.Max());
215      // var topTries = Q.Keys.OrderByDescending(key => T[key]).Take(50);
216      // var topQs = Q.Keys/*.Where(key => key.Contains("E"))*/.OrderByDescending(key => Q[key]).Take(50);
217      // foreach (var t in topTries.Zip(topQs, Tuple.Create)) {
218      //   var id1 = t.Item1;
219      //   var id2 = t.Item2;
220      //   Console.WriteLine("{0,30} {1,6} {2:N4} {3,30} {4,6} {5:N4}", id1, T[id1], Q[id1], id2, T[id2], Q[id2]);
221      // }
222
223      foreach (var option in new String[]
224      {
225        "a*b", "c*d", "a*b+c*d", "e*f", "a*b+c*d+e*f",
226        "a*b+a*b", "c*d+c*d",
227        "a*a", "a*b","a*c","a*d","a*e","a*f","a*g","a*h","a*i","a*j",
228        "a*b","c*d","e*f","a*c","a*f","a*i","a*i*g","c*f","c*f*j",
229        "b+c","a+c","b+d","a+d",
230        "a*b+c*d+e*f", "a*b+c*d+e*f+a", "a*b+c*d+e*f+b", "a*b+c*d+e*f+c", "a*b+c*d+e*f+d","a*b+c*d+e*f+e",  "a*b+c*d+e*f+f", "a*b+c*d+e*f+g", "a*b+c*d+e*f+h", "a*b+c*d+e*f+i", "a*b+c*d+e*f+j",
231        "a*b+c*d+e*f+a*g*i+c*j*f"
232      }) {
233        Console.WriteLine("{0,-10} {1:N5}", option, CalculateValue(option));
234      }
235
236      // var topQs = Q.Keys/*.Where(key => key.Contains("E"))*/.OrderByDescending(key => Math.Abs(Q[key])).Take(10);
237      // foreach (var t in topQs) {
238      //   Console.WriteLine("{0,30} {1:N4}", t, Q[t]);
239      // }
240    }
241  }
242}
Note: See TracBrowser for help on using the repository browser.