Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
04/10/15 16:12:08 (10 years ago)
Author:
gkronber
Message:

#2283: experiments with q-learning

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Problems.GrammaticalOptimization-gkr/HeuristicLab.Algorithms.GrammaticalOptimization/SequentialDecisionPolicies/GenericPolicy.cs

    r12295 r12298  
    4848      foreach (var afterState in afterStates) {
    4949        if (!Done(afterState)) {
    50           if (GetTries(afterState) == 0)
    51             activeAfterStates[idx] = double.PositiveInfinity;
    52           else
    53             activeAfterStates[idx] = GetValue(afterState);
     50          activeAfterStates[idx] = CalculateValue(afterState);
    5451          actionIndexMap[idx] = originalIdx;
    5552          idx++;
     
    5855      }
    5956
     57
    6058      //var eps = Math.Max(500.0 / (GetTries(curState) + 1), 0.01);
    6159      //var eps = 10.0 / Math.Sqrt(GetTries(curState) + 1);
    62       var eps = 0.2;
     60      var eps = 0.01;
    6361      selectedStateIdx = actionIndexMap[SelectEpsGreedy(random, activeAfterStates.Take(idx), eps)];
    6462
     63      UpdateValue(curState, afterStates);
     64
    6565      return true;
    6666    }
     67
     68    private double CalculateValue(string chain) {
     69      var features = problem.GetFeatures(chain);
     70      var sum = 0.0;
     71      foreach (var f in features) {
     72        // if (GetTries(f.Id) == 0)
     73        //   sum = 0.0;
     74        // else
     75        sum += GetValue(f.Id) * f.Value;
     76      }
     77      return sum;
     78    }
     79
     80    private void UpdateValue(string curChain, IEnumerable<string> alternatives) {
     81      const double gamma = 1;
     82      const double alpha = 0.01;
     83      var maxNextQ = alternatives
     84          .Select(CalculateValue).Max();
     85
     86      var delta = gamma * maxNextQ - CalculateValue(curChain);
     87
     88      foreach (var f in problem.GetFeatures(curChain)) {
     89
     90        Q[f.Id] = GetValue(f.Id) + alpha * delta * f.Value;
     91      }
     92    }
     93
     94    private void UpdateLastValue(string terminalChain, double reward) {
     95      const double alpha = 0.01;
     96      var delta = reward - CalculateValue(terminalChain);
     97      foreach (var f in problem.GetFeatures(terminalChain)) {
     98        Q[f.Id] = GetValue(f.Id) + alpha * delta * f.Value;
     99      }
     100    }
     101
    67102
    68103    private int SelectBoltzmann(Random random, IEnumerable<double> qs, double beta = 10) {
     
    121156
    122157    public void UpdateReward(IEnumerable<string> chainTrajectory, double reward) {
    123       const double gamma = 0.95;
    124       const double minAlpha = 0.01;
    125       var reverseChains = chainTrajectory.Reverse();
    126       var terminalChain = reverseChains.First();
    127 
    128       var terminalState = CalcState(terminalChain);
    129       T[terminalState] = GetTries(terminalChain) + 1;
    130       double alpha = Math.Max(1.0 / GetTries(terminalChain), minAlpha);
    131       Q[terminalState] = (1 - alpha) * GetValue(terminalChain) + alpha * reward;
    132 
    133       foreach (var chain in reverseChains.Skip(1)) {
    134 
    135         var maxNextQ = followStates[chain]
    136           //.Where(s=>!Done(s))
    137           .Select(GetValue).Max();
    138         T[CalcState(chain)] = GetTries(chain) + 1;
    139 
    140         alpha = Math.Max(1.0 / GetTries(chain), minAlpha);
    141         Q[CalcState(chain)] = (1 - alpha) * GetValue(chain) + gamma * alpha * maxNextQ; // direct contribution is zero
    142       }
     158      // // only updates the last chain because we already update values after each step
     159      // var reverseChains = chainTrajectory.Reverse();
     160      // var terminalChain = reverseChains.First();
     161      //
     162      // UpdateValue(terminalChain, reward);
     163      //
     164      // foreach (var chain in reverseChains.Skip(1)) {
     165      //
     166      //   var maxNextQ = followStates[chain]
     167      //     //.Where(s=>!Done(s))
     168      //     .Select(GetValue).Max();
     169      //
     170      //   UpdateValue(chain, maxNextQ);
     171      // }
     172      var terminalChain = chainTrajectory.Last();
     173      UpdateLastValue(terminalChain, reward);
    143174      if (problem.Grammar.IsTerminal(terminalChain)) MarkAsDone(terminalChain);
    144175    }
     176
    145177
    146178    public void Reset() {
    147179      Q.Clear();
     180      T.Clear();
    148181      done.Clear();
    149182      followStates.Clear();
     
    160193
    161194
    162     public int GetTries(string state) {
    163       var s = CalcState(state);
    164       if (T.ContainsKey(s)) return T[s];
     195    public int GetTries(string fId) {
     196      if (T.ContainsKey(fId)) return T[fId];
    165197      else return 0;
    166198    }
    167199
    168     public double GetValue(string chain) {
    169       var s = CalcState(chain);
    170       if (Q.ContainsKey(s)) return Q[s];
     200    public double GetValue(string fId) {
     201      // var s = CalcState(chain);
     202      if (Q.ContainsKey(fId)) return Q[fId];
    171203      else return 0.0; // TODO: check alternatives
    172204    }
    173205
    174     private string CalcState(string chain) {
    175       var f = problem.GetFeatures(chain);
    176       // this policy only works for problems that return exactly one feature (the 'state')
    177       if (f.Skip(1).Any()) throw new ArgumentException();
    178       return f.First().Id;
    179     }
     206    // private string CalcState(string chain) {
     207    //   var f = problem.GetFeatures(chain);
     208    //   // this policy only works for problems that return exactly one feature (the 'state')
     209    //   if (f.Skip(1).Any()) throw new ArgumentException();
     210    //   return f.First().Id;
     211    // }
    180212
    181213    public void PrintStats() {
    182214      Console.WriteLine(Q.Values.Max());
    183       var topTries = Q.Keys.OrderByDescending(key => T[key]).Take(50);
    184       var topQs = Q.Keys.Where(key => key.Contains(",")).OrderByDescending(key => Q[key]).Take(50);
    185       foreach (var t in topTries.Zip(topQs, Tuple.Create)) {
    186         var id1 = t.Item1;
    187         var id2 = t.Item2;
    188         Console.WriteLine("{0,30} {1,6} {2:N4} {3,30} {4,6} {5:N4}", id1, T[id1], Q[id1], id2, T[id2], Q[id2]);
    189       }
    190 
     215      // var topTries = Q.Keys.OrderByDescending(key => T[key]).Take(50);
     216      // var topQs = Q.Keys/*.Where(key => key.Contains("E"))*/.OrderByDescending(key => Q[key]).Take(50);
     217      // foreach (var t in topTries.Zip(topQs, Tuple.Create)) {
     218      //   var id1 = t.Item1;
     219      //   var id2 = t.Item2;
     220      //   Console.WriteLine("{0,30} {1,6} {2:N4} {3,30} {4,6} {5:N4}", id1, T[id1], Q[id1], id2, T[id2], Q[id2]);
     221      // }
     222
     223      foreach (var option in new String[]
     224      {
     225        "a*b", "c*d", "a*b+c*d", "e*f", "a*b+c*d+e*f",
     226        "a*b+a*b", "c*d+c*d",
     227        "a*a", "a*b","a*c","a*d","a*e","a*f","a*g","a*h","a*i","a*j",
     228        "a*b","c*d","e*f","a*c","a*f","a*i","a*i*g","c*f","c*f*j",
     229        "b+c","a+c","b+d","a+d",
     230        "a*b+c*d+e*f", "a*b+c*d+e*f+a", "a*b+c*d+e*f+b", "a*b+c*d+e*f+c", "a*b+c*d+e*f+d","a*b+c*d+e*f+e",  "a*b+c*d+e*f+f", "a*b+c*d+e*f+g", "a*b+c*d+e*f+h", "a*b+c*d+e*f+i", "a*b+c*d+e*f+j",
     231        "a*b+c*d+e*f+a*g*i+c*j*f"
     232      }) {
     233        Console.WriteLine("{0,-10} {1:N5}", option, CalculateValue(option));
     234      }
     235
     236      // var topQs = Q.Keys/*.Where(key => key.Contains("E"))*/.OrderByDescending(key => Math.Abs(Q[key])).Take(10);
     237      // foreach (var t in topQs) {
     238      //   Console.WriteLine("{0,30} {1:N4}", t, Q[t]);
     239      // }
    191240    }
    192241  }
Note: See TracChangeset for help on using the changeset viewer.