1 | using System;
|
---|
2 | using System.Collections.Generic;
|
---|
3 | using System.Configuration;
|
---|
4 | using System.Diagnostics;
|
---|
5 | using System.Linq;
|
---|
6 | using System.Security.Policy;
|
---|
7 | using System.Text;
|
---|
8 | using System.Threading;
|
---|
9 | using System.Threading.Tasks;
|
---|
10 | using HeuristicLab.Algorithms.Bandits.BanditPolicies;
|
---|
11 | using HeuristicLab.Common;
|
---|
12 | using HeuristicLab.Problems.GrammaticalOptimization;
|
---|
13 |
|
---|
14 | namespace HeuristicLab.Algorithms.Bandits.GrammarPolicies {
|
---|
15 | public class TDPolicy : GrammarPolicy {
|
---|
16 |
|
---|
17 | private readonly HashSet<string> done;
|
---|
18 | private readonly Dictionary<string, double> v;
|
---|
19 | private IGrammarPolicy epsGreedy;
|
---|
20 |
|
---|
21 | public TDPolicy(IProblem problem, bool useCanonicalRepresentation = false)
|
---|
22 | : base(problem, useCanonicalRepresentation) {
|
---|
23 | this.done = new HashSet<string>();
|
---|
24 | this.v = new Dictionary<string, double>();
|
---|
25 | this.epsGreedy = new GenericGrammarPolicy(problem, new EpsGreedyPolicy(0.1), useCanonicalRepresentation);
|
---|
26 | }
|
---|
27 |
|
---|
28 | public override bool TrySelect(Random random, string curState, IEnumerable<string> afterStates, out int selectedStateIdx) {
|
---|
29 | // only select states that are not yet done
|
---|
30 | afterStates = afterStates.Where(a => !done.Contains(CanonicalState(a.ToString()))).ToArray();
|
---|
31 | if (!afterStates.Any()) {
|
---|
32 | // fail because all follow states have already been visited => also disable the current state
|
---|
33 | done.Add(CanonicalState(curState));
|
---|
34 | selectedStateIdx = -1;
|
---|
35 | return false;
|
---|
36 | }
|
---|
37 | throw new NotImplementedException(); // TODO: remap indices of reduced action enumerable to indices of original enumerable (see genericgrammarpolicy)
|
---|
38 |
|
---|
39 | //return epsGreedy.TrySelect(random, curState, afterStates, out selectedState);
|
---|
40 |
|
---|
41 | var bestQ = double.NegativeInfinity;
|
---|
42 | int idx = -1;
|
---|
43 | selectedStateIdx = -1;
|
---|
44 | foreach (var state in afterStates) {
|
---|
45 | idx++;
|
---|
46 | // try each state at least once
|
---|
47 | if (GetTries(state) == 0) {
|
---|
48 | selectedStateIdx = idx;
|
---|
49 | return true;
|
---|
50 | }
|
---|
51 | var q = V(state);
|
---|
52 | if (q > bestQ) {
|
---|
53 | bestQ = q;
|
---|
54 | selectedStateIdx = idx;
|
---|
55 | }
|
---|
56 | }
|
---|
57 |
|
---|
58 | Debug.Assert(selectedStateIdx > -1);
|
---|
59 | return true;
|
---|
60 | }
|
---|
61 |
|
---|
62 | private double V(string state) {
|
---|
63 | var s = CanonicalState(state);
|
---|
64 | if (v.ContainsKey(s)) return v[s];
|
---|
65 | else return 0.0;
|
---|
66 | }
|
---|
67 |
|
---|
68 | public override void UpdateReward(IEnumerable<string> stateTrajectory, double reward) {
|
---|
69 | base.UpdateReward(stateTrajectory, reward);
|
---|
70 | epsGreedy.UpdateReward(stateTrajectory, reward);
|
---|
71 | // the last state could be terminal
|
---|
72 | var lastState = stateTrajectory.Last();
|
---|
73 | if (problem.Grammar.IsTerminal(lastState)) done.Add(CanonicalState(lastState));
|
---|
74 |
|
---|
75 | v[CanonicalState(lastState)] = V(lastState) + 1.0 / GetTries(lastState) * (reward - V(lastState));
|
---|
76 |
|
---|
77 | foreach (var p in stateTrajectory.Zip(stateTrajectory.Skip(1), Tuple.Create).Reverse()) {
|
---|
78 | var cur = p.Item1;
|
---|
79 | var next = p.Item2;
|
---|
80 |
|
---|
81 | v[CanonicalState(cur)] = V(cur) + 1.0 / GetTries(cur) * (V(next) - V(cur));
|
---|
82 | //v[CanonicalState(cur.ToString())] = V(cur) + 0.1 * (V(next) - V(cur));
|
---|
83 | }
|
---|
84 |
|
---|
85 | }
|
---|
86 |
|
---|
87 | public override double GetValue(string state) {
|
---|
88 | return V(state);
|
---|
89 | }
|
---|
90 |
|
---|
91 | public override void Reset() {
|
---|
92 | base.Reset();
|
---|
93 | epsGreedy.Reset();
|
---|
94 | v.Clear();
|
---|
95 | done.Clear();
|
---|
96 | }
|
---|
97 | }
|
---|
98 | }
|
---|