Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
01/12/15 21:23:01 (10 years ago)
Author:
gkronber
Message:

#2283: implemented test problems for MCTS

Location:
branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits
Files:
1 added
11 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/BernoulliPolicyActionInfo.cs

    r11742 r11747  
    99namespace HeuristicLab.Algorithms.Bandits.BanditPolicies {
    1010  public class BernoulliPolicyActionInfo : IBanditPolicyActionInfo {
     11    private double knownValue;
    1112    public bool Disabled { get { return NumSuccess == -1; } }
    1213    public int NumSuccess { get; private set; }
    1314    public int NumFailure { get; private set; }
    1415    public int Tries { get { return NumSuccess + NumFailure; } }
    15     public double Value { get { return NumSuccess / (double)(Tries); } }
     16    public double Value {
     17      get {
     18        if (Disabled) return knownValue;
     19        else
     20          return NumSuccess / (double)(Tries);
     21      }
     22    }
    1623    public void UpdateReward(double reward) {
    1724      Debug.Assert(!Disabled);
     
    2229      else NumFailure++;
    2330    }
    24     public void Disable() {
     31    public void Disable(double reward) {
    2532      this.NumSuccess = -1;
    2633      this.NumFailure = -1;
     34      this.knownValue = reward;
    2735    }
    2836    public void Reset() {
    2937      NumSuccess = 0;
    3038      NumFailure = 0;
     39      knownValue = 0.0;
    3140    }
    3241    public void PrintStats() {
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/BoltzmannExplorationPolicy.cs

    r11742 r11747  
    1313    private readonly Func<DefaultPolicyActionInfo, double> valueFunction;
    1414
    15     public BoltzmannExplorationPolicy(double eps) : this(eps, DefaultPolicyActionInfo.AverageReward) { }
     15    public BoltzmannExplorationPolicy(double beta) : this(beta, DefaultPolicyActionInfo.AverageReward) { }
    1616
    1717    public BoltzmannExplorationPolicy(double beta, Func<DefaultPolicyActionInfo, double> valueFunction) {
     
    2525      // select best
    2626      var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>();
    27       Debug.Assert(myActionInfos.Any(a => !a.Disabled));
     27
     28      // try any of the untries actions randomly
     29      // for RoyalSequence it is much better to select the actions in the order of occurrence (all terminal alternatives first)
     30      //if (myActionInfos.Any(aInfo => !aInfo.Disabled && aInfo.Tries == 0)) {
     31      //  return myActionInfos
     32      //  .Select((aInfo, idx) => new { aInfo, idx })
     33      //  .Where(p => !p.aInfo.Disabled)
     34      //  .Where(p => p.aInfo.Tries == 0)
     35      //  .SelectRandom(random).idx;
     36      //}
    2837
    2938      var w = from aInfo in myActionInfos
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/DefaultPolicyActionInfo.cs

    r11742 r11747  
    99  // stores information that is relevant for most of the policies
    1010  public class DefaultPolicyActionInfo : IBanditPolicyActionInfo {
     11    private double knownValue;
    1112    public bool Disabled { get { return Tries == -1; } }
    1213    public double SumReward { get; private set; }
    1314    public int Tries { get; private set; }
    1415    public double MaxReward { get; private set; }
    15     public double Value { get { return SumReward / Tries; } }
     16    public double Value {
     17      get {
     18        if (Disabled) return knownValue;
     19        else
     20          return Tries > 0 ? SumReward / Tries : 0.0;
     21      }
     22    }
    1623    public DefaultPolicyActionInfo() {
    1724      MaxReward = double.MinValue;
     
    2532      MaxReward = Math.Max(MaxReward, reward);
    2633    }
    27     public void Disable() {
     34    public void Disable(double reward) {
    2835      this.Tries = -1;
    2936      this.SumReward = 0.0;
     37      this.knownValue = reward;
    3038    }
    3139    public void Reset() {
     
    3341      Tries = 0;
    3442      MaxReward = 0.0;
     43      knownValue = 0.0;
    3544    }
    3645    public void PrintStats() {
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/MeanAndVariancePolicyActionInfo.cs

    r11742 r11747  
    1111    public bool Disabled { get { return disabled; } }
    1212    private OnlineMeanAndVarianceEstimator estimator = new OnlineMeanAndVarianceEstimator();
     13    private double knownValue;
    1314    public int Tries { get { return estimator.N; } }
    1415    public double SumReward { get { return estimator.Sum; } }
    1516    public double AvgReward { get { return estimator.Avg; } }
    1617    public double RewardVariance { get { return estimator.Variance; } }
    17     public double Value { get { return AvgReward; } }
     18    public double Value {
     19      get {
     20        if (disabled) return knownValue;
     21        else
     22          return AvgReward;
     23      }
     24    }
    1825
    1926    public void UpdateReward(double reward) {
     
    2229    }
    2330
    24     public void Disable() {
     31    public void Disable(double reward) {
    2532      disabled = true;
     33      this.knownValue = reward;
    2634    }
    2735
    2836    public void Reset() {
    2937      disabled = false;
     38      knownValue = 0.0;
    3039      estimator.Reset();
    3140    }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/ModelPolicyActionInfo.cs

    r11744 r11747  
    1010  public class ModelPolicyActionInfo : IBanditPolicyActionInfo {
    1111    private readonly IModel model;
     12    private double knownValue;
    1213    public bool Disabled { get { return Tries == -1; } }
    13     public double Value { get { return model.SampleExpectedReward(new Random()); } }
     14    public double Value {
     15      get {
     16        if (Disabled) return knownValue;
     17        else
     18          return model.SampleExpectedReward(new Random());
     19      }
     20    }
    1421
    1522    public int Tries { get; private set; }
     
    2835    }
    2936
    30     public void Disable() {
     37    public void Disable(double reward) {
    3138      this.Tries = -1;
     39      this.knownValue = reward;
    3240    }
    3341
    3442    public void Reset() {
    3543      Tries = 0;
     44      knownValue = 0.0;
    3645      model.Reset();
    3746    }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/ThresholdAscentPolicy.cs

    r11744 r11747  
    2828      public int Tries { get; private set; }
    2929      public int thresholdBin = 1;
    30       public double Value { get { return rewardHistogram[thresholdBin] / (double)Tries; } }
     30      private double knownValue;
     31
     32      public double Value {
     33        get {
     34          if (Disabled) return knownValue;
     35          if(Tries == 0.0) return 0.0;
     36          return rewardHistogram[thresholdBin] / (double)Tries;
     37        }
     38      }
    3139
    3240      public bool Disabled { get { return Tries == -1; } }
     
    3846      }
    3947
    40       public void Disable() {
     48      public void Disable(double reward) {
     49        this.knownValue = reward;
    4150        Tries = -1;
    4251      }
     
    4554        Tries = 0;
    4655        thresholdBin = 1;
     56        this.knownValue = 0.0;
    4757        Array.Clear(rewardHistogram, 0, rewardHistogram.Length);
    4858      }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/UCB1Policy.cs

    r11745 r11747  
    55using System.Text;
    66using System.Threading.Tasks;
     7using HeuristicLab.Common;
    78
    89namespace HeuristicLab.Algorithms.Bandits.BanditPolicies {
     
    1112    public int SelectAction(Random random, IEnumerable<IBanditPolicyActionInfo> actionInfos) {
    1213      var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>();
    13       int bestAction = -1;
    1414      double bestQ = double.NegativeInfinity;
    1515      int totalTries = myActionInfos.Where(a => !a.Disabled).Sum(a => a.Tries);
    1616
     17      var bestActions = new List<int>();
    1718      int aIdx = -1;
    1819      foreach (var aInfo in myActionInfos) {
    1920        aIdx++;
    2021        if (aInfo.Disabled) continue;
    21         if (aInfo.Tries == 0) return aIdx;
    22         var q = aInfo.SumReward / aInfo.Tries + Math.Sqrt((2 * Math.Log(totalTries)) / aInfo.Tries);
     22        double q;
     23        if (aInfo.Tries == 0) {
     24          q = double.PositiveInfinity;
     25        } else {
     26
     27          q = aInfo.SumReward / aInfo.Tries + 0.5 * Math.Sqrt((2 * Math.Log(totalTries)) / aInfo.Tries);
     28        }
    2329        if (q > bestQ) {
    2430          bestQ = q;
    25           bestAction = aIdx;
     31          bestActions.Clear();
     32          bestActions.Add(aIdx);
     33        } else if (q == bestQ) {
     34          bestActions.Add(aIdx);
    2635        }
    2736      }
    28       Debug.Assert(bestAction > -1);
    29       return bestAction;
     37      Debug.Assert(bestActions.Any());
     38      return bestActions.SelectRandom(random);
    3039    }
    3140
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/UCTPolicy.cs

    r11742 r11747  
    55using System.Text;
    66using System.Threading.Tasks;
     7using HeuristicLab.Common;
     8
    79namespace HeuristicLab.Algorithms.Bandits.BanditPolicies {
    810  /* Kocsis et al. Bandit based Monte-Carlo Planning */
     
    2224
    2325      int aIdx = -1;
     26      var bestActions = new List<int>();
    2427      foreach (var aInfo in myActionInfos) {
    2528        aIdx++;
    2629        if (aInfo.Disabled) continue;
    27         if (aInfo.Tries == 0) return aIdx;
    28         var q = aInfo.SumReward / aInfo.Tries + 2.0 * c * Math.Sqrt(Math.Log(totalTries) / aInfo.Tries);
     30        double q;
     31        if (aInfo.Tries == 0) {
     32          q = double.PositiveInfinity;
     33        } else {
     34          q = aInfo.SumReward / aInfo.Tries + 2.0 * c * Math.Sqrt(Math.Log(totalTries) / aInfo.Tries);
     35        }
    2936        if (q > bestQ) {
     37          bestActions.Clear();
    3038          bestQ = q;
    31           bestAction = aIdx;
     39          bestActions.Add(aIdx);
    3240        }
     41        if (q == bestQ) {
     42          bestActions.Add(aIdx);
     43        }
     44
    3345      }
    34       Debug.Assert(bestAction > -1);
    35       return bestAction;
     46      Debug.Assert(bestActions.Any());
     47      return bestActions.SelectRandom(random);
    3648    }
    3749
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/HeuristicLab.Algorithms.Bandits.csproj

    r11744 r11747  
    4848    <Compile Include="BanditPolicies\BoltzmannExplorationPolicy.cs" />
    4949    <Compile Include="BanditPolicies\ChernoffIntervalEstimationPolicy.cs" />
     50    <Compile Include="BanditPolicies\ActiveLearningPolicy.cs" />
    5051    <Compile Include="BanditPolicies\DefaultPolicyActionInfo.cs" />
    5152    <Compile Include="BanditPolicies\EpsGreedyPolicy.cs" />
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/IBanditPolicyActionInfo.cs

    r11742 r11747  
    1111    int Tries { get; }
    1212    void UpdateReward(double reward);
    13     void Disable();
     13    void Disable(double reward);
    1414    // reset causes the state of the action to be reinitialized (as after constructor-call)
    1515    void Reset();
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Models/GaussianMixtureModel.cs

    r11744 r11747  
    99namespace HeuristicLab.Algorithms.Bandits.Models {
    1010  public class GaussianMixtureModel : IModel {
    11     private readonly double[] componentMeans;
    12     private readonly double[] componentVars;
    13     private readonly double[] componentProbs;
     11    private double[] componentMeans;
     12    private double[] componentVars;
     13    private double[] componentProbs;
     14    private readonly List<double> allRewards = new List<double>();
    1415
    1516    private int numComponents;
     
    1718    public GaussianMixtureModel(int nComponents = 5) {
    1819      this.numComponents = nComponents;
    19       this.componentProbs = new double[nComponents];
    20       this.componentMeans = new double[nComponents];
    21       this.componentVars = new double[nComponents];
     20
     21      Reset();
    2222    }
    2323
     
    2929
    3030    public void Update(double reward) {
    31       // see http://www.cs.toronto.edu/~mackay/itprnn/ps/302.320.pdf Algorithm 22.2 soft k-means
    32       throw new NotImplementedException();
     31      allRewards.Add(reward);
     32      throw new NotSupportedException("this does not yet work");
     33      if (allRewards.Count < 1000 && allRewards.Count % 10 == 0) {
     34        // see http://www.cs.toronto.edu/~mackay/itprnn/ps/302.320.pdf Algorithm 22.2 soft k-means
     35        Reset();
     36        for (int i = 0; i < 20; i++) {
     37          var responsibilities = allRewards.Select(r => CalcResponsibility(r)).ToArray();
     38
     39
     40          var sumWeightedRewards = new double[numComponents];
     41          var sumResponsibilities = new double[numComponents];
     42          foreach (var p in allRewards.Zip(responsibilities, Tuple.Create)) {
     43            for (int k = 0; k < numComponents; k++) {
     44              sumWeightedRewards[k] += p.Item2[k] * p.Item1;
     45              sumResponsibilities[k] += p.Item2[k];
     46            }
     47          }
     48          for (int k = 0; k < numComponents; k++) {
     49            componentMeans[k] = sumWeightedRewards[k] / sumResponsibilities[k];
     50          }
     51
     52          sumWeightedRewards = new double[numComponents];
     53          foreach (var p in allRewards.Zip(responsibilities, Tuple.Create)) {
     54            for (int k = 0; k < numComponents; k++) {
     55              sumWeightedRewards[k] += p.Item2[k] * Math.Pow(p.Item1 - componentMeans[k], 2);
     56            }
     57          }
     58          for (int k = 0; k < numComponents; k++) {
     59            componentVars[k] = sumWeightedRewards[k] / sumResponsibilities[k];
     60            componentProbs[k] = sumResponsibilities[k] / sumResponsibilities.Sum();
     61          }
     62        }
     63      }
     64    }
     65
     66    private double[] CalcResponsibility(double r) {
     67      var res = new double[numComponents];
     68      for (int k = 0; k < numComponents; k++) {
     69        componentVars[k] = Math.Max(componentVars[k], 0.001);
     70        res[k] = componentProbs[k] * alglib.normaldistribution((r - componentMeans[k]) / Math.Sqrt(componentVars[k]));
     71        res[k] = Math.Max(res[k], 0.0001);
     72      }
     73      var sum = res.Sum();
     74      for (int k = 0; k < numComponents; k++) {
     75        res[k] /= sum;
     76      }
     77      return res;
    3378    }
    3479
     
    4489
    4590    public void Reset() {
    46       Array.Clear(componentMeans, 0, numComponents);
    47       Array.Clear(componentVars, 0, numComponents);
    48       Array.Clear(componentProbs, 0, numComponents);
     91      var rand = new Random();
     92      this.componentProbs = Enumerable.Range(0, numComponents).Select((_) => rand.NextDouble()).ToArray();
     93      var sum = componentProbs.Sum();
     94      for (int i = 0; i < componentProbs.Length; i++) componentProbs[i] /= sum;
     95      this.componentMeans = Enumerable.Range(0, numComponents).Select((_) => Rand.RandNormal(rand)).ToArray();
     96      this.componentVars = Enumerable.Range(0, numComponents).Select((_) => 0.01).ToArray();
    4997    }
    5098
Note: See TracChangeset for help on using the changeset viewer.