Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization-gkr/HeuristicLab.Algorithms.GrammaticalOptimization/SequentialDecisionPolicies/GenericPolicy.cs @ 12291

Last change on this file since 12291 was 12291, checked in by gkronber, 10 years ago

#2283 added missing files (forgotten in the branch)

File size: 3.8 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Text;
6using System.Threading.Tasks;
7using HeuristicLab.Common;
8using HeuristicLab.Problems.GrammaticalOptimization;
9
10namespace HeuristicLab.Algorithms.Bandits.GrammarPolicies {
11  // resampling is not prevented
12  public sealed class GenericPolicy : IGrammarPolicy {
13    private Dictionary<string, IBanditPolicyActionInfo> stateInfo; // stores the necessary information for bandit policies for each state
14    private readonly IProblem problem;
15    private readonly IBanditPolicy banditPolicy;
16    private readonly HashSet<string> done; // contains all visited chains
17
18    public GenericPolicy(IProblem problem, IBanditPolicy banditPolicy) {
19      this.problem = problem;
20      this.banditPolicy = banditPolicy;
21      this.stateInfo = new Dictionary<string, IBanditPolicyActionInfo>();
22      this.done = new HashSet<string>();
23    }
24
25    private IBanditPolicyActionInfo[] activeAfterStates; // don't allocate each time
26    private int[] actionIndexMap; // don't allocate each time
27
28    public bool TrySelect(Random random, string curState, IEnumerable<string> afterStates, out int selectedStateIdx) {
29      // fail if all states are done (corresponding state infos are disabled)
30      if (afterStates.All(s => Done(s))) {
31        // fail because all follow states have already been visited => also disable the current state (if we can be sure that it has been fully explored)
32        MarkAsDone(curState);
33
34        selectedStateIdx = -1;
35        return false;
36      }
37
38      if (activeAfterStates == null || activeAfterStates.Length < afterStates.Count()) {
39        activeAfterStates = new IBanditPolicyActionInfo[afterStates.Count()];
40        actionIndexMap = new int[afterStates.Count()];
41      }
42      var idx = 0; int originalIdx = 0;
43      foreach (var afterState in afterStates) {
44        if (!Done(afterState)) {
45          activeAfterStates[idx] = GetStateInfo(afterState);
46          actionIndexMap[idx] = originalIdx;
47          idx++;
48        }
49        originalIdx++;
50      }
51
52      selectedStateIdx = actionIndexMap[banditPolicy.SelectAction(random, activeAfterStates.Take(idx))];
53
54      return true;
55    }
56
57
58
59    private IBanditPolicyActionInfo GetStateInfo(string state) {
60      var s = CalcState(state);
61      IBanditPolicyActionInfo info;
62      if (!stateInfo.TryGetValue(s, out info)) {
63        info = banditPolicy.CreateActionInfo();
64        stateInfo[s] = info;
65      }
66      return info;
67    }
68
69    public void UpdateReward(IEnumerable<string> stateTrajectory, double reward) {
70      foreach (var state in stateTrajectory.Reverse()) {
71        GetStateInfo(state).UpdateReward(reward);
72
73        // actually only the last state can be terminal
74        if (problem.Grammar.IsTerminal(state)) {
75          MarkAsDone(state);
76        }
77      }
78    }
79
80    public void Reset() {
81      stateInfo.Clear();
82      done.Clear();
83    }
84
85
86    private bool Done(string chain) {
87      return done.Contains(chain);
88    }
89
90    private void MarkAsDone(string chain) {
91      done.Add(chain);
92    }
93
94
95    public int GetTries(string state) {
96      var s = CalcState(state);
97      if (stateInfo.ContainsKey(s)) return stateInfo[s].Tries;
98      else return 0;
99    }
100
101    public double GetValue(string state) {
102      var s = CalcState(state);
103      if (stateInfo.ContainsKey(s)) return stateInfo[s].Value;
104      else return 0.0; // TODO: check alternatives
105    }
106
107    private string CalcState(string chain) {
108      var f = problem.GetFeatures(chain);
109      // this policy only works for problems that return exactly one feature (the 'state')
110      if (f.Skip(1).Any()) throw new ArgumentException();
111      return f.First().Id;
112    }
113  }
114}
Note: See TracBrowser for help on using the repository browser.