Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
01/12/15 21:23:01 (10 years ago)
Author:
gkronber
Message:

#2283: implemented test problems for MCTS

Location:
branches/HeuristicLab.Problems.GrammaticalOptimization
Files:
3 added
21 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/BernoulliPolicyActionInfo.cs

    r11742 r11747  
    99namespace HeuristicLab.Algorithms.Bandits.BanditPolicies {
    1010  public class BernoulliPolicyActionInfo : IBanditPolicyActionInfo {
     11    private double knownValue;
    1112    public bool Disabled { get { return NumSuccess == -1; } }
    1213    public int NumSuccess { get; private set; }
    1314    public int NumFailure { get; private set; }
    1415    public int Tries { get { return NumSuccess + NumFailure; } }
    15     public double Value { get { return NumSuccess / (double)(Tries); } }
     16    public double Value {
     17      get {
     18        if (Disabled) return knownValue;
     19        else
     20          return NumSuccess / (double)(Tries);
     21      }
     22    }
    1623    public void UpdateReward(double reward) {
    1724      Debug.Assert(!Disabled);
     
    2229      else NumFailure++;
    2330    }
    24     public void Disable() {
     31    public void Disable(double reward) {
    2532      this.NumSuccess = -1;
    2633      this.NumFailure = -1;
     34      this.knownValue = reward;
    2735    }
    2836    public void Reset() {
    2937      NumSuccess = 0;
    3038      NumFailure = 0;
     39      knownValue = 0.0;
    3140    }
    3241    public void PrintStats() {
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/BoltzmannExplorationPolicy.cs

    r11742 r11747  
    1313    private readonly Func<DefaultPolicyActionInfo, double> valueFunction;
    1414
    15     public BoltzmannExplorationPolicy(double eps) : this(eps, DefaultPolicyActionInfo.AverageReward) { }
     15    public BoltzmannExplorationPolicy(double beta) : this(beta, DefaultPolicyActionInfo.AverageReward) { }
    1616
    1717    public BoltzmannExplorationPolicy(double beta, Func<DefaultPolicyActionInfo, double> valueFunction) {
     
    2525      // select best
    2626      var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>();
    27       Debug.Assert(myActionInfos.Any(a => !a.Disabled));
     27
     28      // try any of the untries actions randomly
     29      // for RoyalSequence it is much better to select the actions in the order of occurrence (all terminal alternatives first)
     30      //if (myActionInfos.Any(aInfo => !aInfo.Disabled && aInfo.Tries == 0)) {
     31      //  return myActionInfos
     32      //  .Select((aInfo, idx) => new { aInfo, idx })
     33      //  .Where(p => !p.aInfo.Disabled)
     34      //  .Where(p => p.aInfo.Tries == 0)
     35      //  .SelectRandom(random).idx;
     36      //}
    2837
    2938      var w = from aInfo in myActionInfos
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/DefaultPolicyActionInfo.cs

    r11742 r11747  
    99  // stores information that is relevant for most of the policies
    1010  public class DefaultPolicyActionInfo : IBanditPolicyActionInfo {
     11    private double knownValue;
    1112    public bool Disabled { get { return Tries == -1; } }
    1213    public double SumReward { get; private set; }
    1314    public int Tries { get; private set; }
    1415    public double MaxReward { get; private set; }
    15     public double Value { get { return SumReward / Tries; } }
     16    public double Value {
     17      get {
     18        if (Disabled) return knownValue;
     19        else
     20          return Tries > 0 ? SumReward / Tries : 0.0;
     21      }
     22    }
    1623    public DefaultPolicyActionInfo() {
    1724      MaxReward = double.MinValue;
     
    2532      MaxReward = Math.Max(MaxReward, reward);
    2633    }
    27     public void Disable() {
     34    public void Disable(double reward) {
    2835      this.Tries = -1;
    2936      this.SumReward = 0.0;
     37      this.knownValue = reward;
    3038    }
    3139    public void Reset() {
     
    3341      Tries = 0;
    3442      MaxReward = 0.0;
     43      knownValue = 0.0;
    3544    }
    3645    public void PrintStats() {
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/MeanAndVariancePolicyActionInfo.cs

    r11742 r11747  
    1111    public bool Disabled { get { return disabled; } }
    1212    private OnlineMeanAndVarianceEstimator estimator = new OnlineMeanAndVarianceEstimator();
     13    private double knownValue;
    1314    public int Tries { get { return estimator.N; } }
    1415    public double SumReward { get { return estimator.Sum; } }
    1516    public double AvgReward { get { return estimator.Avg; } }
    1617    public double RewardVariance { get { return estimator.Variance; } }
    17     public double Value { get { return AvgReward; } }
     18    public double Value {
     19      get {
     20        if (disabled) return knownValue;
     21        else
     22          return AvgReward;
     23      }
     24    }
    1825
    1926    public void UpdateReward(double reward) {
     
    2229    }
    2330
    24     public void Disable() {
     31    public void Disable(double reward) {
    2532      disabled = true;
     33      this.knownValue = reward;
    2634    }
    2735
    2836    public void Reset() {
    2937      disabled = false;
     38      knownValue = 0.0;
    3039      estimator.Reset();
    3140    }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/ModelPolicyActionInfo.cs

    r11744 r11747  
    1010  public class ModelPolicyActionInfo : IBanditPolicyActionInfo {
    1111    private readonly IModel model;
     12    private double knownValue;
    1213    public bool Disabled { get { return Tries == -1; } }
    13     public double Value { get { return model.SampleExpectedReward(new Random()); } }
     14    public double Value {
     15      get {
     16        if (Disabled) return knownValue;
     17        else
     18          return model.SampleExpectedReward(new Random());
     19      }
     20    }
    1421
    1522    public int Tries { get; private set; }
     
    2835    }
    2936
    30     public void Disable() {
     37    public void Disable(double reward) {
    3138      this.Tries = -1;
     39      this.knownValue = reward;
    3240    }
    3341
    3442    public void Reset() {
    3543      Tries = 0;
     44      knownValue = 0.0;
    3645      model.Reset();
    3746    }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/ThresholdAscentPolicy.cs

    r11744 r11747  
    2828      public int Tries { get; private set; }
    2929      public int thresholdBin = 1;
    30       public double Value { get { return rewardHistogram[thresholdBin] / (double)Tries; } }
     30      private double knownValue;
     31
     32      public double Value {
     33        get {
     34          if (Disabled) return knownValue;
     35          if(Tries == 0.0) return 0.0;
     36          return rewardHistogram[thresholdBin] / (double)Tries;
     37        }
     38      }
    3139
    3240      public bool Disabled { get { return Tries == -1; } }
     
    3846      }
    3947
    40       public void Disable() {
     48      public void Disable(double reward) {
     49        this.knownValue = reward;
    4150        Tries = -1;
    4251      }
     
    4554        Tries = 0;
    4655        thresholdBin = 1;
     56        this.knownValue = 0.0;
    4757        Array.Clear(rewardHistogram, 0, rewardHistogram.Length);
    4858      }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/UCB1Policy.cs

    r11745 r11747  
    55using System.Text;
    66using System.Threading.Tasks;
     7using HeuristicLab.Common;
    78
    89namespace HeuristicLab.Algorithms.Bandits.BanditPolicies {
     
    1112    public int SelectAction(Random random, IEnumerable<IBanditPolicyActionInfo> actionInfos) {
    1213      var myActionInfos = actionInfos.OfType<DefaultPolicyActionInfo>();
    13       int bestAction = -1;
    1414      double bestQ = double.NegativeInfinity;
    1515      int totalTries = myActionInfos.Where(a => !a.Disabled).Sum(a => a.Tries);
    1616
     17      var bestActions = new List<int>();
    1718      int aIdx = -1;
    1819      foreach (var aInfo in myActionInfos) {
    1920        aIdx++;
    2021        if (aInfo.Disabled) continue;
    21         if (aInfo.Tries == 0) return aIdx;
    22         var q = aInfo.SumReward / aInfo.Tries + Math.Sqrt((2 * Math.Log(totalTries)) / aInfo.Tries);
     22        double q;
     23        if (aInfo.Tries == 0) {
     24          q = double.PositiveInfinity;
     25        } else {
     26
     27          q = aInfo.SumReward / aInfo.Tries + 0.5 * Math.Sqrt((2 * Math.Log(totalTries)) / aInfo.Tries);
     28        }
    2329        if (q > bestQ) {
    2430          bestQ = q;
    25           bestAction = aIdx;
     31          bestActions.Clear();
     32          bestActions.Add(aIdx);
     33        } else if (q == bestQ) {
     34          bestActions.Add(aIdx);
    2635        }
    2736      }
    28       Debug.Assert(bestAction > -1);
    29       return bestAction;
     37      Debug.Assert(bestActions.Any());
     38      return bestActions.SelectRandom(random);
    3039    }
    3140
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/BanditPolicies/UCTPolicy.cs

    r11742 r11747  
    55using System.Text;
    66using System.Threading.Tasks;
     7using HeuristicLab.Common;
     8
    79namespace HeuristicLab.Algorithms.Bandits.BanditPolicies {
    810  /* Kocsis et al. Bandit based Monte-Carlo Planning */
     
    2224
    2325      int aIdx = -1;
     26      var bestActions = new List<int>();
    2427      foreach (var aInfo in myActionInfos) {
    2528        aIdx++;
    2629        if (aInfo.Disabled) continue;
    27         if (aInfo.Tries == 0) return aIdx;
    28         var q = aInfo.SumReward / aInfo.Tries + 2.0 * c * Math.Sqrt(Math.Log(totalTries) / aInfo.Tries);
     30        double q;
     31        if (aInfo.Tries == 0) {
     32          q = double.PositiveInfinity;
     33        } else {
     34          q = aInfo.SumReward / aInfo.Tries + 2.0 * c * Math.Sqrt(Math.Log(totalTries) / aInfo.Tries);
     35        }
    2936        if (q > bestQ) {
     37          bestActions.Clear();
    3038          bestQ = q;
    31           bestAction = aIdx;
     39          bestActions.Add(aIdx);
    3240        }
     41        if (q == bestQ) {
     42          bestActions.Add(aIdx);
     43        }
     44
    3345      }
    34       Debug.Assert(bestAction > -1);
    35       return bestAction;
     46      Debug.Assert(bestActions.Any());
     47      return bestActions.SelectRandom(random);
    3648    }
    3749
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/HeuristicLab.Algorithms.Bandits.csproj

    r11744 r11747  
    4848    <Compile Include="BanditPolicies\BoltzmannExplorationPolicy.cs" />
    4949    <Compile Include="BanditPolicies\ChernoffIntervalEstimationPolicy.cs" />
     50    <Compile Include="BanditPolicies\ActiveLearningPolicy.cs" />
    5051    <Compile Include="BanditPolicies\DefaultPolicyActionInfo.cs" />
    5152    <Compile Include="BanditPolicies\EpsGreedyPolicy.cs" />
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/IBanditPolicyActionInfo.cs

    r11742 r11747  
    1111    int Tries { get; }
    1212    void UpdateReward(double reward);
    13     void Disable();
     13    void Disable(double reward);
    1414    // reset causes the state of the action to be reinitialized (as after constructor-call)
    1515    void Reset();
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.Bandits/Models/GaussianMixtureModel.cs

    r11744 r11747  
    99namespace HeuristicLab.Algorithms.Bandits.Models {
    1010  public class GaussianMixtureModel : IModel {
    11     private readonly double[] componentMeans;
    12     private readonly double[] componentVars;
    13     private readonly double[] componentProbs;
     11    private double[] componentMeans;
     12    private double[] componentVars;
     13    private double[] componentProbs;
     14    private readonly List<double> allRewards = new List<double>();
    1415
    1516    private int numComponents;
     
    1718    public GaussianMixtureModel(int nComponents = 5) {
    1819      this.numComponents = nComponents;
    19       this.componentProbs = new double[nComponents];
    20       this.componentMeans = new double[nComponents];
    21       this.componentVars = new double[nComponents];
     20
     21      Reset();
    2222    }
    2323
     
    2929
    3030    public void Update(double reward) {
    31       // see http://www.cs.toronto.edu/~mackay/itprnn/ps/302.320.pdf Algorithm 22.2 soft k-means
    32       throw new NotImplementedException();
     31      allRewards.Add(reward);
     32      throw new NotSupportedException("this does not yet work");
     33      if (allRewards.Count < 1000 && allRewards.Count % 10 == 0) {
     34        // see http://www.cs.toronto.edu/~mackay/itprnn/ps/302.320.pdf Algorithm 22.2 soft k-means
     35        Reset();
     36        for (int i = 0; i < 20; i++) {
     37          var responsibilities = allRewards.Select(r => CalcResponsibility(r)).ToArray();
     38
     39
     40          var sumWeightedRewards = new double[numComponents];
     41          var sumResponsibilities = new double[numComponents];
     42          foreach (var p in allRewards.Zip(responsibilities, Tuple.Create)) {
     43            for (int k = 0; k < numComponents; k++) {
     44              sumWeightedRewards[k] += p.Item2[k] * p.Item1;
     45              sumResponsibilities[k] += p.Item2[k];
     46            }
     47          }
     48          for (int k = 0; k < numComponents; k++) {
     49            componentMeans[k] = sumWeightedRewards[k] / sumResponsibilities[k];
     50          }
     51
     52          sumWeightedRewards = new double[numComponents];
     53          foreach (var p in allRewards.Zip(responsibilities, Tuple.Create)) {
     54            for (int k = 0; k < numComponents; k++) {
     55              sumWeightedRewards[k] += p.Item2[k] * Math.Pow(p.Item1 - componentMeans[k], 2);
     56            }
     57          }
     58          for (int k = 0; k < numComponents; k++) {
     59            componentVars[k] = sumWeightedRewards[k] / sumResponsibilities[k];
     60            componentProbs[k] = sumResponsibilities[k] / sumResponsibilities.Sum();
     61          }
     62        }
     63      }
     64    }
     65
     66    private double[] CalcResponsibility(double r) {
     67      var res = new double[numComponents];
     68      for (int k = 0; k < numComponents; k++) {
     69        componentVars[k] = Math.Max(componentVars[k], 0.001);
     70        res[k] = componentProbs[k] * alglib.normaldistribution((r - componentMeans[k]) / Math.Sqrt(componentVars[k]));
     71        res[k] = Math.Max(res[k], 0.0001);
     72      }
     73      var sum = res.Sum();
     74      for (int k = 0; k < numComponents; k++) {
     75        res[k] /= sum;
     76      }
     77      return res;
    3378    }
    3479
     
    4489
    4590    public void Reset() {
    46       Array.Clear(componentMeans, 0, numComponents);
    47       Array.Clear(componentVars, 0, numComponents);
    48       Array.Clear(componentProbs, 0, numComponents);
     91      var rand = new Random();
     92      this.componentProbs = Enumerable.Range(0, numComponents).Select((_) => rand.NextDouble()).ToArray();
     93      var sum = componentProbs.Sum();
     94      for (int i = 0; i < componentProbs.Length; i++) componentProbs[i] /= sum;
     95      this.componentMeans = Enumerable.Range(0, numComponents).Select((_) => Rand.RandNormal(rand)).ToArray();
     96      this.componentVars = Enumerable.Range(0, numComponents).Select((_) => 0.01).ToArray();
    4997    }
    5098
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization.csproj

    r11744 r11747  
    4545    <Compile Include="AlternativesSampler.cs" />
    4646    <Compile Include="AlternativesContextSampler.cs" />
     47    <Compile Include="MctsQLearningSampler.cs" />
    4748    <Compile Include="TemporalDifferenceTreeSearchSampler.cs" />
    4849    <Compile Include="ExhaustiveRandomFirstSearch.cs" />
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/MctsContextualSampler.cs

    r11745 r11747  
    1515      public int randomTries;
    1616      public int tries;
     17      public List<TreeNode> parents;
    1718      public TreeNode[] children;
    1819      public bool done = false;
     
    2122        this.ident = id;
    2223        this.alt = alt;
     24        this.parents = new List<TreeNode>();
    2325      }
    2426
     
    2830    }
    2931
     32    private Dictionary<string, TreeNode> treeNodes;
     33    private TreeNode GetTreeNode(string id, ReadonlySequence alt) {
     34      TreeNode n;
     35      var canonicalId = problem.CanonicalRepresentation(id);
     36      if (!treeNodes.TryGetValue(canonicalId, out n)) {
     37        n = new TreeNode(canonicalId, alt);
     38        tries.TryGetValue(canonicalId, out n.tries);
     39        treeNodes[canonicalId] = n;
     40      }
     41      return n;
     42    }
    3043
    3144    public event Action<string, double> FoundNewBestSolution;
     
    5164      this.v = new Dictionary<string, double>(1000000);
    5265      this.tries = new Dictionary<string, int>(1000000);
     66      treeNodes = new Dictionary<string, TreeNode>();
    5367    }
    5468
     
    5771      InitPolicies(problem.Grammar);
    5872      for (int i = 0; !rootNode.done && i < maxIterations; i++) {
    59         var sentence = SampleSentence(problem.Grammar).ToString();
    60         var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen);
    61         Debug.Assert(quality >= 0 && quality <= 1.0);
    62         DistributeReward(quality);
    63 
    64         RaiseSolutionEvaluated(sentence, quality);
    65 
    66         if (quality > bestQuality) {
    67           bestQuality = quality;
    68           RaiseFoundNewBestSolution(sentence, quality);
     73        bool success;
     74        var sentence = SampleSentence(problem.Grammar, out success).ToString();
     75        if (success) {
     76          var quality = problem.Evaluate(sentence) / problem.BestKnownQuality(maxLen);
     77          Debug.Assert(quality >= 0 && quality <= 1.0);
     78          DistributeReward(quality);
     79
     80          RaiseSolutionEvaluated(sentence, quality);
     81
     82          if (quality > bestQuality) {
     83            bestQuality = quality;
     84            RaiseFoundNewBestSolution(sentence, quality);
     85          }
    6986        }
    7087      }
     
    7895      Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}, rootQ {3:F3}, bestQ {4:F3}", treeDepth, treeSize, n.tries, V(n), bestQuality);
    7996      while (n.children != null) {
    80         Console.WriteLine("{0}", n.ident);
    81         double maxVForRow = n.children.Select(ch => V(ch)).Max();
     97        Console.WriteLine("{0,-30}", n.ident);
     98        double maxVForRow = n.children.Select(ch => Math.Min(1.0, Math.Max(0.0, V(ch)))).Max();
    8299        if (maxVForRow == 0) maxVForRow = 1.0;
    83100
    84101        for (int i = 0; i < n.children.Length; i++) {
    85102          var ch = n.children[i];
    86           Console.ForegroundColor = ConsoleEx.ColorForValue(V(ch) / maxVForRow);
     103          Console.ForegroundColor = ConsoleEx.ColorForValue(Math.Min(1.0, V(ch)) / maxVForRow);
    87104          Console.Write("{0,5}", ch.alt);
    88105        }
     
    90107        for (int i = 0; i < n.children.Length; i++) {
    91108          var ch = n.children[i];
    92           Console.ForegroundColor = ConsoleEx.ColorForValue(V(ch) / maxVForRow);
    93           Console.Write("{0,5:F2}", V(ch) * 10);
     109          Console.ForegroundColor = ConsoleEx.ColorForValue(Math.Min(1.0, V(ch)) / maxVForRow);
     110          Console.Write("{0,5:F2}", Math.Min(1.0, V(ch)) * 10);
    94111        }
    95112        Console.WriteLine();
    96113        for (int i = 0; i < n.children.Length; i++) {
    97114          var ch = n.children[i];
    98           Console.ForegroundColor = ConsoleEx.ColorForValue(V(ch) / maxVForRow);
     115          Console.ForegroundColor = ConsoleEx.ColorForValue(Math.Min(1.0, V(ch)) / maxVForRow);
    99116          Console.Write("{0,5}", ch.done ? "X" : ch.tries.ToString());
    100117        }
     
    102119        Console.WriteLine();
    103120        //n.policy.PrintStats();
    104         n = n.children.Where(ch => !ch.done).OrderByDescending(c => V(c)).First();
     121        n = n.children.Where(ch => !ch.done).OrderByDescending(c => c.tries).First();
    105122      }
    106123    }
     
    112129      this.tries.Clear();
    113130
    114       rootNode = new TreeNode(grammar.SentenceSymbol.ToString(), new ReadonlySequence("$"));
     131      rootNode = GetTreeNode(grammar.SentenceSymbol.ToString(), new ReadonlySequence("$"));
    115132      treeDepth = 0;
    116133      treeSize = 0;
    117134    }
    118135
    119     private Sequence SampleSentence(IGrammar grammar) {
     136    private Sequence SampleSentence(IGrammar grammar, out bool success) {
    120137      updateChain.Clear();
    121138      //var startPhrase = new Sequence("a*b+c*d+e*f+E");
    122139      var startPhrase = new Sequence(grammar.SentenceSymbol);
    123       return CompleteSentence(grammar, startPhrase);
    124     }
    125 
    126     private Sequence CompleteSentence(IGrammar g, Sequence phrase) {
     140      return CompleteSentence(grammar, startPhrase, out success);
     141    }
     142
     143    private Sequence CompleteSentence(IGrammar g, Sequence phrase, out bool success) {
    127144      if (phrase.Length > maxLen) throw new ArgumentException();
    128145      if (g.MinPhraseLength(phrase) > maxLen) throw new ArgumentException();
     
    136153          n.randomTries++;
    137154          treeDepth = Math.Max(treeDepth, curDepth);
     155          success = true;
    138156          return g.CompleteSentenceRandomly(random, phrase, maxLen);
    139157        } else {
     
    153171              newPhrase.ReplaceAt(newPhrase.FirstNonTerminalIndex, 1, alt);
    154172              if (!newPhrase.IsTerminal) newPhrase = newPhrase.Subsequence(0, newPhrase.FirstNonTerminalIndex + 1);
    155               n.children[i++] = new TreeNode(newPhrase.ToString(), new ReadonlySequence(alt));
     173              var treeNode = GetTreeNode(newPhrase.ToString(), new ReadonlySequence(alt));
     174              treeNode.parents.Add(n);
     175              n.children[i++] = treeNode;
    156176            }
    157177            treeSize += n.children.Length;
     178            UpdateDone(n);
     179
     180            // it could happend that we already finished all variations starting from the branch
     181            // stop
     182            if (n.done) {
     183              success = false;
     184              return phrase;
     185            }
    158186          }
     187          //int selectedAltIdx = SelectRandom(random, n.children);
     188
    159189          // => select using eps-greedy
    160190          int selectedAltIdx = SelectEpsGreedy(random, n.children);
     
    167197
    168198          curDepth++;
     199
    169200
    170201          // prepare for next iteration
    171202          parent = n;
    172203          n = n.children[selectedAltIdx];
     204          //UpdateTD(parent, n, 0.0);
    173205        }
    174206      } // while
     
    181213
    182214      treeDepth = Math.Max(treeDepth, curDepth);
     215      success = true;
    183216      return phrase;
    184217    }
    185218
     219
     220    //private void UpdateTD(TreeNode parent, TreeNode child, double reward) {
     221    //  double alpha = 1.0;
     222    //  var vParent = V(parent);
     223    //  var vChild = V(child);
     224    //  if (double.IsInfinity(vParent)) vParent = 0.0;
     225    //  if (double.IsInfinity(vChild)) vChild = 0.0;
     226    //  UpdateV(parent, (alpha * (reward + vChild - vParent)));
     227    //}
     228
    186229    private void DistributeReward(double reward) {
     230
    187231      // iterate in reverse order (bottom up)
    188       updateChain.Reverse();
    189 
     232      //updateChain.Reverse();
     233      UpdateDone(updateChain.Last().Item1);
     234      //UpdateTD(updateChain.Last().Item2, updateChain.Last().Item1, reward);
     235      //return;
     236
     237      BackPropReward(updateChain.Last().Item1, reward);
     238      /*
    190239      foreach (var e in updateChain) {
    191240        var node = e.Item1;
    192         var parent = e.Item2;
     241        //var parent = e.Item2;
    193242        node.tries++;
    194         if (node.children != null && node.children.All(c => c.done)) {
    195           node.done = true;
    196         }
     243        //if (node.children != null && node.children.All(c => c.done)) {
     244        //  node.done = true;
     245        //}
    197246        UpdateV(node, reward);
    198247
    199248        // the reward for the parent is either the just recieved reward or the value of the best action so far
    200         double value = 0.0;
    201         if (parent != null) {
    202           var doneChilds = parent.children.Where(ch => ch.done);
    203           if (doneChilds.Any()) value = doneChilds.Select(ch => V(ch)).Max();
    204         }
     249        //double value = 0.0;
     250        //if (parent != null) {
     251        //  var doneChilds = parent.children.Where(ch => ch.done);
     252        //  if (doneChilds.Any()) value = doneChilds.Select(ch => V(ch)).Max();
     253        //}
    205254        //if (value > reward) reward = value;
    206       }
    207     }
     255      }*/
     256    }
     257
     258    private void BackPropReward(TreeNode n, double reward) {
     259      n.tries++;
     260      UpdateV(n, reward);
     261      foreach (var p in n.parents) BackPropReward(p, reward);
     262    }
     263
     264    private void UpdateDone(TreeNode n) {
     265      if (!n.done && n.children != null && n.children.All(c => c.done)) n.done = true;
     266      if (n.done) foreach (var p in n.parents) UpdateDone(p);
     267    }
     268
    208269
    209270    private Dictionary<string, double> v;
     
    219280        tries.Add(canonicalStr, 1);
    220281      } else {
    221         v[canonicalStr] = stateV + 0.005 * (reward - stateV);
    222         //v[canonicalStr] = stateV + (1.0 / tries[canonicalStr]) * (reward - stateV);
     282        //v[canonicalStr] = stateV + 0.005 * (reward - stateV);
     283        v[canonicalStr] = stateV + (1.0 / tries[canonicalStr]) * (reward - stateV);
    223284        tries[canonicalStr]++;
    224285      }
     
    229290      //var canonicalStr = n.ident;
    230291      double stateV;
    231 
     292      if (!tries.ContainsKey(canonicalStr)) return double.PositiveInfinity;
    232293      if (!v.TryGetValue(canonicalStr, out  stateV)) {
    233294        return 0.0;
     
    237298    }
    238299
     300    private int SelectRandom(Random random, TreeNode[] children) {
     301      return children.Select((ch, i) => Tuple.Create(ch, i)).Where(p => !p.Item1.done).SelectRandom(random).Item2;
     302    }
     303
    239304    private int SelectEpsGreedy(Random random, TreeNode[] children) {
    240305      if (random.NextDouble() < 0.2) {
    241 
    242         return children.Select((ch, i) => Tuple.Create(ch, i)).Where(p => !p.Item1.done).SelectRandom(random).Item2;
     306        return SelectRandom(random, children);
    243307      } else {
    244308        var bestQ = double.NegativeInfinity;
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/MctsSampler.cs

    r11745 r11747  
    55using System.Text;
    66using HeuristicLab.Algorithms.Bandits;
     7using HeuristicLab.Common;
    78using HeuristicLab.Problems.GrammaticalOptimization;
    89
     
    1314      public int randomTries;
    1415      public IBanditPolicyActionInfo actionInfo;
     16      public TreeNode parent;
    1517      public TreeNode[] children;
    1618      public bool done = false;
    1719
    18       public TreeNode(string id) {
     20      public TreeNode(string id, TreeNode parent) {
    1921        this.ident = id;
     22        this.parent = parent;
    2023      }
    2124
     
    3538    private readonly IBanditPolicy policy;
    3639
    37     private List<TreeNode> updateChain;
     40    private TreeNode lastNode; // the bottom node in one episode
    3841    private TreeNode rootNode;
    3942
     
    7578      Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}, rootQ {3:F3}, bestQ {4:F3}", treeDepth, treeSize, n.actionInfo.Tries, n.actionInfo.Value, bestQuality);
    7679      while (n.children != null) {
     80        Console.WriteLine("{0,-30}", n.ident);
     81        double maxVForRow = n.children.Select(ch => ch.actionInfo.Value).Max();
     82        if (maxVForRow == 0) maxVForRow = 1.0;
     83
     84        for (int i = 0; i < n.children.Length; i++) {
     85          var ch = n.children[i];
     86          SetColorForChild(ch, maxVForRow);
     87          Console.Write("{0,5}", ch.ident);
     88        }
    7789        Console.WriteLine();
    78         Console.WriteLine("{0,5}->{1,-50}", n.ident, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.ident))));
    79         Console.WriteLine("{0,5}  {1,-50}", string.Empty, string.Join(" ", n.children.Select(ch => string.Format("{0,4:F2}", ch.actionInfo.Value * 10))));
    80         Console.WriteLine("{0,5}  {1,-50}", string.Empty, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.done ? "X" : ch.actionInfo.Tries.ToString()))));
     90        for (int i = 0; i < n.children.Length; i++) {
     91          var ch = n.children[i];
     92          SetColorForChild(ch, maxVForRow);
     93          Console.Write("{0,5:F2}", ch.actionInfo.Value * 10);
     94        }
     95        Console.WriteLine();
     96        for (int i = 0; i < n.children.Length; i++) {
     97          var ch = n.children[i];
     98          SetColorForChild(ch, maxVForRow);
     99          Console.Write("{0,5}", ch.done ? "X" : ch.actionInfo.Tries.ToString());
     100        }
     101        Console.ForegroundColor = ConsoleColor.White;
     102        Console.WriteLine();
    81103        //n.policy.PrintStats();
    82         n = n.children.Where(ch => !ch.done).OrderByDescending(c => c.actionInfo.Value).First();
    83       }
     104        //n = n.children.Where(ch => !ch.done).OrderByDescending(c => c.actionInfo.Value).First();
     105        n = n.children.Where(ch=>!ch.done).OrderByDescending(c => c.actionInfo.Value).First();
     106      }
     107      Console.WriteLine("-----------------------");
     108    }
     109
     110    private void SetColorForChild(TreeNode ch, double maxVForRow) {
     111      //if (ch.done) Console.ForegroundColor = ConsoleColor.White;
     112      //else
     113      Console.ForegroundColor = ConsoleEx.ColorForValue(ch.actionInfo.Value / maxVForRow);
    84114    }
    85115
    86116    private void InitPolicies(IGrammar grammar) {
    87       this.updateChain = new List<TreeNode>();
    88 
    89       rootNode = new TreeNode(grammar.SentenceSymbol.ToString());
     117
     118
     119      rootNode = new TreeNode(grammar.SentenceSymbol.ToString(), null);
    90120      rootNode.actionInfo = policy.CreateActionInfo();
    91121      treeDepth = 0;
     
    94124
    95125    private Sequence SampleSentence(IGrammar grammar) {
    96       updateChain.Clear();
     126      lastNode = null;
    97127      var startPhrase = new Sequence(grammar.SentenceSymbol);
     128      //var startPhrase = new Sequence("a*b+c*d+e*f+E");
     129
    98130      return CompleteSentence(grammar, startPhrase);
    99131    }
     
    105137      var curDepth = 0;
    106138      while (!phrase.IsTerminal) {
    107         updateChain.Add(n);
    108139
    109140        if (n.randomTries < randomTries) {
    110141          n.randomTries++;
    111142          treeDepth = Math.Max(treeDepth, curDepth);
     143          lastNode = n;
    112144          return g.CompleteSentenceRandomly(random, phrase, maxLen);
    113145        } else {
     
    120152
    121153          if (n.randomTries == randomTries && n.children == null) {
    122             n.children = alts.Select(alt => new TreeNode(alt.ToString())).ToArray(); // create a new node for each alternative
     154            n.children = alts.Select(alt => new TreeNode(alt.ToString(), n)).ToArray(); // create a new node for each alternative
    123155            foreach (var ch in n.children) ch.actionInfo = policy.CreateActionInfo();
    124156            treeSize += n.children.Length;
     
    138170      } // while
    139171
    140       updateChain.Add(n);
     172      lastNode = n;
    141173
    142174
     
    150182    private void DistributeReward(double reward) {
    151183      // iterate in reverse order (bottom up)
    152       updateChain.Reverse();
    153 
    154       foreach (var e in updateChain) {
    155         var node = e;
    156         if (node.done) node.actionInfo.Disable();
     184
     185      var node = lastNode;
     186      while (node != null) {
     187        if (node.done) node.actionInfo.Disable(reward);
    157188        if (node.children != null && node.children.All(c => c.done)) {
    158189          node.done = true;
    159           node.actionInfo.Disable();
     190          var bestActionValue = node.children.Select(c => c.actionInfo.Value).Max();
     191          node.actionInfo.Disable(bestActionValue);
    160192        }
    161193        if (!node.done) {
    162194          node.actionInfo.UpdateReward(reward);
    163195        }
     196        node = node.parent;
    164197      }
    165198    }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Algorithms.GrammaticalOptimization/TemporalDifferenceTreeSearchSampler.cs

    r11744 r11747  
    3636    private readonly Random random;
    3737    private readonly int randomTries;
    38     private readonly IBanditPolicy policy;
    3938
    4039    private List<TreeNode> updateChain;
     
    4645
    4746
    48     public TemporalDifferenceTreeSearchSampler(IProblem problem, int maxLen, Random random, int randomTries, IBanditPolicy policy) {
     47    public TemporalDifferenceTreeSearchSampler(IProblem problem, int maxLen, Random random, int randomTries) {
    4948      this.maxLen = maxLen;
    5049      this.problem = problem;
    5150      this.random = random;
    5251      this.randomTries = randomTries;
    53       this.policy = policy;
    5452    }
    5553
     
    7876      Console.WriteLine("depth: {0,5} size: {1,10} root tries {2,10}, rootQ {3:F3}, bestQ {4:F3}", treeDepth, treeSize, n.tries, n.q, bestQuality);
    7977      while (n.children != null) {
     78        Console.WriteLine("{0,-30}", n.ident);
     79        double maxVForRow = n.children.Select(ch => ch.q).Max();
     80        if (maxVForRow == 0) maxVForRow = 1.0;
     81
     82        for (int i = 0; i < n.children.Length; i++) {
     83          var ch = n.children[i];
     84          Console.ForegroundColor = ConsoleEx.ColorForValue(ch.q / maxVForRow);
     85          Console.Write("{0,5}", ch.ident);
     86        }
    8087        Console.WriteLine();
    81         Console.WriteLine("{0,5}->{1,-50}", n.ident, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.ident))));
    82         Console.WriteLine("{0,5}  {1,-50}", string.Empty, string.Join(" ", n.children.Select(ch => string.Format("{0,4:F2}", ch.q * 10))));
    83         Console.WriteLine("{0,5}  {1,-50}", string.Empty, string.Join(" ", n.children.Select(ch => string.Format("{0,4}", ch.done ? "X" : ch.tries.ToString()))));
     88        for (int i = 0; i < n.children.Length; i++) {
     89          var ch = n.children[i];
     90          Console.ForegroundColor = ConsoleEx.ColorForValue(ch.q / maxVForRow);
     91          Console.Write("{0,5:F2}", ch.q * 10);
     92        }
     93        Console.WriteLine();
     94        for (int i = 0; i < n.children.Length; i++) {
     95          var ch = n.children[i];
     96          Console.ForegroundColor = ConsoleEx.ColorForValue(ch.q / maxVForRow);
     97          Console.Write("{0,5}", ch.done ? "X" : ch.tries.ToString());
     98        }
     99        Console.ForegroundColor = ConsoleColor.White;
     100        Console.WriteLine();
    84101        //n.policy.PrintStats();
    85102        n = n.children.Where(ch => !ch.done).OrderByDescending(c => c.q).First();
    86103      }
    87       //Console.ReadLine();
    88104    }
    89105
     
    127143          }
    128144          // => select using bandit policy
    129           int selectedAltIdx = SelectAction(random, n.children);
     145          int selectedAltIdx = SelectEpsGreedy(random, n.children);
    130146          Sequence selectedAlt = alts.ElementAt(selectedAltIdx);
    131147
     
    152168
    153169    // eps-greedy
    154     private int SelectAction(Random random, TreeNode[] children) {
     170    private int SelectEpsGreedy(Random random, TreeNode[] children) {
    155171      if (random.NextDouble() < 0.1) {
    156172
     
    158174      } else {
    159175        var bestQ = double.NegativeInfinity;
    160         var bestChildIdx = -1;
     176        var bestChildIdx = new List<int>();
    161177        for (int i = 0; i < children.Length; i++) {
    162178          if (children[i].done) continue;
    163           if (children[i].tries == 0) return i;
    164           if (children[i].q > bestQ) {
    165             bestQ = children[i].q;
    166             bestChildIdx = i;
     179          // if (children[i].tries == 0) return i;
     180          var q = children[i].q;
     181          if (q > bestQ) {
     182            bestQ = q;
     183            bestChildIdx.Clear();
     184            bestChildIdx.Add(i);
     185          } else if (q == bestQ) {
     186            bestChildIdx.Add(i);
    167187          }
    168188        }
    169         Debug.Assert(bestChildIdx > -1);
    170         return bestChildIdx;
     189        Debug.Assert(bestChildIdx.Any());
     190        return bestChildIdx.SelectRandom(random);
    171191      }
    172192    }
    173193
    174194    private void DistributeReward(double reward) {
    175       const double alpha = 0.1;
    176       const double gamma = 1;
    177       // iterate in reverse order (bottom up)
    178195      updateChain.Reverse();
    179       var nextQ = 0.0;
    180       foreach (var e in updateChain) {
    181         var node = e;
    182         node.tries++;
     196      foreach (var node in updateChain) {
    183197        if (node.children != null && node.children.All(c => c.done)) {
    184198          node.done = true;
    185199        }
    186         // reward is recieved only for the last action
    187         if (e == updateChain.First()) {
    188           node.q = node.q + alpha * (reward + gamma * nextQ - node.q);
    189           nextQ = node.q;
    190         } else {
    191           node.q = node.q + alpha * (0 + gamma * nextQ - node.q);
    192           nextQ = node.q;
    193         }
    194       }
     200      }
     201      updateChain.Reverse();
     202
     203      //const double alpha = 0.1;
     204      const double gamma = 1;
     205      double alpha;
     206      foreach (var p in updateChain.Zip(updateChain.Skip(1), Tuple.Create)) {
     207        var parent = p.Item1;
     208        var child = p.Item2;
     209        parent.tries++;
     210        alpha = 1.0 / parent.tries;
     211        //alpha = 0.01;
     212        parent.q = parent.q + alpha * (0 + gamma * child.q - parent.q);
     213      }
     214      // reward is recieved only for the last action
     215      var n = updateChain.Last();
     216      n.tries++;
     217      alpha = 1.0 / n.tries;
     218      //alpha = 0.1;
     219      n.q = n.q + alpha * reward;
    195220    }
    196221
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization.SymbReg/SymbolicRegressionProblem.cs

    r11742 r11747  
    7575    // right now only + and * is supported
    7676    public string CanonicalRepresentation(string terminalPhrase) {
    77       return terminalPhrase;
    78       //var terms = terminalPhrase.Split('+');
    79       //return string.Join("+", terms.Select(term => string.Join("", term.Replace("*", "").OrderBy(ch => ch)))
    80       //  .OrderBy(term => term));
     77      //return terminalPhrase;
     78      var terms = terminalPhrase.Split('+');
     79      return string.Join("+", terms.Select(term => string.Join("", term.Replace("*", "").OrderBy(ch => ch)))
     80        .OrderBy(term => term));
    8181    }
    8282  }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization.Test/TestInstances.cs

    r11730 r11747  
    256256      Assert.AreEqual(0.116199534934045, p.Evaluate("c*f*j"), 1.0E-7);
    257257
     258      Assert.AreEqual(0.824522210419616, p.Evaluate("a*b+c*d+e*f"), 1E-7);
     259
    258260
    259261      Assert.AreEqual(1.0, p.Evaluate("a*b+c*d+e*f+a*g*i+c*f*j"), 1.0E-7);
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization.csproj

    r11732 r11747  
    4343  </ItemGroup>
    4444  <ItemGroup>
     45    <Compile Include="RoyalPhraseSequenceProblem.cs" />
     46    <Compile Include="RoyalSequenceProblem.cs" />
    4547    <Compile Include="ExpressionInterpreter.cs" />
    4648    <Compile Include="Grammar.cs" />
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization/SantaFeAntProblem.cs

    r11742 r11747  
    9999
    100100    public string CanonicalRepresentation(string terminalPhrase) {
    101       return terminalPhrase.Replace("rl", "").Replace("lr", "");
     101      //return terminalPhrase;
     102      string oldPhrase;
     103      do {
     104        oldPhrase = terminalPhrase;
     105        terminalPhrase.Replace("ll", "rr").Replace("rl", "lr");
     106      } while (terminalPhrase != oldPhrase);
     107      return terminalPhrase;
    102108    }
    103109  }
  • branches/HeuristicLab.Problems.GrammaticalOptimization/HeuristicLab.Problems.GrammaticalOptimization/SymbolicRegressionPoly10Problem.cs

    r11745 r11747  
    1616    private const string grammarString = @"
    1717    G(E):
    18     E -> a | b | c | d | e | f | g | h | i | j | a+E | b+E | c+E | d+E | e+E | f+E | g+E | h+E | i+E | j+E | a*E | b*E | c*E | d*E | e*E | f*E | g*E | h*E | i*E | j*E
     18    E -> a | b | c | d | e | f | g | h | i | j | a+E | b+E | c+E | d+E | e+E | f+E | g+E | h+E | i+E | j+E | a*E | b*E | c*E | d*E | e*E | f*E | g*E | h*E | i*E | j*E 
    1919    ";
    2020
  • branches/HeuristicLab.Problems.GrammaticalOptimization/Main/Program.cs

    r11745 r11747  
    8888          Tuple.Create((IProblem)new SymbolicRegressionPoly10Problem(), 23),
    8989        })
    90         foreach (var randomTries in new int[] { 1, 10, /* 5, 100 /*, 500, 1000 */}) {
     90        foreach (var randomTries in new int[] { 0, 1, 10, /* 5, 100 /*, 500, 1000 */}) {
    9191          foreach (var policy in policies) {
    9292            var myRandomTries = randomTries;
     
    137137
    138138    private static void RunDemo() {
     139      // TODO: unify MCTS, TD and ContextMCTS Solvers (stateInfos)
    139140      // TODO: test with eps-greedy using max instead of average as value (seems to work well for symb-reg! explore further!)
    140141      // TODO: separate value function from policy
     
    165166      var random = new Random();
    166167
    167       var problem = new SymbolicRegressionPoly10Problem();   // good results e.g. 10 randomtries and EpsGreedyPolicy(0.2, (aInfo)=>aInfo.MaxReward)
     168      var phraseLen = 1;
     169      var sentenceLen = 25;
     170      var numPhrases = sentenceLen / phraseLen;
     171      var problem = new RoyalPhraseSequenceProblem(random, 10, numPhrases, phraseLen: 1, k: 1, correctReward: 1, incorrectReward: 0);
     172
     173      //var problem = new SymbolicRegressionPoly10Problem();   // good results e.g. 10 randomtries and EpsGreedyPolicy(0.2, (aInfo)=>aInfo.MaxReward)
    168174      // Ant
    169175      // good results e.g. with       var alg = new MctsSampler(problem, 17, random, 1, (rand, numActions) => new ThresholdAscentPolicy(numActions, 500, 0.01));
     
    175181      //var problem = new RoyalPairProblem();
    176182      //var problem = new EvenParityProblem();
    177       //var alg = new MctsSampler(problem, 23, random, 0, new GaussianThompsonSamplingPolicy(true));
    178       var alg = new MctsContextualSampler(problem, 23, random, 0);
    179       //var alg = new TemporalDifferenceTreeSearchSampler(problem, 17, random, 10, new EpsGreedyPolicy(0.1));
    180       //var alg = new ExhaustiveBreadthFirstSearch(problem, 17);
     183      // symbreg length = 11 q = 0.824522210419616
     184      var alg = new MctsSampler(problem, sentenceLen, random, 0, new BoltzmannExplorationPolicy(200));
     185      //var alg = new MctsQLearningSampler(problem, sentenceLen, random, 0, null);
     186      //var alg = new MctsQLearningSampler(problem, 30, random, 0, new EpsGreedyPolicy(0.2));
     187      //var alg = new MctsContextualSampler(problem, 23, random, 0); // must visit each canonical solution only once
     188      //var alg = new TemporalDifferenceTreeSearchSampler(problem, 30, random, 1);
     189      //var alg = new ExhaustiveBreadthFirstSearch(problem, 7);
    181190      //var alg = new AlternativesContextSampler(problem, random, 17, 4, (rand, numActions) => new RandomPolicy(rand, numActions));
    182191      //var alg = new ExhaustiveDepthFirstSearch(problem, 17);
    183192      // var alg = new AlternativesSampler(problem, 17);
    184193      // var alg = new RandomSearch(problem, random, 17);
    185       // var alg = new ExhaustiveRandomFirstSearch(problem, random, 17);
     194      //var alg = new ExhaustiveRandomFirstSearch(problem, random, 17);
    186195
    187196      alg.FoundNewBestSolution += (sentence, quality) => {
     
    199208          alg.PrintStats();
    200209        }
     210        //Console.WriteLine(sentence);
    201211
    202212        if (iterations % 10000 == 0) {
    203           //Console.WriteLine("{0,10} {1,10:F5} {2,10:F5} {3}", iterations, bestQuality, quality, sentence);
    204           //Console.WriteLine("{0,4} {1,7} {2}", alg.treeDepth, alg.treeSize, globalStatistics);
    205213          //Console.WriteLine("{0,4} {1,7} {2}", alg.treeDepth, alg.treeSize, globalStatistics);
    206214        }
Note: See TracChangeset for help on using the changeset viewer.