Changeset 12298


Ignore:
Timestamp:
04/10/15 16:12:08 (4 years ago)
Author:
gkronber
Message:

#2283: experiments with q-learning

Location:
branches/HeuristicLab.Problems.GrammaticalOptimization-gkr
Files:
5 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Problems.GrammaticalOptimization-gkr/HeuristicLab.Algorithms.GrammaticalOptimization/SequentialDecisionPolicies/GenericPolicy.cs

    r12295 r12298  
    4848      foreach (var afterState in afterStates) {
    4949        if (!Done(afterState)) {
    50           if (GetTries(afterState) == 0)
    51             activeAfterStates[idx] = double.PositiveInfinity;
    52           else
    53             activeAfterStates[idx] = GetValue(afterState);
     50          activeAfterStates[idx] = CalculateValue(afterState);
    5451          actionIndexMap[idx] = originalIdx;
    5552          idx++;
     
    5855      }
    5956
     57
    6058      //var eps = Math.Max(500.0 / (GetTries(curState) + 1), 0.01);
    6159      //var eps = 10.0 / Math.Sqrt(GetTries(curState) + 1);
    62       var eps = 0.2;
     60      var eps = 0.01;
    6361      selectedStateIdx = actionIndexMap[SelectEpsGreedy(random, activeAfterStates.Take(idx), eps)];
    6462
     63      UpdateValue(curState, afterStates);
     64
    6565      return true;
    6666    }
     67
     68    private double CalculateValue(string chain) {
     69      var features = problem.GetFeatures(chain);
     70      var sum = 0.0;
     71      foreach (var f in features) {
     72        // if (GetTries(f.Id) == 0)
     73        //   sum = 0.0;
     74        // else
     75        sum += GetValue(f.Id) * f.Value;
     76      }
     77      return sum;
     78    }
     79
     80    private void UpdateValue(string curChain, IEnumerable<string> alternatives) {
     81      const double gamma = 1;
     82      const double alpha = 0.01;
     83      var maxNextQ = alternatives
     84          .Select(CalculateValue).Max();
     85
     86      var delta = gamma * maxNextQ - CalculateValue(curChain);
     87
     88      foreach (var f in problem.GetFeatures(curChain)) {
     89
     90        Q[f.Id] = GetValue(f.Id) + alpha * delta * f.Value;
     91      }
     92    }
     93
     94    private void UpdateLastValue(string terminalChain, double reward) {
     95      const double alpha = 0.01;
     96      var delta = reward - CalculateValue(terminalChain);
     97      foreach (var f in problem.GetFeatures(terminalChain)) {
     98        Q[f.Id] = GetValue(f.Id) + alpha * delta * f.Value;
     99      }
     100    }
     101
    67102
    68103    private int SelectBoltzmann(Random random, IEnumerable<double> qs, double beta = 10) {
     
    121156
    122157    public void UpdateReward(IEnumerable<string> chainTrajectory, double reward) {
    123       const double gamma = 0.95;
    124       const double minAlpha = 0.01;
    125       var reverseChains = chainTrajectory.Reverse();
    126       var terminalChain = reverseChains.First();
    127 
    128       var terminalState = CalcState(terminalChain);
    129       T[terminalState] = GetTries(terminalChain) + 1;
    130       double alpha = Math.Max(1.0 / GetTries(terminalChain), minAlpha);
    131       Q[terminalState] = (1 - alpha) * GetValue(terminalChain) + alpha * reward;
    132 
    133       foreach (var chain in reverseChains.Skip(1)) {
    134 
    135         var maxNextQ = followStates[chain]
    136           //.Where(s=>!Done(s))
    137           .Select(GetValue).Max();
    138         T[CalcState(chain)] = GetTries(chain) + 1;
    139 
    140         alpha = Math.Max(1.0 / GetTries(chain), minAlpha);
    141         Q[CalcState(chain)] = (1 - alpha) * GetValue(chain) + gamma * alpha * maxNextQ; // direct contribution is zero
    142       }
     158      // // only updates the last chain because we already update values after each step
     159      // var reverseChains = chainTrajectory.Reverse();
     160      // var terminalChain = reverseChains.First();
     161      //
     162      // UpdateValue(terminalChain, reward);
     163      //
     164      // foreach (var chain in reverseChains.Skip(1)) {
     165      //
     166      //   var maxNextQ = followStates[chain]
     167      //     //.Where(s=>!Done(s))
     168      //     .Select(GetValue).Max();
     169      //
     170      //   UpdateValue(chain, maxNextQ);
     171      // }
     172      var terminalChain = chainTrajectory.Last();
     173      UpdateLastValue(terminalChain, reward);
    143174      if (problem.Grammar.IsTerminal(terminalChain)) MarkAsDone(terminalChain);
    144175    }
     176
    145177
    146178    public void Reset() {
    147179      Q.Clear();
     180      T.Clear();
    148181      done.Clear();
    149182      followStates.Clear();
     
    160193
    161194
    162     public int GetTries(string state) {
    163       var s = CalcState(state);
    164       if (T.ContainsKey(s)) return T[s];
     195    public int GetTries(string fId) {
     196      if (T.ContainsKey(fId)) return T[fId];
    165197      else return 0;
    166198    }
    167199
    168     public double GetValue(string chain) {
    169       var s = CalcState(chain);
    170       if (Q.ContainsKey(s)) return Q[s];
     200    public double GetValue(string fId) {
     201      // var s = CalcState(chain);
     202      if (Q.ContainsKey(fId)) return Q[fId];
    171203      else return 0.0; // TODO: check alternatives
    172204    }
    173205
    174     private string CalcState(string chain) {
    175       var f = problem.GetFeatures(chain);
    176       // this policy only works for problems that return exactly one feature (the 'state')
    177       if (f.Skip(1).Any()) throw new ArgumentException();
    178       return f.First().Id;
    179     }
     206    // private string CalcState(string chain) {
     207    //   var f = problem.GetFeatures(chain);
     208    //   // this policy only works for problems that return exactly one feature (the 'state')
     209    //   if (f.Skip(1).Any()) throw new ArgumentException();
     210    //   return f.First().Id;
     211    // }
    180212
    181213    public void PrintStats() {
    182214      Console.WriteLine(Q.Values.Max());
    183       var topTries = Q.Keys.OrderByDescending(key => T[key]).Take(50);
    184       var topQs = Q.Keys.Where(key => key.Contains(",")).OrderByDescending(key => Q[key]).Take(50);
    185       foreach (var t in topTries.Zip(topQs, Tuple.Create)) {
    186         var id1 = t.Item1;
    187         var id2 = t.Item2;
    188         Console.WriteLine("{0,30} {1,6} {2:N4} {3,30} {4,6} {5:N4}", id1, T[id1], Q[id1], id2, T[id2], Q[id2]);
    189       }
    190 
     215      // var topTries = Q.Keys.OrderByDescending(key => T[key]).Take(50);
     216      // var topQs = Q.Keys/*.Where(key => key.Contains("E"))*/.OrderByDescending(key => Q[key]).Take(50);
     217      // foreach (var t in topTries.Zip(topQs, Tuple.Create)) {
     218      //   var id1 = t.Item1;
     219      //   var id2 = t.Item2;
     220      //   Console.WriteLine("{0,30} {1,6} {2:N4} {3,30} {4,6} {5:N4}", id1, T[id1], Q[id1], id2, T[id2], Q[id2]);
     221      // }
     222
     223      foreach (var option in new String[]
     224      {
     225        "a*b", "c*d", "a*b+c*d", "e*f", "a*b+c*d+e*f",
     226        "a*b+a*b", "c*d+c*d",
     227        "a*a", "a*b","a*c","a*d","a*e","a*f","a*g","a*h","a*i","a*j",
     228        "a*b","c*d","e*f","a*c","a*f","a*i","a*i*g","c*f","c*f*j",
     229        "b+c","a+c","b+d","a+d",
     230        "a*b+c*d+e*f", "a*b+c*d+e*f+a", "a*b+c*d+e*f+b", "a*b+c*d+e*f+c", "a*b+c*d+e*f+d","a*b+c*d+e*f+e",  "a*b+c*d+e*f+f", "a*b+c*d+e*f+g", "a*b+c*d+e*f+h", "a*b+c*d+e*f+i", "a*b+c*d+e*f+j",
     231        "a*b+c*d+e*f+a*g*i+c*j*f"
     232      }) {
     233        Console.WriteLine("{0,-10} {1:N5}", option, CalculateValue(option));
     234      }
     235
     236      // var topQs = Q.Keys/*.Where(key => key.Contains("E"))*/.OrderByDescending(key => Math.Abs(Q[key])).Take(10);
     237      // foreach (var t in topQs) {
     238      //   Console.WriteLine("{0,30} {1:N4}", t, Q[t]);
     239      // }
    191240    }
    192241  }
  • branches/HeuristicLab.Problems.GrammaticalOptimization-gkr/HeuristicLab.Problems.GrammaticalOptimization/PartialExpressionInterpreter.cs

    r12291 r12298  
    22using System.Collections.Generic;
    33using System.Linq;
     4using System.Security.Policy;
    45using HeuristicLab.Common;
    56
     
    910    private string sentence;
    1011    private int syIdx;
     12    private HashSet<double> intermediateValues = new HashSet<double>();
    1113    private Stack<double> stack = new Stack<double>();
    12     private Stack<char> opStack = new Stack<char>();
    1314    // interprets sentences from L(G(Expr)):
    1415    // Expr -> Term { ('+' | '-' | '^' ) Term }
     
    2223    // The constant symbols '0' .. '9' are treated as ERC indices
    2324
    24     public Stack<double> Interpret(string sentence, double[] vars) {
     25    public IEnumerable<double> Interpret(string sentence, double[] vars) {
    2526      return Interpret(sentence, vars, emptyErc);
    2627    }
    2728
    28     public Stack<double> Interpret(string sentence, double[] vars, double[] erc) {
     29    public IEnumerable<double> Interpret(string sentence, double[] vars, double[] erc) {
    2930      InitLex(sentence);
    30       stack.Clear(); opStack.Clear();
     31      intermediateValues.Clear();
     32      stack.Clear();
    3133      Expr(vars, erc);
    32       return new Stack<double>(stack);
     34      return intermediateValues;
    3335    }
    3436
     
    5860        if (curSy == '+') {
    5961          NewSy();
    60           if (!Term(d, erc)) { stack.Push(-1.0); return false; }
     62          if (!Term(d, erc)) { return false; }
    6163          stack.Push(stack.Pop() + stack.Pop());
     64          intermediateValues.Add(stack.Peek());
    6265        } else if (curSy == '-') {
    6366          NewSy();
    64           if (!Term(d, erc)) { stack.Push(-2.0); return false; return false; }
     67          if (!Term(d, erc)) { return false; }
    6568          stack.Push(-stack.Pop() + stack.Pop());
     69          intermediateValues.Add(stack.Peek());
    6670        } else {
    6771          NewSy();
    68           if (!Term(d, erc)) { stack.Push(-3.0); return false; }
     72          if (!Term(d, erc)) { return false; }
    6973          var e = stack.Pop();
    7074          var r = stack.Pop();
    7175          stack.Push(Not(r) * e + r * Not(e)); // xor = (!x AND y) OR (x AND !y)
     76          intermediateValues.Add(stack.Peek());
    7277        }
    7378        curSy = CurSy();
     
    8287        if (curSy == '*') {
    8388          NewSy();
    84           if (!Fact(d, erc)) { stack.Push(-4.0); return false; }
     89          if (!Fact(d, erc)) { return false; }
    8590          stack.Push(stack.Pop() * stack.Pop());
     91          intermediateValues.Add(stack.Peek());
    8692        } else {
    8793          NewSy();
    88           if (!Fact(d, erc)) { stack.Push(-5.0); return false; }
     94          if (!Fact(d, erc)) { return false; }
    8995          var nom = stack.Pop();
    9096          var r = stack.Pop();
    9197          if (HeuristicLab.Common.Extensions.IsAlmost(nom, 0.0)) nom = 1.0;
    9298          stack.Push(r / nom);
     99          intermediateValues.Add(stack.Peek());
    93100        }
    94101        curSy = CurSy();
     
    100107      var curSy = CurSy();
    101108      if (curSy == '!') {
    102         NewSy();
    103         if (!Expr(d, erc)) { stack.Push(-7.0); return false; }
    104         stack.Push(Not(stack.Pop()));
     109        //NewSy();
     110        //if (!Expr(d, erc)) { stack.Push(-7.0); return false; }
     111        //stack.Push(Not(stack.Pop()));
    105112      } else if (curSy == '(') {
    106         NewSy();
    107         if (!Expr(d, erc)) { stack.Push(-8.0); return false; }
    108         if (CurSy() != ')') throw new ArgumentException();
    109         NewSy();
     113        //NewSy();
     114        //if (!Expr(d, erc)) { stack.Push(-8.0); return false; }
     115        //if (CurSy() != ')') throw new ArgumentException();
     116        //NewSy();
    110117      } else if (curSy >= 'a' && curSy <= 'z') {
    111118        int o = (byte)curSy - (byte)'a';
     
    113120        if (o < 0 || o >= d.Length) throw new ArgumentException();
    114121        stack.Push(d[o]);
     122        intermediateValues.Add(stack.Peek());
    115123        NewSy();
    116124      } else if (curSy == '/') {
    117125        // /-symbol is used in the expressionextender to represent inverse (1/x).
    118126        // this is necessary because we also use symbols 0..9 as indices for ERCs
    119         NewSy();
    120         if (!Fact(d, erc)) { stack.Push(-9.0); return false; }
    121         stack.Push(1.0 / stack.Pop());
     127        //NewSy();
     128        //if (!Fact(d, erc)) { stack.Push(-9.0); return false; }
     129        //stack.Push(1.0 / stack.Pop());
    122130      } else if (curSy >= '0' && curSy <= '9') {
    123131        int o = (byte)curSy - (byte)'0';
     
    125133        if (o < 0 || o >= 10) throw new ArgumentException();
    126134        stack.Push(erc[o]);
     135        intermediateValues.Add(stack.Peek());
    127136        NewSy();
    128137      } else {
  • branches/HeuristicLab.Problems.GrammaticalOptimization-gkr/HeuristicLab.Problems.GrammaticalOptimization/Problems/SymbolicRegressionPoly10Problem.cs

    r12295 r12298  
    156156    // splits the phrase into terms and creates (sparse) term-occurrance features
    157157    public IEnumerable<Feature> GetFeatures(string phrase) {
    158       // var canonicalTerms = new HashSet<string>();
    159       // foreach (string t in phrase.Split('+')) {
    160       //   canonicalTerms.Add(CanonicalTerm(t));
     158      //if (phrase.EndsWith("E")) phrase = phrase.TrimEnd('*', '+', 'E');
     159      //yield return new Feature("$$$", 1.0); // const
     160      //var canonicalTerms = new HashSet<string>();
     161      //foreach (string t in phrase.Split('+')) {
     162      //  canonicalTerms.Add(CanonicalTerm(t));
     163      //}
     164      //return canonicalTerms.Select(entry => new Feature(entry, 1.0));
     165      //.Concat(new Feature[] { new Feature(CanonicalRepresentation(phrase), 1.0) });
     166
     167
     168      if (phrase.EndsWith("E")) phrase = phrase.TrimEnd('*', '+', 'E');
     169      //var len = 5;
     170      //var start = Math.Max(0, phrase.Length - len);
     171      //var end = Math.Min(phrase.Length, start + len);
     172      //string f = phrase.Substring(start, end - start);
     173      //yield return new Feature(f, 1.0);
     174      //
     175
     176      var terms = phrase.Split('+');
     177      foreach (var t in terms.Distinct()) yield return new Feature(t, 1.0);
     178
     179      for (int i = 0; i < terms.Length; i++) {
     180        for (int j = i + 1; j < terms.Length; j++) {
     181          yield return new Feature(terms[i] + " " + terms[j], 1.0);
     182        }
     183      }
     184
     185      // var substrings = new HashSet<string>();
     186      // for (int start = 0; start <= phrase.Length - 2; start += 2) {
     187      //   var s = phrase.Substring(start, 3);
     188      //   substrings.Add(s);
    161189      // }
    162       // return canonicalTerms.Select(entry => new Feature(entry, 1.0))
    163       //   .Concat(new Feature[] { new Feature(CanonicalRepresentation(phrase), 1.0) });
    164 
    165       return new Feature[] { new Feature(phrase, 1.0), };
    166 
    167       // var partialInterpreter = new PartialExpressionInterpreter();
    168       // var vars = new double[] { 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, };
    169       // var s = partialInterpreter.Interpret(phrase, vars);
    170       // //if (s.Any())
    171       // //  return new Feature[] { new Feature(s.Pop().ToString(), 1.0), };
    172       // //else
    173       // //  return new Feature[] { new Feature("$", 1.0), };
    174       // return new Feature[] { new Feature(string.Join(",", s), 1.0) };
     190      //
     191      // var list = new List<string>(substrings);
     192      //
     193      // for (int i = 0; i < list.Count; i++) {
     194      //   yield return new Feature(list[i], 1.0);
     195      //   //for (int j = i+1; j < list.Count; j++) {
     196      //   //  yield return new Feature(list[i] + " " + list[j], 1.0);
     197      //   //}
     198      // }
     199
     200      //
     201      // for (int len = 1; len <= phrase.Length; len += 2) {
     202      //   var start = Math.Max(0, phrase.Length - len);
     203      //   var end = Math.Min(phrase.Length, start + len);
     204      //   string f = phrase.Substring(start, end - start);
     205      //   yield return new Feature(f, 1.0);
     206      //
     207      // }
     208
     209      //var partialInterpreter = new PartialExpressionInterpreter();
     210      //var vars = new double[] { 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, };
     211      //var s = partialInterpreter.Interpret(phrase, vars);
     212      ////if (s.Any())
     213      ////  return new Feature[] { new Feature(s.Pop().ToString(), 1.0), };
     214      ////else
     215      ////  return new Feature[] { new Feature("$", 1.0), };
     216      //return s.Select(f => new Feature(f.ToString(), 1.0));
    175217    }
    176218
  • branches/HeuristicLab.Problems.GrammaticalOptimization-gkr/HeuristicLab.Problems.GrammaticalOptimization/SentenceSetStatistics.cs

    r11865 r12298  
    5050    public override string ToString() {
    5151      return
    52         string.Format("Sentences: {0,10} avg.-quality {1,7:F5} best {2,7:F5} {3,2} {4,10} {5,30} first {6,7:F5} {7,20} last {8,7:F5} {9,20}",
     52        string.Format("Sentences: {0,10} avg.-quality {1,7:F5} best {2,7:F5} {3,2} {4,10} {5,30} last {6,7:F5} {7,20}",
    5353      NumberOfSentences, AverageQuality,
    5454      BestSentenceQuality, DoubleExtensions.IsAlmost(BestSentenceQuality, bestKnownQuality) ? 1.0 : 0.0,
    5555      BestSentenceIndex, TrimToSize(BestSentence, 30),
    56       FirstSentenceQuality, TrimToSize(FirstSentence, 20),
    5756      LastSentenceQuality, TrimToSize(LastSentence, 20)
     57      //LastSentenceQuality, TrimToSize(LastSentence, 20)
    5858     );
    5959    }
  • branches/HeuristicLab.Problems.GrammaticalOptimization-gkr/Main/Program.cs

    r12295 r12298  
    11using System;
     2using System.Collections.Generic;
    23using System.Diagnostics;
    34using System.Globalization;
     5using System.Linq;
     6using System.Text.RegularExpressions;
    47using HeuristicLab.Algorithms.Bandits.BanditPolicies;
    58using HeuristicLab.Algorithms.Bandits.GrammarPolicies;
     
    3841
    3942        var globalStatistics = new SentenceSetStatistics();
     43        ResetAlleleStatistics();
    4044        var random = new Random();
    4145
     
    6670          iterations++;
    6771          globalStatistics.AddSentence(sentence, quality);
    68 
     72          UpdateAlleleStatistics(sentence);
    6973          // comment this if you don't want to see solver statistics
    7074          if (iterations % 100 == 0) {
    71             if (iterations % 1000 == 0) Console.Clear();
     75            if (iterations % 1000 == 0) {
     76              Console.Clear();
     77            }
    7278            Console.SetCursorPosition(0, 0);
    73              alg.PrintStats();
    74             //policy.PrintStats();
     79            Console.WriteLine(iterations);
     80            WriteAlleleStatistics();
     81            Console.WriteLine(globalStatistics.BestSentenceQuality);
     82            Console.WriteLine(globalStatistics.BestSentence);
     83            Console.WriteLine(globalStatistics);
     84            //alg.PrintStats();
     85            policy.PrintStats();
     86            //ResetAlleleStatistics();
    7587          }
    76 
     88         
    7789          // uncomment this if you want to collect statistics of the generated sentences
    7890          //if (iterations % 100 == 0) {
     
    94106      }
    95107    }
     108
     109    private static void UpdateAlleleStatistics(string sentence) {
     110      for (int i = 0; i < sentence.Length; i++) {
     111        var allele = sentence.Substring(i, 1);
     112        if (alleleStatistics.ContainsKey(allele)) alleleStatistics[allele]++;
     113      }
     114      for (int i = 0; i < sentence.Length - 2; i+=2) {
     115        var allele = sentence.Substring(i, 3);
     116        if (alleleStatistics.ContainsKey(allele)) alleleStatistics[allele]++;
     117      }
     118      for (int i = 0; i < sentence.Length - 4; i+=2) {
     119        var allele = sentence.Substring(i, 5);
     120        if (alleleStatistics.ContainsKey(allele)) alleleStatistics[allele]++;
     121      }
     122    }
     123
     124
     125    private static Dictionary<string, int> alleleStatistics;
     126
     127    private static void ResetAlleleStatistics() {
     128      alleleStatistics = new Dictionary<string, int>()
     129      {
     130        {"a", 0},
     131        {"b", 0},
     132        {"c", 0},
     133        {"d", 0},
     134        {"e", 0},
     135        {"f", 0},
     136        {"g", 0},
     137        {"h", 0},
     138        {"i", 0},
     139        {"j", 0},
     140        {"a*b", 0},
     141        {"b*a", 0},
     142        {"c*d", 0},
     143        {"d*c", 0},
     144        {"e*f", 0},
     145        {"f*e", 0},
     146        {"a*g*i", 0},
     147        {"a*i*g", 0},
     148        {"g*a*i", 0},
     149        {"g*i*a", 0},
     150        {"i*g*a", 0},
     151        {"i*a*g", 0},
     152        {"j*c*f", 0},
     153        {"j*f*c", 0},
     154        {"c*j*f", 0},
     155        {"c*f*j", 0},
     156        {"f*c*j", 0},
     157        {"f*j*c", 0}
     158      };
     159    }
     160
     161
     162    private static void WriteAlleleStatistics() {
     163      double count = alleleStatistics.Sum(e => e.Value);
     164      foreach (var entry in alleleStatistics.OrderByDescending(e=>e.Value)) {
     165        Console.WriteLine("{0,-10} {1,-10}", entry.Key, entry.Value);
     166      }
     167    }
    96168  }
    97169}
Note: See TracChangeset for help on using the changeset viewer.