source: trunk/HeuristicLab.Encodings.SymbolicExpressionTreeEncoding/3.4/Creators/BalancedTreeCreator.cs @ 17437

Last change on this file since 17437 was 17437, checked in by bburlacu, 20 months ago

#3039: Refactor code and fix bug in sampling random child symbols from the grammar.

File size: 10.0 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HEAL.Attic;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Parameters;
30using HeuristicLab.PluginInfrastructure;
31using HeuristicLab.Random;
32
33namespace HeuristicLab.Encodings.SymbolicExpressionTreeEncoding {
34  [NonDiscoverableType]
35  [StorableType("AA3649C4-18CF-480B-AA41-F5D6F148B494")]
36  [Item("BalancedTreeCreator", "An operator that produces trees with a specified distribution")]
37  public class BalancedTreeCreator : SymbolicExpressionTreeCreator {
38    private const string IrregularityBiasParameterName = "IrregularityBias";
39
40    public IFixedValueParameter<PercentValue> IrregularityBiasParameter {
41      get { return (IFixedValueParameter<PercentValue>)Parameters[IrregularityBiasParameterName]; }
42    }
43
44    public double IrregularityBias {
45      get { return IrregularityBiasParameter.Value.Value; }
46      set { IrregularityBiasParameter.Value.Value = value; }
47    }
48
49    [StorableConstructor]
50    protected BalancedTreeCreator(StorableConstructorFlag _) : base(_) { }
51
52    [StorableHook(HookType.AfterDeserialization)]
53    private void AfterDeserialization() {
54      if (!Parameters.ContainsKey(IrregularityBiasParameterName)) {
55        Parameters.Add(new FixedValueParameter<PercentValue>(IrregularityBiasParameterName, new PercentValue(0.0)));
56      }
57    }
58
59    protected BalancedTreeCreator(BalancedTreeCreator original, Cloner cloner) : base(original, cloner) { }
60
61    public BalancedTreeCreator() {
62      Parameters.Add(new FixedValueParameter<PercentValue>(IrregularityBiasParameterName, new PercentValue(0.0)));
63    }
64
65    public override IDeepCloneable Clone(Cloner cloner) {
66      return new BalancedTreeCreator(this, cloner);
67    }
68
69    public override ISymbolicExpressionTree CreateTree(IRandom random, ISymbolicExpressionGrammar grammar, int maxLength, int maxDepth) {
70      return Create(random, grammar, maxLength, maxDepth, IrregularityBias);
71    }
72
73    public static ISymbolicExpressionTree Create(IRandom random, ISymbolicExpressionGrammar grammar, int maxLength, int maxDepth, double irregularityBias = 0) {
74      int targetLength = random.Next(3, maxLength); // because we have 2 extra nodes for the root and start symbols, and the end is exclusive
75      return CreateExpressionTree(random, grammar, targetLength, maxDepth, irregularityBias);
76    }
77
78    public static ISymbolicExpressionTree CreateExpressionTree(IRandom random, ISymbolicExpressionGrammar grammar, int targetLength, int maxDepth, double irregularityBias = 1) {
79      // even lengths cannot be achieved without symbols of odd arity
80      // therefore we randomly pick a neighbouring odd length value
81      var tree = MakeStump(random, grammar); // create a stump consisting of just a ProgramRootSymbol and a StartSymbol
82      CreateExpression(random, tree.Root.GetSubtree(0), targetLength - tree.Length, maxDepth - 2, irregularityBias); // -2 because the stump has length 2 and depth 2
83      return tree;
84    }
85
86    private static ISymbolicExpressionTreeNode SampleNode(IRandom random, ISymbolicExpressionTreeGrammar grammar, IEnumerable<ISymbol> allowedSymbols, int minChildArity, int maxChildArity) {
87      var candidates = new List<ISymbol>();
88      var weights = new List<double>();
89
90      foreach (var s in allowedSymbols) {
91        var minSubtreeCount = grammar.GetMinimumSubtreeCount(s);
92        var maxSubtreeCount = grammar.GetMaximumSubtreeCount(s);
93
94        if (maxChildArity < minSubtreeCount || minChildArity > maxSubtreeCount) { continue; }
95
96        candidates.Add(s);
97        weights.Add(s.InitialFrequency);
98      }
99      var symbol = candidates.SampleProportional(random, 1, weights).First();
100      var node = symbol.CreateTreeNode();
101      if (node.HasLocalParameters) {
102        node.ResetLocalParameters(random);
103      }
104      return node;
105    }
106
107    public static void CreateExpression(IRandom random, ISymbolicExpressionTreeNode root, int targetLength, int maxDepth, double irregularityBias = 1) {
108      var grammar = root.Grammar;
109      var minSubtreeCount = grammar.GetMinimumSubtreeCount(root.Symbol);
110      var maxSubtreeCount = grammar.GetMinimumSubtreeCount(root.Symbol);
111      var arity = random.Next(minSubtreeCount, maxSubtreeCount + 1);
112      int openSlots = arity;
113
114      var allowedSymbols = grammar.AllowedSymbols.Where(x => !(x is ProgramRootSymbol || x is GroupSymbol || x is Defun || x is StartSymbol)).ToList();
115      bool hasUnarySymbols = allowedSymbols.Any(x => grammar.GetMinimumSubtreeCount(x) <= 1 && grammar.GetMaximumSubtreeCount(x) >= 1);
116
117      if (!hasUnarySymbols && targetLength % 2 == 0) {
118        // without functions of arity 1 some target lengths cannot be reached
119        targetLength = random.NextDouble() < 0.5 ? targetLength - 1 : targetLength + 1;
120      }
121
122      var tuples = new List<NodeInfo>(targetLength) { new NodeInfo { Node = root, Depth = 0, Arity = arity } };
123
124      // we use tuples.Count instead of targetLength in the if condition
125      // because depth limits may prevent reaching the target length
126      for (int i = 0; i < tuples.Count; ++i) {
127        var t = tuples[i];
128        var node = t.Node;
129
130        for (int childIndex = 0; childIndex < t.Arity; ++childIndex) {
131          // min and max arity here refer to the required arity limits for the child node
132          int minChildArity = 0;
133          int maxChildArity = 0;
134
135          var allowedChildSymbols = allowedSymbols.Where(x => grammar.IsAllowedChildSymbol(node.Symbol, x, childIndex)).ToList();
136
137          // if we are reaching max depth we have to fill the slot with a leaf node (max arity will be zero)
138          // otherwise, find the maximum value from the grammar which does not exceed the length limit
139          if (t.Depth < maxDepth - 1 && openSlots < targetLength) {
140
141            // we don't want to allow sampling a leaf symbol if it prevents us from reaching the target length
142            // this should be allowed only when we have enough open expansion points (more than one)
143            // the random check against the irregularity bias helps to increase shape variability when the conditions are met
144            int minAllowedArity = allowedChildSymbols.Min(x => grammar.GetMaximumSubtreeCount(x));
145            if (minAllowedArity == 0 && (openSlots - tuples.Count <= 1 || random.NextDouble() > irregularityBias)) {
146              minAllowedArity = 1;
147            }
148
149            // finally adjust min and max arity according to the expansion limits
150            int maxAllowedArity = allowedChildSymbols.Max(x => grammar.GetMaximumSubtreeCount(x));
151            maxChildArity = Math.Min(maxAllowedArity, targetLength - openSlots);
152            minChildArity = Math.Min(minAllowedArity, maxChildArity);
153          }
154         
155          // sample a random child with the arity limits
156          var child = SampleNode(random, grammar, allowedChildSymbols, minChildArity, maxChildArity);
157
158          // get actual child arity limits
159          minChildArity = Math.Max(minChildArity, grammar.GetMinimumSubtreeCount(child.Symbol));
160          maxChildArity = Math.Min(maxChildArity, grammar.GetMaximumSubtreeCount(child.Symbol));
161          minChildArity = Math.Min(minChildArity, maxChildArity);
162
163          // pick a random arity for the new child node
164          var childArity = random.Next(minChildArity, maxChildArity + 1);
165          var childDepth = t.Depth + 1;
166          node.AddSubtree(child);
167          tuples.Add(new NodeInfo { Node = child, Depth = childDepth, Arity = childArity });
168          openSlots += childArity;
169        }
170      }
171    }
172
173    protected override ISymbolicExpressionTree Create(IRandom random) {
174      var maxLength = MaximumSymbolicExpressionTreeLengthParameter.ActualValue.Value;
175      var maxDepth = MaximumSymbolicExpressionTreeDepthParameter.ActualValue.Value;
176      var grammar = ClonedSymbolicExpressionTreeGrammarParameter.ActualValue;
177      return Create(random, grammar, maxLength, maxDepth);
178    }
179
180    #region helpers
181    private class NodeInfo {
182      public ISymbolicExpressionTreeNode Node;
183      public int Depth;
184      public int Arity;
185    }
186
187    private static ISymbolicExpressionTree MakeStump(IRandom random, ISymbolicExpressionGrammar grammar) {
188      SymbolicExpressionTree tree = new SymbolicExpressionTree();
189      var rootNode = (SymbolicExpressionTreeTopLevelNode)grammar.ProgramRootSymbol.CreateTreeNode();
190      if (rootNode.HasLocalParameters) rootNode.ResetLocalParameters(random);
191      rootNode.SetGrammar(grammar.CreateExpressionTreeGrammar());
192
193      var startNode = (SymbolicExpressionTreeTopLevelNode)grammar.StartSymbol.CreateTreeNode();
194      if (startNode.HasLocalParameters) startNode.ResetLocalParameters(random);
195      startNode.SetGrammar(grammar.CreateExpressionTreeGrammar());
196
197      rootNode.AddSubtree(startNode);
198      tree.Root = rootNode;
199      return tree;
200    }
201
202    public override void CreateExpression(IRandom random, ISymbolicExpressionTreeNode seedNode, int maxTreeLength, int maxTreeDepth) {
203      CreateExpression(random, seedNode, maxTreeLength, maxTreeDepth, IrregularityBias);
204    }
205    #endregion
206  }
207}
Note: See TracBrowser for help on using the repository browser.