[11770] | 1 | using System;
|
---|
| 2 | using System.Collections.Generic;
|
---|
| 3 | using System.Configuration;
|
---|
[11793] | 4 | using System.Diagnostics;
|
---|
[11770] | 5 | using System.Linq;
|
---|
| 6 | using System.Security.Policy;
|
---|
| 7 | using System.Text;
|
---|
| 8 | using System.Threading;
|
---|
| 9 | using System.Threading.Tasks;
|
---|
[11793] | 10 | using HeuristicLab.Algorithms.Bandits.BanditPolicies;
|
---|
[11770] | 11 | using HeuristicLab.Common;
|
---|
| 12 | using HeuristicLab.Problems.GrammaticalOptimization;
|
---|
| 13 |
|
---|
| 14 | namespace HeuristicLab.Algorithms.Bandits.GrammarPolicies {
|
---|
| 15 | public class TDPolicy : GrammarPolicy {
|
---|
| 16 |
|
---|
| 17 | private readonly HashSet<string> done;
|
---|
| 18 | private readonly Dictionary<string, double> v;
|
---|
[11793] | 19 | private IGrammarPolicy epsGreedy;
|
---|
[11770] | 20 |
|
---|
| 21 | public TDPolicy(IProblem problem, bool useCanonicalRepresentation = false)
|
---|
| 22 | : base(problem, useCanonicalRepresentation) {
|
---|
| 23 | this.done = new HashSet<string>();
|
---|
| 24 | this.v = new Dictionary<string, double>();
|
---|
[11793] | 25 | this.epsGreedy = new GenericGrammarPolicy(problem, new EpsGreedyPolicy(0.1), useCanonicalRepresentation);
|
---|
[11770] | 26 | }
|
---|
| 27 |
|
---|
[11793] | 28 | public override bool TrySelect(Random random, string curState, IEnumerable<string> afterStates, out int selectedStateIdx) {
|
---|
[11770] | 29 | // only select states that are not yet done
|
---|
| 30 | afterStates = afterStates.Where(a => !done.Contains(CanonicalState(a.ToString()))).ToArray();
|
---|
| 31 | if (!afterStates.Any()) {
|
---|
| 32 | // fail because all follow states have already been visited => also disable the current state
|
---|
[11793] | 33 | done.Add(CanonicalState(curState));
|
---|
| 34 | selectedStateIdx = -1;
|
---|
[11770] | 35 | return false;
|
---|
| 36 | }
|
---|
[11799] | 37 | throw new NotImplementedException(); // TODO: remap indices of reduced action enumerable to indices of original enumerable (see genericgrammarpolicy)
|
---|
[11770] | 38 |
|
---|
| 39 | //return epsGreedy.TrySelect(random, curState, afterStates, out selectedState);
|
---|
| 40 |
|
---|
| 41 | var bestQ = double.NegativeInfinity;
|
---|
[11793] | 42 | int idx = -1;
|
---|
| 43 | selectedStateIdx = -1;
|
---|
[11770] | 44 | foreach (var state in afterStates) {
|
---|
[11793] | 45 | idx++;
|
---|
[11770] | 46 | // try each state at least once
|
---|
| 47 | if (GetTries(state) == 0) {
|
---|
[11793] | 48 | selectedStateIdx = idx;
|
---|
[11770] | 49 | return true;
|
---|
| 50 | }
|
---|
| 51 | var q = V(state);
|
---|
| 52 | if (q > bestQ) {
|
---|
| 53 | bestQ = q;
|
---|
[11793] | 54 | selectedStateIdx = idx;
|
---|
[11770] | 55 | }
|
---|
| 56 | }
|
---|
| 57 |
|
---|
[11793] | 58 | Debug.Assert(selectedStateIdx > -1);
|
---|
[11770] | 59 | return true;
|
---|
| 60 | }
|
---|
| 61 |
|
---|
[11793] | 62 | private double V(string state) {
|
---|
| 63 | var s = CanonicalState(state);
|
---|
[11770] | 64 | if (v.ContainsKey(s)) return v[s];
|
---|
| 65 | else return 0.0;
|
---|
| 66 | }
|
---|
| 67 |
|
---|
[11793] | 68 | public override void UpdateReward(IEnumerable<string> stateTrajectory, double reward) {
|
---|
[11770] | 69 | base.UpdateReward(stateTrajectory, reward);
|
---|
| 70 | epsGreedy.UpdateReward(stateTrajectory, reward);
|
---|
| 71 | // the last state could be terminal
|
---|
| 72 | var lastState = stateTrajectory.Last();
|
---|
[11793] | 73 | if (problem.Grammar.IsTerminal(lastState)) done.Add(CanonicalState(lastState));
|
---|
[11770] | 74 |
|
---|
[11793] | 75 | v[CanonicalState(lastState)] = V(lastState) + 1.0 / GetTries(lastState) * (reward - V(lastState));
|
---|
[11770] | 76 |
|
---|
| 77 | foreach (var p in stateTrajectory.Zip(stateTrajectory.Skip(1), Tuple.Create).Reverse()) {
|
---|
| 78 | var cur = p.Item1;
|
---|
| 79 | var next = p.Item2;
|
---|
| 80 |
|
---|
[11793] | 81 | v[CanonicalState(cur)] = V(cur) + 1.0 / GetTries(cur) * (V(next) - V(cur));
|
---|
[11770] | 82 | //v[CanonicalState(cur.ToString())] = V(cur) + 0.1 * (V(next) - V(cur));
|
---|
| 83 | }
|
---|
| 84 |
|
---|
| 85 | }
|
---|
| 86 |
|
---|
[11793] | 87 | public override double GetValue(string state) {
|
---|
[11770] | 88 | return V(state);
|
---|
| 89 | }
|
---|
| 90 |
|
---|
[11793] | 91 | public override void Reset() {
|
---|
[11770] | 92 | base.Reset();
|
---|
| 93 | epsGreedy.Reset();
|
---|
| 94 | v.Clear();
|
---|
| 95 | done.Clear();
|
---|
| 96 | }
|
---|
| 97 | }
|
---|
| 98 | }
|
---|