Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/SequentialDecisionPolicies/GenericContextualGrammarPolicy.cs @ 11850

Last change on this file since 11850 was 11832, checked in by gkronber, 10 years ago

linear value function approximation and good results for poly-10 benchmark

File size: 6.7 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Runtime.ExceptionServices;
6using System.Text;
7using System.Text.RegularExpressions;
8using System.Threading;
9using System.Threading.Tasks;
10using HeuristicLab.Common;
11using HeuristicLab.Problems.GrammaticalOptimization;
12
13namespace HeuristicLab.Algorithms.Bandits.GrammarPolicies {
14  public sealed class GenericContextualGrammarPolicy : IGrammarPolicy {
15    private Dictionary<string, IBanditPolicyActionInfo> stateInfo; // stores the necessary information for bandit policies for each state (=canonical phrase)
16    private HashSet<string> done;
17    private readonly bool useCanonicalPhrases;
18    private readonly IProblem problem;
19    private readonly IBanditPolicy banditPolicy;
20
21    public GenericContextualGrammarPolicy(IProblem problem, IBanditPolicy banditPolicy, bool useCanonicalPhrases = false) {
22      this.useCanonicalPhrases = useCanonicalPhrases;
23      this.problem = problem;
24      this.banditPolicy = banditPolicy;
25      this.stateInfo = new Dictionary<string, IBanditPolicyActionInfo>();
26      this.done = new HashSet<string>();
27    }
28
29    private IBanditPolicyActionInfo[] activeAfterStates; // don't allocate each time
30    private int[] actionIndexMap; // don't allocate each time
31
32    public bool TrySelect(Random random, string curState, IEnumerable<string> afterStates, out int selectedStateIdx) {
33      // fail if all states are done (corresponding state infos are disabled)
34      if (afterStates.All(s => Done(s))) {
35        // fail because all follow states have already been visited => also disable the current state (if we can be sure that it has been fully explored)
36        MarkAsDone(curState);
37
38        selectedStateIdx = -1;
39        return false;
40      }
41
42      // determine active actions (not done yet) and create an array to map the selected index back to original actions
43      if (activeAfterStates == null || activeAfterStates.Length < afterStates.Count()) {
44        activeAfterStates = new IBanditPolicyActionInfo[afterStates.Count()];
45        actionIndexMap = new int[afterStates.Count()];
46      }
47      var idx = 0; int originalIdx = 0;
48      foreach (var afterState in afterStates) {
49        if (!Done(afterState)) {
50          activeAfterStates[idx] = GetStateInfo(afterState);
51          actionIndexMap[idx] = originalIdx;
52          idx++;
53        }
54        originalIdx++;
55      }
56
57      selectedStateIdx = actionIndexMap[banditPolicy.SelectAction(random, activeAfterStates.Take(idx))];
58
59      return true;
60    }
61
62
63    private IBanditPolicyActionInfo GetStateInfo(string state) {
64      var s = Context(state);
65      IBanditPolicyActionInfo info;
66      if (!stateInfo.TryGetValue(s, out info)) {
67        info = banditPolicy.CreateActionInfo();
68        stateInfo[s] = info;
69      }
70      return info;
71    }
72
73    private string Context(string state) {
74      // var cutOff = problem.Grammar.IsTerminal(state) ? 0 : 2;
75      // var contextLen = 7;
76      // var firstPos = Math.Max(0, state.Length - contextLen - cutOff);
77      // var lastPos = Math.Max(0, Math.Min(state.Length, state.Length - cutOff));
78      // var context = state.Substring(firstPos, (lastPos - firstPos));
79      var contextLen = 5;
80      var firstPos = Math.Max(0, state.Length - contextLen);
81      var lastPos = Math.Min(state.Length, firstPos + contextLen);
82      var context = state.Substring(firstPos, (lastPos - firstPos));
83      return context + Regex.Matches(state, "(?=" + Regex.Escape(context) + ")").Count;
84
85
86      //var lastPlusPos = Math.Max(0, state.LastIndexOf('+') + 1);
87      //var context = state.Substring(lastPlusPos);
88      //return context + Regex.Matches(state, "(?=" + Regex.Escape(context) + ")").Count;
89    }
90
91    public void UpdateReward(IEnumerable<string> stateTrajectory, double reward) {
92      foreach (var state in stateTrajectory) {
93        GetStateInfo(state).UpdateReward(reward);
94
95        // only the last state can be terminal
96        if (problem.Grammar.IsTerminal(state)) {
97          MarkAsDone(state);
98        }
99      }
100    }
101
102
103    public void Reset() {
104      stateInfo.Clear();
105      done.Clear();
106    }
107
108    public int GetTries(string state) {
109      var s = Context(state);
110      if (stateInfo.ContainsKey(s)) return stateInfo[s].Tries;
111      else return 0;
112    }
113
114    public double GetValue(string state) {
115      var s = Context(state);
116      if (stateInfo.ContainsKey(s)) return stateInfo[s].Value;
117      else return 0.0; // TODO: check alternatives
118    }
119
120    // the canonical states for the value function (banditInfos) and the done set must be distinguished
121    // sequences of different length could have the same canonical representation and can have the same value (banditInfo)
122    // however, if the canonical representation of a state is shorter than we must not mark the canonical state as done when all possible derivations from the initial state have been explored
123    // eg. in the ant problem the canonical representation for ...lllA is ...rA
124    // even though all possible derivations (of limited length) of lllA have been visited we must not mark the state rA as done
125    private void MarkAsDone(string state) {
126      var s = CanonicalState(state);
127      // when the lengths of the canonical string and the original string are the same we also disable the actions
128      // always disable terminals
129      Debug.Assert(s.Length <= state.Length);
130      if (s.Length == state.Length || problem.Grammar.IsTerminal(state)) {
131        Debug.Assert(!done.Contains(s));
132        done.Add(s);
133      } else {
134        // for non-terminals where the canonical string is shorter than the original string we can only disable the canonical representation for all states in the same level
135        Debug.Assert(!done.Contains(s + state.Length));
136        done.Add(s + state.Length); // encode the original length of the state, states in the same level of the tree are treated as equivalent
137      }
138    }
139
140    // symmetric to MarkDone
141    private bool Done(string state) {
142      var s = CanonicalState(state);
143      if (s.Length == state.Length || problem.Grammar.IsTerminal(state)) {
144        return done.Contains(s);
145      } else {
146        // it is not necessary to visit states if the canonical representation has already been fully explored
147        if (done.Contains(s)) return true;
148        if (done.Contains(s + state.Length)) return true;
149        for (int i = 1; i < state.Length; i++) {
150          if (done.Contains(s + i)) return true;
151        }
152        return false;
153      }
154    }
155
156    private string CanonicalState(string state) {
157      if (useCanonicalPhrases) {
158        return problem.CanonicalRepresentation(state);
159      } else
160        return state;
161    }
162  }
163}
Note: See TracBrowser for help on using the repository browser.