Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.GP/3.3/TreeGardener.cs @ 2594

Last change on this file since 2594 was 2566, checked in by gkronber, 15 years ago

Added new predefined function libraries for symbolic regression algorithms. Changed CEDMA dispatcher to choose a function library randomly. #813 (GP structure-identification algorithms that use only a simple function library)

File size: 20.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2008 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using HeuristicLab.Core;
25using System.Linq;
26using System.Collections;
27using HeuristicLab.GP.Interfaces;
28
29namespace HeuristicLab.GP {
30  public class TreeGardener {
31    private IRandom random;
32    private FunctionLibrary funLibrary;
33    private List<IFunction> functions;
34
35    private List<IFunction> terminals;
36    public IList<IFunction> Terminals {
37      get { return terminals; }
38    }
39
40    private List<IFunction> allFunctions;
41    public IList<IFunction> AllFunctions {
42      get { return allFunctions; }
43    }
44
45    #region constructors
46    public TreeGardener(IRandom random, FunctionLibrary funLibrary) {
47      this.random = random;
48      this.funLibrary = funLibrary;
49      this.allFunctions = new List<IFunction>();
50      terminals = new List<IFunction>();
51      functions = new List<IFunction>();
52      // init functions and terminals based on constraints
53      foreach (IFunction fun in funLibrary.Functions) {
54        if (fun.MaxSubTrees == 0) {
55          terminals.Add(fun);
56          allFunctions.Add(fun);
57        } else {
58          functions.Add(fun);
59          allFunctions.Add(fun);
60        }
61      }
62    }
63    #endregion
64
65    #region random initialization
66    /// <summary>
67    /// Creates a random balanced tree with a maximal size and height. When the max-height or max-size are 1 it will return a random terminal.
68    /// In other cases it will return either a terminal (tree of size 1) or any other tree with a function in it's root (at least height 2).
69    /// </summary>
70    /// <param name="maxTreeSize">Maximal size of the tree (number of nodes).</param>
71    /// <param name="maxTreeHeight">Maximal height of the tree.</param>
72    /// <returns></returns>
73    public IFunctionTree CreateBalancedRandomTree(int maxTreeSize, int maxTreeHeight) {
74      IFunction rootFunction = GetRandomRoot(maxTreeSize, maxTreeHeight);
75      IFunctionTree tree = MakeBalancedTree(rootFunction, maxTreeHeight - 1);
76      return tree;
77    }
78
79    /// <summary>
80    /// Creates a random (unbalanced) tree with a maximal size and height. When the max-height or max-size are 1 it will return a random terminal.
81    /// In other cases it will return either a terminal (tree of size 1) or any other tree with a function in it's root (at least height 2).
82    /// </summary>
83    /// <param name="maxTreeSize">Maximal size of the tree (number of nodes).</param>
84    /// <param name="maxTreeHeight">Maximal height of the tree.</param>
85    /// <returns></returns>
86    public IFunctionTree CreateUnbalancedRandomTree(int maxTreeSize, int maxTreeHeight) {
87      IFunction rootFunction = GetRandomRoot(maxTreeSize, maxTreeHeight);
88      IFunctionTree tree = MakeUnbalancedTree(rootFunction, maxTreeHeight - 1);
89      return tree;
90    }
91
92    public IFunctionTree PTC2(int size, int maxDepth) {
93      return PTC2(GetRandomRoot(size, maxDepth), size, maxDepth);
94    }
95
96    private IFunctionTree PTC2(IFunction rootFunction, int size, int maxDepth) {
97      IFunctionTree root = rootFunction.GetTreeNode();
98      if (size <= 1 || maxDepth <= 1) return root;
99      List<object[]> list = new List<object[]>();
100      int currentSize = 1;
101      int totalListMinSize = 0;
102      int minArity = root.Function.MinSubTrees;
103      int maxArity = root.Function.MaxSubTrees;
104      if (maxArity >= size) {
105        maxArity = size;
106      }
107      int actualArity = random.Next(minArity, maxArity + 1);
108      totalListMinSize += root.Function.MinTreeSize - 1;
109      for (int i = 0; i < actualArity; i++) {
110        // insert a dummy sub-tree and add the pending extension to the list
111        root.AddSubTree(null);
112        list.Add(new object[] { root, i, 2 });
113      }
114      if (IsRecursiveExpansionPossible(root.Function)) {
115        while (list.Count > 0 && totalListMinSize + currentSize < size) {
116          int randomIndex = random.Next(list.Count);
117          object[] nextExtension = list[randomIndex];
118          list.RemoveAt(randomIndex);
119          IFunctionTree parent = (IFunctionTree)nextExtension[0];
120          int a = (int)nextExtension[1];
121          int d = (int)nextExtension[2];
122          if (d == maxDepth) {
123            parent.RemoveSubTree(a);
124            IFunctionTree branch = CreateRandomTree(GetAllowedSubFunctions(parent.Function, a), 1, 1);
125            parent.InsertSubTree(a, branch); // insert a smallest possible tree
126            currentSize += branch.GetSize();
127            totalListMinSize -= branch.GetSize();
128          } else {
129            IFunction selectedFunction = TreeGardener.RandomSelect(random, GetAllowedSubFunctions(parent.Function, a).Where(
130              f => IsRecursiveExpansionPossible(f) && f.MinTreeHeight + (d - 1) <= maxDepth).ToArray());
131            IFunctionTree newTree = selectedFunction.GetTreeNode();
132            parent.RemoveSubTree(a);
133            parent.InsertSubTree(a, newTree);
134            currentSize++;
135            totalListMinSize--;
136
137            minArity = selectedFunction.MinSubTrees;
138            maxArity = selectedFunction.MaxSubTrees;
139            if (maxArity >= size) {
140              maxArity = size;
141            }
142            actualArity = random.Next(minArity, maxArity + 1);
143            for (int i = 0; i < actualArity; i++) {
144              // insert a dummy sub-tree and add the pending extension to the list
145              newTree.AddSubTree(null);
146              list.Add(new object[] { newTree, i, d + 1 });
147            }
148            totalListMinSize += newTree.Function.MinTreeSize - 1;
149          }
150        }
151      }
152      while (list.Count > 0) {
153        int randomIndex = random.Next(list.Count);
154        object[] nextExtension = list[randomIndex];
155        list.RemoveAt(randomIndex);
156        IFunctionTree parent = (IFunctionTree)nextExtension[0];
157        int a = (int)nextExtension[1];
158        int d = (int)nextExtension[2];
159        parent.RemoveSubTree(a);
160        parent.InsertSubTree(a,
161          CreateRandomTree(GetAllowedSubFunctions(parent.Function, a), 1, 1)); // append a tree with minimal possible height
162      }
163      return root;
164    }
165
166    private bool IsRecursiveExpansionPossible(IFunction parent) {
167      return FindCycle(parent, new Stack<IFunction>());
168    }
169
170    private Dictionary<IFunction, bool> inCycle = new Dictionary<IFunction, bool>();
171    private bool FindCycle(IFunction parent, Stack<IFunction> parentChain) {
172      if (inCycle.ContainsKey(parent)) {
173        return inCycle[parent];
174      } else if (IsTerminal(parent)) {
175        inCycle[parent] = false;
176        return false;
177      } else if (parentChain.Contains(parent)) {
178        inCycle[parent] = true;
179        return true;
180      } else {
181        parentChain.Push(parent);
182        bool result = false;
183        // all slot indexes
184        for (int i = 0; i < parent.MaxSubTrees; i++) {
185          foreach (IFunction subFunction in GetAllowedSubFunctions(parent, i)) {
186            result |= FindCycle(subFunction, parentChain);
187          }
188        }
189
190        parentChain.Pop();
191        inCycle[parent] = result;
192        return result;
193      }
194    }
195
196    /// <summary>
197    /// selects a random function from allowedFunctions and creates a random (unbalanced) tree with maximal size and height.
198    /// </summary>
199    /// <param name="allowedFunctions">Set of allowed functions.</param>
200    /// <param name="maxTreeSize">Maximal size of the tree (number of nodes).</param>
201    /// <param name="maxTreeHeight">Maximal height of the tree.</param>
202    /// <returns>New random unbalanced tree</returns>
203    public IFunctionTree CreateRandomTree(ICollection<IFunction> allowedFunctions, int maxTreeSize, int maxTreeHeight) {
204      // get the minimal needed height based on allowed functions and extend the max-height if necessary
205      int minTreeHeight = allowedFunctions.Select(f => f.MinTreeHeight).Min();
206      if (minTreeHeight > maxTreeHeight)
207        maxTreeHeight = minTreeHeight;
208      // get the minimal needed size based on allowed functions and extend the max-size if necessary
209      int minTreeSize = allowedFunctions.Select(f => f.MinTreeSize).Min();
210      if (minTreeSize > maxTreeSize)
211        maxTreeSize = minTreeSize;
212
213      // select a random value for the size and height
214      int treeHeight = random.Next(minTreeHeight, maxTreeHeight + 1);
215      int treeSize = random.Next(minTreeSize, maxTreeSize + 1);
216
217      // filter the set of allowed functions and select only from those that fit into the given maximal size and height limits
218      IFunction[] possibleFunctions = allowedFunctions.Where(f => f.MinTreeHeight <= treeHeight &&
219        f.MinTreeSize <= treeSize).ToArray();
220      IFunction selectedFunction = RandomSelect(possibleFunctions);
221
222      // build the tree
223      IFunctionTree root;
224      root = PTC2(selectedFunction, maxTreeSize, maxTreeHeight);
225      return root;
226    }
227    #endregion
228
229    #region tree information gathering
230    public IFunctionTree GetRandomParentNode(IFunctionTree tree) {
231      List<IFunctionTree> parentNodes = new List<IFunctionTree>();
232
233      // add null for the parent of the root node
234      parentNodes.Add(null);
235
236      TreeForEach(tree, delegate(IFunctionTree possibleParentNode) {
237        if (possibleParentNode.SubTrees.Count > 0) {
238          parentNodes.Add(possibleParentNode);
239        }
240      });
241
242      return parentNodes[random.Next(parentNodes.Count)];
243    }
244
245    public static ICollection<IFunctionTree> GetAllSubTrees(IFunctionTree root) {
246      List<IFunctionTree> allTrees = new List<IFunctionTree>();
247      TreeForEach(root, t => { allTrees.Add(t); });
248      return allTrees;
249    }
250
251    /// <summary>
252    /// returns the height level of branch in the tree
253    /// if the branch == tree => 1
254    /// if branch is in the sub-trees of tree => 2
255    /// ...
256    /// if branch is not found => -1
257    /// </summary>
258    /// <param name="tree">root of the function tree to process</param>
259    /// <param name="branch">branch that is searched in the tree</param>
260    /// <returns></returns>
261    public int GetBranchLevel(IFunctionTree tree, IFunctionTree branch) {
262      return GetBranchLevelHelper(tree, branch, 1);
263    }
264
265    // 'tail-recursive' helper
266    private int GetBranchLevelHelper(IFunctionTree tree, IFunctionTree branch, int level) {
267      if (branch == tree) return level;
268
269      foreach (IFunctionTree subTree in tree.SubTrees) {
270        int result = GetBranchLevelHelper(subTree, branch, level + 1);
271        if (result != -1) return result;
272      }
273
274      return -1;
275    }
276
277    public bool IsValidTree(IFunctionTree tree) {
278      for (int i = 0; i < tree.SubTrees.Count; i++) {
279        if (!tree.Function.GetAllowedSubFunctions(i).Contains(tree.SubTrees[i].Function)) return false;
280      }
281
282      if (tree.SubTrees.Count < tree.Function.MinSubTrees || tree.SubTrees.Count > tree.Function.MaxSubTrees)
283        return false;
284      foreach (IFunctionTree subTree in tree.SubTrees) {
285        if (!IsValidTree(subTree)) return false;
286      }
287      return true;
288    }
289
290    // returns a random branch from the specified level in the tree
291    public IFunctionTree GetRandomBranch(IFunctionTree tree, int level) {
292      if (level == 0) return tree;
293      List<IFunctionTree> branches = new List<IFunctionTree>();
294      GetBranchesAtLevel(tree, level, branches);
295      return branches[random.Next(branches.Count)];
296    }
297    #endregion
298
299    #region function information (arity, allowed childs and parents)
300    internal ICollection<IFunction> GetPossibleParents(List<IFunction> list) {
301      List<IFunction> result = new List<IFunction>();
302      foreach (IFunction f in functions) {
303        if (IsPossibleParent(f, list)) {
304          result.Add(f);
305        }
306      }
307      return result;
308    }
309
310    private bool IsPossibleParent(IFunction f, List<IFunction> children) {
311      int minArity = f.MinSubTrees;
312      int maxArity = f.MaxSubTrees;
313      // note: we can't assume that the operators in the children list have different types!
314
315      // when the maxArity of this function is smaller than the list of operators that
316      // should be included as sub-operators then it can't be a parent
317      if (maxArity < children.Count()) {
318        return false;
319      }
320      int nSlots = Math.Max(minArity, children.Count);
321
322      List<HashSet<IFunction>> slotSets = new List<HashSet<IFunction>>();
323
324      // we iterate through all slots for sub-trees and calculate the set of
325      // allowed functions for this slot.
326      // we only count those slots that can hold at least one of the children that we should combine
327      for (int slot = 0; slot < nSlots; slot++) {
328        HashSet<IFunction> functionSet = new HashSet<IFunction>(f.GetAllowedSubFunctions(slot));
329        if (functionSet.Count() > 0) {
330          slotSets.Add(functionSet);
331        }
332      }
333
334      // ok at the end of this operation we know how many slots of the parent can actually
335      // hold one of our children.
336      // if the number of slots is smaller than the number of children we can be sure that
337      // we can never combine all children as sub-trees of the function and thus the function
338      // can't be a parent.
339      if (slotSets.Count() < children.Count()) {
340        return false;
341      }
342
343      // finally we sort the sets by size and beginning from the first set select one
344      // function for the slot and thus remove it as possible sub-tree from the remaining sets.
345      // when we can successfully assign all available children to a slot the function is a valid parent
346      // when only a subset of all children can be assigned to slots the function is no valid parent
347      slotSets.Sort((p, q) => p.Count() - q.Count());
348
349      int assignments = 0;
350      for (int i = 0; i < slotSets.Count() - 1; i++) {
351        if (slotSets[i].Count > 0) {
352          IFunction selected = slotSets[i].ElementAt(0);
353          assignments++;
354          for (int j = i + 1; j < slotSets.Count(); j++) {
355            slotSets[j].Remove(selected);
356          }
357        }
358      }
359
360      // sanity check
361      if (assignments > children.Count) throw new InvalidProgramException();
362      return assignments == children.Count - 1;
363    }
364    public IList<IFunction> GetAllowedParents(IFunction child, int childIndex) {
365      List<IFunction> parents = new List<IFunction>();
366      foreach (IFunction function in functions) {
367        ICollection<IFunction> allowedSubFunctions = GetAllowedSubFunctions(function, childIndex);
368        if (allowedSubFunctions.Contains(child)) {
369          parents.Add(function);
370        }
371      }
372      return parents;
373    }
374    public static bool IsTerminal(IFunction f) {
375      return f.MinSubTrees == 0 && f.MaxSubTrees == 0;
376    }
377    public ICollection<IFunction> GetAllowedSubFunctions(IFunction f, int index) {
378      if (f == null) {
379        return allFunctions;
380      } else {
381        return f.GetAllowedSubFunctions(index);
382      }
383    }
384    #endregion
385
386    #region private utility methods
387    public IFunction GetRandomRoot(int maxTreeSize, int maxTreeHeight) {
388      if (maxTreeHeight == 1 || maxTreeSize == 1) {
389        IFunction selectedTerminal = RandomSelect(terminals);
390        return selectedTerminal;
391      } else {
392        int minExpandableTreeSize = (from f in functions
393                                     where IsRecursiveExpansionPossible(f)
394                                     select f.MinTreeSize).Min();
395        int minExpandableTreeHeight = (from f in functions
396                                       where IsRecursiveExpansionPossible(f)
397                                       select f.MinTreeHeight).Min();
398        IFunction[] possibleFunctions;
399        if (maxTreeSize < minExpandableTreeSize || maxTreeHeight < minExpandableTreeHeight) {
400          possibleFunctions = functions.Where(f => f.MinTreeHeight <= maxTreeHeight &&
401            f.MinTreeSize <= maxTreeSize).ToArray();
402        } else {
403          possibleFunctions = functions.Where(f => f.MinTreeHeight <= maxTreeHeight &&
404            f.MinTreeSize <= maxTreeSize && IsRecursiveExpansionPossible(f)).ToArray();
405        }
406        return RandomSelect(possibleFunctions);
407      }
408    }
409
410
411    private IFunctionTree MakeUnbalancedTree(IFunction parent, int maxTreeHeight) {
412      if (maxTreeHeight == 0) return parent.GetTreeNode();
413      int minArity = parent.MinSubTrees;
414      int maxArity = parent.MaxSubTrees;
415      int actualArity = random.Next(minArity, maxArity + 1);
416      if (actualArity > 0) {
417        IFunctionTree parentTree = parent.GetTreeNode();
418        for (int i = 0; i < actualArity; i++) {
419          IFunction[] possibleFunctions = GetAllowedSubFunctions(parent, i).Where(f => f.MinTreeHeight <= maxTreeHeight).ToArray();
420          IFunction selectedFunction = RandomSelect(possibleFunctions);
421          IFunctionTree newSubTree = MakeUnbalancedTree(selectedFunction, maxTreeHeight - 1);
422          parentTree.InsertSubTree(i, newSubTree);
423        }
424        return parentTree;
425      }
426      return parent.GetTreeNode();
427    }
428
429
430    // NOTE: this method doesn't build fully balanced trees because we have constraints on the
431    // types of possible sub-functions which can indirectly impose a limit for the depth of a given sub-tree
432    private IFunctionTree MakeBalancedTree(IFunction parent, int maxTreeHeight) {
433      if (maxTreeHeight == 0) return parent.GetTreeNode();
434      int minArity = parent.MinSubTrees;
435      int maxArity = parent.MaxSubTrees;
436      int actualArity = random.Next(minArity, maxArity + 1);
437      if (actualArity > 0) {
438        IFunctionTree parentTree = parent.GetTreeNode();
439        for (int i = 0; i < actualArity; i++) {
440          // first try to find a function that fits into the maxHeight limit
441          IFunction[] possibleFunctions = GetAllowedSubFunctions(parent, i).Where(f => f.MinTreeHeight <= maxTreeHeight &&
442            !IsTerminal(f)).ToArray();
443          // no possible function found => extend function set to terminals
444          if (possibleFunctions.Length == 0) {
445            possibleFunctions = GetAllowedSubFunctions(parent, i).Where(f => IsTerminal(f)).ToArray();
446            IFunction selectedTerminal = RandomSelect(possibleFunctions);
447            IFunctionTree newTree = selectedTerminal.GetTreeNode();
448            parentTree.InsertSubTree(i, newTree);
449          } else {
450            IFunction selectedFunction = RandomSelect(possibleFunctions);
451            IFunctionTree newTree = MakeBalancedTree(selectedFunction, maxTreeHeight - 1);
452            parentTree.InsertSubTree(i, newTree);
453          }
454        }
455        return parentTree;
456      }
457      return parent.GetTreeNode();
458    }
459
460    private static void TreeForEach(IFunctionTree tree, Action<IFunctionTree> action) {
461      action(tree);
462      foreach (IFunctionTree subTree in tree.SubTrees) {
463        TreeForEach(subTree, action);
464      }
465    }
466
467    private static void GetBranchesAtLevel(IFunctionTree tree, int level, List<IFunctionTree> result) {
468      if (level == 1) result.AddRange(tree.SubTrees);
469      foreach (IFunctionTree subTree in tree.SubTrees) {
470        if (subTree.GetHeight() >= level - 1)
471          GetBranchesAtLevel(subTree, level - 1, result);
472      }
473    }
474
475    private IFunction RandomSelect(IList<IFunction> functionSet) {
476      return RandomSelect(random, functionSet);
477    }
478
479    public static IFunction RandomSelect(IRandom random, IList<IFunction> functionSet) {
480      double[] accumulatedTickets = new double[functionSet.Count];
481      double ticketAccumulator = 0;
482      int i = 0;
483      // precalculate the slot-sizes
484      foreach (IFunction function in functionSet) {
485        ticketAccumulator += function.Tickets;
486        accumulatedTickets[i] = ticketAccumulator;
487        i++;
488      }
489      // throw ball
490      double r = random.NextDouble() * ticketAccumulator;
491      // find the slot that has been hit
492      for (i = 0; i < accumulatedTickets.Length; i++) {
493        if (r < accumulatedTickets[i]) return functionSet[i];
494      }
495      // sanity check
496      throw new InvalidProgramException(); // should never happen
497    }
498
499    #endregion
500
501  }
502}
Note: See TracBrowser for help on using the repository browser.