Changeset 11730


Ignore:
Timestamp:
01/02/15 16:08:21 (6 years ago)
Author:
gkronber
Message:

#2283: several major extensions for grammatical optimization

Location:
branches/HeuristicLab.Problems.GrammaticalOptimization
Files:
14 added
31 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Bandits/BernoulliBandit.cs

    r11711 r11730  
    66
    77namespace HeuristicLab.Algorithms.Bandits {
    8   public class BernoulliBandit {
     8  public class BernoulliBandit : IBandit {
    99    public int NumArms { get; private set; }
    1010    public double OptimalExpectedReward { get; private set; } // reward of the best arm, for calculating regret
     11    public int OptimalExpectedRewardArm { get; private set; }
     12    // the arm with highest expected reward also has the highest probability of return a reward of 1.0
     13    public int OptimalMaximalRewardArm { get { return OptimalExpectedRewardArm; } }
     14
    1115    private readonly Random random;
    1216    private readonly double[] expReward;
     
    1923      for (int i = 0; i < nArms; i++) {
    2024        expReward[i] = random.NextDouble();
    21         if (expReward[i] > OptimalExpectedReward) OptimalExpectedReward = expReward[i];
     25        if (expReward[i] > OptimalExpectedReward) {
     26          OptimalExpectedReward = expReward[i];
     27          OptimalExpectedRewardArm = i;
     28        }
    2229      }
    2330    }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Bandits/TruncatedNormalBandit.cs

    r11711 r11730  
    44using System.Text;
    55using System.Threading.Tasks;
     6using HeuristicLab.Common;
    67
    78namespace HeuristicLab.Algorithms.Bandits {
    8   public class TruncatedNormalBandit {
     9  public class TruncatedNormalBandit : IBandit {
    910    public int NumArms { get; private set; }
    1011    public double OptimalExpectedReward { get; private set; } // reward of the best arm, for calculating regret
     12    public int OptimalExpectedRewardArm { get; private set; }
     13    // the arm with highest expected reward also has the highest probability of return a reward of 1.0
     14    public int OptimalMaximalRewardArm { get { return OptimalExpectedRewardArm; } }
     15
    1116    private readonly Random random;
    1217    private readonly double[] expReward;
     
    1823      OptimalExpectedReward = double.NegativeInfinity;
    1924      for (int i = 0; i < nArms; i++) {
    20         expReward[i] = random.NextDouble();
    21         if (expReward[i] > OptimalExpectedReward) OptimalExpectedReward = expReward[i];
     25        expReward[i] = random.NextDouble() * 0.7;
     26        if (expReward[i] > OptimalExpectedReward) {
     27          OptimalExpectedReward = expReward[i];
     28          OptimalExpectedRewardArm = i;
     29        }
    2230      }
    2331    }
     
    2836      double x = 0;
    2937      do {
    30         var z = Transform(random.NextDouble(), random.NextDouble());
     38        var z = Rand.RandNormal(random);
    3139        x = z * 0.1 + expReward[arm];
    3240      }
     
    3442      return x;
    3543    }
    36 
    37     // box muller transform
    38     private double Transform(double u1, double u2) {
    39       return Math.Sqrt(-2 * Math.Log(u1)) * Math.Cos(2 * Math.PI * u2);
    40     }
    4144  }
    4245}
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/HeuristicLab.Algorithms.Bandits.csproj

    r11727 r11730  
    4040  </ItemGroup>
    4141  <ItemGroup>
     42    <Compile Include="BanditHelper.cs" />
    4243    <Compile Include="Bandits\BernoulliBandit.cs" />
     44    <Compile Include="Bandits\GaussianMixtureBandit.cs" />
     45    <Compile Include="Bandits\IBandit.cs" />
    4346    <Compile Include="Bandits\TruncatedNormalBandit.cs" />
     47    <Compile Include="Models\BernoulliModel.cs" />
     48    <Compile Include="Models\GaussianModel.cs" />
     49    <Compile Include="Models\GaussianMixtureModel.cs" />
     50    <Compile Include="Models\IModel.cs" />
    4451    <Compile Include="Policies\BanditPolicy.cs" />
    4552    <Compile Include="Policies\BernoulliThompsonSamplingPolicy.cs" />
     53    <Compile Include="Policies\BoltzmannExplorationPolicy.cs" />
     54    <Compile Include="Policies\ChernoffIntervalEstimationPolicy.cs" />
     55    <Compile Include="Policies\GenericThompsonSamplingPolicy.cs" />
     56    <Compile Include="Policies\ThresholdAscentPolicy.cs" />
     57    <Compile Include="Policies\UCTPolicy.cs" />
    4658    <Compile Include="Policies\GaussianThompsonSamplingPolicy.cs" />
    4759    <Compile Include="Policies\Exp3Policy.cs" />
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/IPolicy.cs

    r11727 r11730  
    2020    // reset causes the policy to be reinitialized to it's initial state (as after constructor-call)
    2121    void Reset();
     22
     23    void PrintStats();
    2224  }
    2325}
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/BanditPolicy.cs

    r11727 r11730  
    2828      Actions = Enumerable.Range(0, numInitialActions).ToArray();
    2929    }
     30
     31    public abstract void PrintStats();
    3032  }
    3133}
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/BernoulliThompsonSamplingPolicy.cs

    r11727 r11730  
    5555      Array.Clear(failure, 0, failure.Length);
    5656    }
     57
     58    public override void PrintStats() {
     59      for (int i = 0; i < success.Length; i++) {
     60        if (success[i] >= 0) {
     61          Console.Write("{0,5:F2}", success[i] / failure[i]);
     62        } else {
     63          Console.Write("{0,5}", "");
     64        }
     65      }
     66      Console.WriteLine();
     67    }
     68
     69    public override string ToString() {
     70      return "BernoulliThompsonSamplingPolicy";
     71    }
    5772  }
    5873}
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/EpsGreedyPolicy.cs

    r11727 r11730  
    2727      if (random.NextDouble() > eps) {
    2828        // select best
    29         var maxReward = double.NegativeInfinity;
     29        var bestQ = double.NegativeInfinity;
    3030        int bestAction = -1;
    3131        foreach (var a in Actions) {
    3232          if (tries[a] == 0) return a;
    33           var avgReward = sumReward[a] / tries[a];
    34           if (maxReward < avgReward) {
    35             maxReward = avgReward;
     33          var q = sumReward[a] / tries[a];
     34          if (bestQ < q) {
     35            bestQ = q;
    3636            bestAction = a;
    3737          }
     
    6565      Array.Clear(sumReward, 0, sumReward.Length);
    6666    }
     67    public override void PrintStats() {
     68      for (int i = 0; i < sumReward.Length; i++) {
     69        if (tries[i] >= 0) {
     70          Console.Write(" {0,5:F2} {1}", sumReward[i] / tries[i], tries[i]);
     71        } else {
     72          Console.Write("-", "");
     73        }
     74      }
     75      Console.WriteLine();
     76    }
     77    public override string ToString() {
     78      return string.Format("EpsGreedyPolicy({0:F2})", eps);
     79    }
    6780  }
    6881}
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/Exp3Policy.cs

    r11727 r11730  
    5252      foreach (var a in Actions) w[a] = 1.0;
    5353    }
     54    public override void PrintStats() {
     55      for (int i = 0; i < w.Length; i++) {
     56        if (w[i] > 0) {
     57          Console.Write("{0,5:F2}", w[i]);
     58        } else {
     59          Console.Write("{0,5}", "");
     60        }
     61      }
     62      Console.WriteLine();
     63    }
     64    public override string ToString() {
     65      return "Exp3Policy";
     66    }
    5467  }
    5568}
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/GaussianThompsonSamplingPolicy.cs

    r11727 r11730  
    55
    66namespace HeuristicLab.Algorithms.Bandits {
     7 
    78  public class GaussianThompsonSamplingPolicy : BanditPolicy {
    89    private readonly Random random;
    9     private readonly double[] sumRewards;
    10     private readonly double[] sumSqrRewards;
     10    private readonly double[] sampleMean;
     11    private readonly double[] sampleM2;
    1112    private readonly int[] tries;
    12     public GaussianThompsonSamplingPolicy(Random random, int numActions)
     13    private bool compatibility;
     14
     15    // assumes a Gaussian reward distribution with different means but the same variances for each action
     16    // the prior for the mean is also Gaussian with the following parameters
     17    private readonly double rewardVariance = 0.1; // we assume a known variance
     18
     19    private readonly double priorMean = 0.5;
     20    private readonly double priorVariance = 1;
     21
     22
     23    public GaussianThompsonSamplingPolicy(Random random, int numActions, bool compatibility = false)
    1324      : base(numActions) {
    1425      this.random = random;
    15       this.sumRewards = new double[numActions];
    16       this.sumSqrRewards = new double[numActions];
     26      this.sampleMean = new double[numActions];
     27      this.sampleM2 = new double[numActions];
    1728      this.tries = new int[numActions];
     29      this.compatibility = compatibility;
    1830    }
    1931
     
    2436      int bestAction = -1;
    2537      foreach (var a in Actions) {
    26         if (tries[a] == 0) return a;
    27         var mu = sumRewards[a] / tries[a];
    28         var stdDev = Math.Sqrt(sumSqrRewards[a] / tries[a] - Math.Pow(mu, 2));
    29         var theta = Rand.RandNormal(random) * stdDev + mu;
     38        if(tries[a] == -1) continue; // skip disabled actions
     39        double theta;
     40        if (compatibility) {
     41          if (tries[a] < 2) return a;
     42          var mu = sampleMean[a];
     43          var variance = sampleM2[a] / tries[a];
     44          var stdDev = Math.Sqrt(variance);
     45          theta = Rand.RandNormal(random) * stdDev + mu;
     46        } else {
     47          // calculate posterior mean and variance (for mean reward)
     48
     49          // see Murphy 2007: Conjugate Bayesian analysis of the Gaussian distribution (http://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf)
     50          var posteriorVariance = 1.0 / (tries[a] / rewardVariance + 1.0 / priorVariance);
     51          var posteriorMean = posteriorVariance * (priorMean / priorVariance + tries[a] * sampleMean[a] / rewardVariance);
     52
     53          // sample a mean from the posterior
     54          theta = Rand.RandNormal(random) * Math.Sqrt(posteriorVariance) + posteriorMean;
     55
     56          // theta already represents the expected reward value => nothing else to do
     57        }
    3058        if (theta > maxTheta) {
    3159          maxTheta = theta;
     
    3361        }
    3462      }
     63      Debug.Assert(Actions.Contains(bestAction));
    3564      return bestAction;
    3665    }
     
    3867    public override void UpdateReward(int action, double reward) {
    3968      Debug.Assert(Actions.Contains(action));
    40 
    41       sumRewards[action] += reward;
    42       sumSqrRewards[action] += reward * reward;
    4369      tries[action]++;
     70      var delta = reward - sampleMean[action];
     71      sampleMean[action] += delta / tries[action];
     72      sampleM2[action] += sampleM2[action] + delta * (reward - sampleMean[action]);
    4473    }
    4574
    4675    public override void DisableAction(int action) {
    4776      base.DisableAction(action);
    48       sumRewards[action] = 0;
    49       sumSqrRewards[action] = 0;
     77      sampleMean[action] = 0;
     78      sampleM2[action] = 0;
    5079      tries[action] = -1;
    5180    }
     
    5382    public override void Reset() {
    5483      base.Reset();
    55       Array.Clear(sumRewards, 0, sumRewards.Length);
    56       Array.Clear(sumSqrRewards, 0, sumSqrRewards.Length);
     84      Array.Clear(sampleMean, 0, sampleMean.Length);
     85      Array.Clear(sampleM2, 0, sampleM2.Length);
    5786      Array.Clear(tries, 0, tries.Length);
     87    }
     88
     89    public override void PrintStats() {
     90      for (int i = 0; i < sampleMean.Length; i++) {
     91        if (tries[i] >= 0) {
     92          Console.Write(" {0,5:F2} {1}", sampleMean[i] / tries[i], tries[i]);
     93        } else {
     94          Console.Write("{0,5}", "");
     95        }
     96      }
     97      Console.WriteLine();
     98    }
     99    public override string ToString() {
     100      return "GaussianThompsonSamplingPolicy";
    58101    }
    59102  }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/RandomPolicy.cs

    r11727 r11730  
    2323      // do nothing
    2424    }
    25 
     25    public override void PrintStats() {
     26      Console.WriteLine("Random");
     27    }
     28    public override string ToString() {
     29      return "RandomPolicy";
     30    }
    2631  }
    2732}
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCB1Policy.cs

    r11727 r11730  
    5050      Array.Clear(sumReward, 0, sumReward.Length);
    5151    }
     52    public override void PrintStats() {
     53      for (int i = 0; i < sumReward.Length; i++) {
     54        if (tries[i] >= 0) {
     55          Console.Write("{0,5:F2}", sumReward[i] / tries[i]);
     56        } else {
     57          Console.Write("{0,5}", "");
     58        }
     59      }
     60      Console.WriteLine();
     61    }
     62    public override string ToString() {
     63      return "UCB1Policy";
     64    }
    5265  }
    5366}
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCB1TunedPolicy.cs

    r11727 r11730  
    6262      Array.Clear(sumSqrReward, 0, sumSqrReward.Length);
    6363    }
     64    public override void PrintStats() {
     65      for (int i = 0; i < sumReward.Length; i++) {
     66        if (tries[i] >= 0) {
     67          Console.Write("{0,5:F2}", sumReward[i] / tries[i]);
     68        } else {
     69          Console.Write("{0,5}", "");
     70        }
     71      }
     72      Console.WriteLine();
     73    }
     74    public override string ToString() {
     75      return "UCB1TunedPolicy";
     76    }
    6477  }
    6578}
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCBNormalPolicy.cs

    r11727 r11730  
    2424      double bestQ = double.NegativeInfinity;
    2525      foreach (var a in Actions) {
    26         if (totalTries == 0 || tries[a] == 0 || tries[a] < Math.Ceiling(8 * Math.Log(totalTries))) return a;
     26        if (totalTries <= 1 || tries[a] <= 1 || tries[a] <= Math.Ceiling(8 * Math.Log(totalTries))) return a;
    2727        var avgReward = sumReward[a] / tries[a];
     28        var estVariance = 16 * ((sumSqrReward[a] - tries[a] * Math.Pow(avgReward, 2)) / (tries[a] - 1)) * (Math.Log(totalTries - 1) / tries[a]);
     29        if (estVariance < 0) estVariance = 0; // numerical problems
    2830        var q = avgReward
    29           + Math.Sqrt(16 * ((sumSqrReward[a] - tries[a] * Math.Pow(avgReward, 2)) / (tries[a] - 1)) * (Math.Log(totalTries - 1) / tries[a]));
     31          + Math.Sqrt(estVariance);
    3032        if (q > bestQ) {
    3133          bestQ = q;
     
    3335        }
    3436      }
     37      Debug.Assert(Actions.Contains(bestAction));
    3538      return bestAction;
    3639    }
     
    5861      Array.Clear(sumSqrReward, 0, sumSqrReward.Length);
    5962    }
     63    public override void PrintStats() {
     64      for (int i = 0; i < sumReward.Length; i++) {
     65        if (tries[i] >= 0) {
     66          Console.Write("{0,5:F2}", sumReward[i] / tries[i]);
     67        } else {
     68          Console.Write("{0,5}", "");
     69        }
     70      }
     71      Console.WriteLine();
     72    }
     73    public override string ToString() {
     74      return "UCBNormalPolicy";
     75    }
    6076  }
    6177}
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/AlternativesContextSampler.cs

    r11727 r11730  
    1717    private readonly Random random;
    1818    private readonly int contextLen;
     19    private readonly Func<Random, int, IPolicy> policyFactory;
    1920
    20     public AlternativesContextSampler(IProblem problem, int maxLen) {
     21    public AlternativesContextSampler(IProblem problem, Random random, int maxLen, int contextLen, Func<Random, int, IPolicy> policyFactory) {
    2122      this.maxLen = maxLen;
    2223      this.problem = problem;
    23       this.random = new Random(31415);
    24       this.contextLen = 25;
     24      this.random = random;
     25      this.contextLen = contextLen;
     26      this.policyFactory = policyFactory;
    2527    }
    2628
     
    2931      InitPolicies(problem.Grammar);
    3032      for (int i = 0; i < maxIterations; i++) {
    31         var sentence = SampleSentence(problem.Grammar);
    32         var quality = problem.Evaluate(sentence) / problem.GetBestKnownQuality(maxLen); 
     33        var sentence = SampleSentence(problem.Grammar).ToString();
     34        var quality = problem.Evaluate(sentence) / problem.GetBestKnownQuality(maxLen);
    3335        DistributeReward(quality);
    3436
     
    4547    private Dictionary<string, IPolicy> ntPolicy;
    4648    private List<Tuple<string, int>> updateChain;
     49
    4750    private void InitPolicies(IGrammar grammar) {
    4851      this.ntPolicy = new Dictionary<string, IPolicy>();
     
    5053    }
    5154
    52     private string SampleSentence(IGrammar grammar) {
     55    private Sequence SampleSentence(IGrammar grammar) {
    5356      updateChain.Clear();
    54       return CompleteSentence(grammar, grammar.SentenceSymbol.ToString());
     57      return CompleteSentence(grammar, new Sequence(grammar.SentenceSymbol));
    5558    }
    5659
    57     public string CompleteSentence(IGrammar g, string phrase) {
     60    public Sequence CompleteSentence(IGrammar g, Sequence phrase) {
    5861      if (phrase.Length > maxLen) throw new ArgumentException();
    5962      if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException();
    60       bool done = phrase.All(g.IsTerminal); // terminal phrase means we are done
     63      bool done = phrase.IsTerminal; // terminal phrase means we are done
    6164      while (!done) {
    62         int ntIdx; char nt;
    63         Grammar.FindFirstNonTerminal(g, phrase, out nt, out ntIdx);
     65        char nt = phrase.FirstNonTerminal;
    6466
    6567        int maxLenOfReplacement = maxLen - (phrase.Length - 1); // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2
     
    6769
    6870        var alts = g.GetAlternatives(nt);
    69         string selectedAlt;
     71        Sequence selectedAlt;
    7072        // if the choice is restricted then one of the allowed alternatives is selected randomly
    7173        if (alts.Any(alt => g.MinPhraseLength(alt) > maxLenOfReplacement)) {
     
    7678        } else {
    7779          // all alts are allowed => select using bandit policy
     80          var ntIdx = phrase.FirstNonTerminalIndex;
    7881          var startIdx = Math.Max(0, ntIdx - contextLen);
    7982          var endIdx = Math.Min(startIdx + contextLen, ntIdx);
    80           var lft = phrase.Substring(startIdx, endIdx - startIdx + 1);
     83          var lft = phrase.Subsequence(startIdx, endIdx - startIdx + 1).ToString();
    8184          lft = problem.Hash(lft);
    8285          if (!ntPolicy.ContainsKey(lft)) {
    83             ntPolicy.Add(lft, new UCB1TunedPolicy(g.GetAlternatives(nt).Count()));
     86            ntPolicy.Add(lft, policyFactory(random, g.GetAlternatives(nt).Count()));
    8487          }
    8588          var selectedAltIdx = ntPolicy[lft].SelectAction();
     
    8992
    9093        // replace nt with alt
    91         phrase = phrase.Remove(ntIdx, 1);
    92         phrase = phrase.Insert(ntIdx, selectedAlt);
     94        phrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, selectedAlt);
    9395
    94         done = phrase.All(g.IsTerminal); // terminal phrase means we are done
     96        done = phrase.IsTerminal; // terminal phrase means we are done
    9597      }
    9698      return phrase;
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/AlternativesSampler.cs

    r11727 r11730  
    2727      InitPolicies(problem.Grammar);
    2828      for (int i = 0; i < maxIterations; i++) {
    29         var sentence = SampleSentence(problem.Grammar);
    30         var quality = problem.Evaluate(sentence) / problem.GetBestKnownQuality(maxLen); 
     29        var sentence = SampleSentence(problem.Grammar).ToString();
     30        var quality = problem.Evaluate(sentence) / problem.GetBestKnownQuality(maxLen);
    3131        DistributeReward(quality);
    3232
     
    5252    }
    5353
    54     private string SampleSentence(IGrammar grammar) {
     54    private Sequence SampleSentence(IGrammar grammar) {
    5555      updateChain.Clear();
    56       return CompleteSentence(grammar, grammar.SentenceSymbol.ToString());
     56      return CompleteSentence(grammar, new Sequence(grammar.SentenceSymbol));
    5757    }
    5858
    59     public string CompleteSentence(IGrammar g, string phrase) {
     59    public Sequence CompleteSentence(IGrammar g, Sequence phrase) {
    6060      if (phrase.Length > maxLen) throw new ArgumentException();
    6161      if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException();
    62       bool done = phrase.All(g.IsTerminal); // terminal phrase means we are done
     62      bool done = phrase.IsTerminal; // terminal phrase means we are done
    6363      while (!done) {
    64         int ntIdx; char nt;
    65         Grammar.FindFirstNonTerminal(g, phrase, out nt, out ntIdx);
     64        char nt = phrase.FirstNonTerminal;
    6665
    6766        int maxLenOfReplacement = maxLen - (phrase.Length - 1); // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2
     
    6968
    7069        var alts = g.GetAlternatives(nt);
    71         string selectedAlt;
     70        Sequence selectedAlt;
    7271        // if the choice is restricted then one of the allowed alternatives is selected randomly
    7372        if (alts.Any(alt => g.MinPhraseLength(alt) > maxLenOfReplacement)) {
     
    8483
    8584        // replace nt with alt
    86         phrase = phrase.Remove(ntIdx, 1);
    87         phrase = phrase.Insert(ntIdx, selectedAlt);
     85        phrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, selectedAlt);
    8886
    89         done = phrase.All(g.IsTerminal); // terminal phrase means we are done
     87        done = phrase.IsTerminal; // terminal phrase means we are done
    9088      }
    9189      return phrase;
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/ExhaustiveBreadthFirstSearch.cs

    r11727 r11730  
    1111
    1212    private readonly int maxLen;
    13     private readonly Queue<string> bfsQueue = new Queue<string>();
     13    private readonly Queue<Sequence> bfsQueue = new Queue<Sequence>();
    1414    private readonly IProblem problem;
    1515
     
    2121    public void Run(int maxIterations) {
    2222      double bestQuality = double.MinValue;
    23       bfsQueue.Enqueue(problem.Grammar.SentenceSymbol.ToString());
     23      bfsQueue.Enqueue(new Sequence(problem.Grammar.SentenceSymbol));
    2424      var sentences = GenerateLanguage(problem.Grammar);
    2525      var sentenceEnumerator = sentences.GetEnumerator();
    2626      for (int i = 0; sentenceEnumerator.MoveNext() && i < maxIterations; i++) {
    27         var sentence = sentenceEnumerator.Current;
    28         var quality = problem.Evaluate(sentence) / problem.GetBestKnownQuality(maxLen); 
     27        var sentence = sentenceEnumerator.Current.ToString();
     28        var quality = problem.Evaluate(sentence) / problem.GetBestKnownQuality(maxLen);
    2929        RaiseSolutionEvaluated(sentence, quality);
    3030
     
    3737
    3838    // create sentences lazily
    39     private IEnumerable<string> GenerateLanguage(IGrammar grammar) {
     39    private IEnumerable<Sequence> GenerateLanguage(IGrammar grammar) {
    4040      while (bfsQueue.Any()) {
    4141        var phrase = bfsQueue.Dequeue();
    4242
    43         char nt;
     43        char nt = phrase.FirstNonTerminal;
    4444        int ntIdx;
    45         Grammar.FindFirstNonTerminal(grammar, phrase, out nt, out ntIdx);
     45
    4646        var alts = grammar.GetAlternatives(nt);
    4747        foreach (var alt in alts) {
    48           var newPhrase = phrase.Remove(ntIdx, 1).Insert(ntIdx, alt);
    49           if (newPhrase.All(grammar.IsTerminal) && newPhrase.Length <= maxLen) {
     48          var newPhrase = new Sequence(phrase);
     49          newPhrase.ReplaceAt(newPhrase.FirstNonTerminalIndex, 1, alt);
     50          if (newPhrase.IsTerminal && newPhrase.Length <= maxLen) {
    5051            yield return newPhrase;
    5152          } else if (grammar.MinPhraseLength(newPhrase) <= maxLen) {
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/ExhaustiveDepthFirstSearch.cs

    r11727 r11730  
    1111
    1212    private readonly int maxLen;
    13     private readonly Stack<string> stack = new Stack<string>();
     13    private readonly Stack<Sequence> stack = new Stack<Sequence>();
     14    private readonly IProblem problem;
    1415
    15     public ExhaustiveDepthFirstSearch(int maxLen) {
     16    public ExhaustiveDepthFirstSearch(IProblem problem, int maxLen) {
    1617      this.maxLen = maxLen;
     18      this.problem = problem;
    1719    }
    1820
    19     public void Run(IProblem problem, int maxIterations) {
     21    public void Run(int maxIterations) {
    2022      double bestQuality = double.MinValue;
    21       stack.Push(problem.Grammar.SentenceSymbol.ToString());
     23      stack.Push(new Sequence(problem.Grammar.SentenceSymbol));
    2224      var sentences = GenerateLanguage(problem.Grammar);
    2325      var sentenceEnumerator = sentences.GetEnumerator();
    2426      for (int i = 0; sentenceEnumerator.MoveNext() && i < maxIterations; i++) {
    25         var sentence = sentenceEnumerator.Current;
     27        var sentence = sentenceEnumerator.Current.ToString();
    2628        var quality = problem.Evaluate(sentence) / problem.GetBestKnownQuality(maxLen);
    2729        RaiseSolutionEvaluated(sentence, quality);
     
    3537
    3638    // create sentences lazily
    37     private IEnumerable<string> GenerateLanguage(IGrammar grammar) {
     39    private IEnumerable<Sequence> GenerateLanguage(IGrammar grammar) {
    3840      while (stack.Any()) {
    3941        var phrase = stack.Pop();
    4042
    41         char nt;
    42         int ntIdx;
    43         Grammar.FindFirstNonTerminal(grammar, phrase, out nt, out ntIdx);
     43        char nt = phrase.FirstNonTerminal;
    4444        var alts = grammar.GetAlternatives(nt);
    4545        foreach (var alt in alts) {
    46           var newPhrase = phrase.Remove(ntIdx, 1).Insert(ntIdx, alt);
    47           if (newPhrase.All(grammar.IsTerminal) && newPhrase.Length <= maxLen) {
     46          var newPhrase = new Sequence(phrase);
     47          newPhrase.ReplaceAt(newPhrase.FirstNonTerminalIndex, 1, alt);
     48
     49          if (newPhrase.IsTerminal && newPhrase.Length <= maxLen) {
    4850            yield return newPhrase;
    4951          } else if (grammar.MinPhraseLength(newPhrase) <= maxLen) {
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/MctsSampler.cs

    r11727 r11730  
    1010  public class MctsSampler {
    1111    private class TreeNode {
     12      public string ident;
    1213      public int randomTries;
     14      public int policyTries;
    1315      public IPolicy policy;
    1416      public TreeNode[] children;
    1517      public bool done = false;
    1618
     19      public TreeNode(string id) {
     20        this.ident = id;
     21      }
     22
    1723      public override string ToString() {
    18         return string.Format("Node(random-tries: {0}, done: {1}, policy: {2})", randomTries, done, policy);
     24        return string.Format("Node({0} tries: {1}, done: {2}, policy: {3})", ident, randomTries + policyTries, done, policy);
    1925      }
    2026    }
     27
    2128
    2229    public event Action<string, double> FoundNewBestSolution;
     
    2734    private readonly Random random;
    2835    private readonly int randomTries;
    29     private readonly Func<int, IPolicy> policyFactory;
     36    private readonly Func<Random, int, IPolicy> policyFactory;
    3037
    3138    private List<Tuple<TreeNode, int>> updateChain;
    3239    private TreeNode rootNode;
    3340
     41    public int treeDepth;
     42    public int treeSize;
     43
    3444    public MctsSampler(IProblem problem, int maxLen, Random random) :
    35       this(problem, maxLen, random, 10, (numActions) => new EpsGreedyPolicy(random, numActions, 0.1)) {
     45      this(problem, maxLen, random, 10, (rand, numActions) => new EpsGreedyPolicy(rand, numActions, 0.1)) {
    3646
    3747    }
    3848
    39     public MctsSampler(IProblem problem, int maxLen, Random random, int randomTries, Func<int, IPolicy> policyFactory) {
     49    public MctsSampler(IProblem problem, int maxLen, Random random, int randomTries, Func<Random, int, IPolicy> policyFactory) {
    4050      this.maxLen = maxLen;
    4151      this.problem = problem;
     
    4757    public void Run(int maxIterations) {
    4858      double bestQuality = double.MinValue;
    49       InitPolicies();
     59      InitPolicies(problem.Grammar);
    5060      for (int i = 0; !rootNode.done && i < maxIterations; i++) {
    51         var sentence = SampleSentence(problem.Grammar);
     61        var sentence = SampleSentence(problem.Grammar).ToString();
    5262        var quality = problem.Evaluate(sentence) / problem.GetBestKnownQuality(maxLen);
    5363        Debug.Assert(quality >= 0 && quality <= 1.0);
     
    6171        }
    6272      }
     73
     74      // clean up
     75      InitPolicies(problem.Grammar); GC.Collect();
    6376    }
    6477
    65     private void InitPolicies() {
    66       this.updateChain = new List<Tuple<TreeNode, int>>();
    67       rootNode = new TreeNode();
     78    public void PrintStats() {
     79      var n = rootNode;
     80      Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}", treeDepth, treeSize, rootNode.policyTries + rootNode.randomTries);
     81      while (n.policy != null) {
     82        Console.WriteLine();
     83        Console.WriteLine("{0,5}->{1,-50}", n.ident, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.ident))));
     84        Console.WriteLine("{0,5}  {1,-50}", string.Empty, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.randomTries + ch.policyTries))));
     85        //n.policy.PrintStats();
     86        n = n.children.OrderByDescending(c => c.policyTries).First();
     87      }
     88      Console.ReadLine();
    6889    }
    6990
    70     private string SampleSentence(IGrammar grammar) {
    71       updateChain.Clear();
    72       return CompleteSentence(grammar, grammar.SentenceSymbol.ToString());
     91    private void InitPolicies(IGrammar grammar) {
     92      this.updateChain = new List<Tuple<TreeNode, int>>();
     93
     94      rootNode = new TreeNode(grammar.SentenceSymbol.ToString());
     95      treeDepth = 0;
     96      treeSize = 0;
    7397    }
    7498
    75     public string CompleteSentence(IGrammar g, string phrase) {
     99    private Sequence SampleSentence(IGrammar grammar) {
     100      updateChain.Clear();
     101      var startPhrase = new Sequence(grammar.SentenceSymbol);
     102      return CompleteSentence(grammar, startPhrase);
     103    }
     104
     105    private Sequence CompleteSentence(IGrammar g, Sequence phrase) {
    76106      if (phrase.Length > maxLen) throw new ArgumentException();
    77107      if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException();
    78108      TreeNode n = rootNode;
    79       bool done = phrase.All(g.IsTerminal); // terminal phrase means we are done
     109      bool done = phrase.IsTerminal;
    80110      int selectedAltIdx = -1;
     111      var curDepth = 0;
    81112      while (!done) {
    82         int ntIdx; char nt;
    83         Grammar.FindFirstNonTerminal(g, phrase, out nt, out ntIdx);
     113        char nt = phrase.FirstNonTerminal;
    84114
    85115        int maxLenOfReplacement = maxLen - (phrase.Length - 1); // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2
     
    90120        if (n.randomTries < randomTries) {
    91121          n.randomTries++;
     122
     123          treeDepth = Math.Max(treeDepth, curDepth);
     124
    92125          return g.CompleteSentenceRandomly(random, phrase, maxLen);
    93126        } else if (n.randomTries == randomTries && n.policy == null) {
    94           n.policy = policyFactory(alts.Count());
    95           n.children = alts.Select(_ => new TreeNode()).ToArray(); // create a new node for each alternative
     127          n.policy = policyFactory(random, alts.Count());
     128          //n.children = alts.Select(alt => new TreeNode(alt.ToString())).ToArray(); // create a new node for each alternative
     129          n.children = alts.Select(alt => new TreeNode(string.Empty)).ToArray(); // create a new node for each alternative
     130
     131          treeSize += n.children.Length;
    96132        }
    97 
     133        n.policyTries++;
    98134        // => select using bandit policy
    99135        selectedAltIdx = n.policy.SelectAction();
    100         string selectedAlt = alts.ElementAt(selectedAltIdx);
     136        Sequence selectedAlt = alts.ElementAt(selectedAltIdx);
     137
    101138        // replace nt with alt
    102         phrase = phrase.Remove(ntIdx, 1);
    103         phrase = phrase.Insert(ntIdx, selectedAlt);
     139        phrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, selectedAlt);
    104140
    105141        updateChain.Add(Tuple.Create(n, selectedAltIdx));
    106142
    107         done = phrase.All(g.IsTerminal); // terminal phrase means we are done
     143        curDepth++;
     144
     145        done = phrase.IsTerminal;
    108146        if (!done) {
    109147          // prepare for next iteration
     
    116154      n.children[selectedAltIdx].done = true;
    117155
     156      treeDepth = Math.Max(treeDepth, curDepth);
    118157      return phrase;
    119158    }
     
    127166        var policy = node.policy;
    128167        var action = e.Item2;
     168        //policy.UpdateReward(action, reward / updateChain.Count);
    129169        policy.UpdateReward(action, reward);
    130170
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/RandomSearch.cs

    r11690 r11730  
    1313    private readonly int maxLen;
    1414    private readonly Random random;
     15    private readonly IProblem problem;
    1516
    16     public RandomSearch(int maxLen) {
     17    public RandomSearch(IProblem problem, Random random, int maxLen) {
    1718      this.maxLen = maxLen;
    18       this.random = new Random(31415);
     19      this.random = random;
     20      this.problem = problem;
    1921    }
    2022
    21     public void Run(IProblem problem, int maxIterations) {
     23    public void Run(int maxIterations) {
    2224      double bestQuality = double.MinValue;
    2325      for (int i = 0; i < maxIterations; i++) {
    24         var sentence = CreateSentence(problem.Grammar);
    25         var quality = problem.Evaluate(sentence);
     26        var sentence = CreateSentence(problem.Grammar).ToString();
     27        var quality = problem.Evaluate(sentence) / problem.GetBestKnownQuality(maxLen);
    2628        RaiseSolutionEvaluated(sentence, quality);
    2729
     
    3335    }
    3436
    35     private string CreateSentence(IGrammar grammar) {
    36       var sentence = grammar.SentenceSymbol.ToString();
     37    private Sequence CreateSentence(IGrammar grammar) {
     38      var sentence = new Sequence(grammar.SentenceSymbol);
    3739      return grammar.CompleteSentenceRandomly(random, sentence, maxLen);
    3840    }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Common/Extensions.cs

    r11727 r11730  
    1515      return xsArr[rand.Next(xsArr.Length)];
    1616    }
     17
     18    public static IEnumerable<T> SampleProportional<T>(this IEnumerable<T> source, Random random, IEnumerable<double> weights) {
     19      var sourceArray = source.ToArray();
     20      var valueArray = weights.ToArray();
     21      double total = valueArray.Sum();
     22
     23      while (true) {
     24        int index = 0;
     25        double ball = valueArray[index], sum = random.NextDouble() * total;
     26        while (ball < sum)
     27          ball += valueArray[++index];
     28        yield return sourceArray[index];
     29      }
     30    }
    1731  }
    1832}
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization.Test/HeuristicLab.Problems.GrammaticalOptimization.Test.csproj

    r11708 r11730  
    5757  </Choose>
    5858  <ItemGroup>
     59    <Compile Include="TestSequence.cs" />
    5960    <Compile Include="TestBanditPolicies.cs" />
    6061    <Compile Include="TestInstances.cs" />
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization.Test/TestBanditPolicies.cs

    r11727 r11730  
    44using System.Globalization;
    55using HeuristicLab.Algorithms.Bandits;
     6using HeuristicLab.Algorithms.Bandits.Models;
    67using Microsoft.VisualStudio.TestTools.UnitTesting;
    78
     
    910  [TestClass]
    1011  public class TestBanditPolicies {
     12
     13
    1114    [TestMethod]
    1215    public void ComparePoliciesForBernoulliBandit() {
    13       System.Threading.Thread.CurrentThread.CurrentCulture = CultureInfo.InvariantCulture;
     16      CultureInfo.DefaultThreadCurrentCulture = CultureInfo.InvariantCulture;
     17
    1418      var globalRand = new Random(31415);
    1519      var seedForPolicy = globalRand.Next();
    16       var nArms = 10;
     20      var nArms = 20;
    1721      //Console.WriteLine("Exp3 (gamma=0.01)");
    1822      //TestPolicyBernoulli(globalRand, nArms, new Exp3Policy(new Random(seedForPolicy), nArms, 1));
    1923      //Console.WriteLine("Exp3 (gamma=0.05)");
    2024      //estPolicyBernoulli(globalRand, nArms, new Exp3Policy(new Random(seedForPolicy), nArms, 1));
    21       Console.WriteLine("Thompson (Bernoulli)");
    22       TestPolicyBernoulli(globalRand, nArms, new BernoulliThompsonSamplingPolicy(new Random(seedForPolicy), nArms));
     25      Console.WriteLine("Thompson (Bernoulli)"); TestPolicyBernoulli(globalRand, nArms, new BernoulliThompsonSamplingPolicy(new Random(seedForPolicy), nArms));
     26      Console.WriteLine("Generic Thompson (Bernoulli)"); TestPolicyBernoulli(globalRand, nArms, new GenericThompsonSamplingPolicy(new Random(seedForPolicy), nArms, new BernoulliModel(nArms)));
    2327      Console.WriteLine("Random");
    2428      TestPolicyBernoulli(globalRand, nArms, new RandomPolicy(new Random(seedForPolicy), nArms));
     
    3943      //Console.WriteLine("Eps(0.5)");
    4044      //TestPolicyBernoulli(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.5));
    41     }
     45      Console.WriteLine("UCT(0.1)"); TestPolicyBernoulli(globalRand, nArms, new UCTPolicy(nArms, 0.1));
     46      Console.WriteLine("UCT(0.5)"); TestPolicyBernoulli(globalRand, nArms, new UCTPolicy(nArms, 0.5));
     47      Console.WriteLine("UCT(1)  "); TestPolicyBernoulli(globalRand, nArms, new UCTPolicy(nArms, 1));
     48      Console.WriteLine("UCT(2)  "); TestPolicyBernoulli(globalRand, nArms, new UCTPolicy(nArms, 2));
     49      Console.WriteLine("UCT(5)  "); TestPolicyBernoulli(globalRand, nArms, new UCTPolicy(nArms, 5));
     50      Console.WriteLine("BoltzmannExploration(0.1)"); TestPolicyBernoulli(globalRand, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 0.1));
     51      Console.WriteLine("BoltzmannExploration(0.5)"); TestPolicyBernoulli(globalRand, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 0.5));
     52      Console.WriteLine("BoltzmannExploration(1)  "); TestPolicyBernoulli(globalRand, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 1));
     53      Console.WriteLine("BoltzmannExploration(10) "); TestPolicyBernoulli(globalRand, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 10));
     54      Console.WriteLine("BoltzmannExploration(100)"); TestPolicyBernoulli(globalRand, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 100));
     55      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.01)"); TestPolicyBernoulli(globalRand, nArms, new ChernoffIntervalEstimationPolicy(nArms, 0.01));
     56      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.05)"); TestPolicyBernoulli(globalRand, nArms, new ChernoffIntervalEstimationPolicy(nArms, 0.05));
     57      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.1) "); TestPolicyBernoulli(globalRand, nArms, new ChernoffIntervalEstimationPolicy(nArms, 0.1));
     58
     59      // not applicable to bernoulli rewards
     60      //Console.WriteLine("ThresholdAscent(10, 0.01)  "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 10, 0.01));
     61      //Console.WriteLine("ThresholdAscent(10, 0.05)  "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 10, 0.05));
     62      //Console.WriteLine("ThresholdAscent(10, 0.1)   "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 10, 0.1));
     63      //Console.WriteLine("ThresholdAscent(100, 0.01) "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 100, 0.01));
     64      //Console.WriteLine("ThresholdAscent(100, 0.05) "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 100, 0.05));
     65      //Console.WriteLine("ThresholdAscent(100, 0.1)  "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 100, 0.1));
     66      //Console.WriteLine("ThresholdAscent(1000, 0.01)"); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.01));
     67      //Console.WriteLine("ThresholdAscent(1000, 0.05)"); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.05));
     68      //Console.WriteLine("ThresholdAscent(1000, 0.1) "); TestPolicyBernoulli(globalRand, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.1));
     69    }
     70
    4271    [TestMethod]
    4372    public void ComparePoliciesForNormalBandit() {
    44       System.Threading.Thread.CurrentThread.CurrentCulture = CultureInfo.InvariantCulture;
     73      CultureInfo.DefaultThreadCurrentCulture = CultureInfo.InvariantCulture;
     74
    4575      var globalRand = new Random(31415);
    4676      var seedForPolicy = globalRand.Next();
    47       var nArms = 10;
    48       Console.WriteLine("Thompson (Gaussian)");
    49       TestPolicyNormal(globalRand, nArms, new GaussianThompsonSamplingPolicy(new Random(seedForPolicy), nArms));
    50       Console.WriteLine("Random");
    51       TestPolicyNormal(globalRand, nArms, new RandomPolicy(new Random(seedForPolicy), nArms));
    52       Console.WriteLine("UCB1");
    53       TestPolicyNormal(globalRand, nArms, new UCB1Policy(nArms));
    54       Console.WriteLine("UCB1Tuned");
    55       TestPolicyNormal(globalRand, nArms, new UCB1TunedPolicy(nArms));
    56       Console.WriteLine("UCB1Normal");
    57       TestPolicyNormal(globalRand, nArms, new UCBNormalPolicy(nArms));
     77      var nArms = 20;
     78      Console.WriteLine("Thompson (Gaussian orig)"); TestPolicyNormal(globalRand, nArms, new GaussianThompsonSamplingPolicy(new Random(seedForPolicy), nArms, true));
     79      Console.WriteLine("Thompson (Gaussian new)"); TestPolicyNormal(globalRand, nArms, new GaussianThompsonSamplingPolicy(new Random(seedForPolicy), nArms));
     80      Console.WriteLine("Generic Thompson (Gaussian)"); TestPolicyNormal(globalRand, nArms, new GenericThompsonSamplingPolicy(new Random(seedForPolicy), nArms, new GaussianModel(nArms, 0.5, 1)));
     81      /*
     82      Console.WriteLine("Random"); TestPolicyNormal(globalRand, nArms, new RandomPolicy(new Random(seedForPolicy), nArms));
     83      Console.WriteLine("UCB1"); TestPolicyNormal(globalRand, nArms, new UCB1Policy(nArms));
     84      Console.WriteLine("UCB1Tuned"); TestPolicyNormal(globalRand, nArms, new UCB1TunedPolicy(nArms));
     85      Console.WriteLine("UCB1Normal"); TestPolicyNormal(globalRand, nArms, new UCBNormalPolicy(nArms));
    5886      //Console.WriteLine("Exp3 (gamma=0.01)");
    5987      //TestPolicyNormal(globalRand, nArms, new Exp3Policy(new Random(seedForPolicy), nArms, 0.01));
    6088      //Console.WriteLine("Exp3 (gamma=0.05)");
    6189      //TestPolicyNormal(globalRand, nArms, new Exp3Policy(new Random(seedForPolicy), nArms, 0.05));
    62       Console.WriteLine("Eps(0.01)");
    63       TestPolicyNormal(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.01));
    64       Console.WriteLine("Eps(0.05)");
    65       TestPolicyNormal(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.05));
     90      Console.WriteLine("Eps(0.01)"); TestPolicyNormal(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.01));
     91      Console.WriteLine("Eps(0.05)"); TestPolicyNormal(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.05));
    6692      //Console.WriteLine("Eps(0.1)");
    6793      //TestPolicyNormal(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.1));
     
    7096      //Console.WriteLine("Eps(0.5)");
    7197      //TestPolicyNormal(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.5));
    72     }
     98      Console.WriteLine("UCT(0.1)"); TestPolicyNormal(globalRand, nArms, new UCTPolicy(nArms, 0.1));
     99      Console.WriteLine("UCT(0.5)"); TestPolicyNormal(globalRand, nArms, new UCTPolicy(nArms, 0.5));
     100      Console.WriteLine("UCT(1)  "); TestPolicyNormal(globalRand, nArms, new UCTPolicy(nArms, 1));
     101      Console.WriteLine("UCT(2)  "); TestPolicyNormal(globalRand, nArms, new UCTPolicy(nArms, 2));
     102      Console.WriteLine("UCT(5)  "); TestPolicyNormal(globalRand, nArms, new UCTPolicy(nArms, 5));
     103      Console.WriteLine("BoltzmannExploration(0.1)"); TestPolicyNormal(globalRand, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 0.1));
     104      Console.WriteLine("BoltzmannExploration(0.5)"); TestPolicyNormal(globalRand, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 0.5));
     105      Console.WriteLine("BoltzmannExploration(1)  "); TestPolicyNormal(globalRand, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 1));
     106      Console.WriteLine("BoltzmannExploration(10) "); TestPolicyNormal(globalRand, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 10));
     107      Console.WriteLine("BoltzmannExploration(100)"); TestPolicyNormal(globalRand, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 100));
     108      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.01)"); TestPolicyNormal(globalRand, nArms, new ChernoffIntervalEstimationPolicy(nArms, 0.01));
     109      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.05)"); TestPolicyNormal(globalRand, nArms, new ChernoffIntervalEstimationPolicy(nArms, 0.05));
     110      Console.WriteLine("ChernoffIntervalEstimationPolicy(0.1) "); TestPolicyNormal(globalRand, nArms, new ChernoffIntervalEstimationPolicy(nArms, 0.1));
     111      Console.WriteLine("ThresholdAscent(10,0.01)  "); TestPolicyNormal(globalRand, nArms, new ThresholdAscentPolicy(nArms, 10, 0.01));
     112      Console.WriteLine("ThresholdAscent(10,0.05)  "); TestPolicyNormal(globalRand, nArms, new ThresholdAscentPolicy(nArms, 10, 0.05));
     113      Console.WriteLine("ThresholdAscent(10,0.1)   "); TestPolicyNormal(globalRand, nArms, new ThresholdAscentPolicy(nArms, 10, 0.1));
     114      Console.WriteLine("ThresholdAscent(100,0.01) "); TestPolicyNormal(globalRand, nArms, new ThresholdAscentPolicy(nArms, 100, 0.01));
     115      Console.WriteLine("ThresholdAscent(100,0.05) "); TestPolicyNormal(globalRand, nArms, new ThresholdAscentPolicy(nArms, 100, 0.05));
     116      Console.WriteLine("ThresholdAscent(100,0.1)  "); TestPolicyNormal(globalRand, nArms, new ThresholdAscentPolicy(nArms, 100, 0.1));
     117      Console.WriteLine("ThresholdAscent(1000,0.01)"); TestPolicyNormal(globalRand, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.01));
     118      Console.WriteLine("ThresholdAscent(1000,0.05)"); TestPolicyNormal(globalRand, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.05));
     119      Console.WriteLine("ThresholdAscent(1000,0.1) "); TestPolicyNormal(globalRand, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.1));
     120       */
     121    }
     122
     123    [TestMethod]
     124    public void ComparePoliciesForGaussianMixtureBandit() {
     125      CultureInfo.DefaultThreadCurrentCulture = CultureInfo.InvariantCulture;
     126
     127      var globalRand = new Random(31415);
     128      var seedForPolicy = globalRand.Next();
     129      var nArms = 20;
     130      Console.WriteLine("Thompson (Gaussian orig)"); TestPolicyGaussianMixture(globalRand, nArms, new GaussianThompsonSamplingPolicy(new Random(seedForPolicy), nArms, true));
     131      Console.WriteLine("Thompson (Gaussian new)"); TestPolicyGaussianMixture(globalRand, nArms, new GaussianThompsonSamplingPolicy(new Random(seedForPolicy), nArms));
     132      Console.WriteLine("Generic Thompson (Gaussian)"); TestPolicyGaussianMixture(globalRand, nArms, new GenericThompsonSamplingPolicy(new Random(seedForPolicy), nArms, new GaussianModel(nArms, 0.5, 1)));
     133
     134      /*
     135      Console.WriteLine("Random"); TestPolicyGaussianMixture(globalRand, nArms, new RandomPolicy(new Random(seedForPolicy), nArms));
     136      Console.WriteLine("UCB1"); TestPolicyGaussianMixture(globalRand, nArms, new UCB1Policy(nArms));
     137      Console.WriteLine("UCB1Tuned "); TestPolicyGaussianMixture(globalRand, nArms, new UCB1TunedPolicy(nArms));
     138      Console.WriteLine("UCB1Normal"); TestPolicyGaussianMixture(globalRand, nArms, new UCBNormalPolicy(nArms));
     139      Console.WriteLine("Eps(0.01) "); TestPolicyGaussianMixture(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.01));
     140      Console.WriteLine("Eps(0.05) "); TestPolicyGaussianMixture(globalRand, nArms, new EpsGreedyPolicy(new Random(seedForPolicy), nArms, 0.05));
     141      Console.WriteLine("UCT(1)  "); TestPolicyGaussianMixture(globalRand, nArms, new UCTPolicy(nArms, 1));
     142      Console.WriteLine("UCT(2)  "); TestPolicyGaussianMixture(globalRand, nArms, new UCTPolicy(nArms, 2));
     143      Console.WriteLine("UCT(5)  "); TestPolicyGaussianMixture(globalRand, nArms, new UCTPolicy(nArms, 5));
     144      Console.WriteLine("BoltzmannExploration(1)  "); TestPolicyGaussianMixture(globalRand, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 1));
     145      Console.WriteLine("BoltzmannExploration(10) "); TestPolicyGaussianMixture(globalRand, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 10));
     146      Console.WriteLine("BoltzmannExploration(100)"); TestPolicyGaussianMixture(globalRand, nArms, new BoltzmannExplorationPolicy(new Random(seedForPolicy), nArms, 100));
     147
     148      Console.WriteLine("ThresholdAscent(10,0.01)  "); TestPolicyGaussianMixture(globalRand, nArms, new ThresholdAscentPolicy(nArms, 10, 0.01));
     149      Console.WriteLine("ThresholdAscent(100,0.01) "); TestPolicyGaussianMixture(globalRand, nArms, new ThresholdAscentPolicy(nArms, 100, 0.01));
     150      Console.WriteLine("ThresholdAscent(1000,0.01)"); TestPolicyGaussianMixture(globalRand, nArms, new ThresholdAscentPolicy(nArms, 1000, 0.01));
     151      Console.WriteLine("ThresholdAscent(10000,0.01)"); TestPolicyGaussianMixture(globalRand, nArms, new ThresholdAscentPolicy(nArms, 10000, 0.01));
     152       */
     153    }
     154
    73155
    74156    private void TestPolicyBernoulli(Random globalRand, int nArms, IPolicy policy) {
    75       var maxIt = 1E6;
    76       var reps = 10; // 10 independent runs
    77       var avgRegretForIteration = new Dictionary<int, double>();
     157      TestPolicy(globalRand, nArms, policy, (banditRandom, nActions) => new BernoulliBandit(banditRandom, nActions));
     158    }
     159    private void TestPolicyNormal(Random globalRand, int nArms, IPolicy policy) {
     160      TestPolicy(globalRand, nArms, policy, (banditRandom, nActions) => new TruncatedNormalBandit(banditRandom, nActions));
     161    }
     162    private void TestPolicyGaussianMixture(Random globalRand, int nArms, IPolicy policy) {
     163      TestPolicy(globalRand, nArms, policy, (banditRandom, nActions) => new GaussianMixtureBandit(banditRandom, nActions));
     164    }
     165
     166
     167    private void TestPolicy(Random globalRand, int nArms, IPolicy policy, Func<Random, int, IBandit> banditFactory) {
     168      var maxIt = 1E5;
     169      var reps = 30; // independent runs
     170      var regretForIteration = new Dictionary<int, List<double>>();
     171      var numberOfPullsOfSuboptimalArmsForExp = new Dictionary<int, double>();
     172      var numberOfPullsOfSuboptimalArmsForMax = new Dictionary<int, double>();
    78173      // calculate statistics
    79174      for (int r = 0; r < reps; r++) {
    80175        var nextLogStep = 1;
    81         var b = new BernoulliBandit(new Random(globalRand.Next()), 10);
     176        var b = banditFactory(new Random(globalRand.Next()), nArms);
    82177        policy.Reset();
    83178        var totalRegret = 0.0;
    84 
     179        var totalPullsOfSuboptimalArmsExp = 0.0;
     180        var totalPullsOfSuboptimalArmsMax = 0.0;
    85181        for (int i = 0; i <= maxIt; i++) {
    86182          var selectedAction = policy.SelectAction();
    87183          var reward = b.Pull(selectedAction);
     184          policy.UpdateReward(selectedAction, reward);
     185
     186          // collect stats
     187          if (selectedAction != b.OptimalExpectedRewardArm) totalPullsOfSuboptimalArmsExp++;
     188          if (selectedAction != b.OptimalMaximalRewardArm) totalPullsOfSuboptimalArmsMax++;
    88189          totalRegret += b.OptimalExpectedReward - reward;
    89           policy.UpdateReward(selectedAction, reward);
     190
    90191          if (i == nextLogStep) {
    91             nextLogStep *= 10;
    92             if (!avgRegretForIteration.ContainsKey(i)) {
    93               avgRegretForIteration.Add(i, 0.0);
     192            nextLogStep *= 2;
     193            if (!regretForIteration.ContainsKey(i)) {
     194              regretForIteration.Add(i, new List<double>());
    94195            }
    95             avgRegretForIteration[i] += totalRegret / i;
     196            regretForIteration[i].Add(totalRegret / i);
     197
     198            if (!numberOfPullsOfSuboptimalArmsForExp.ContainsKey(i)) {
     199              numberOfPullsOfSuboptimalArmsForExp.Add(i, 0.0);
     200            }
     201            numberOfPullsOfSuboptimalArmsForExp[i] += totalPullsOfSuboptimalArmsExp;
     202
     203            if (!numberOfPullsOfSuboptimalArmsForMax.ContainsKey(i)) {
     204              numberOfPullsOfSuboptimalArmsForMax.Add(i, 0.0);
     205            }
     206            numberOfPullsOfSuboptimalArmsForMax[i] += totalPullsOfSuboptimalArmsMax;
    96207          }
    97208        }
    98209      }
    99210      // print
    100       foreach (var p in avgRegretForIteration.Keys.OrderBy(k => k)) {
    101         Console.WriteLine("{0} {1}", p, avgRegretForIteration[p] / reps); // print avg. of avg. regret
    102       }
    103     }
    104     private void TestPolicyNormal(Random globalRand, int nArms, IPolicy policy) {
    105       var maxIt = 1E6;
    106       var reps = 10; // 10 independent runs
    107       var avgRegretForIteration = new Dictionary<int, double>();
    108       // calculate statistics
    109       for (int r = 0; r < reps; r++) {
    110         var nextLogStep = 1;
    111         var b = new TruncatedNormalBandit(new Random(globalRand.Next()), 10);
    112         policy.Reset();
    113         var totalRegret = 0.0;
    114 
    115         for (int i = 0; i <= maxIt; i++) {
    116           var selectedAction = policy.SelectAction();
    117           var reward = b.Pull(selectedAction);
    118           totalRegret += b.OptimalExpectedReward - reward;
    119           policy.UpdateReward(selectedAction, reward);
    120           if (i == nextLogStep) {
    121             nextLogStep *= 10;
    122             if (!avgRegretForIteration.ContainsKey(i)) {
    123               avgRegretForIteration.Add(i, 0.0);
    124             }
    125             avgRegretForIteration[i] += totalRegret / i;
    126           }
    127         }
    128       }
    129       // print
    130       foreach (var p in avgRegretForIteration.Keys.OrderBy(k => k)) {
    131         Console.WriteLine("{0} {1}", p, avgRegretForIteration[p] / reps); // print avg. of avg. regret
     211      foreach (var p in regretForIteration.Keys.OrderBy(k => k)) {
     212        Console.WriteLine("iter {0,8} regret avg {1,7:F5} min {2,7:F5} max {3,7:F5} suboptimal pulls (exp) {4,7:F2} suboptimal pulls (max) {5,7:F2}",
     213          p,
     214          regretForIteration[p].Average(),
     215          regretForIteration[p].Min(),
     216          regretForIteration[p].Max(),
     217          numberOfPullsOfSuboptimalArmsForExp[p] / (double)reps,
     218          numberOfPullsOfSuboptimalArmsForMax[p] / (double)reps
     219          );
    132220      }
    133221    }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization.Test/TestInstances.cs

    r11659 r11730  
    2727        Assert.AreEqual(2, g.GetAlternatives('B').Count());
    2828
    29         Assert.IsTrue(g.GetAlternatives('S').Contains("aA"));
    30         Assert.IsTrue(g.GetAlternatives('S').Contains("bB"));
    31         Assert.IsTrue(g.GetAlternatives('A').Contains("aA"));
    32         Assert.IsTrue(g.GetAlternatives('A').Contains("a"));
    33         Assert.IsTrue(g.GetAlternatives('B').Contains("Bb"));
    34         Assert.IsTrue(g.GetAlternatives('B').Contains("b"));
    35 
    36         Assert.AreEqual(2, g.MinPhraseLength("S"));
    37         Assert.AreEqual(short.MaxValue, g.MaxPhraseLength("S"));
    38         Assert.AreEqual(1, g.MinPhraseLength("A"));
    39         Assert.AreEqual(short.MaxValue, g.MaxPhraseLength("A"));
    40         Assert.AreEqual(1, g.MinPhraseLength("B"));
    41         Assert.AreEqual(short.MaxValue, g.MaxPhraseLength("B"));
     29        Assert.IsTrue(g.GetAlternatives('S').Any(s => s.ToString() == "aA"));
     30        Assert.IsTrue(g.GetAlternatives('S').Any(s => s.ToString() == "bB"));
     31        Assert.IsTrue(g.GetAlternatives('A').Any(s => s.ToString() == "aA"));
     32        Assert.IsTrue(g.GetAlternatives('A').Any(s => s.ToString() == "a"));
     33        Assert.IsTrue(g.GetAlternatives('B').Any(s => s.ToString() == "Bb"));
     34        Assert.IsTrue(g.GetAlternatives('B').Any(s => s.ToString() == "b"));
     35
     36        Assert.AreEqual(2, g.MinPhraseLength(new Sequence("S")));
     37        Assert.AreEqual(short.MaxValue, g.MaxPhraseLength(new Sequence("S")));
     38        Assert.AreEqual(1, g.MinPhraseLength(new Sequence("A")));
     39        Assert.AreEqual(short.MaxValue, g.MaxPhraseLength(new Sequence("A")));
     40        Assert.AreEqual(1, g.MinPhraseLength(new Sequence("B")));
     41        Assert.AreEqual(short.MaxValue, g.MaxPhraseLength(new Sequence("B")));
    4242      }
    4343
     
    5656        Assert.AreEqual(1, g.GetAlternatives('S').Count());
    5757
    58         Assert.IsTrue(g.GetAlternatives('S').Contains("sS"));
    59 
    60         Assert.AreEqual(short.MaxValue, g.MinPhraseLength("S"));
    61         Assert.AreEqual(short.MaxValue, g.MaxPhraseLength("S"));
     58        Assert.IsTrue(g.GetAlternatives('S').Any(s => s.ToString() == "sS"));
     59
     60        Assert.AreEqual(short.MaxValue, g.MinPhraseLength(new Sequence("S")));
     61        Assert.AreEqual(short.MaxValue, g.MaxPhraseLength(new Sequence("S")));
    6262      }
    6363
     
    7575        Assert.AreEqual(2, g.GetAlternatives('S').Count());
    7676
    77         Assert.IsTrue(g.GetAlternatives('S').Contains("sss"));
    78         Assert.IsTrue(g.GetAlternatives('S').Contains("sS"));
    79 
    80         Assert.AreEqual(3, g.MinPhraseLength("S"));
    81         Assert.AreEqual(short.MaxValue, g.MaxPhraseLength("S"));
    82         Assert.AreEqual(4, g.MinPhraseLength("sS"));
    83         Assert.AreEqual(7, g.MinPhraseLength("sSS"));
    84         Assert.AreEqual(short.MaxValue, g.MaxPhraseLength("sSS"));
    85         Assert.AreEqual(3, g.MaxPhraseLength("sss"));
    86         Assert.AreEqual(3, g.MinPhraseLength("sss"));
     77        Assert.IsTrue(g.GetAlternatives('S').Any(s => s.ToString() == "sss"));
     78        Assert.IsTrue(g.GetAlternatives('S').Any(s => s.ToString() == "sS"));
     79
     80        Assert.AreEqual(3, g.MinPhraseLength(new Sequence("S")));
     81        Assert.AreEqual(short.MaxValue, g.MaxPhraseLength(new Sequence("S")));
     82        Assert.AreEqual(4, g.MinPhraseLength(new Sequence("sS")));
     83        Assert.AreEqual(7, g.MinPhraseLength(new Sequence("sSS")));
     84        Assert.AreEqual(short.MaxValue, g.MaxPhraseLength(new Sequence("sSS")));
     85        Assert.AreEqual(3, g.MaxPhraseLength(new Sequence("sss")));
     86        Assert.AreEqual(3, g.MinPhraseLength(new Sequence("sss")));
    8787      }
    8888
     
    101101        Assert.AreEqual(2, g.GetAlternatives('S').Count());
    102102
    103         Assert.IsTrue(g.GetAlternatives('S').Contains("T"));
    104         Assert.IsTrue(g.GetAlternatives('S').Contains("TS"));
    105 
    106         Assert.AreEqual(1, g.MinPhraseLength("S"));
    107         Assert.AreEqual(short.MaxValue, g.MaxPhraseLength("S"));
    108         Assert.AreEqual(1, g.MinPhraseLength("T"));
    109         Assert.AreEqual(1, g.MaxPhraseLength("T"));
     103        Assert.IsTrue(g.GetAlternatives('S').Any(s => s.ToString() == "T"));
     104        Assert.IsTrue(g.GetAlternatives('S').Any(s => s.ToString() == "TS"));
     105
     106        Assert.AreEqual(1, g.MinPhraseLength(new Sequence("S")));
     107        Assert.AreEqual(short.MaxValue, g.MaxPhraseLength(new Sequence("S")));
     108        Assert.AreEqual(1, g.MinPhraseLength(new Sequence("T")));
     109        Assert.AreEqual(1, g.MaxPhraseLength(new Sequence("T")));
    110110      }
    111111
     
    247247
    248248      Assert.AreEqual(0.252718466940018, p.Evaluate("a*b"), 1.0E-7);
     249      Assert.AreEqual(0.290635611728845, p.Evaluate("c*d"), 1.0E-7);
     250      Assert.AreEqual(0.25737325167716, p.Evaluate("e*f"), 1.0E-7);
     251
    249252      Assert.AreEqual(0.00173739472363473, p.Evaluate("b*c"), 1.0E-7);
    250253      Assert.AreEqual(3.15450564064128E-05, p.Evaluate("d*e"), 1.0E-7);
    251254
     255      Assert.AreEqual(0.0943358163760454, p.Evaluate("a*g*i"), 1.0E-7);
     256      Assert.AreEqual(0.116199534934045, p.Evaluate("c*f*j"), 1.0E-7);
     257
     258
    252259      Assert.AreEqual(1.0, p.Evaluate("a*b+c*d+e*f+a*g*i+c*f*j"), 1.0E-7);
    253260    }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization.Test/TestSolvers.cs

    r11727 r11730  
    3636        // E -> V | V+E | V-E | V*E | V/E | (E)
    3737        // V -> a .. j
     38        /* grammar has been change ... unit test not yet adapted
    3839        var prob = new SymbolicRegressionPoly10Problem();
    3940        var comb = 10;
    4041        TestDFS(prob, 1, comb);
    4142        TestDFS(prob, 2, comb);
    42 
     43       
    4344        comb = comb + 10 * 4 * comb + comb;
    4445        TestDFS(prob, 3, comb);
    4546        TestDFS(prob, 4, comb);
    46 
     47       
    4748        comb = comb + 10 * 4 * comb + 10; // ((E))
    4849        TestDFS(prob, 5, comb);
    4950        TestDFS(prob, 6, comb);
    50 
    51         comb = comb + 10 * 4 * comb + 10; // (((E)))
     51       
     52        comb = comb + 10 * 4 * comb + 10; // (((E)))  */
    5253        // takes too long
    5354        //TestDFS(prob, 7, comb);
     
    9697        // E -> V | V+E | V-E | V*E | V/E | (E)
    9798        // V -> a .. j
     99        /* grammar has been change ... unit test not yet adapted
    98100        var prob = new SymbolicRegressionPoly10Problem();
    99101        var comb = 10;
     
    109111        TestDFS(prob, 6, comb);
    110112
    111         comb = comb + 10 * 4 * comb + 10; // (((E)))
     113        comb = comb + 10 * 4 * comb + 10; // (((E))) */
    112114        // takes too long
    113115        //TestDFS(prob, 7, comb);
     
    117119
    118120    private void TestDFS(IProblem prob, int len, int numExpectedSols) {
    119       var solver = new ExhaustiveDepthFirstSearch(len);
     121      var solver = new ExhaustiveDepthFirstSearch(prob, len);
    120122      int numSols = 0;
    121123
    122       solver.SolutionEvaluated += (s, d) => { numSols++; };
     124      solver.SolutionEvaluated += (s, d) => { numSols++; Console.WriteLine(s); };
    123125
    124       solver.Run(prob, int.MaxValue);
     126      solver.Run(int.MaxValue);
    125127      Assert.AreEqual(numExpectedSols, numSols);
    126128    }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization/Grammar.cs

    r11727 r11730  
    1919  public class Grammar : IGrammar {
    2020
    21     private readonly Dictionary<char, List<string>> rules;
     21    private readonly Dictionary<char, List<Sequence>> rules;
    2222    private readonly HashSet<char> terminalSymbols;
    2323    private readonly char sentenceSymbol;
    2424    private readonly HashSet<char> nonTerminalSymbols;
    25     private readonly Dictionary<string, int> maxPhraseLength = new Dictionary<string, int>();
    26     private readonly Dictionary<string, int> minPhraseLength = new Dictionary<string, int>();
     25    private readonly Dictionary<Sequence, int> maxPhraseLength = new Dictionary<Sequence, int>();
     26    private readonly Dictionary<Sequence, int> minPhraseLength = new Dictionary<Sequence, int>();
    2727
    2828    public char SentenceSymbol { get { return sentenceSymbol; } }
     
    3333    // cloning ctor
    3434    public Grammar(Grammar orig) {
    35       this.rules = new Dictionary<char, List<string>>(orig.rules);
     35      this.rules = new Dictionary<char, List<Sequence>>();
     36      foreach (var r in orig.rules)
     37        this.rules.Add(r.Key, new List<Sequence>(r.Value.Select(v => new Sequence(v)))); // clone sequences
    3638      this.terminalSymbols = new HashSet<char>(orig.terminalSymbols);
    3739      this.sentenceSymbol = orig.sentenceSymbol;
    3840      this.nonTerminalSymbols = new HashSet<char>(orig.nonTerminalSymbols);
    39       this.maxPhraseLength = new Dictionary<string, int>(orig.maxPhraseLength);
    40       this.minPhraseLength = new Dictionary<string, int>(orig.minPhraseLength);
     41      this.maxPhraseLength = new Dictionary<Sequence, int>();
     42      foreach (var p in orig.maxPhraseLength) this.maxPhraseLength.Add(new Sequence(p.Key), p.Value);
     43      this.minPhraseLength = new Dictionary<Sequence, int>();
     44      foreach (var p in orig.minPhraseLength) this.minPhraseLength.Add(new Sequence(p.Key), p.Value);
    4145    }
    4246
     
    4751      this.terminalSymbols = new HashSet<char>(terminalSymbols);
    4852      this.nonTerminalSymbols = new HashSet<char>(nonTerminalSymbols);
    49       this.rules = new Dictionary<char, List<string>>();
     53      this.rules = new Dictionary<char, List<Sequence>>();
    5054      foreach (var r in rules) {
    51         if (!this.rules.ContainsKey(r.Item1)) this.rules.Add(r.Item1, new List<string>());
    52         this.rules[r.Item1].Add(r.Item2); // here we store an array of symbols for a phase
     55        if (!this.rules.ContainsKey(r.Item1)) this.rules.Add(r.Item1, new List<Sequence>());
     56        this.rules[r.Item1].Add(new Sequence(r.Item2)); // here we store an array of symbols for a phase
    5357      }
    5458
     
    8690          max = Math.Max(max, maxPhraseLength[alt]);
    8791        }
    88         minPhraseLength[nt.ToString()] = min;
    89         maxPhraseLength[nt.ToString()] = max;
    90       }
    91     }
    92 
    93 
    94     public IEnumerable<string> GetAlternatives(char nt) {
     92        minPhraseLength[new Sequence(nt)] = min;
     93        maxPhraseLength[new Sequence(nt)] = max;
     94      }
     95    }
     96
     97
     98    public IEnumerable<Sequence> GetAlternatives(char nt) {
    9599      return rules[nt];
    96100    }
    97101
    98     public IEnumerable<string> GetTerminalAlternatives(char nt) {
     102    public IEnumerable<Sequence> GetTerminalAlternatives(char nt) {
    99103      return GetAlternatives(nt).Where(alt => alt.All(IsTerminal));
    100104    }
    101105
    102     public IEnumerable<string> GetNonTerminalAlternatives(char nt) {
     106    public IEnumerable<Sequence> GetNonTerminalAlternatives(char nt) {
    103107      return GetAlternatives(nt).Where(alt => alt.Any(IsNonTerminal));
    104108    }
    105109
    106110    // caches for this are build in construction of object
    107     public int MinPhraseLength(string phrase) {
     111    public int MinPhraseLength(Sequence phrase) {
    108112      int l;
    109113      if (minPhraseLength.TryGetValue(phrase, out l)) return l;
     
    125129
    126130    // caches for this are build in construction of object
    127     public int MaxPhraseLength(string phrase) {
     131    public int MaxPhraseLength(Sequence phrase) {
    128132      int l;
    129133      if (maxPhraseLength.TryGetValue(phrase, out l)) return l;
     
    152156    }
    153157
    154     public string CompleteSentenceRandomly(Random random, string phrase, int maxLen) {
     158    public Sequence CompleteSentenceRandomly(Random random, Sequence phrase, int maxLen) {
    155159      if (phrase.Length > maxLen) throw new ArgumentException();
    156160      if (MinPhraseLength(phrase) > maxLen) throw new ArgumentException();
    157       bool done = phrase.All(IsTerminal); // terminal phrase means we are done
     161      bool done = phrase.IsTerminal; // terminal phrase means we are done
    158162      while (!done) {
    159         int ntIdx; char nt;
    160         FindFirstNonTerminal(this, phrase, out nt, out ntIdx);
     163        char nt = phrase.FirstNonTerminal;
    161164
    162165        int maxLenOfReplacement = maxLen - (phrase.Length - 1); // replacing aAb with maxLen 4 means we can only use alternatives with a minPhraseLen <= 2
     
    168171        // replace nt with random alternative
    169172        var selectedAlt = alts.SelectRandom(random);
    170         phrase = phrase.Remove(ntIdx, 1);
    171         phrase = phrase.Insert(ntIdx, selectedAlt);
     173        phrase.ReplaceAt(phrase.FirstNonTerminalIndex, 1, selectedAlt);
    172174
    173175        done = phrase.All(IsTerminal); // terminal phrase means we are done
    174176      }
    175177      return phrase;
    176     }
    177 
    178     public static void FindFirstNonTerminal(IGrammar g, string phrase, out char nt, out int ntIdx) {
    179       ntIdx = 0;
    180       while (ntIdx < phrase.Length && g.IsTerminal(phrase[ntIdx])) ntIdx++;
    181       if (ntIdx >= phrase.Length) {
    182         ntIdx = -1;
    183         nt = '\0';
    184       } else {
    185         nt = phrase[ntIdx];
    186       }
    187178    }
    188179
     
    262253      foreach (var r in rules) {
    263254        foreach (var alt in r.Value) {
    264           var phrase = string.Join(" ", alt);
    265           sb.AppendFormat("  {0} -> {1} (min: {2}, max {3})", r.Key, phrase, MinPhraseLength(phrase), MaxPhraseLength(phrase))
     255          sb.AppendFormat("  {0} -> {1} (min: {2}, max {3})", r.Key, alt, MinPhraseLength(alt), MaxPhraseLength(alt))
    266256            .AppendLine();
    267257        }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization.csproj

    r11727 r11730  
    4747    <Compile Include="EvenParityProblem.cs" />
    4848    <Compile Include="SentenceSetStatistics.cs" />
     49    <Compile Include="Sequence.cs" />
    4950    <Compile Include="SymbolicRegressionPoly10Problem.cs" />
    5051    <Compile Include="SantaFeAntProblem.cs" />
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization/IGrammar.cs

    r11659 r11730  
    1414    IEnumerable<char> Symbols { get; }
    1515
    16     IEnumerable<string> GetAlternatives(char nt);
    17     IEnumerable<string> GetTerminalAlternatives(char nt);
    18     IEnumerable<string> GetNonTerminalAlternatives(char nt);
     16    IEnumerable<Sequence> GetAlternatives(char nt);
     17    IEnumerable<Sequence> GetTerminalAlternatives(char nt);
     18    IEnumerable<Sequence> GetNonTerminalAlternatives(char nt);
    1919
    20     int MinPhraseLength(string phrase);
    21     int MaxPhraseLength(string phrase);
    22     string CompleteSentenceRandomly(Random random, string phrase, int maxLen);
     20    int MinPhraseLength(Sequence phrase);
     21    int MaxPhraseLength(Sequence phrase);
     22    Sequence CompleteSentenceRandomly(Random random, Sequence phrase, int maxLen);
    2323
    2424    bool IsTerminal(char symbol);
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization/SantaFeAntProblem.cs

    r11727 r11730  
    112112    private int steps;
    113113    private HeadingEnum heading;
     114
     115
    114116
    115117    public Ant() {
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization/SentenceSetStatistics.cs

    r11727 r11730  
    44using System.Text;
    55using System.Threading.Tasks;
     6using HeuristicLab.Common;
    67
    78namespace HeuristicLab.Problems.GrammaticalOptimization {
     
    1213    public string LastSentence { get; private set; }
    1314    public double BestSentenceQuality { get; private set; }
     15    public double BestSentenceIndex { get; private set; }
    1416    public double FirstSentenceQuality { get; private set; }
    1517    public double LastSentenceQuality { get; private set; }
     
    1921
    2022    public void AddSentence(string sentence, double quality) {
    21       if (NumberOfSentences == 0) {
     23      sumQualities += quality;
     24      NumberOfSentences++;
     25
     26      if (NumberOfSentences == 1) {
    2227        FirstSentence = sentence;
    2328        FirstSentenceQuality = quality;
     
    2732        BestSentence = sentence;
    2833        BestSentenceQuality = quality;
     34        BestSentenceIndex = NumberOfSentences;
    2935      }
    30 
    31       sumQualities += quality;
    32       NumberOfSentences++;
    3336
    3437      LastSentence = sentence;
     
    3841    public override string ToString() {
    3942      return
    40         string.Format("Sentences: {0,10} avg.-quality {1,7:F5} best {2,7:F5} {3} first {4,7:F5} {5} last {6,7:F5} {7}",
     43        string.Format("Sentences: {0,10} avg.-quality {1,7:F5} best {2,7:F5} {3,2} {4,10} {5} first {6,7:F5} {7} last {8,7:F5} {9}",
    4144      NumberOfSentences, AverageQuality,
    42       BestSentenceQuality, BestSentence,
     45      BestSentenceQuality, BestSentenceQuality.IsAlmost(1.0)?1.0:0.0,
     46      BestSentenceIndex, BestSentence,
    4347      FirstSentenceQuality, FirstSentence,
    4448      LastSentenceQuality, LastSentence
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization/SymbolicRegressionPoly10Problem.cs

    r11727 r11730  
    1616    private const string grammarString = @"
    1717    G(E):
    18     E -> a | b | c | d | e | f | g | h | j | a+E | b+E | c+E | d+E | e+E | f+E | g+E | h+E | j+E | a*E | b*E | c*E | d*E | e*E | f*E | g*E | h*E | j*E
     18    E -> a | b | c | d | e | f | g | h | i | j | a+E | b+E | c+E | d+E | e+E | f+E | g+E | h+E | i+E | j+E | a*E | b*E | c*E | d*E | e*E | f*E | g*E | h*E | i*E | j*E
    1919    ";
    2020
     
    4747        }
    4848        // poly-10 no noise
     49        /* a*b + c*d + e*f + a*g*i + c*f*j */
    4950        y[i] = x[i][0] * x[i][1] +
    5051               x[i][2] * x[i][3] +
  • branches/HeuristicLab.Problems.GrammaticalOptimization/Main/Program.cs

    r11727 r11730  
    33using System.Data;
    44using System.Diagnostics;
     5using System.Globalization;
    56using System.Linq;
    67using System.Text;
    78using System.Threading.Tasks;
    89using HeuristicLab.Algorithms.Bandits;
     10using HeuristicLab.Algorithms.Bandits.Models;
    911using HeuristicLab.Algorithms.GrammaticalOptimization;
    1012using HeuristicLab.Problems.GrammaticalOptimization;
     
    1315  class Program {
    1416    static void Main(string[] args) {
    15       // RunDemo();
    16       RunGridTest();
     17      CultureInfo.DefaultThreadCurrentCulture = CultureInfo.InvariantCulture;
     18
     19      RunDemo();
     20      //RunGridTest();
    1721    }
    1822
    1923    private static void RunGridTest() {
    20       int maxIterations = 150000;
    21       var globalRandom = new Random(31415);
    22       var reps = 10;
    23       Parallel.ForEach(new int[] { 1, 5, 10, 100, 500, 1000 }, (randomTries) => {
    24         Random localRand;
    25         lock (globalRandom) {
    26           localRand = new Random(globalRandom.Next());
    27         }
    28         var policyFactories = new Func<int, IPolicy>[]
     24      int maxIterations = 100000; // for poly-10 with 50000 evaluations no successful try with hl yet
     25      // var globalRandom = new Random(31415);
     26      var localRandSeed = 31415;
     27      var reps = 20;
     28
     29      var policyFactories = new Func<Random, int, IPolicy>[]
    2930        {
    30           (numActions) => new RandomPolicy(localRand, numActions),
    31           (numActions) => new UCB1Policy(numActions),
    32           (numActions) => new UCB1TunedPolicy(numActions),
    33           (numActions) => new UCBNormalPolicy(numActions),
    34           (numActions) => new EpsGreedyPolicy(localRand, numActions, 0.01),
    35           (numActions) => new EpsGreedyPolicy(localRand, numActions, 0.05),
    36           (numActions) => new EpsGreedyPolicy(localRand, numActions, 0.1),
    37           (numActions) => new EpsGreedyPolicy(localRand, numActions, 0.2),
    38           (numActions) => new EpsGreedyPolicy(localRand, numActions, 0.5),
    39           (numActions) => new GaussianThompsonSamplingPolicy(localRand, numActions),
    40           (numActions) => new BernoulliThompsonSamplingPolicy(localRand, numActions)
     31          (rand, numActions) => new GaussianThompsonSamplingPolicy(rand, numActions),
     32          (rand, numActions) => new BernoulliThompsonSamplingPolicy(rand, numActions),
     33          (rand, numActions) => new RandomPolicy(rand, numActions),
     34          (rand, numActions) => new EpsGreedyPolicy(rand, numActions, 0.01),
     35          (rand, numActions) => new EpsGreedyPolicy(rand, numActions, 0.05),
     36          (rand, numActions) => new EpsGreedyPolicy(rand, numActions, 0.1),
     37          (rand, numActions) => new EpsGreedyPolicy(rand, numActions, 0.2),
     38          (rand, numActions) => new EpsGreedyPolicy(rand, numActions, 0.5),
     39          (rand, numActions) => new UCTPolicy(numActions, 0.1),
     40          (rand, numActions) => new UCTPolicy(numActions, 0.5),
     41          (rand, numActions) => new UCTPolicy(numActions, 1),
     42          (rand, numActions) => new UCTPolicy(numActions, 2),
     43          (rand, numActions) => new UCTPolicy(numActions, 5),
     44          (rand, numActions) => new UCTPolicy(numActions, 10),
     45          (rand, numActions) => new UCB1Policy(numActions),
     46          (rand, numActions) => new UCB1TunedPolicy(numActions),
     47          (rand, numActions) => new UCBNormalPolicy(numActions),
     48          (rand, numActions) => new BoltzmannExplorationPolicy(rand, numActions, 0.1),
     49          (rand, numActions) => new BoltzmannExplorationPolicy(rand, numActions, 0.5),
     50          (rand, numActions) => new BoltzmannExplorationPolicy(rand, numActions, 1),
     51          (rand, numActions) => new BoltzmannExplorationPolicy(rand, numActions, 5),
     52          (rand, numActions) => new BoltzmannExplorationPolicy(rand, numActions, 10),
     53          (rand, numActions) => new BoltzmannExplorationPolicy(rand, numActions, 20),
     54          (rand, numActions) => new BoltzmannExplorationPolicy(rand, numActions, 100),
     55          (rand, numActions) => new ChernoffIntervalEstimationPolicy(numActions, 0.01),
     56          (rand, numActions) => new ChernoffIntervalEstimationPolicy(numActions, 0.05),
     57          (rand, numActions) => new ChernoffIntervalEstimationPolicy(numActions, 0.1),
     58          (rand, numActions) => new ChernoffIntervalEstimationPolicy(numActions, 0.2),
     59          (rand, numActions) => new ThresholdAscentPolicy(numActions, 10, 0.01),
     60          (rand, numActions) => new ThresholdAscentPolicy(numActions, 10, 0.05),
     61          (rand, numActions) => new ThresholdAscentPolicy(numActions, 10, 0.1),
     62          (rand, numActions) => new ThresholdAscentPolicy(numActions, 10, 0.2),
     63          (rand, numActions) => new ThresholdAscentPolicy(numActions, 100, 0.01),
     64          (rand, numActions) => new ThresholdAscentPolicy(numActions, 100, 0.05),
     65          (rand, numActions) => new ThresholdAscentPolicy(numActions, 100, 0.1),
     66          (rand, numActions) => new ThresholdAscentPolicy(numActions, 100, 0.2),
     67          (rand, numActions) => new ThresholdAscentPolicy(numActions, 1000, 0.01),
     68          (rand, numActions) => new ThresholdAscentPolicy(numActions, 1000, 0.05),
     69          (rand, numActions) => new ThresholdAscentPolicy(numActions, 1000, 0.1),
     70          (rand, numActions) => new ThresholdAscentPolicy(numActions, 1000, 0.2),
     71          (rand, numActions) => new ThresholdAscentPolicy(numActions, 5000, 0.01),
     72          (rand, numActions) => new ThresholdAscentPolicy(numActions, 10000, 0.01),
    4173        };
    4274
    43         foreach (var policyFactory in policyFactories)
    44           for (int i = 0; i < reps; i++) {
     75      var tasks = new List<Task>();
     76      foreach (var randomTries in new int[] { 1, 10, /* 5, 100 /*, 500, 1000 */}) {
     77        foreach (var policyFactory in policyFactories) {
     78          var myPolicyFactory = policyFactory;
     79          var myRandomTries = randomTries;
     80          var localRand = new Random(localRandSeed);
     81          var options = new ParallelOptions();
     82          options.MaxDegreeOfParallelism = 1;
     83          Parallel.For(0, reps, options, (i) => {
     84            //var t = Task.Run(() => {
     85            Random myLocalRand;
     86            lock (localRand)
     87              myLocalRand = new Random(localRand.Next());
     88
     89            //for (int i = 0; i < reps; i++) {
     90
    4591            int iterations = 0;
    4692            var sw = new Stopwatch();
    4793            var globalStatistics = new SentenceSetStatistics();
    4894
    49             // var problem = new SymbolicRegressionPoly10Problem();
    50             var problem = new SantaFeAntProblem();
     95            var problem = new SymbolicRegressionPoly10Problem();
     96            //var problem = new SantaFeAntProblem();
    5197            //var problem = new PalindromeProblem();
    5298            //var problem = new HardPalindromeProblem();
    5399            //var problem = new RoyalPairProblem();
    54100            //var problem = new EvenParityProblem();
    55             var alg = new MctsSampler(problem, 17, localRand, randomTries, policyFactory);
     101            var alg = new MctsSampler(problem, 25, myLocalRand, myRandomTries, myPolicyFactory);
    56102            //var alg = new ExhaustiveBreadthFirstSearch(problem, 25);
    57103            //var alg = new AlternativesContextSampler(problem, 25);
     
    61107              globalStatistics.AddSentence(sentence, quality);
    62108              if (iterations % 10000 == 0) {
    63                 Console.WriteLine("{0} {1} {2}", randomTries, policyFactory(1), globalStatistics);
     109                Console.WriteLine("{0,4} {1,7} {2,5} {3,25} {4}", alg.treeDepth, alg.treeSize, myRandomTries, myPolicyFactory(myLocalRand, 1), globalStatistics);
    64110              }
    65111            };
     
    70116
    71117            sw.Stop();
    72           }
    73       });
     118            //Console.WriteLine("{0,5} {1} {2}", randomTries, policyFactory(1), globalStatistics);
     119            //}
     120            //});
     121            //tasks.Add(t);
     122          });
     123        }
     124      }
     125      //Task.WaitAll(tasks.ToArray());
    74126    }
    75127
    76128    private static void RunDemo() {
    77       // TODO: implement threshold ascent
    78       // TODO: implement inspection for MCTS
     129      // TODO: warum funktioniert die alte Implementierung von GaussianThompson besser für SantaFe als alte? Siehe Vergleich: alte vs. neue implementierung GaussianThompsonSampling
     130      // TODO: why does GaussianThompsonSampling work so well with MCTS for the artificial ant problem?
     131      // TODO: wie kann ich sampler noch vergleichen bzw. was kann man messen um die qualität des samplers abzuschätzen (bis auf qualität und iterationen bis zur besten lösung) => ziel schnellere iterationen zu gutem ergebnis
     132      // TODO: likelihood für R=1 bei Gaussian oder GaussianMixture einfach berechenbar?
     133      // TODO: research thompson sampling for max bandit?
     134      // TODO: ausführlicher test von strategien für k-armed max bandit
     135      // TODO: verify TA implementation using example from the original paper
     136      // TODO: reference HL.ProblemInstances and try on tower dataset
     137      // TODO: compare results for different policies also for the symb-reg problem
     138      // TODO: separate policy from MCTS tree data structure to allow sharing of information over disconnected parts of the tree (semantic equivalence)
     139      // TODO: implement thompson sampling for gaussian mixture models
     140      // TODO: implement inspection for MCTS (eventuell interactive command line für statistiken aus dem baum anzeigen)
     141      // TODO: implement ACO-style bandit policy
     142      // TODO: implement sequences that can be manipulated in-place (instead of strings), alternatives are also stored as sequences, for a sequence the index of the first NT-symb can be stored
     143      // TODO: gleichzeitige modellierung von transformierter zielvariable (y, 1/y, log(y), exp(y), sqrt(y), ...)
     144      // TODO: vergleich bei complete-randomly möglichst kurze sätze generieren vs. einfach zufällig alternativen wählen
     145      // TODO: reward discounting (für veränderliche reward distributions über zeit). speziellen unit-test dafür erstellen
     146
    79147
    80148      int maxIterations = 10000000;
     
    84152      string bestSentence = "";
    85153      var globalStatistics = new SentenceSetStatistics();
    86       var random = new Random(31415);
    87 
    88       // var problem = new SymbolicRegressionPoly10Problem();
     154      var random = new Random();
     155
     156      //var problem = new SymbolicRegressionPoly10Problem();
    89157      var problem = new SantaFeAntProblem();
    90158      //var problem = new PalindromeProblem();
     
    92160      //var problem = new RoyalPairProblem();
    93161      //var problem = new EvenParityProblem();
    94       var alg = new MctsSampler(problem, 17, random);
    95       //var alg = new ExhaustiveBreadthFirstSearch(problem, 25);
    96       //var alg = new AlternativesContextSampler(problem, 25);
     162      //var alg = new MctsSampler(problem, 17, random, 1, (rand, numActions) => new GenericThompsonSamplingPolicy(rand, numActions, new GaussianModel(numActions, 0.5, 10)));
     163      //var alg = new ExhaustiveBreadthFirstSearch(problem, 17);
     164      //var alg = new AlternativesContextSampler(problem, random, 17, 4, (rand, numActions) => new RandomPolicy(rand, numActions));
     165      //var alg = new ExhaustiveDepthFirstSearch(problem, 17);
     166      // var alg = new AlternativesSampler(problem, 17);
     167      var alg = new RandomSearch(problem, random, 17);
    97168
    98169      alg.FoundNewBestSolution += (sentence, quality) => {
     
    104175        iterations++;
    105176        globalStatistics.AddSentence(sentence, quality);
     177        if (iterations % 1000 == 0) {
     178          //alg.PrintStats();
     179        }
    106180        if (iterations % 10000 == 0) {
    107181          //Console.WriteLine("{0,10} {1,10:F5} {2,10:F5} {3}", iterations, bestQuality, quality, sentence);
    108           Console.WriteLine(globalStatistics.ToString());
     182          //Console.WriteLine("{0,4} {1,7} {2}", alg.treeDepth, alg.treeSize, globalStatistics);
     183          Console.WriteLine(globalStatistics);
    109184        }
    110185      };
Note: See TracChangeset for help on using the changeset viewer.