Changeset 11744


Ignore:
Timestamp:
01/09/15 16:54:05 (7 years ago)
Author:
gkronber
Message:

#2283 worked on TD, and models for MCTS

Location:
branches/HeuristicLab.Problems.GrammaticalOptimization
Files:
3 added
7 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/ModelPolicyActionInfo.cs

    r11742 r11744  
    2020    public void UpdateReward(double reward) {
    2121      Debug.Assert(!Disabled);
     22      Tries++;
    2223      model.Update(reward);
    2324    }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/ThresholdAscentPolicy.cs

    r11742 r11744  
    7777    }
    7878
    79     private double U(double mu, int totalTries, int n, int k) {
     79    private double U(double mu, double totalTries, int n, int k) {
    8080      //var alpha = Math.Log(2.0 * totalTries * k / delta);
    8181      double alpha = Math.Log(2) + Math.Log(totalTries) + Math.Log(k) - Math.Log(delta);
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/HeuristicLab.Algorithms.Bandits.csproj

    r11742 r11744  
    6969    <Compile Include="IBanditPolicy.cs" />
    7070    <Compile Include="IBanditPolicyActionInfo.cs" />
     71    <Compile Include="Models\GaussianMixtureModel.cs" />
     72    <Compile Include="Models\LogitNormalModel.cs" />
    7173    <Compile Include="OnlineMeanAndVarianceEstimator.cs" />
    7274    <Compile Include="Models\BernoulliModel.cs" />
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Models/GaussianMixtureModel.cs

    r11730 r11744  
    99namespace HeuristicLab.Algorithms.Bandits.Models {
    1010  public class GaussianMixtureModel : IModel {
    11     private readonly int numActions;
    12     private readonly double[][] meanMean; // mean of mean for each arm and component
    13     private readonly double[][] meanVariance; // variance of mean for each arm and component
    14     private readonly double[][] componentProb;
     11    private readonly double[] componentMeans;
     12    private readonly double[] componentVars;
     13    private readonly double[] componentProbs;
    1514
    16     // parameters of beta prior distribution
    1715    private int numComponents;
    18     private double priorMean;
    1916
    20     public GaussianMixtureModel(int numActions, double priorMean = 0.5, int nComponents = 5) {
    21       this.numActions = numActions;
     17    public GaussianMixtureModel(int nComponents = 5) {
    2218      this.numComponents = nComponents;
    23       this.priorMean = priorMean;
    24       this.meanMean = new double[numActions][];
    25       this.meanVariance = new double[numActions][];
    26       this.componentProb = new double[numActions][];
    27       for (int a = 0; a < numActions; a++) {
    28         // TODO: probably need to initizalize this randomly to allow learning
    29         meanMean[a] = Enumerable.Repeat(priorMean, nComponents).ToArray();
    30         meanVariance[a] = Enumerable.Repeat(1.0, nComponents).ToArray(); // prior variance of mean variance = 1
    31         componentProb[a] = Enumerable.Repeat(1.0 / nComponents, nComponents).ToArray(); // uniform prior for component probabilities
    32       }
     19      this.componentProbs = new double[nComponents];
     20      this.componentMeans = new double[nComponents];
     21      this.componentVars = new double[nComponents];
    3322    }
    3423
    3524
    36     public double[] SampleExpectedRewards(Random random) {
    37       // sample mean foreach action and component from the prior
    38       var exp = new double[numActions];
    39       for (int a = 0; a < numActions; a++) {
    40         var sumReward = 0.0;
    41         var numSamples = 10000;
    42         var sampledComponents = Enumerable.Range(0, numComponents).SampleProportional(random, componentProb[a]).Take(numSamples);
    43         foreach (var k in sampledComponents) {
    44           sumReward += Rand.RandNormal(random) * Math.Sqrt(meanVariance[a][k]) + meanMean[a][k];
    45         }
    46         exp[a] = sumReward / (double)numSamples;
    47       }
    48 
    49       return exp;
     25    public double SampleExpectedReward(Random random) {
     26      var k = Enumerable.Range(0, numComponents).SampleProportional(random, componentProbs).First();
     27      return alglib.invnormaldistribution(random.NextDouble()) * Math.Sqrt(componentVars[k]) + componentMeans[k];
    5028    }
    5129
    52     public void Update(int action, double reward) {
     30    public void Update(double reward) {
    5331      // see http://www.cs.toronto.edu/~mackay/itprnn/ps/302.320.pdf Algorithm 22.2 soft k-means
    5432      throw new NotImplementedException();
    5533    }
    5634
    57     public void Disable(int action) {
    58       Array.Clear(meanMean[action], 0, meanMean[action].Length);
    59       Array.Clear(meanVariance[action], 0, meanVariance[action].Length);
     35    public void Disable() {
     36      Array.Clear(componentMeans, 0, numComponents);
     37      for (int i = 0; i < numComponents; i++)
     38        componentVars[i] = 0.0;
     39    }
     40
     41    public object Clone() {
     42      return new GaussianMixtureModel(numComponents);
    6043    }
    6144
    6245    public void Reset() {
    63       Array.Clear(meanMean, 0, meanMean.Length);
    64       Array.Clear(meanVariance, 0, meanVariance.Length);
     46      Array.Clear(componentMeans, 0, numComponents);
     47      Array.Clear(componentVars, 0, numComponents);
     48      Array.Clear(componentProbs, 0, numComponents);
    6549    }
    6650
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization.csproj

    r11742 r11744  
    4545    <Compile Include="AlternativesSampler.cs" />
    4646    <Compile Include="AlternativesContextSampler.cs" />
     47    <Compile Include="TemporalDifferenceTreeSearchSampler.cs" />
    4748    <Compile Include="ExhaustiveRandomFirstSearch.cs" />
    4849    <Compile Include="MctsContextualSampler.cs">
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/MctsSampler.cs

    r11742 r11744  
    4040    public int treeDepth;
    4141    public int treeSize;
     42    private double bestQuality;
    4243
    4344    // public MctsSampler(IProblem problem, int maxLen, Random random) :
     
    5556
    5657    public void Run(int maxIterations) {
    57       double bestQuality = double.MinValue;
     58      bestQuality = double.MinValue;
    5859      InitPolicies(problem.Grammar);
    5960      for (int i = 0; !rootNode.done && i < maxIterations; i++) {
     
    7778    public void PrintStats() {
    7879      var n = rootNode;
    79       Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}", treeDepth, treeSize, n.actionInfo.Tries);
     80      Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}, rootQ {3:F3}, bestQ {4:F3}", treeDepth, treeSize, n.actionInfo.Tries, n.actionInfo.Value, bestQuality);
    8081      while (n.children != null) {
    8182        Console.WriteLine();
     
    8687        n = n.children.Where(ch => !ch.done).OrderByDescending(c => c.actionInfo.Value).First();
    8788      }
    88       Console.ReadLine();
    8989    }
    9090
  • branches/HeuristicLab.Problems.GrammaticalOptimization/Main/Program.cs

    r11742 r11744  
    175175      //var problem = new RoyalPairProblem();
    176176      //var problem = new EvenParityProblem();
    177       var alg = new MctsSampler(problem, 25, random, 0, new GaussianThompsonSamplingPolicy(true));
     177      var alg = new MctsSampler(problem, 25, random, 0, new GenericThompsonSamplingPolicy(new LogitNormalModel()));
     178      //var alg = new TemporalDifferenceTreeSearchSampler(problem, 23, random, 0, new RandomPolicy());
    178179      //var alg = new ExhaustiveBreadthFirstSearch(problem, 17);
    179180      //var alg = new AlternativesContextSampler(problem, random, 17, 4, (rand, numActions) => new RandomPolicy(rand, numActions));
     
    191192        iterations++;
    192193        globalStatistics.AddSentence(sentence, quality);
    193         if (iterations % 1000 == 0) {
     194        if (iterations % 100 == 0) {
     195          Console.Clear();
    194196          alg.PrintStats();
    195197        }
Note: See TracChangeset for help on using the changeset viewer.