Ignore:
Timestamp:
12/29/14 11:02:36 (8 years ago)
Author:
gkronber
Message:

#2283: worked on grammatical optimization problem solvers (simple MCTS done)

Location:
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits
Files:
3 added
8 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/HeuristicLab.Algorithms.Bandits.csproj

    r11711 r11727  
    4343    <Compile Include="Bandits\TruncatedNormalBandit.cs" />
    4444    <Compile Include="Policies\BanditPolicy.cs" />
     45    <Compile Include="Policies\BernoulliThompsonSamplingPolicy.cs" />
     46    <Compile Include="Policies\GaussianThompsonSamplingPolicy.cs" />
     47    <Compile Include="Policies\Exp3Policy.cs" />
    4548    <Compile Include="Policies\EpsGreedyPolicy.cs" />
    4649    <Compile Include="Policies\RandomPolicy.cs" />
     
    5154    <Compile Include="Properties\AssemblyInfo.cs" />
    5255  </ItemGroup>
    53   <ItemGroup />
     56  <ItemGroup>
     57    <ProjectReference Include="..\HeuristicLab.Common\HeuristicLab.Common.csproj">
     58      <Project>{3A2FBBCB-F9DF-4970-87F3-F13337D941AD}</Project>
     59      <Name>HeuristicLab.Common</Name>
     60    </ProjectReference>
     61  </ItemGroup>
    5462  <Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" />
    5563  <!-- To modify your build process, add your task inside one of the targets below and uncomment it.
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/IPolicy.cs

    r11708 r11727  
    66
    77namespace HeuristicLab.Algorithms.Bandits {
     8  // this interface represents a policy for reinforcement learning
    89  public interface IPolicy {
    9     int SelectAction();
    10     void UpdateReward(int action, double reward);
     10    IEnumerable<int> Actions { get; }
     11    int SelectAction(); // action selection ...
     12    void UpdateReward(int action, double reward); // ... and reward update are defined as usual
     13
     14    // policies must also support disabling of potential actions
     15    // for instance if we know that an action in a state has a deterministic
     16    // reward we need to sample it only once
     17    // it is necessary to sample an action only once
     18    void DisableAction(int action);
     19
     20    // reset causes the policy to be reinitialized to it's initial state (as after constructor-call)
    1121    void Reset();
    1222  }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/BanditPolicy.cs

    r11711 r11727  
    11using System;
    22using System.Collections.Generic;
     3using System.Diagnostics;
    34using System.Linq;
    45using System.Text;
     
    78namespace HeuristicLab.Algorithms.Bandits {
    89  public abstract class BanditPolicy : IPolicy {
    9     public int NumActions { get; private set; }
    10     public BanditPolicy(int numActions) {
    11       this.NumActions = numActions;
     10    public IEnumerable<int> Actions { get; private set; }
     11    private readonly int numInitialActions;
     12
     13    protected BanditPolicy(int numActions) {
     14      this.numInitialActions = numActions;
     15      Actions = Enumerable.Range(0, numActions).ToArray();
    1216    }
    1317
    1418    public abstract int SelectAction();
    1519    public abstract void UpdateReward(int action, double reward);
    16     public abstract void Reset();
     20
     21    public virtual void DisableAction(int action) {
     22      Debug.Assert(Actions.Contains(action));
     23
     24      Actions = Actions.Where(a => a != action).ToArray();
     25    }
     26
     27    public virtual void Reset() {
     28      Actions = Enumerable.Range(0, numInitialActions).ToArray();
     29    }
    1730  }
    1831}
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/EpsGreedyPolicy.cs

    r11711 r11727  
    11using System;
    22using System.Collections.Generic;
     3using System.Diagnostics;
    34using System.Linq;
    45using System.Text;
     
    1112    private readonly int[] tries;
    1213    private readonly double[] sumReward;
     14    private readonly RandomPolicy randomPolicy;
     15
    1316    public EpsGreedyPolicy(Random random, int numActions, double eps)
    1417      : base(numActions) {
    1518      this.random = random;
    1619      this.eps = eps;
    17       this.tries = new int[NumActions];
    18       this.sumReward = new double[NumActions];
     20      this.randomPolicy = new RandomPolicy(random, numActions);
     21      this.tries = new int[numActions];
     22      this.sumReward = new double[numActions];
    1923    }
    2024
    2125    public override int SelectAction() {
     26      Debug.Assert(Actions.Any());
    2227      if (random.NextDouble() > eps) {
    2328        // select best
    2429        var maxReward = double.NegativeInfinity;
    2530        int bestAction = -1;
    26         for (int i = 0; i < NumActions; i++) {
    27           if (tries[i] == 0) return i;
    28           var avgReward = sumReward[i] / tries[i];
     31        foreach (var a in Actions) {
     32          if (tries[a] == 0) return a;
     33          var avgReward = sumReward[a] / tries[a];
    2934          if (maxReward < avgReward) {
    3035            maxReward = avgReward;
    31             bestAction = i;
     36            bestAction = a;
    3237          }
    3338        }
     39        Debug.Assert(bestAction >= 0);
    3440        return bestAction;
    3541      } else {
    3642        // select random
    37         return random.Next(NumActions);
     43        return randomPolicy.SelectAction();
    3844      }
    3945    }
    4046    public override void UpdateReward(int action, double reward) {
     47      Debug.Assert(Actions.Contains(action));
     48
     49      randomPolicy.UpdateReward(action, reward); // does nothing
    4150      tries[action]++;
    4251      sumReward[action] += reward;
    4352    }
     53
     54    public override void DisableAction(int action) {
     55      base.DisableAction(action);
     56      randomPolicy.DisableAction(action);
     57      sumReward[action] = 0;
     58      tries[action] = -1;
     59    }
     60
    4461    public override void Reset() {
     62      base.Reset();
     63      randomPolicy.Reset();
    4564      Array.Clear(tries, 0, tries.Length);
    4665      Array.Clear(sumReward, 0, sumReward.Length);
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/RandomPolicy.cs

    r11711 r11727  
    11using System;
    22using System.Collections.Generic;
     3using System.Diagnostics;
    34using System.Linq;
    45using System.Text;
    56using System.Threading.Tasks;
     7using HeuristicLab.Common;
    68
    79namespace HeuristicLab.Algorithms.Bandits {
    810  public class RandomPolicy : BanditPolicy {
    911    private readonly Random random;
     12
    1013    public RandomPolicy(Random random, int numActions)
    1114      : base(numActions) {
     
    1417
    1518    public override int SelectAction() {
    16       return random.Next(NumActions);
     19      Debug.Assert(Actions.Any());
     20      return Actions.SelectRandom(random);
    1721    }
    1822    public override void UpdateReward(int action, double reward) {
    1923      // do nothing
    2024    }
    21     public override void Reset() {
    22       // do nothing
    23     }
     25
    2426  }
    2527}
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCB1Policy.cs

    r11711 r11727  
    11using System;
    22using System.Collections.Generic;
     3using System.Diagnostics;
    34using System.Linq;
    45using System.Text;
     
    1213    public UCB1Policy(int numActions)
    1314      : base(numActions) {
    14       this.tries = new int[NumActions];
    15       this.sumReward = new double[NumActions];
     15      this.tries = new int[numActions];
     16      this.sumReward = new double[numActions];
    1617    }
    1718
     
    1920      int bestAction = -1;
    2021      double bestQ = double.NegativeInfinity;
    21       for (int i = 0; i < NumActions; i++) {
    22         if (tries[i] == 0) return i;
    23         var q = sumReward[i] / tries[i] + Math.Sqrt((2 * Math.Log(totalTries)) / tries[i]);
     22      foreach (var a in Actions) {
     23        if (tries[a] == 0) return a;
     24        var q = sumReward[a] / tries[a] + Math.Sqrt((2 * Math.Log(totalTries)) / tries[a]);
    2425        if (q > bestQ) {
    2526          bestQ = q;
    26           bestAction = i;
     27          bestAction = a;
    2728        }
    2829      }
     
    3031    }
    3132    public override void UpdateReward(int action, double reward) {
     33      Debug.Assert(Actions.Contains(action));
    3234      totalTries++;
    3335      tries[action]++;
    3436      sumReward[action] += reward;
    3537    }
     38
     39    public override void DisableAction(int action) {
     40      base.DisableAction(action);
     41      totalTries -= tries[action];
     42      tries[action] = -1;
     43      sumReward[action] = 0;
     44    }
     45
    3646    public override void Reset() {
     47      base.Reset();
    3748      totalTries = 0;
    3849      Array.Clear(tries, 0, tries.Length);
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCB1TunedPolicy.cs

    r11711 r11727  
    11using System;
    22using System.Collections.Generic;
     3using System.Diagnostics;
    34using System.Linq;
    45using System.Text;
     
    1314    public UCB1TunedPolicy(int numActions)
    1415      : base(numActions) {
    15       this.tries = new int[NumActions];
    16       this.sumReward = new double[NumActions];
    17       this.sumSqrReward = new double[NumActions];
     16      this.tries = new int[numActions];
     17      this.sumReward = new double[numActions];
     18      this.sumSqrReward = new double[numActions];
    1819    }
    1920
     
    2526
    2627    public override int SelectAction() {
     28      Debug.Assert(Actions.Any());
    2729      int bestAction = -1;
    2830      double bestQ = double.NegativeInfinity;
    29       for (int i = 0; i < NumActions; i++) {
    30         if (tries[i] == 0) return i;
    31         var q = sumReward[i] / tries[i] + Math.Sqrt((Math.Log(totalTries) / tries[i]) * Math.Min(1.0 / 4, V(i))); // 1/4 is upper bound of bernoulli distributed variable
     31      foreach (var a in Actions) {
     32        if (tries[a] == 0) return a;
     33        var q = sumReward[a] / tries[a] + Math.Sqrt((Math.Log(totalTries) / tries[a]) * Math.Min(1.0 / 4, V(a))); // 1/4 is upper bound of bernoulli distributed variable
    3234        if (q > bestQ) {
    3335          bestQ = q;
    34           bestAction = i;
     36          bestAction = a;
    3537        }
    3638      }
     
    3840    }
    3941    public override void UpdateReward(int action, double reward) {
     42      Debug.Assert(Actions.Contains(action));
    4043      totalTries++;
    4144      tries[action]++;
     
    4346      sumSqrReward[action] += reward * reward;
    4447    }
     48
     49    public override void DisableAction(int action) {
     50      base.DisableAction(action);
     51      totalTries -= tries[action];
     52      tries[action] = -1;
     53      sumReward[action] = 0;
     54      sumSqrReward[action] = 0;
     55    }
     56
    4557    public override void Reset() {
     58      base.Reset();
    4659      totalTries = 0;
    4760      Array.Clear(tries, 0, tries.Length);
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Policies/UCBNormalPolicy.cs

    r11711 r11727  
    11using System;
    22using System.Collections.Generic;
     3using System.Diagnostics;
    34using System.Linq;
    45using System.Text;
     
    1314    public UCBNormalPolicy(int numActions)
    1415      : base(numActions) {
    15       this.tries = new int[NumActions];
    16       this.sumReward = new double[NumActions];
    17       this.sumSqrReward = new double[NumActions];
     16      this.tries = new int[numActions];
     17      this.sumReward = new double[numActions];
     18      this.sumSqrReward = new double[numActions];
    1819    }
    1920
    20     private double V(int arm) {
    21       var s = tries[arm];
    22       return sumSqrReward[arm] / s - Math.Pow(sumReward[arm] / s, 2) + Math.Sqrt(2 * Math.Log(totalTries) / s);
    23     }
    24 
    25 
    2621    public override int SelectAction() {
     22      Debug.Assert(Actions.Any());
    2723      int bestAction = -1;
    2824      double bestQ = double.NegativeInfinity;
    29       for (int i = 0; i < NumActions; i++) {
    30         if (totalTries == 0 || tries[i] == 0 || tries[i] < Math.Ceiling(8 * Math.Log(totalTries))) return i;
    31         var avgReward = sumReward[i] / tries[i];
     25      foreach (var a in Actions) {
     26        if (totalTries == 0 || tries[a] == 0 || tries[a] < Math.Ceiling(8 * Math.Log(totalTries))) return a;
     27        var avgReward = sumReward[a] / tries[a];
    3228        var q = avgReward
    33           + Math.Sqrt(16 * ((sumSqrReward[i] - tries[i] * Math.Pow(avgReward, 2)) / (tries[i] - 1)) * (Math.Log(totalTries - 1) / tries[i]));
     29          + Math.Sqrt(16 * ((sumSqrReward[a] - tries[a] * Math.Pow(avgReward, 2)) / (tries[a] - 1)) * (Math.Log(totalTries - 1) / tries[a]));
    3430        if (q > bestQ) {
    3531          bestQ = q;
    36           bestAction = i;
     32          bestAction = a;
    3733        }
    3834      }
     
    4036    }
    4137    public override void UpdateReward(int action, double reward) {
     38      Debug.Assert(Actions.Contains(action));
    4239      totalTries++;
    4340      tries[action]++;
     
    4542      sumSqrReward[action] += reward * reward;
    4643    }
     44
     45    public override void DisableAction(int action) {
     46      base.DisableAction(action);
     47      totalTries -= tries[action];
     48      tries[action] = -1;
     49      sumReward[action] = 0;
     50      sumSqrReward[action] = 0;
     51    }
     52
    4753    public override void Reset() {
     54      base.Reset();
    4855      totalTries = 0;
    4956      Array.Clear(tries, 0, tries.Length);
Note: See TracChangeset for help on using the changeset viewer.