Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
04/13/10 20:44:31 (14 years ago)
Author:
gkronber
Message:

Fixed bugs related to dynamic symbol constraints with ADFs. #290 (Implement ADFs)

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Encodings.SymbolicExpressionTreeEncoding/3.3/DefaultSymbolicExpressionGrammar.cs

    r3294 r3338  
    3434  public class DefaultSymbolicExpressionGrammar : Item, ISymbolicExpressionGrammar {
    3535    [Storable]
    36     private int minFunctionDefinitions;
    37     [Storable]
    38     private int maxFunctionDefinitions;
    39     [Storable]
    40     private int minFunctionArguments;
    41     [Storable]
    42     private int maxFunctionArguments;
    43 
    44     [Storable]
    4536    private Dictionary<string, int> minSubTreeCount;
    4637    [Storable]
    4738    private Dictionary<string, int> maxSubTreeCount;
    4839    [Storable]
    49     private Dictionary<string, List<HashSet<string>>> allowedFunctions;
     40    private Dictionary<string, List<HashSet<string>>> allowedChildSymbols;
    5041    [Storable]
    5142    private HashSet<Symbol> allSymbols;
    5243
    53     public DefaultSymbolicExpressionGrammar(int minFunctionDefinitions, int maxFunctionDefinitions, int minFunctionArguments, int maxFunctionArguments)
     44    public DefaultSymbolicExpressionGrammar()
    5445      : base() {
    55       this.minFunctionDefinitions = minFunctionDefinitions;
    56       this.maxFunctionDefinitions = maxFunctionDefinitions;
    57       this.minFunctionArguments = minFunctionArguments;
    58       this.maxFunctionArguments = maxFunctionArguments;
     46      Reset();
     47    }
     48
     49    private void Initialize() {
     50      startSymbol = new StartSymbol();
     51      AddSymbol(startSymbol);
     52      SetMinSubtreeCount(startSymbol, 1);
     53      SetMaxSubtreeCount(startSymbol, 1);
     54    }
     55
     56    #region ISymbolicExpressionGrammar Members
     57
     58    private Symbol startSymbol;
     59    public Symbol StartSymbol {
     60      get { return startSymbol; }
     61      set { startSymbol = value; }
     62    }
     63
     64    protected void Reset() {
    5965      minSubTreeCount = new Dictionary<string, int>();
    6066      maxSubTreeCount = new Dictionary<string, int>();
    61       allowedFunctions = new Dictionary<string, List<HashSet<string>>>();
     67      allowedChildSymbols = new Dictionary<string, List<HashSet<string>>>();
    6268      allSymbols = new HashSet<Symbol>();
    63       cachedMinExpressionLength = new Dictionary<Symbol, int>();
    64       cachedMaxExpressionLength = new Dictionary<Symbol, int>();
    65       cachedMinExpressionDepth = new Dictionary<Symbol, int>();
     69      cachedMinExpressionLength = new Dictionary<string, int>();
     70      cachedMaxExpressionLength = new Dictionary<string, int>();
     71      cachedMinExpressionDepth = new Dictionary<string, int>();
    6672      Initialize();
    6773    }
    6874
    69     private void Initialize() {
    70       programRootSymbol = new ProgramRootSymbol();
    71       var defunSymbol = new Defun();
    72       startSymbol = new StartSymbol();
    73       var invokeFunctionSymbol = new InvokeFunction();
    74 
    75       SetMinSubTreeCount(programRootSymbol, minFunctionDefinitions + 1);
    76       SetMaxSubTreeCount(programRootSymbol, maxFunctionDefinitions + 1);
    77       SetMinSubTreeCount(startSymbol, 1);
    78       SetMaxSubTreeCount(startSymbol, 1);
    79       SetMinSubTreeCount(defunSymbol, 1);
    80       SetMaxSubTreeCount(defunSymbol, 1);
    81       SetMinSubTreeCount(invokeFunctionSymbol, minFunctionArguments);
    82       SetMaxSubTreeCount(invokeFunctionSymbol, maxFunctionArguments);
    83       AddAllowedSymbols(programRootSymbol, 0, startSymbol);
    84       for (int argumentIndex = 1; argumentIndex < maxFunctionDefinitions + 1; argumentIndex++) {
    85         AddAllowedSymbols(programRootSymbol, argumentIndex, defunSymbol);
    86       }
    87     }
    88 
    89     public void AddAllowedSymbols(Symbol parent, int argumentIndex, Symbol allowedChild) {
    90       allSymbols.Add(parent); allSymbols.Add(allowedChild);
    91       if (!allowedFunctions.ContainsKey(parent.Name)) {
    92         allowedFunctions[parent.Name] = new List<HashSet<string>>();
    93       }
    94       while (allowedFunctions[parent.Name].Count <= argumentIndex)
    95         allowedFunctions[parent.Name].Add(new HashSet<string>());
    96       allowedFunctions[parent.Name][argumentIndex].Add(allowedChild.Name);
    97       ClearCaches();
    98     }
    99 
    100     public void SetMaxSubTreeCount(Symbol parent, int nSubTrees) {
    101       maxSubTreeCount[parent.Name] = nSubTrees;
    102       ClearCaches();
    103     }
    104 
    105     public void SetMinSubTreeCount(Symbol parent, int nSubTrees) {
    106       minSubTreeCount[parent.Name] = nSubTrees;
    107       ClearCaches();
    108     }
     75    public void AddSymbol(Symbol symbol) {
     76      if (allSymbols.Any(s => s.Name == symbol.Name)) throw new ArgumentException("Symbol " + symbol + " is already defined.");
     77      allSymbols.Add(symbol);
     78      allowedChildSymbols[symbol.Name] = new List<HashSet<string>>();
     79      ClearCaches();
     80    }
     81
     82    public void RemoveSymbol(Symbol symbol) {
     83      allSymbols.RemoveWhere(s => s.Name == symbol.Name);
     84      minSubTreeCount.Remove(symbol.Name);
     85      maxSubTreeCount.Remove(symbol.Name);
     86      allowedChildSymbols.Remove(symbol.Name);
     87      ClearCaches();
     88    }
     89
     90    public IEnumerable<Symbol> Symbols {
     91      get { return allSymbols.AsEnumerable(); }
     92    }
     93
     94    public void SetAllowedChild(Symbol parent, Symbol child, int argumentIndex) {
     95      if (!allSymbols.Any(s => s.Name == parent.Name)) throw new ArgumentException("Unknown symbol: " + parent, "parent");
     96      if (!allSymbols.Any(s => s.Name == child.Name)) throw new ArgumentException("Unknown symbol: " + child, "child");
     97      if (argumentIndex >= GetMaxSubtreeCount(parent)) throw new ArgumentException("Symbol " + parent + " can have only " + GetMaxSubtreeCount(parent) + " subtrees.");
     98      allowedChildSymbols[parent.Name][argumentIndex].Add(child.Name);
     99      ClearCaches();
     100    }
     101
     102    public bool IsAllowedChild(Symbol parent, Symbol child, int argumentIndex) {
     103      if (!allSymbols.Any(s => s.Name == parent.Name)) throw new ArgumentException("Unknown symbol: " + parent, "parent");
     104      if (!allSymbols.Any(s => s.Name == child.Name)) throw new ArgumentException("Unknown symbol: " + child, "child");
     105      if (argumentIndex >= GetMaxSubtreeCount(parent)) throw new ArgumentException("Symbol " + parent + " can have only " + GetMaxSubtreeCount(parent) + " subtrees.");
     106      if (allowedChildSymbols.ContainsKey(parent.Name)) return allowedChildSymbols[parent.Name][argumentIndex].Contains(child.Name);
     107      return false;
     108    }
     109
     110    private Dictionary<string, int> cachedMinExpressionLength;
     111    public int GetMinExpressionLength(Symbol symbol) {
     112      if (!allSymbols.Any(s => s.Name == symbol.Name)) throw new ArgumentException("Unknown symbol: " + symbol);
     113      if (!cachedMinExpressionLength.ContainsKey(symbol.Name)) {
     114        cachedMinExpressionLength[symbol.Name] = int.MaxValue; // prevent infinite recursion
     115        long sumOfMinExpressionLengths = 1 + (from argIndex in Enumerable.Range(0, GetMinSubtreeCount(symbol))
     116                                              let minForSlot = (long)(from s in allSymbols
     117                                                                      where IsAllowedChild(symbol, s, argIndex)
     118                                                                      select GetMinExpressionLength(s)).DefaultIfEmpty(0).Min()
     119                                              select minForSlot).DefaultIfEmpty(0).Sum();
     120
     121        cachedMinExpressionLength[symbol.Name] = (int)Math.Min(sumOfMinExpressionLengths, int.MaxValue);
     122      }
     123      return cachedMinExpressionLength[symbol.Name];
     124    }
     125
     126    private Dictionary<string, int> cachedMaxExpressionLength;
     127    public int GetMaxExpressionLength(Symbol symbol) {
     128      if (!allSymbols.Any(s => s.Name == symbol.Name)) throw new ArgumentException("Unknown symbol: " + symbol);
     129      if (!cachedMaxExpressionLength.ContainsKey(symbol.Name)) {
     130        cachedMaxExpressionLength[symbol.Name] = int.MaxValue; // prevent infinite recursion
     131        long sumOfMaxTrees = 1 + (from argIndex in Enumerable.Range(0, GetMaxSubtreeCount(symbol))
     132                                  let maxForSlot = (long)(from s in allSymbols
     133                                                          where IsAllowedChild(symbol, s, argIndex)
     134                                                          select GetMaxExpressionLength(s)).DefaultIfEmpty(0).Max()
     135                                  select maxForSlot).DefaultIfEmpty(0).Sum();
     136        long limit = int.MaxValue;
     137        cachedMaxExpressionLength[symbol.Name] = (int)Math.Min(sumOfMaxTrees, limit);
     138      }
     139      return cachedMaxExpressionLength[symbol.Name];
     140    }
     141
     142    private Dictionary<string, int> cachedMinExpressionDepth;
     143    public int GetMinExpressionDepth(Symbol symbol) {
     144      if (!allSymbols.Any(s => s.Name == symbol.Name)) throw new ArgumentException("Unknown symbol: " + symbol);
     145      if (!cachedMinExpressionDepth.ContainsKey(symbol.Name)) {
     146        cachedMinExpressionDepth[symbol.Name] = int.MaxValue; // prevent infinite recursion
     147        cachedMinExpressionDepth[symbol.Name] = 1 + (from argIndex in Enumerable.Range(0, GetMinSubtreeCount(symbol))
     148                                                     let minForSlot = (from s in allSymbols
     149                                                                       where IsAllowedChild(symbol, s, argIndex)
     150                                                                       select GetMinExpressionDepth(s)).DefaultIfEmpty(0).Min()
     151                                                     select minForSlot).DefaultIfEmpty(0).Max();
     152      }
     153      return cachedMinExpressionDepth[symbol.Name];
     154    }
     155
     156    public void SetMaxSubtreeCount(Symbol symbol, int nSubTrees) {
     157      if (!allSymbols.Any(s => s.Name == symbol.Name)) throw new ArgumentException("Unknown symbol: " + symbol);
     158      maxSubTreeCount[symbol.Name] = nSubTrees;
     159      while (allowedChildSymbols[symbol.Name].Count <= nSubTrees)
     160        allowedChildSymbols[symbol.Name].Add(new HashSet<string>());
     161      while (allowedChildSymbols[symbol.Name].Count > nSubTrees) {
     162        allowedChildSymbols[symbol.Name].RemoveAt(allowedChildSymbols[symbol.Name].Count - 1);
     163      }
     164      ClearCaches();
     165    }
     166
     167    public void SetMinSubtreeCount(Symbol symbol, int nSubTrees) {
     168      if (!allSymbols.Any(s => s.Name == symbol.Name)) throw new ArgumentException("Unknown symbol: " + symbol);
     169      minSubTreeCount[symbol.Name] = nSubTrees;
     170      ClearCaches();
     171    }
     172
     173    public int GetMinSubtreeCount(Symbol symbol) {
     174      if (!allSymbols.Any(s => s.Name == symbol.Name)) throw new ArgumentException("Unknown symbol: " + symbol);
     175      return minSubTreeCount[symbol.Name];
     176    }
     177
     178    public int GetMaxSubtreeCount(Symbol symbol) {
     179      if (!allSymbols.Any(s => s.Name == symbol.Name)) throw new ArgumentException("Unknown symbol: " + symbol);
     180      return maxSubTreeCount[symbol.Name];
     181    }
     182
     183    #endregion
    109184
    110185    private void ClearCaches() {
     
    114189    }
    115190
    116     private void symbol_ToStringChanged(object sender, EventArgs e) {
    117       OnToStringChanged();
    118     }
    119 
    120     #region ISymbolicExpressionGrammar Members
    121 
    122     private Symbol programRootSymbol;
    123     public Symbol ProgramRootSymbol {
    124       get { return programRootSymbol; }
    125     }
    126 
    127     private Symbol startSymbol;
    128     public Symbol StartSymbol {
    129       get { return startSymbol; }
    130     }
    131 
    132     public IEnumerable<Symbol> GetAllowedSymbols(Symbol parent, int argumentIndex) {
    133       return from name in allowedFunctions[parent.Name][argumentIndex]
    134              from sym in allSymbols
    135              where name == sym.Name
    136              select sym;
    137     }
    138 
    139 
    140     private Dictionary<Symbol, int> cachedMinExpressionLength;
    141     public int GetMinExpressionLength(Symbol start) {
    142       if (!cachedMinExpressionLength.ContainsKey(start)) {
    143         cachedMinExpressionLength[start] = int.MaxValue; // prevent infinite recursion
    144         cachedMinExpressionLength[start] = 1 + (from argIndex in Enumerable.Range(0, GetMinSubTreeCount(start))
    145                                                 let minForSlot = (from symbol in GetAllowedSymbols(start, argIndex)
    146                                                                   select GetMinExpressionLength(symbol)).DefaultIfEmpty(0).Min()
    147                                                 select minForSlot).DefaultIfEmpty(0).Sum();
    148       }
    149       return cachedMinExpressionLength[start];
    150     }
    151 
    152     private Dictionary<Symbol, int> cachedMaxExpressionLength;
    153     public int GetMaxExpressionLength(Symbol start) {
    154       if (!cachedMaxExpressionLength.ContainsKey(start)) {
    155         cachedMaxExpressionLength[start] = int.MaxValue; // prevent infinite recursion
    156         long sumOfMaxTrees = 1 + (from argIndex in Enumerable.Range(0, GetMaxSubTreeCount(start))
    157                                   let maxForSlot = (long)(from symbol in GetAllowedSymbols(start, argIndex)
    158                                                           select GetMaxExpressionLength(symbol)).DefaultIfEmpty(0).Max()
    159                                   select maxForSlot).DefaultIfEmpty(0).Sum();
    160         long limit = int.MaxValue;
    161         cachedMaxExpressionLength[start] = (int)Math.Min(sumOfMaxTrees, limit);
    162       }
    163       return cachedMaxExpressionLength[start];
    164     }
    165 
    166     private Dictionary<Symbol, int> cachedMinExpressionDepth;
    167     public int GetMinExpressionDepth(Symbol start) {
    168       if (!cachedMinExpressionDepth.ContainsKey(start)) {
    169         cachedMinExpressionDepth[start] = int.MaxValue; // prevent infinite recursion
    170         cachedMinExpressionDepth[start] = 1 + (from argIndex in Enumerable.Range(0, GetMinSubTreeCount(start))
    171                                                let minForSlot = (from symbol in GetAllowedSymbols(start, argIndex)
    172                                                                  select GetMinExpressionDepth(symbol)).DefaultIfEmpty(0).Min()
    173                                                select minForSlot).DefaultIfEmpty(0).Max();
    174       }
    175       return cachedMinExpressionDepth[start];
    176     }
    177 
    178     public int GetMinSubTreeCount(Symbol start) {
    179       return minSubTreeCount[start.Name];
    180     }
    181 
    182     public int GetMaxSubTreeCount(Symbol start) {
    183       return maxSubTreeCount[start.Name];
    184     }
    185 
    186     public bool IsValidExpression(SymbolicExpressionTree expression) {
    187       if (expression.Root.Symbol != ProgramRootSymbol) return false;
    188       // check dynamic symbols
    189       foreach (var branch in expression.Root.SubTrees) {
    190         foreach (var dynamicNode in branch.DynamicSymbols) {
    191           if (!dynamicNode.StartsWith("ARG")) {
    192             if (FindDefinitionOfDynamicFunction(expression.Root, dynamicNode) == null) return false;
    193           }
    194         }
    195       }
    196       return IsValidExpression(expression.Root);
    197     }
    198 
    199     #endregion
    200     private bool IsValidExpression(SymbolicExpressionTreeNode root) {
    201       if (root.SubTrees.Count < GetMinSubTreeCount(root.Symbol)) return false;
    202       if (root.SubTrees.Count > GetMaxSubTreeCount(root.Symbol)) return false;
    203       if (root.Symbol is Defun || root.Symbol is StartSymbol) {
    204         // check references to dynamic symbols
    205         if (!CheckDynamicSymbolsInBranch(root, root.SubTrees[0])) return false;
    206       }
    207       for (int i = 0; i < root.SubTrees.Count; i++) {
    208         if (!GetAllowedSymbols(root.Symbol, i).Contains(root.SubTrees[i].Symbol)) return false;
    209         if (!IsValidExpression(root.SubTrees[i])) return false;
    210       }
    211       return true;
    212     }
    213 
    214     private SymbolicExpressionTreeNode FindDefinitionOfDynamicFunction(SymbolicExpressionTreeNode root, string dynamicNode) {
    215       return (from node in root.SubTrees.OfType<DefunTreeNode>()
    216               where node.Name == dynamicNode
    217               select node).FirstOrDefault();
    218     }
    219 
    220     private bool CheckDynamicSymbolsInBranch(SymbolicExpressionTreeNode root, SymbolicExpressionTreeNode node) {
    221       var argNode = node as ArgumentTreeNode;
    222       var invokeNode = node as InvokeFunctionTreeNode;
    223       if (argNode != null) {
    224         if (!root.DynamicSymbols.Contains("ARG" + argNode.ArgumentIndex)) return false;
    225       } else if (invokeNode != null) {
    226         if (!root.DynamicSymbols.Contains(invokeNode.InvokedFunctionName)) return false;
    227         if (root.GetDynamicSymbolArgumentCount(invokeNode.InvokedFunctionName) != invokeNode.SubTrees.Count()) return false;
    228       }
    229       foreach (var subtree in node.SubTrees) {
    230         if (!CheckDynamicSymbolsInBranch(root, subtree)) return false;
    231       }
    232       return true;
    233     }
    234 
     191    //private void symbol_ToStringChanged(object sender, EventArgs e) {
     192    //  OnToStringChanged();
     193    //}
     194
     195    //private bool IsValidExpression(SymbolicExpressionTreeNode root) {
     196    //  if (root.SubTrees.Count < root.GetMinSubtreeCount()) return false;
     197    //  else if (root.SubTrees.Count > root.GetMaxSubtreeCount()) return false;
     198    //  else for (int i = 0; i < root.SubTrees.Count; i++) {
     199    //      if (!root.GetAllowedSymbols(i).Select(x => x.Name).Contains(root.SubTrees[i].Symbol.Name)) return false;
     200    //      if (!IsValidExpression(root.SubTrees[i])) return false;
     201    //    }
     202    //  return true;
     203    //}
     204
     205    public override IDeepCloneable Clone(Cloner cloner) {
     206      DefaultSymbolicExpressionGrammar clone = (DefaultSymbolicExpressionGrammar)base.Clone(cloner);
     207      clone.maxSubTreeCount = new Dictionary<string, int>(maxSubTreeCount);
     208      clone.minSubTreeCount = new Dictionary<string, int>(minSubTreeCount);
     209      clone.startSymbol = startSymbol;
     210      clone.allowedChildSymbols = new Dictionary<string, List<HashSet<string>>>(allowedChildSymbols);
     211      clone.allSymbols = new HashSet<Symbol>(allSymbols);
     212      return clone;
     213    }
    235214  }
    236215}
Note: See TracChangeset for help on using the changeset viewer.