Ignore:
Timestamp:
02/17/20 10:18:53 (20 months ago)
Author:
bburlacu
Message:

#3039: Refactor code and fix bug in sampling random child symbols from the grammar.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/HeuristicLab.Encodings.SymbolicExpressionTreeEncoding/3.4/Creators/BalancedTreeCreator.cs

    r17347 r17437  
    7676    }
    7777
    78     private class SymbolCacheEntry {
    79       public int MinSubtreeCount;
    80       public int MaxSubtreeCount;
    81       public int[] MaxChildArity;
    82     }
    83 
    84     private class SymbolCache {
    85       public SymbolCache(ISymbolicExpressionGrammarBase grammar) {
    86         Grammar = grammar;
    87       }
    88 
    89       public ISymbolicExpressionTreeNode SampleNode(IRandom random, ISymbol parent, int childIndex, int minArity, int maxArity) {
    90         var symbols = new List<ISymbol>();
    91         var weights = new List<double>();
    92         foreach (var child in AllowedSymbols.Where(x => !(x is StartSymbol || x is Defun))) {
    93           var t = Tuple.Create(parent, child);
    94           if (!allowedCache.TryGetValue(t, out bool[] allowed)) { continue; }
    95           if (!allowed[childIndex]) { continue; }
    96 
    97           if (symbolCache.TryGetValue(child, out SymbolCacheEntry cacheItem)) {
    98             if (cacheItem.MinSubtreeCount < minArity) { continue; }
    99             if (cacheItem.MaxSubtreeCount > maxArity) { continue; }
    100           }
    101 
    102           symbols.Add(child);
    103           weights.Add(child.InitialFrequency);
    104         }
    105         if (symbols.Count == 0) {
    106           throw new ArgumentException("SampleNode: parent symbol " + parent.Name
    107             + " does not have any allowed child symbols with min arity " + minArity
    108             + " and max arity " + maxArity + ". Please ensure the grammar is properly configured.");
    109         }
    110         var symbol = symbols.SampleProportional(random, 1, weights).First();
    111         var node = symbol.CreateTreeNode();
    112         if (node.HasLocalParameters) {
    113           node.ResetLocalParameters(random);
    114         }
    115         return node;
    116       }
    117 
    118       public ISymbolicExpressionGrammarBase Grammar {
    119         get { return grammar; }
    120         set {
    121           grammar = value;
    122           RebuildCache();
    123         }
    124       }
    125 
    126       public IList<ISymbol> AllowedSymbols { get; private set; }
    127 
    128       public SymbolCacheEntry this[ISymbol symbol] {
    129         get { return symbolCache[symbol]; }
    130       }
    131 
    132       public bool[] this[ISymbol parent, ISymbol child] {
    133         get { return allowedCache[Tuple.Create(parent, child)]; }
    134       }
    135 
    136       public bool HasUnarySymbols { get; private set; }
    137 
    138       private void RebuildCache() {
    139         AllowedSymbols = Grammar.AllowedSymbols.Where(x => x.InitialFrequency > 0 && !(x is ProgramRootSymbol)).ToList();
    140 
    141         allowedCache = new Dictionary<Tuple<ISymbol, ISymbol>, bool[]>();
    142         symbolCache = new Dictionary<ISymbol, SymbolCacheEntry>();
    143 
    144         SymbolCacheEntry TryAddItem(ISymbol symbol) {
    145           if (!symbolCache.TryGetValue(symbol, out SymbolCacheEntry cacheItem)) {
    146             cacheItem = new SymbolCacheEntry {
    147               MinSubtreeCount = Grammar.GetMinimumSubtreeCount(symbol),
    148               MaxSubtreeCount = Grammar.GetMaximumSubtreeCount(symbol)
    149             };
    150             symbolCache[symbol] = cacheItem;
    151           }
    152           return cacheItem;
    153         }
    154 
    155         foreach (var parent in AllowedSymbols) {
    156           var parentCacheEntry = TryAddItem(parent);
    157           var maxChildArity = new int[parentCacheEntry.MaxSubtreeCount];
    158 
    159           if (!(parent is StartSymbol || parent is Defun)) {
    160             HasUnarySymbols |= parentCacheEntry.MaxSubtreeCount == 1;
    161           }
    162 
    163           foreach (var child in AllowedSymbols) {
    164             var childCacheEntry = TryAddItem(child);
    165             var allowed = new bool[parentCacheEntry.MaxSubtreeCount];
    166 
    167             for (int childIndex = 0; childIndex < parentCacheEntry.MaxSubtreeCount; ++childIndex) {
    168               allowed[childIndex] = Grammar.IsAllowedChildSymbol(parent, child, childIndex);
    169               maxChildArity[childIndex] = Math.Max(maxChildArity[childIndex], allowed[childIndex] ? childCacheEntry.MaxSubtreeCount : 0);
    170             }
    171             allowedCache[Tuple.Create(parent, child)] = allowed;
    172           }
    173           parentCacheEntry.MaxChildArity = maxChildArity;
    174         }
    175       }
    176 
    177       private ISymbolicExpressionGrammarBase grammar;
    178       private Dictionary<Tuple<ISymbol, ISymbol>, bool[]> allowedCache;
    179       private Dictionary<ISymbol, SymbolCacheEntry> symbolCache;
    180     }
    181 
    18278    public static ISymbolicExpressionTree CreateExpressionTree(IRandom random, ISymbolicExpressionGrammar grammar, int targetLength, int maxDepth, double irregularityBias = 1) {
    18379      // even lengths cannot be achieved without symbols of odd arity
    18480      // therefore we randomly pick a neighbouring odd length value
    18581      var tree = MakeStump(random, grammar); // create a stump consisting of just a ProgramRootSymbol and a StartSymbol
    186       CreateExpression(random, tree.Root.GetSubtree(0), targetLength - 2, maxDepth - 2, irregularityBias); // -2 because the stump has length 2 and depth 2
     82      CreateExpression(random, tree.Root.GetSubtree(0), targetLength - tree.Length, maxDepth - 2, irregularityBias); // -2 because the stump has length 2 and depth 2
    18783      return tree;
     84    }
     85
     86    private static ISymbolicExpressionTreeNode SampleNode(IRandom random, ISymbolicExpressionTreeGrammar grammar, IEnumerable<ISymbol> allowedSymbols, int minChildArity, int maxChildArity) {
     87      var candidates = new List<ISymbol>();
     88      var weights = new List<double>();
     89
     90      foreach (var s in allowedSymbols) {
     91        var minSubtreeCount = grammar.GetMinimumSubtreeCount(s);
     92        var maxSubtreeCount = grammar.GetMaximumSubtreeCount(s);
     93
     94        if (maxChildArity < minSubtreeCount || minChildArity > maxSubtreeCount) { continue; }
     95
     96        candidates.Add(s);
     97        weights.Add(s.InitialFrequency);
     98      }
     99      var symbol = candidates.SampleProportional(random, 1, weights).First();
     100      var node = symbol.CreateTreeNode();
     101      if (node.HasLocalParameters) {
     102        node.ResetLocalParameters(random);
     103      }
     104      return node;
    188105    }
    189106
    190107    public static void CreateExpression(IRandom random, ISymbolicExpressionTreeNode root, int targetLength, int maxDepth, double irregularityBias = 1) {
    191108      var grammar = root.Grammar;
    192       var symbolCache = new SymbolCache(grammar);
    193       var entry = symbolCache[root.Symbol];
    194       var arity = random.Next(entry.MinSubtreeCount, entry.MaxSubtreeCount + 1);
     109      var minSubtreeCount = grammar.GetMinimumSubtreeCount(root.Symbol);
     110      var maxSubtreeCount = grammar.GetMinimumSubtreeCount(root.Symbol);
     111      var arity = random.Next(minSubtreeCount, maxSubtreeCount + 1);
     112      int openSlots = arity;
     113
     114      var allowedSymbols = grammar.AllowedSymbols.Where(x => !(x is ProgramRootSymbol || x is GroupSymbol || x is Defun || x is StartSymbol)).ToList();
     115      bool hasUnarySymbols = allowedSymbols.Any(x => grammar.GetMinimumSubtreeCount(x) <= 1 && grammar.GetMaximumSubtreeCount(x) >= 1);
     116
     117      if (!hasUnarySymbols && targetLength % 2 == 0) {
     118        // without functions of arity 1 some target lengths cannot be reached
     119        targetLength = random.NextDouble() < 0.5 ? targetLength - 1 : targetLength + 1;
     120      }
     121
    195122      var tuples = new List<NodeInfo>(targetLength) { new NodeInfo { Node = root, Depth = 0, Arity = arity } };
    196       int openSlots = arity;
    197 
     123
     124      // we use tuples.Count instead of targetLength in the if condition
     125      // because depth limits may prevent reaching the target length
    198126      for (int i = 0; i < tuples.Count; ++i) {
    199127        var t = tuples[i];
    200128        var node = t.Node;
    201         var parentEntry = symbolCache[node.Symbol];
    202129
    203130        for (int childIndex = 0; childIndex < t.Arity; ++childIndex) {
    204131          // min and max arity here refer to the required arity limits for the child node
    205           int maxChildArity = t.Depth == maxDepth - 1 ? 0 : Math.Min(parentEntry.MaxChildArity[childIndex], targetLength - openSlots);
    206           int minChildArity = Math.Min((openSlots - tuples.Count > 1 && random.NextDouble() < irregularityBias) ? 0 : 1, maxChildArity);
    207           var child = symbolCache.SampleNode(random, node.Symbol, childIndex, minChildArity, maxChildArity);
    208           var childEntry = symbolCache[child.Symbol];
    209           var childArity = random.Next(childEntry.MinSubtreeCount, childEntry.MaxSubtreeCount + 1);
     132          int minChildArity = 0;
     133          int maxChildArity = 0;
     134
     135          var allowedChildSymbols = allowedSymbols.Where(x => grammar.IsAllowedChildSymbol(node.Symbol, x, childIndex)).ToList();
     136
     137          // if we are reaching max depth we have to fill the slot with a leaf node (max arity will be zero)
     138          // otherwise, find the maximum value from the grammar which does not exceed the length limit
     139          if (t.Depth < maxDepth - 1 && openSlots < targetLength) {
     140
     141            // we don't want to allow sampling a leaf symbol if it prevents us from reaching the target length
     142            // this should be allowed only when we have enough open expansion points (more than one)
     143            // the random check against the irregularity bias helps to increase shape variability when the conditions are met
     144            int minAllowedArity = allowedChildSymbols.Min(x => grammar.GetMaximumSubtreeCount(x));
     145            if (minAllowedArity == 0 && (openSlots - tuples.Count <= 1 || random.NextDouble() > irregularityBias)) {
     146              minAllowedArity = 1;
     147            }
     148
     149            // finally adjust min and max arity according to the expansion limits
     150            int maxAllowedArity = allowedChildSymbols.Max(x => grammar.GetMaximumSubtreeCount(x));
     151            maxChildArity = Math.Min(maxAllowedArity, targetLength - openSlots);
     152            minChildArity = Math.Min(minAllowedArity, maxChildArity);
     153          }
     154         
     155          // sample a random child with the arity limits
     156          var child = SampleNode(random, grammar, allowedChildSymbols, minChildArity, maxChildArity);
     157
     158          // get actual child arity limits
     159          minChildArity = Math.Max(minChildArity, grammar.GetMinimumSubtreeCount(child.Symbol));
     160          maxChildArity = Math.Min(maxChildArity, grammar.GetMaximumSubtreeCount(child.Symbol));
     161          minChildArity = Math.Min(minChildArity, maxChildArity);
     162
     163          // pick a random arity for the new child node
     164          var childArity = random.Next(minChildArity, maxChildArity + 1);
    210165          var childDepth = t.Depth + 1;
    211166          node.AddSubtree(child);
     
    244199      return tree;
    245200    }
     201
     202    public override void CreateExpression(IRandom random, ISymbolicExpressionTreeNode seedNode, int maxTreeLength, int maxTreeDepth) {
     203      CreateExpression(random, seedNode, maxTreeLength, maxTreeDepth, IrregularityBias);
     204    }
    246205    #endregion
    247206  }
Note: See TracChangeset for help on using the changeset viewer.