Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/GrammarPolicies/GenericTDPolicy.cs @ 11832

Last change on this file since 11832 was 11806, checked in by gkronber, 9 years ago

#2283: separated value-states from done-states in GenericGrammarPolicy and removed disabling of actions from bandit policies

File size: 6.1 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using System.Threading.Tasks;
7using HeuristicLab.Algorithms.Bandits.BanditPolicies;
8using HeuristicLab.Common;
9using HeuristicLab.Problems.GrammaticalOptimization;
10
11namespace HeuristicLab.Algorithms.Bandits.GrammarPolicies {
12  public sealed class GenericTDPolicy : IGrammarPolicy {
13    private Dictionary<string, IBanditPolicyActionInfo> stateInfo; // stores the necessary information for bandit policies for each state (=canonical phrase)
14    private HashSet<string> done;
15    private readonly bool useCanonicalPhrases;
16    private readonly IProblem problem;
17    private readonly IBanditPolicy banditPolicy;
18
19    public GenericTDPolicy(IProblem problem, bool useCanonicalPhrases = false) {
20      this.useCanonicalPhrases = useCanonicalPhrases;
21      this.problem = problem;
22      this.banditPolicy = new EpsGreedyPolicy(0.1);
23      this.stateInfo = new Dictionary<string, IBanditPolicyActionInfo>();
24      this.done = new HashSet<string>();
25    }
26
27    private IBanditPolicyActionInfo[] activeAfterStates; // don't allocate each time
28    private int[] actionIndexMap; // don't allocate each time
29
30    public bool TrySelect(Random random, string curState, IEnumerable<string> afterStates, out int selectedStateIdx) {
31      // fail if all states are done (corresponding state infos are disabled)
32      if (afterStates.All(s => Done(s))) {
33        // fail because all follow states have already been visited => also disable the current state (if we can be sure that it has been fully explored)
34        MarkAsDone(curState);
35
36        selectedStateIdx = -1;
37        return false;
38      }
39
40      // determine active actions (not done yet) and create an array to map the selected index back to original actions
41      if (activeAfterStates == null || activeAfterStates.Length < afterStates.Count()) {
42        activeAfterStates = new IBanditPolicyActionInfo[afterStates.Count()];
43        actionIndexMap = new int[afterStates.Count()];
44      }
45      var idx = 0; int originalIdx = 0;
46      foreach (var afterState in afterStates) {
47        if (!Done(afterState)) {
48          activeAfterStates[idx] = GetStateInfo(afterState);
49          actionIndexMap[idx] = originalIdx;
50          idx++;
51        }
52        originalIdx++;
53      }
54
55      selectedStateIdx = actionIndexMap[banditPolicy.SelectAction(random, activeAfterStates.Take(idx))];
56     
57     
58      GetStateInfo(curState).UpdateReward(afterStates.Select(s=>GetStateInfo(s).Value).Max());
59
60      return true;
61    }
62
63
64
65    private IBanditPolicyActionInfo GetStateInfo(string state) {
66      var s = CanonicalState(state);
67      IBanditPolicyActionInfo info;
68      if (!stateInfo.TryGetValue(s, out info)) {
69        info = banditPolicy.CreateActionInfo();
70        stateInfo[s] = info;
71      }
72      return info;
73    }
74
75    public void UpdateReward(IEnumerable<string> stateTrajectory, double reward) {
76
77      foreach (var state in stateTrajectory.Zip(stateTrajectory.Skip(1), Tuple.Create)) {
78        var parent = state.Item1;
79        var child = state.Item2;
80
81        // only the last state can be terminal
82        if (problem.Grammar.IsTerminal(child)) {
83          MarkAsDone(child);
84        }
85        // GetStateInfo(parent).UpdateReward(GetValue(child));
86
87      }
88
89      GetStateInfo(stateTrajectory.Last()).UpdateReward(reward);
90    }
91
92
93    public void Reset() {
94      stateInfo.Clear();
95      done.Clear();
96    }
97
98    public int GetTries(string state) {
99      var s = CanonicalState(state);
100      if (stateInfo.ContainsKey(s)) return stateInfo[s].Tries;
101      else return 0;
102    }
103
104    public double GetValue(string state) {
105      var s = CanonicalState(state);
106      if (stateInfo.ContainsKey(s)) return stateInfo[s].Value;
107      else return 0.0; // TODO: check alternatives
108    }
109
110    // the canonical states for the value function (banditInfos) and the done set must be distinguished
111    // sequences of different length could have the same canonical representation and can have the same value (banditInfo)
112    // however, if the canonical representation of a state is shorter than we must not mark the canonical state as done when all possible derivations from the initial state have been explored
113    // eg. in the ant problem the canonical representation for ...lllA is ...rA
114    // even though all possible derivations (of limited length) of lllA have been visited we must not mark the state rA as done
115    private void MarkAsDone(string state) {
116      var s = CanonicalState(state);
117      // when the lengths of the canonical string and the original string are the same we also disable the actions
118      // always disable terminals
119      Debug.Assert(s.Length <= state.Length);
120      if (s.Length == state.Length || problem.Grammar.IsTerminal(state)) {
121        Debug.Assert(!done.Contains(s));
122        done.Add(s);
123      } else {
124        // for non-terminals where the canonical string is shorter than the original string we can only disable the canonical representation for all states in the same level
125        Debug.Assert(!done.Contains(s + state.Length));
126        done.Add(s + state.Length); // encode the original length of the state, states in the same level of the tree are treated as equivalent
127      }
128    }
129
130    // symmetric to MarkDone
131    private bool Done(string state) {
132      var s = CanonicalState(state);
133      if (s.Length == state.Length || problem.Grammar.IsTerminal(state)) {
134        return done.Contains(s);
135      } else {
136        // it is not necessary to visit states if the canonical representation has already been fully explored
137        if (done.Contains(s)) return true;
138        if (done.Contains(s + state.Length)) return true;
139        for (int i = 1; i < state.Length; i++) {
140          if (done.Contains(s + i)) return true;
141        }
142        return false;
143      }
144    }
145
146    private string CanonicalState(string state) {
147      if (useCanonicalPhrases) {
148        return problem.CanonicalRepresentation(state);
149      } else
150        return state;
151    }
152  }
153}
Note: See TracBrowser for help on using the repository browser.