Ignore:
Timestamp:
01/02/15 16:08:21 (8 years ago)
Author:
gkronber
Message:

#2283: several major extensions for grammatical optimization

Location:
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits
Files:
12 added
13 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Bandits/BernoulliBandit.cs

    r11711 r11730  
    66
    77namespace HeuristicLab.Algorithms.Bandits {
    8   public class BernoulliBandit {
     8  public class BernoulliBandit : IBandit {
    99    public int NumArms { get; private set; }
    1010    public double OptimalExpectedReward { get; private set; } // reward of the best arm, for calculating regret
     11    public int OptimalExpectedRewardArm { get; private set; }
     12    // the arm with highest expected reward also has the highest probability of return a reward of 1.0
     13    public int OptimalMaximalRewardArm { get { return OptimalExpectedRewardArm; } }
     14
    1115    private readonly Random random;
    1216    private readonly double[] expReward;
     
    1923      for (int i = 0; i < nArms; i++) {
    2024        expReward[i] = random.NextDouble();
    21         if (expReward[i] > OptimalExpectedReward) OptimalExpectedReward = expReward[i];
     25        if (expReward[i] > OptimalExpectedReward) {
     26          OptimalExpectedReward = expReward[i];
     27          OptimalExpectedRewardArm = i;
     28        }
    2229      }
    2330    }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Bandits/TruncatedNormalBandit.cs

    r11711 r11730  
    44using System.Text;
    55using System.Threading.Tasks;
     6using HeuristicLab.Common;
    67
    78namespace HeuristicLab.Algorithms.Bandits {
    8   public class TruncatedNormalBandit {
     9  public class TruncatedNormalBandit : IBandit {
    910    public int NumArms { get; private set; }
    1011    public double OptimalExpectedReward { get; private set; } // reward of the best arm, for calculating regret
     12    public int OptimalExpectedRewardArm { get; private set; }
     13    // the arm with highest expected reward also has the highest probability of return a reward of 1.0
     14    public int OptimalMaximalRewardArm { get { return OptimalExpectedRewardArm; } }
     15
    1116    private readonly Random random;
    1217    private readonly double[] expReward;
     
    1823      OptimalExpectedReward = double.NegativeInfinity;
    1924      for (int i = 0; i < nArms; i++) {
    20         expReward[i] = random.NextDouble();
    21         if (expReward[i] > OptimalExpectedReward) OptimalExpectedReward = expReward[i];
     25        expReward[i] = random.NextDouble() * 0.7;
     26        if (expReward[i] > OptimalExpectedReward) {
     27          OptimalExpectedReward = expReward[i];
     28          OptimalExpectedRewardArm = i;
     29        }
    2230      }
    2331    }
     
    2836      double x = 0;
    2937      do {
    30         var z = Transform(random.NextDouble(), random.NextDouble());
     38        var z = Rand.RandNormal(random);
    3139        x = z * 0.1 + expReward[arm];
    3240      }
     
    3442      return x;
    3543    }
    36 
    37     // box muller transform
    38     private double Transform(double u1, double u2) {
    39       return Math.Sqrt(-2 * Math.Log(u1)) * Math.Cos(2 * Math.PI * u2);
    40     }
    4144  }
    4245}
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/HeuristicLab.Algorithms.Bandits.csproj

    r11727 r11730  
    4040  </ItemGroup>
    4141  <ItemGroup>
     42    <Compile Include="BanditHelper.cs" />
    4243    <Compile Include="Bandits\BernoulliBandit.cs" />
     44    <Compile Include="Bandits\GaussianMixtureBandit.cs" />
     45    <Compile Include="Bandits\IBandit.cs" />
    4346    <Compile Include="Bandits\TruncatedNormalBandit.cs" />
     47    <Compile Include="Models\BernoulliModel.cs" />
     48    <Compile Include="Models\GaussianModel.cs" />
     49    <Compile Include="Models\GaussianMixtureModel.cs" />
     50    <Compile Include="Models\IModel.cs" />
    4451    <Compile Include="Policies\BanditPolicy.cs" />
    4552    <Compile Include="Policies\BernoulliThompsonSamplingPolicy.cs" />
     53    <Compile Include="Policies\BoltzmannExplorationPolicy.cs" />
     54    <Compile Include="Policies\ChernoffIntervalEstimationPolicy.cs" />
     55    <Compile Include="Policies\GenericThompsonSamplingPolicy.cs" />
     56    <Compile Include="Policies\ThresholdAscentPolicy.cs" />
     57    <Compile Include="Policies\UCTPolicy.cs" />
    4658    <Compile Include="Policies\GaussianThompsonSamplingPolicy.cs" />
    4759    <Compile Include="Policies\Exp3Policy.cs" />
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/IPolicy.cs

    r11727 r11730  
    2020    // reset causes the policy to be reinitialized to it's initial state (as after constructor-call)
    2121    void Reset();
     22
     23    void PrintStats();
    2224  }
    2325}
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/BanditPolicy.cs

    r11727 r11730  
    2828      Actions = Enumerable.Range(0, numInitialActions).ToArray();
    2929    }
     30
     31    public abstract void PrintStats();
    3032  }
    3133}
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/BernoulliThompsonSamplingPolicy.cs

    r11727 r11730  
    5555      Array.Clear(failure, 0, failure.Length);
    5656    }
     57
     58    public override void PrintStats() {
     59      for (int i = 0; i < success.Length; i++) {
     60        if (success[i] >= 0) {
     61          Console.Write("{0,5:F2}", success[i] / failure[i]);
     62        } else {
     63          Console.Write("{0,5}", "");
     64        }
     65      }
     66      Console.WriteLine();
     67    }
     68
     69    public override string ToString() {
     70      return "BernoulliThompsonSamplingPolicy";
     71    }
    5772  }
    5873}
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/EpsGreedyPolicy.cs

    r11727 r11730  
    2727      if (random.NextDouble() > eps) {
    2828        // select best
    29         var maxReward = double.NegativeInfinity;
     29        var bestQ = double.NegativeInfinity;
    3030        int bestAction = -1;
    3131        foreach (var a in Actions) {
    3232          if (tries[a] == 0) return a;
    33           var avgReward = sumReward[a] / tries[a];
    34           if (maxReward < avgReward) {
    35             maxReward = avgReward;
     33          var q = sumReward[a] / tries[a];
     34          if (bestQ < q) {
     35            bestQ = q;
    3636            bestAction = a;
    3737          }
     
    6565      Array.Clear(sumReward, 0, sumReward.Length);
    6666    }
     67    public override void PrintStats() {
     68      for (int i = 0; i < sumReward.Length; i++) {
     69        if (tries[i] >= 0) {
     70          Console.Write(" {0,5:F2} {1}", sumReward[i] / tries[i], tries[i]);
     71        } else {
     72          Console.Write("-", "");
     73        }
     74      }
     75      Console.WriteLine();
     76    }
     77    public override string ToString() {
     78      return string.Format("EpsGreedyPolicy({0:F2})", eps);
     79    }
    6780  }
    6881}
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/Exp3Policy.cs

    r11727 r11730  
    5252      foreach (var a in Actions) w[a] = 1.0;
    5353    }
     54    public override void PrintStats() {
     55      for (int i = 0; i < w.Length; i++) {
     56        if (w[i] > 0) {
     57          Console.Write("{0,5:F2}", w[i]);
     58        } else {
     59          Console.Write("{0,5}", "");
     60        }
     61      }
     62      Console.WriteLine();
     63    }
     64    public override string ToString() {
     65      return "Exp3Policy";
     66    }
    5467  }
    5568}
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/GaussianThompsonSamplingPolicy.cs

    r11727 r11730  
    55
    66namespace HeuristicLab.Algorithms.Bandits {
     7 
    78  public class GaussianThompsonSamplingPolicy : BanditPolicy {
    89    private readonly Random random;
    9     private readonly double[] sumRewards;
    10     private readonly double[] sumSqrRewards;
     10    private readonly double[] sampleMean;
     11    private readonly double[] sampleM2;
    1112    private readonly int[] tries;
    12     public GaussianThompsonSamplingPolicy(Random random, int numActions)
     13    private bool compatibility;
     14
     15    // assumes a Gaussian reward distribution with different means but the same variances for each action
     16    // the prior for the mean is also Gaussian with the following parameters
     17    private readonly double rewardVariance = 0.1; // we assume a known variance
     18
     19    private readonly double priorMean = 0.5;
     20    private readonly double priorVariance = 1;
     21
     22
     23    public GaussianThompsonSamplingPolicy(Random random, int numActions, bool compatibility = false)
    1324      : base(numActions) {
    1425      this.random = random;
    15       this.sumRewards = new double[numActions];
    16       this.sumSqrRewards = new double[numActions];
     26      this.sampleMean = new double[numActions];
     27      this.sampleM2 = new double[numActions];
    1728      this.tries = new int[numActions];
     29      this.compatibility = compatibility;
    1830    }
    1931
     
    2436      int bestAction = -1;
    2537      foreach (var a in Actions) {
    26         if (tries[a] == 0) return a;
    27         var mu = sumRewards[a] / tries[a];
    28         var stdDev = Math.Sqrt(sumSqrRewards[a] / tries[a] - Math.Pow(mu, 2));
    29         var theta = Rand.RandNormal(random) * stdDev + mu;
     38        if(tries[a] == -1) continue; // skip disabled actions
     39        double theta;
     40        if (compatibility) {
     41          if (tries[a] < 2) return a;
     42          var mu = sampleMean[a];
     43          var variance = sampleM2[a] / tries[a];
     44          var stdDev = Math.Sqrt(variance);
     45          theta = Rand.RandNormal(random) * stdDev + mu;
     46        } else {
     47          // calculate posterior mean and variance (for mean reward)
     48
     49          // see Murphy 2007: Conjugate Bayesian analysis of the Gaussian distribution (http://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf)
     50          var posteriorVariance = 1.0 / (tries[a] / rewardVariance + 1.0 / priorVariance);
     51          var posteriorMean = posteriorVariance * (priorMean / priorVariance + tries[a] * sampleMean[a] / rewardVariance);
     52
     53          // sample a mean from the posterior
     54          theta = Rand.RandNormal(random) * Math.Sqrt(posteriorVariance) + posteriorMean;
     55
     56          // theta already represents the expected reward value => nothing else to do
     57        }
    3058        if (theta > maxTheta) {
    3159          maxTheta = theta;
     
    3361        }
    3462      }
     63      Debug.Assert(Actions.Contains(bestAction));
    3564      return bestAction;
    3665    }
     
    3867    public override void UpdateReward(int action, double reward) {
    3968      Debug.Assert(Actions.Contains(action));
    40 
    41       sumRewards[action] += reward;
    42       sumSqrRewards[action] += reward * reward;
    4369      tries[action]++;
     70      var delta = reward - sampleMean[action];
     71      sampleMean[action] += delta / tries[action];
     72      sampleM2[action] += sampleM2[action] + delta * (reward - sampleMean[action]);
    4473    }
    4574
    4675    public override void DisableAction(int action) {
    4776      base.DisableAction(action);
    48       sumRewards[action] = 0;
    49       sumSqrRewards[action] = 0;
     77      sampleMean[action] = 0;
     78      sampleM2[action] = 0;
    5079      tries[action] = -1;
    5180    }
     
    5382    public override void Reset() {
    5483      base.Reset();
    55       Array.Clear(sumRewards, 0, sumRewards.Length);
    56       Array.Clear(sumSqrRewards, 0, sumSqrRewards.Length);
     84      Array.Clear(sampleMean, 0, sampleMean.Length);
     85      Array.Clear(sampleM2, 0, sampleM2.Length);
    5786      Array.Clear(tries, 0, tries.Length);
     87    }
     88
     89    public override void PrintStats() {
     90      for (int i = 0; i < sampleMean.Length; i++) {
     91        if (tries[i] >= 0) {
     92          Console.Write(" {0,5:F2} {1}", sampleMean[i] / tries[i], tries[i]);
     93        } else {
     94          Console.Write("{0,5}", "");
     95        }
     96      }
     97      Console.WriteLine();
     98    }
     99    public override string ToString() {
     100      return "GaussianThompsonSamplingPolicy";
    58101    }
    59102  }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/RandomPolicy.cs

    r11727 r11730  
    2323      // do nothing
    2424    }
    25 
     25    public override void PrintStats() {
     26      Console.WriteLine("Random");
     27    }
     28    public override string ToString() {
     29      return "RandomPolicy";
     30    }
    2631  }
    2732}
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCB1Policy.cs

    r11727 r11730  
    5050      Array.Clear(sumReward, 0, sumReward.Length);
    5151    }
     52    public override void PrintStats() {
     53      for (int i = 0; i < sumReward.Length; i++) {
     54        if (tries[i] >= 0) {
     55          Console.Write("{0,5:F2}", sumReward[i] / tries[i]);
     56        } else {
     57          Console.Write("{0,5}", "");
     58        }
     59      }
     60      Console.WriteLine();
     61    }
     62    public override string ToString() {
     63      return "UCB1Policy";
     64    }
    5265  }
    5366}
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCB1TunedPolicy.cs

    r11727 r11730  
    6262      Array.Clear(sumSqrReward, 0, sumSqrReward.Length);
    6363    }
     64    public override void PrintStats() {
     65      for (int i = 0; i < sumReward.Length; i++) {
     66        if (tries[i] >= 0) {
     67          Console.Write("{0,5:F2}", sumReward[i] / tries[i]);
     68        } else {
     69          Console.Write("{0,5}", "");
     70        }
     71      }
     72      Console.WriteLine();
     73    }
     74    public override string ToString() {
     75      return "UCB1TunedPolicy";
     76    }
    6477  }
    6578}
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCBNormalPolicy.cs

    r11727 r11730  
    2424      double bestQ = double.NegativeInfinity;
    2525      foreach (var a in Actions) {
    26         if (totalTries == 0 || tries[a] == 0 || tries[a] < Math.Ceiling(8 * Math.Log(totalTries))) return a;
     26        if (totalTries <= 1 || tries[a] <= 1 || tries[a] <= Math.Ceiling(8 * Math.Log(totalTries))) return a;
    2727        var avgReward = sumReward[a] / tries[a];
     28        var estVariance = 16 * ((sumSqrReward[a] - tries[a] * Math.Pow(avgReward, 2)) / (tries[a] - 1)) * (Math.Log(totalTries - 1) / tries[a]);
     29        if (estVariance < 0) estVariance = 0; // numerical problems
    2830        var q = avgReward
    29           + Math.Sqrt(16 * ((sumSqrReward[a] - tries[a] * Math.Pow(avgReward, 2)) / (tries[a] - 1)) * (Math.Log(totalTries - 1) / tries[a]));
     31          + Math.Sqrt(estVariance);
    3032        if (q > bestQ) {
    3133          bestQ = q;
     
    3335        }
    3436      }
     37      Debug.Assert(Actions.Contains(bestAction));
    3538      return bestAction;
    3639    }
     
    5861      Array.Clear(sumSqrReward, 0, sumSqrReward.Length);
    5962    }
     63    public override void PrintStats() {
     64      for (int i = 0; i < sumReward.Length; i++) {
     65        if (tries[i] >= 0) {
     66          Console.Write("{0,5:F2}", sumReward[i] / tries[i]);
     67        } else {
     68          Console.Write("{0,5}", "");
     69        }
     70      }
     71      Console.WriteLine();
     72    }
     73    public override string ToString() {
     74      return "UCBNormalPolicy";
     75    }
    6076  }
    6177}
Note: See TracChangeset for help on using the changeset viewer.