Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/GrammarPolicies/GenericFunctionApproximationGrammarPolicy.cs @ 11832

Last change on this file since 11832 was 11832, checked in by gkronber, 9 years ago

linear value function approximation and good results for poly-10 benchmark

File size: 6.4 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Runtime.ExceptionServices;
6using System.Text;
7using System.Text.RegularExpressions;
8using System.Threading;
9using System.Threading.Tasks;
10using HeuristicLab.Common;
11using HeuristicLab.Problems.GrammaticalOptimization;
12
13namespace HeuristicLab.Algorithms.Bandits.GrammarPolicies {
14  public sealed class GenericFunctionApproximationGrammarPolicy : IGrammarPolicy {
15    private Dictionary<string, double> featureWeigths; // stores the necessary information for bandit policies for each state (=canonical phrase)
16    private HashSet<string> done;
17    private readonly bool useCanonicalPhrases;
18    private readonly IProblem problem;
19
20
21    public GenericFunctionApproximationGrammarPolicy(IProblem problem, bool useCanonicalPhrases = false) {
22      this.useCanonicalPhrases = useCanonicalPhrases;
23      this.problem = problem;
24      this.featureWeigths = new Dictionary<string, double>();
25      this.done = new HashSet<string>();
26    }
27
28    private double[] activeAfterStates; // don't allocate each time
29    private int[] actionIndexMap; // don't allocate each time
30
31    public bool TrySelect(Random random, string curState, IEnumerable<string> afterStates, out int selectedStateIdx) {
32      // fail if all states are done (corresponding state infos are disabled)
33      if (afterStates.All(s => Done(s))) {
34        // fail because all follow states have already been visited => also disable the current state (if we can be sure that it has been fully explored)
35        MarkAsDone(curState);
36
37        selectedStateIdx = -1;
38        return false;
39      }
40
41      // determine active actions (not done yet) and create an array to map the selected index back to original actions
42      if (activeAfterStates == null || activeAfterStates.Length < afterStates.Count()) {
43        activeAfterStates = new double[afterStates.Count()];
44        actionIndexMap = new int[afterStates.Count()];
45      }
46      var maxIdx = 0; int originalIdx = 0;
47      foreach (var afterState in afterStates) {
48        if (!Done(afterState)) {
49          activeAfterStates[maxIdx] = 0.0;
50          actionIndexMap[maxIdx] = originalIdx;
51
52
53          activeAfterStates[maxIdx] = GetValue(afterState);
54
55          maxIdx++;
56        }
57        originalIdx++;
58      }
59
60      if (random.NextDouble() < 0.2) {
61        selectedStateIdx = actionIndexMap[random.Next(maxIdx)];
62      } else {
63        // find max
64        var bestQ = double.NegativeInfinity;
65        var bestIdx = -1;
66        for (int i = 0; i < maxIdx; i++) {
67          if (activeAfterStates[i] > bestQ) {
68            bestQ = activeAfterStates[i];
69            bestIdx = i;
70          }
71        }
72        selectedStateIdx = actionIndexMap[bestIdx];
73      }
74
75      return true;
76    }
77
78
79    public void UpdateReward(IEnumerable<string> stateTrajectory, double reward) {
80      foreach (var state in stateTrajectory) {
81        UpdateWeights(state, reward);
82
83        // only the last state can be terminal
84        if (problem.Grammar.IsTerminal(state)) {
85          MarkAsDone(state);
86        }
87      }
88    }
89
90
91    private IEnumerable<KeyValuePair<string, double>> Values {
92      get { return featureWeigths.OrderByDescending(p => p.Value); }
93    }
94
95    public void Reset() {
96      featureWeigths.Clear();
97      done.Clear();
98    }
99
100    public int GetTries(string state) {
101      return 1;
102    }
103
104    public double GetValue(string state) {
105      return problem.GetFeatures(state).Sum(feature => GetWeight(feature)) ;
106    }
107
108    private double GetWeight(Feature feature) {
109      double w;
110      if (featureWeigths.TryGetValue(feature.Id, out w)) return w * feature.Value;
111      else return 0.0;
112    }
113    private void UpdateWeights(string state, double reward) {
114      const double alpha = 0.01;
115      double delta = reward - GetValue(state);
116      foreach (var feature in problem.GetFeatures(state)) {
117        double w;
118        if (!featureWeigths.TryGetValue(feature.Id, out w)) {
119          featureWeigths[feature.Id] = alpha * delta;
120        } else {
121          featureWeigths[feature.Id] += alpha * delta;
122        }
123      }
124    }
125
126
127
128    // the canonical states for the value function (banditInfos) and the done set must be distinguished
129    // sequences of different length could have the same canonical representation and can have the same value (banditInfo)
130    // however, if the canonical representation of a state is shorter than we must not mark the canonical state as done when all possible derivations from the initial state have been explored
131    // eg. in the ant problem the canonical representation for ...lllA is ...rA
132    // even though all possible derivations (of limited length) of lllA have been visited we must not mark the state rA as done
133    private void MarkAsDone(string state) {
134      var s = CanonicalState(state);
135      // when the lengths of the canonical string and the original string are the same we also disable the actions
136      // always disable terminals
137      Debug.Assert(s.Length <= state.Length);
138      if (s.Length == state.Length || problem.Grammar.IsTerminal(state)) {
139        Debug.Assert(!done.Contains(s));
140        done.Add(s);
141      } else {
142        // for non-terminals where the canonical string is shorter than the original string we can only disable the canonical representation for all states in the same level
143        Debug.Assert(!done.Contains(s + state.Length));
144        done.Add(s + state.Length); // encode the original length of the state, states in the same level of the tree are treated as equivalent
145      }
146    }
147
148    // symmetric to MarkDone
149    private bool Done(string state) {
150      var s = CanonicalState(state);
151      if (s.Length == state.Length || problem.Grammar.IsTerminal(state)) {
152        return done.Contains(s);
153      } else {
154        // it is not necessary to visit states if the canonical representation has already been fully explored
155        if (done.Contains(s)) return true;
156        if (done.Contains(s + state.Length)) return true;
157        for (int i = 1; i < state.Length; i++) {
158          if (done.Contains(s + i)) return true;
159        }
160        return false;
161      }
162    }
163
164    private string CanonicalState(string state) {
165      if (useCanonicalPhrases) {
166        return problem.CanonicalRepresentation(state);
167      } else
168        return state;
169    }
170  }
171}
Note: See TracBrowser for help on using the repository browser.