Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/GrammarPolicies/GenericGrammarPolicy.cs @ 11799

Last change on this file since 11799 was 11799, checked in by gkronber, 9 years ago

#2283: performance tuning and reactivated random-roll-out policy in sequential search

File size: 3.0 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using System.Threading.Tasks;
6using HeuristicLab.Common;
7using HeuristicLab.Problems.GrammaticalOptimization;
8
9namespace HeuristicLab.Algorithms.Bandits.GrammarPolicies {
10  // this represents grammar policies that use one of the available bandit policies for state selection
11  public class GenericGrammarPolicy : IGrammarPolicy {
12    protected Dictionary<string, IBanditPolicyActionInfo> stateInfo; // stores the necessary information for bandit policies for each state
13    private readonly bool useCanonicalState;
14    private readonly IProblem problem;
15    private readonly IBanditPolicy banditPolicy;
16
17    public GenericGrammarPolicy(IProblem problem, IBanditPolicy banditPolicy, bool useCanonicalState = false) {
18      this.useCanonicalState = useCanonicalState;
19      this.problem = problem;
20      this.banditPolicy = banditPolicy;
21      this.stateInfo = new Dictionary<string, IBanditPolicyActionInfo>();
22    }
23
24    public bool TrySelect(Random random, string curState, IEnumerable<string> afterStates, out int selectedStateIdx) {
25      // fail if all states are done (corresponding state infos are disabled)
26      if (afterStates.All(s => GetStateInfo(s).Disabled)) {
27        // fail because all follow states have already been visited => also disable the current state (if we can be sure that it has been fully explored)
28
29        GetStateInfo(curState).Disable(afterStates.Select(afterState => GetStateInfo(afterState).Value).Max());
30        selectedStateIdx = -1;
31        return false;
32      }
33
34      selectedStateIdx = banditPolicy.SelectAction(random, afterStates.Select(s => GetStateInfo(s)));
35
36      return true;
37    }
38
39    private IBanditPolicyActionInfo GetStateInfo(string state) {
40      var s = CanonicalState(state);
41      IBanditPolicyActionInfo info;
42      if (!stateInfo.TryGetValue(s, out info)) {
43        info = banditPolicy.CreateActionInfo();
44        stateInfo[s] = info;
45      }
46      return info;
47    }
48
49    public virtual void UpdateReward(IEnumerable<string> stateTrajectory, double reward) {
50      foreach (var state in stateTrajectory) {
51        GetStateInfo(state).UpdateReward(reward);
52
53        // only the last state can be terminal
54        if (problem.Grammar.IsTerminal(state)) {
55          GetStateInfo(state).Disable(reward);
56        }
57      }
58    }
59
60    public virtual void Reset() {
61      stateInfo.Clear();
62    }
63
64    public int GetTries(string state) {
65      var s = CanonicalState(state);
66      if (stateInfo.ContainsKey(s)) return stateInfo[s].Tries;
67      else return 0;
68    }
69
70    public double GetValue(string state) {
71      var s = CanonicalState(state);
72      if (stateInfo.ContainsKey(s)) return stateInfo[s].Value;
73      else return 0.0; // TODO: check alternatives
74    }
75
76    protected string CanonicalState(string state) {
77      if (useCanonicalState) {
78        return problem.CanonicalRepresentation(state);
79      } else
80        return state;
81    }
82  }
83}
Note: See TracBrowser for help on using the repository browser.