Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/SequentialDecisionPolicies/GenericFunctionApproximationGrammarPolicy.cs @ 11980

Last change on this file since 11980 was 11980, checked in by gkronber, 9 years ago

#2283: commit before cleanup after EuroCAST

File size: 7.6 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Runtime.ExceptionServices;
6using System.Text;
7using System.Text.RegularExpressions;
8using System.Threading;
9using System.Threading.Tasks;
10using HeuristicLab.Common;
11using HeuristicLab.Problems.GrammaticalOptimization;
12
13namespace HeuristicLab.Algorithms.Bandits.GrammarPolicies {
14  public sealed class GenericFunctionApproximationGrammarPolicy : IGrammarPolicy {
15    private Dictionary<string, double> featureWeigths; // stores the necessary information for bandit policies for each state (=canonical phrase)
16    private Dictionary<string, int> featureTries;
17    private HashSet<string> done;
18    private readonly bool useCanonicalPhrases;
19    private readonly IProblem problem;
20
21
22
23    public GenericFunctionApproximationGrammarPolicy(IProblem problem, bool useCanonicalPhrases = false) {
24      this.useCanonicalPhrases = useCanonicalPhrases;
25      this.problem = problem;
26      this.featureWeigths = new Dictionary<string, double>();
27      this.featureTries = new Dictionary<string, int>();
28      this.done = new HashSet<string>();
29    }
30
31    private double[] activeAfterStates; // don't allocate each time
32    private int[] actionIndexMap; // don't allocate each time
33
34    public bool TrySelect(Random random, string curState, IEnumerable<string> afterStates, out int selectedStateIdx) {
35      // fail if all states are done (corresponding state infos are disabled)
36      if (afterStates.All(s => Done(s))) {
37        // fail because all follow states have already been visited => also disable the current state (if we can be sure that it has been fully explored)
38        MarkAsDone(curState);
39
40        selectedStateIdx = -1;
41        return false;
42      }
43
44      // determine active actions (not done yet) and create an array to map the selected index back to original actions
45      if (activeAfterStates == null || activeAfterStates.Length < afterStates.Count()) {
46        activeAfterStates = new double[afterStates.Count()];
47        actionIndexMap = new int[afterStates.Count()];
48      }
49      var maxIdx = 0; int originalIdx = 0;
50      foreach (var afterState in afterStates) {
51        if (!Done(afterState)) {
52          activeAfterStates[maxIdx] = 0.0;
53          actionIndexMap[maxIdx] = originalIdx;
54
55
56          activeAfterStates[maxIdx] = GetValue(afterState);
57
58          maxIdx++;
59        }
60        originalIdx++;
61      }
62
63     
64      /*
65      const double beta = 10;
66      var w = from idx in Enumerable.Range(0, maxIdx)
67              let afterStateQ = activeAfterStates[idx]
68              select Math.Exp(beta * afterStateQ);
69
70      var bestAction = Enumerable.Range(0, maxIdx).SampleProportional(random, w);
71      selectedStateIdx = actionIndexMap[bestAction];
72      Debug.Assert(selectedStateIdx >= 0);
73      */
74     
75     
76      if (random.NextDouble() < 0.5) {
77        selectedStateIdx = actionIndexMap[random.Next(maxIdx)];
78      } else {
79        // find max
80        var bestQ = double.NegativeInfinity;
81        var bestIdxs = new List<int>();
82        for (int i = 0; i < maxIdx; i++) {
83          if (activeAfterStates[i] > bestQ) {
84            bestIdxs.Clear();
85            bestIdxs.Add(i);
86            bestQ = activeAfterStates[i];
87          } else if (activeAfterStates[i].IsAlmost(bestQ)) {
88            bestIdxs.Add(i);
89          }
90        }
91        selectedStateIdx = actionIndexMap[bestIdxs[random.Next(bestIdxs.Count)]];
92      }
93     
94      return true;
95    }
96
97
98    public void UpdateReward(IEnumerable<string> stateTrajectory, double reward) {
99      foreach (var state in stateTrajectory) {
100        UpdateWeights(state, reward);
101
102        // only the last state can be terminal
103        if (problem.Grammar.IsTerminal(state)) {
104          MarkAsDone(state);
105        }
106      }
107    }
108
109
110    private IEnumerable<KeyValuePair<string, double>> Values {
111      get { return featureWeigths.OrderByDescending(p => p.Value); }
112    }
113
114    public void Reset() {
115      featureWeigths.Clear();
116      done.Clear();
117    }
118
119    public int GetTries(string state) {
120      return 0;
121    }
122
123    public int GetFeatureTries(string featureId) {
124      int t;
125      if (featureTries.TryGetValue(featureId, out t)) {
126        return t;
127      } else return 0;
128    }
129
130    public double GetValue(string state) {
131      return problem.GetFeatures(state).Sum(feature => GetWeight(feature));
132    }
133
134    private double GetWeight(Feature feature) {
135      double w;
136      if (featureWeigths.TryGetValue(feature.Id, out w)) return w * feature.Value;
137      else return 0.0;
138    }
139    private void UpdateWeights(string state, double reward) {
140      double delta = reward - GetValue(state);
141      // delta /= problem.GetFeatures(state).Count();
142      //const double alpha = 0.001;
143      foreach (var feature in problem.GetFeatures(state)) {
144        featureTries[feature.Id] = GetFeatureTries(feature.Id) + 1;
145        Debug.Assert(GetFeatureTries(feature.Id) >= 1);
146        double alpha = 1.0 / GetFeatureTries(feature.Id);
147        alpha = Math.Max(alpha, 0.001);
148
149        double w;
150        if (!featureWeigths.TryGetValue(feature.Id, out w)) {
151          featureWeigths[feature.Id] = alpha * delta * feature.Value;
152        } else {
153          featureWeigths[feature.Id] += alpha * delta * feature.Value;
154        }
155      }
156    }
157
158
159
160    // the canonical states for the value function (banditInfos) and the done set must be distinguished
161    // sequences of different length could have the same canonical representation and can have the same value (banditInfo)
162    // however, if the canonical representation of a state is shorter than we must not mark the canonical state as done when all possible derivations from the initial state have been explored
163    // eg. in the ant problem the canonical representation for ...lllA is ...rA
164    // even though all possible derivations (of limited length) of lllA have been visited we must not mark the state rA as done
165    private void MarkAsDone(string state) {
166      var s = CanonicalState(state);
167      // when the lengths of the canonical string and the original string are the same we also disable the actions
168      // always disable terminals
169      Debug.Assert(s.Length <= state.Length);
170      if (s.Length == state.Length || problem.Grammar.IsTerminal(state)) {
171        Debug.Assert(!done.Contains(s));
172        done.Add(s);
173      } else {
174        // for non-terminals where the canonical string is shorter than the original string we can only disable the canonical representation for all states in the same level
175        Debug.Assert(!done.Contains(s + state.Length));
176        done.Add(s + state.Length); // encode the original length of the state, states in the same level of the tree are treated as equivalent
177      }
178    }
179
180    // symmetric to MarkDone
181    private bool Done(string state) {
182      var s = CanonicalState(state);
183      if (s.Length == state.Length || problem.Grammar.IsTerminal(state)) {
184        return done.Contains(s);
185      } else {
186        // it is not necessary to visit states if the canonical representation has already been fully explored
187        if (done.Contains(s)) return true;
188        if (done.Contains(s + state.Length)) return true;
189        for (int i = 1; i < state.Length; i++) {
190          if (done.Contains(s + i)) return true;
191        }
192        return false;
193      }
194    }
195
196    private string CanonicalState(string state) {
197      if (useCanonicalPhrases) {
198        return problem.CanonicalRepresentation(state);
199      } else
200        return state;
201    }
202  }
203}
Note: See TracBrowser for help on using the repository browser.