Changeset 11976 for branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/SequentialDecisionPolicies
- Timestamp:
- 02/11/15 02:22:18 (10 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/SequentialDecisionPolicies/GenericFunctionApproximationGrammarPolicy.cs
r11974 r11976 14 14 public sealed class GenericFunctionApproximationGrammarPolicy : IGrammarPolicy { 15 15 private Dictionary<string, double> featureWeigths; // stores the necessary information for bandit policies for each state (=canonical phrase) 16 private Dictionary<string, int> featureTries; 16 17 private HashSet<string> done; 17 18 private readonly bool useCanonicalPhrases; 18 19 private readonly IProblem problem; 20 19 21 20 22 … … 23 25 this.problem = problem; 24 26 this.featureWeigths = new Dictionary<string, double>(); 27 this.featureTries = new Dictionary<string, int>(); 25 28 this.done = new HashSet<string>(); 26 29 } … … 57 60 originalIdx++; 58 61 } 59 60 const double beta = 20.0; 61 var w = from q in activeAfterStates 62 select Math.Exp(beta * q); 62 63 64 /* 65 const double beta = 1; 66 var w = from idx in Enumerable.Range(0, maxIdx) 67 let afterStateQ = activeAfterStates[idx] 68 select Math.Exp(beta * afterStateQ); 63 69 64 70 var bestAction = Enumerable.Range(0, maxIdx).SampleProportional(random, w); 65 71 selectedStateIdx = actionIndexMap[bestAction]; 66 72 Debug.Assert(selectedStateIdx >= 0); 67 68 /* 73 */ 74 75 69 76 if (random.NextDouble() < 0.2) { 70 77 selectedStateIdx = actionIndexMap[random.Next(maxIdx)]; … … 84 91 selectedStateIdx = actionIndexMap[bestIdxs[random.Next(bestIdxs.Count)]]; 85 92 } 86 */ 93 87 94 88 95 … … 114 121 115 122 public int GetTries(string state) { 116 return 1; 123 return 0; 124 } 125 126 public int GetFeatureTries(string featureId) { 127 int t; 128 if (featureTries.TryGetValue(featureId, out t)) { 129 return t; 130 } else return 0; 117 131 } 118 132 119 133 public double GetValue(string state) { 120 return problem.GetFeatures(state). Sum(feature => GetWeight(feature));134 return problem.GetFeatures(state).Average(feature => GetWeight(feature)); 121 135 } 122 136 … … 124 138 double w; 125 139 if (featureWeigths.TryGetValue(feature.Id, out w)) return w * feature.Value; 126 else return 0.0; // TODO: alternatives?140 else return 0.0; 127 141 } 128 142 private void UpdateWeights(string state, double reward) { 129 const double alpha = 0.01;130 143 double delta = reward - GetValue(state); 144 delta /= problem.GetFeatures(state).Count(); 145 const double alpha = 0.001; 131 146 foreach (var feature in problem.GetFeatures(state)) { 147 featureTries[feature.Id] = GetFeatureTries(feature.Id) + 1; 148 Debug.Assert(GetFeatureTries(feature.Id) >= 1); 149 //double alpha = 1.0 / GetFeatureTries(feature.Id); 150 //alpha = Math.Max(alpha, 0.01); 151 132 152 double w; 133 153 if (!featureWeigths.TryGetValue(feature.Id, out w)) { 134 featureWeigths[feature.Id] = alpha * delta ;154 featureWeigths[feature.Id] = alpha * delta * feature.Value; 135 155 } else { 136 featureWeigths[feature.Id] += alpha * delta ;156 featureWeigths[feature.Id] += alpha * delta * feature.Value; 137 157 } 138 158 }
Note: See TracChangeset
for help on using the changeset viewer.