source: trunk/HeuristicLab.Encodings.SymbolicExpressionTreeEncoding/3.4/Creators/BalancedTreeCreator.cs @ 17345

Last change on this file since 17345 was 17345, checked in by bburlacu, 2 years ago

#3039: Small bugfix.

File size: 9.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HEAL.Attic;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.PluginInfrastructure;
29using HeuristicLab.Random;
30
31namespace HeuristicLab.Encodings.SymbolicExpressionTreeEncoding {
32  [NonDiscoverableType]
33  [StorableType("AA3649C4-18CF-480B-AA41-F5D6F148B494")]
34  [Item("BalancedTreeCreator", "An operator that produces trees with a specified distribution")]
35  public class BalancedTreeCreator : SymbolicExpressionTreeCreator {
36    [StorableConstructor]
37    protected BalancedTreeCreator(StorableConstructorFlag _) : base(_) { }
38
39    protected BalancedTreeCreator(BalancedTreeCreator original, Cloner cloner) : base(original, cloner) { }
40
41    public BalancedTreeCreator() { }
42
43    public override IDeepCloneable Clone(Cloner cloner) {
44      return new BalancedTreeCreator(this, cloner);
45    }
46
47    public override ISymbolicExpressionTree CreateTree(IRandom random, ISymbolicExpressionGrammar grammar, int maxLength, int maxDepth) {
48      return Create(random, grammar, maxLength, maxDepth);
49    }
50
51    public static ISymbolicExpressionTree Create(IRandom random, ISymbolicExpressionGrammar grammar, int maxLength, int maxDepth) {
52      int targetLength = random.Next(3, maxLength); // because we have 2 extra nodes for the root and start symbols, and the end is exclusive
53      return CreateExpressionTree(random, grammar, targetLength, maxDepth);
54    }
55
56    private class SymbolCacheEntry {
57      public int MinSubtreeCount;
58      public int MaxSubtreeCount;
59      public int[] MaxChildArity;
60    }
61
62    private class SymbolCache {
63      public SymbolCache(ISymbolicExpressionGrammar grammar) {
64        Grammar = grammar;
65      }
66
67      public ISymbolicExpressionTreeNode SampleNode(IRandom random, ISymbol parent, int childIndex, int minArity, int maxArity) {
68        var symbols = new List<ISymbol>();
69        var weights = new List<double>();
70        foreach (var child in AllowedSymbols.Where(x => !(x is StartSymbol || x is Defun))) {
71          var t = Tuple.Create(parent, child);
72          if (!allowedCache.TryGetValue(t, out bool[] allowed)) { continue; }
73          if (!allowed[childIndex]) { continue; }
74
75          if (symbolCache.TryGetValue(child, out SymbolCacheEntry cacheItem)) {
76            if (cacheItem.MinSubtreeCount < minArity) { continue; }
77            if (cacheItem.MaxSubtreeCount > maxArity) { continue; }
78          }
79
80          symbols.Add(child);
81          weights.Add(child.InitialFrequency);
82        }
83        if (!symbols.Any()) {
84          throw new ArgumentException("SampleNode: parent symbol " + parent.Name
85            + " does not have any allowed child symbols with min arity " + minArity
86            + " and max arity " + maxArity + ". Please ensure the grammar is properly configured.");
87        }
88        var symbol = symbols.SampleProportional(random, 1, weights).First();
89        var node = symbol.CreateTreeNode();
90        if (node.HasLocalParameters) {
91          node.ResetLocalParameters(random);
92        }
93        return node;
94      }
95
96      public ISymbolicExpressionGrammar Grammar {
97        get { return grammar; }
98        set {
99          grammar = value;
100          RebuildCache();
101        }
102      }
103
104      public IList<ISymbol> AllowedSymbols { get; private set; }
105
106      public SymbolCacheEntry this[ISymbol symbol] {
107        get { return symbolCache[symbol]; }
108      }
109
110      public bool[] this[ISymbol parent, ISymbol child] {
111        get { return allowedCache[Tuple.Create(parent, child)]; }
112      }
113
114      public bool HasUnarySymbols { get; private set; }
115
116      private void RebuildCache() {
117        AllowedSymbols = Grammar.AllowedSymbols.Where(x => x.InitialFrequency > 0 && !(x is ProgramRootSymbol)).ToList();
118
119        allowedCache = new Dictionary<Tuple<ISymbol, ISymbol>, bool[]>();
120        symbolCache = new Dictionary<ISymbol, SymbolCacheEntry>();
121
122        SymbolCacheEntry TryAddItem(ISymbol symbol) {
123          if (!symbolCache.TryGetValue(symbol, out SymbolCacheEntry cacheItem)) {
124            cacheItem = new SymbolCacheEntry {
125              MinSubtreeCount = Grammar.GetMinimumSubtreeCount(symbol),
126              MaxSubtreeCount = Grammar.GetMaximumSubtreeCount(symbol)
127            };
128            symbolCache[symbol] = cacheItem;
129          }
130          return cacheItem;
131        }
132
133        foreach (var parent in AllowedSymbols) {
134          var parentCacheEntry = TryAddItem(parent);
135          var maxChildArity = new int[parentCacheEntry.MaxSubtreeCount];
136
137          if (!(parent is StartSymbol || parent is Defun)) {
138            HasUnarySymbols |= parentCacheEntry.MaxSubtreeCount == 1;
139          }
140
141          foreach (var child in AllowedSymbols) {
142            var childCacheEntry = TryAddItem(child);
143            var allowed = new bool[parentCacheEntry.MaxSubtreeCount];
144
145            for (int childIndex = 0; childIndex < parentCacheEntry.MaxSubtreeCount; ++childIndex) {
146              allowed[childIndex] = Grammar.IsAllowedChildSymbol(parent, child, childIndex);
147              maxChildArity[childIndex] = Math.Max(maxChildArity[childIndex], allowed[childIndex] ? childCacheEntry.MaxSubtreeCount : 0);
148            }
149            allowedCache[Tuple.Create(parent, child)] = allowed;
150          }
151          parentCacheEntry.MaxChildArity = maxChildArity;
152        }
153      }
154
155      private ISymbolicExpressionGrammar grammar;
156      private Dictionary<Tuple<ISymbol, ISymbol>, bool[]> allowedCache;
157      private Dictionary<ISymbol, SymbolCacheEntry> symbolCache;
158    }
159
160    public static ISymbolicExpressionTree CreateExpressionTree(IRandom random, ISymbolicExpressionGrammar grammar, int targetLength, int maxDepth) {
161      // even lengths cannot be achieved without symbols of odd arity
162      // therefore we randomly pick a neighbouring odd length value
163      var symbolCache = new SymbolCache(grammar);
164      if (!symbolCache.HasUnarySymbols && targetLength % 2 == 0) {
165        targetLength += random.NextDouble() < 0.5 ? -1 : +1;
166      }
167      return CreateExpressionTree(random, symbolCache, targetLength, maxDepth);
168    }
169
170    private static ISymbolicExpressionTree CreateExpressionTree(IRandom random, SymbolCache symbolCache, int targetLength, int maxDepth) {
171      var allowedSymbols = symbolCache.AllowedSymbols;
172      var tree = MakeStump(random, symbolCache.Grammar);
173      var tuples = new List<NodeInfo>(targetLength) {
174        new NodeInfo { Node = tree.Root, Depth = 0, Arity = 1 },
175        new NodeInfo { Node = tree.Root.GetSubtree(0), Depth = 1, Arity = 1 }
176      };
177      targetLength -= 2; // remaining length; -2 because we already have a root and start node
178      int openSlots = 1; // remaining extension points; startNode has arity 1
179
180      for (int i = 1; i < tuples.Count; ++i) {
181        var t = tuples[i];
182        var node = t.Node;
183        var parentEntry = symbolCache[node.Symbol];
184
185        for (int childIndex = 0; childIndex < t.Arity; ++childIndex) {
186          // min and max arity here refer to the required arity limits for the child node
187          int maxChildArity = t.Depth == maxDepth - 1 ? 0 : Math.Min(parentEntry.MaxChildArity[childIndex], targetLength - openSlots);
188          int minChildArity = Math.Min(1, maxChildArity);
189          var child = symbolCache.SampleNode(random, node.Symbol, childIndex, minChildArity, maxChildArity);
190          var childEntry = symbolCache[child.Symbol];
191          var childArity = random.Next(childEntry.MinSubtreeCount, childEntry.MaxSubtreeCount + 1);
192          var childDepth = t.Depth + 1;
193          node.AddSubtree(child);
194          tuples.Add(new NodeInfo { Node = child, Depth = childDepth, Arity = childArity });
195          openSlots += childArity;
196        }
197      }
198      return tree;
199    }
200
201    protected override ISymbolicExpressionTree Create(IRandom random) {
202      var maxLength = MaximumSymbolicExpressionTreeLengthParameter.ActualValue.Value;
203      var maxDepth = MaximumSymbolicExpressionTreeDepthParameter.ActualValue.Value;
204      var grammar = ClonedSymbolicExpressionTreeGrammarParameter.ActualValue;
205      return Create(random, grammar, maxLength, maxDepth);
206    }
207
208    #region helpers
209    private class NodeInfo {
210      public ISymbolicExpressionTreeNode Node;
211      public int Depth;
212      public int Arity;
213    }
214
215    private static ISymbolicExpressionTree MakeStump(IRandom random, ISymbolicExpressionGrammar grammar) {
216      SymbolicExpressionTree tree = new SymbolicExpressionTree();
217      var rootNode = (SymbolicExpressionTreeTopLevelNode)grammar.ProgramRootSymbol.CreateTreeNode();
218      if (rootNode.HasLocalParameters) rootNode.ResetLocalParameters(random);
219      rootNode.SetGrammar(grammar.CreateExpressionTreeGrammar());
220
221      var startNode = (SymbolicExpressionTreeTopLevelNode)grammar.StartSymbol.CreateTreeNode();
222      if (startNode.HasLocalParameters) startNode.ResetLocalParameters(random);
223      startNode.SetGrammar(grammar.CreateExpressionTreeGrammar());
224
225      rootNode.AddSubtree(startNode);
226      tree.Root = rootNode;
227      return tree;
228    }
229    #endregion
230  }
231}
Note: See TracBrowser for help on using the repository browser.