Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Problems.GrammaticalOptimization-gkr/HeuristicLab.Algorithms.GrammaticalOptimization/SequentialDecisionPolicies/GrammarPolicy.cs @ 13348

Last change on this file since 13348 was 12893, checked in by gkronber, 9 years ago

#2283: experiments on grammatical optimization algorithms (maxreward instead of avg reward, ...)

File size: 2.5 KB
RevLine 
[11770]1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using System.Threading.Tasks;
6using HeuristicLab.Common;
7using HeuristicLab.Problems.GrammaticalOptimization;
8
9namespace HeuristicLab.Algorithms.Bandits.GrammarPolicies {
[11793]10  // stores: tries, avg reward and max reward for each state (base class for RandomPolicy and TDPolicy
[11770]11  public abstract class GrammarPolicy : IGrammarPolicy {
12    protected Dictionary<string, double> avgReward;
13    protected Dictionary<string, int> tries;
14    protected Dictionary<string, double> maxReward;
[11793]15    protected readonly bool useCanonicalState;
16    protected readonly IProblem problem;
[11770]17
[11793]18    protected GrammarPolicy(IProblem problem, bool useCanonicalState = false) {
[11770]19      this.useCanonicalState = useCanonicalState;
20      this.problem = problem;
21      this.tries = new Dictionary<string, int>();
22      this.avgReward = new Dictionary<string, double>();
23      this.maxReward = new Dictionary<string, double>();
24    }
25
[12893]26    public abstract bool TrySelect(System.Random random, string curState, IEnumerable<string> afterStates, out int selectedStateIdx);
[11770]27
[11793]28    public virtual void UpdateReward(IEnumerable<string> stateTrajectory, double reward) {
[11770]29      foreach (var state in stateTrajectory) {
[11793]30        var s = CanonicalState(state);
[11770]31
32        if (!tries.ContainsKey(s)) tries.Add(s, 0);
33        if (!avgReward.ContainsKey(s)) avgReward.Add(s, 0);
34        if (!maxReward.ContainsKey(s)) maxReward.Add(s, 0);
35
36        tries[s]++;
37        double alpha = 1.0 / tries[s];
38        avgReward[s] += alpha * (reward - avgReward[s]);
39        maxReward[s] = Math.Max(maxReward[s], reward);
40      }
41    }
42
43    public virtual void Reset() {
44      avgReward.Clear();
45      maxReward.Clear();
46      tries.Clear();
47    }
48
[11793]49    public double AvgReward(string state) {
50      var s = CanonicalState(state);
[11770]51      if (avgReward.ContainsKey(s)) return avgReward[s];
52      else return 0.0;
53    }
54
[11793]55    public double MaxReward(string state) {
56      var s = CanonicalState(state);
[11770]57      if (maxReward.ContainsKey(s)) return maxReward[s];
58      else return 0.0;
59    }
60
[11793]61    public virtual int GetTries(string state) {
62      var s = CanonicalState(state);
[11770]63      if (tries.ContainsKey(s)) return tries[s];
64      else return 0;
65    }
66
[11793]67    public virtual double GetValue(string state) {
[11770]68      return AvgReward(state);
69    }
70
71    protected string CanonicalState(string state) {
72      if (useCanonicalState) return problem.CanonicalRepresentation(state);
73      else return state;
74    }
75  }
76}
Note: See TracBrowser for help on using the repository browser.