Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/GrammarPolicies/GrammarPolicy.cs @ 11806

Last change on this file since 11806 was 11793, checked in by gkronber, 10 years ago

#2283 fixed compile errors and refactoring

File size: 2.5 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using System.Threading.Tasks;
6using HeuristicLab.Common;
7using HeuristicLab.Problems.GrammaticalOptimization;
8
9namespace HeuristicLab.Algorithms.Bandits.GrammarPolicies {
10  // stores: tries, avg reward and max reward for each state (base class for RandomPolicy and TDPolicy
11  public abstract class GrammarPolicy : IGrammarPolicy {
12    protected Dictionary<string, double> avgReward;
13    protected Dictionary<string, int> tries;
14    protected Dictionary<string, double> maxReward;
15    protected readonly bool useCanonicalState;
16    protected readonly IProblem problem;
17
18    protected GrammarPolicy(IProblem problem, bool useCanonicalState = false) {
19      this.useCanonicalState = useCanonicalState;
20      this.problem = problem;
21      this.tries = new Dictionary<string, int>();
22      this.avgReward = new Dictionary<string, double>();
23      this.maxReward = new Dictionary<string, double>();
24    }
25
26    public abstract bool TrySelect(Random random, string curState, IEnumerable<string> afterStates, out int selectedStateIdx);
27
28    public virtual void UpdateReward(IEnumerable<string> stateTrajectory, double reward) {
29      foreach (var state in stateTrajectory) {
30        var s = CanonicalState(state);
31
32        if (!tries.ContainsKey(s)) tries.Add(s, 0);
33        if (!avgReward.ContainsKey(s)) avgReward.Add(s, 0);
34        if (!maxReward.ContainsKey(s)) maxReward.Add(s, 0);
35
36        tries[s]++;
37        double alpha = 1.0 / tries[s];
38        avgReward[s] += alpha * (reward - avgReward[s]);
39        maxReward[s] = Math.Max(maxReward[s], reward);
40      }
41    }
42
43    public virtual void Reset() {
44      avgReward.Clear();
45      maxReward.Clear();
46      tries.Clear();
47    }
48
49    public double AvgReward(string state) {
50      var s = CanonicalState(state);
51      if (avgReward.ContainsKey(s)) return avgReward[s];
52      else return 0.0;
53    }
54
55    public double MaxReward(string state) {
56      var s = CanonicalState(state);
57      if (maxReward.ContainsKey(s)) return maxReward[s];
58      else return 0.0;
59    }
60
61    public virtual int GetTries(string state) {
62      var s = CanonicalState(state);
63      if (tries.ContainsKey(s)) return tries[s];
64      else return 0;
65    }
66
67    public virtual double GetValue(string state) {
68      return AvgReward(state);
69    }
70
71    protected string CanonicalState(string state) {
72      if (useCanonicalState) return problem.CanonicalRepresentation(state);
73      else return state;
74    }
75  }
76}
Note: See TracBrowser for help on using the repository browser.