Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Encodings.SymbolicExpressionTreeEncoding/3.4/Grammars/SymbolicExpressionGrammarBase.cs @ 16671

Last change on this file since 16671 was 16565, checked in by gkronber, 6 years ago

#2520: merged changes from PersistenceOverhaul branch (r16451:16564) into trunk

File size: 20.3 KB
RevLine 
[5686]1#region License Information
2/* HeuristicLab
[16565]3 * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[5686]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
[16565]27using HEAL.Attic;
[5686]28
29namespace HeuristicLab.Encodings.SymbolicExpressionTreeEncoding {
30  /// <summary>
31  /// The default symbolic expression grammar stores symbols and syntactic constraints for symbols.
32  /// Symbols are treated as equvivalent if they have the same name.
33  /// Syntactic constraints limit the number of allowed sub trees for a node with a symbol and which symbols are allowed
34  /// in the sub-trees of a symbol (can be specified for each sub-tree index separately).
35  /// </summary>
[16565]36  [StorableType("E76C087C-4E10-488A-86D0-295A4265DA53")]
[5686]37  public abstract class SymbolicExpressionGrammarBase : NamedItem, ISymbolicExpressionGrammarBase {
[6803]38
[5686]39    #region properties for separation between implementation and persistence
[5695]40    [Storable(Name = "Symbols")]
41    private IEnumerable<ISymbol> StorableSymbols {
42      get { return symbols.Values.ToArray(); }
[12422]43      set { foreach (var s in value) symbols.Add(s.Name, s); }
[5695]44    }
[5686]45
[5695]46    [Storable(Name = "SymbolSubtreeCount")]
47    private IEnumerable<KeyValuePair<ISymbol, Tuple<int, int>>> StorableSymbolSubtreeCount {
48      get { return symbolSubtreeCount.Select(x => new KeyValuePair<ISymbol, Tuple<int, int>>(GetSymbol(x.Key), x.Value)).ToArray(); }
[12422]49      set { foreach (var pair in value) symbolSubtreeCount.Add(pair.Key.Name, pair.Value); }
[5695]50    }
[5686]51
[5695]52    [Storable(Name = "AllowedChildSymbols")]
53    private IEnumerable<KeyValuePair<ISymbol, IEnumerable<ISymbol>>> StorableAllowedChildSymbols {
[6814]54      get { return allowedChildSymbols.Select(x => new KeyValuePair<ISymbol, IEnumerable<ISymbol>>(GetSymbol(x.Key), x.Value.Select(GetSymbol).ToArray())).ToArray(); }
[12422]55      set { foreach (var pair in value) allowedChildSymbols.Add(pair.Key.Name, pair.Value.Select(y => y.Name).ToList()); }
[5695]56    }
57
58    [Storable(Name = "AllowedChildSymbolsPerIndex")]
59    private IEnumerable<KeyValuePair<Tuple<ISymbol, int>, IEnumerable<ISymbol>>> StorableAllowedChildSymbolsPerIndex {
[12422]60      get { return allowedChildSymbolsPerIndex.Select(x => new KeyValuePair<Tuple<ISymbol, int>, IEnumerable<ISymbol>>(Tuple.Create(GetSymbol(x.Key.Item1), x.Key.Item2), x.Value.Select(GetSymbol).ToArray())).ToArray(); }
61      set {
62        foreach (var pair in value)
63          allowedChildSymbolsPerIndex.Add(Tuple.Create(pair.Key.Item1.Name, pair.Key.Item2), pair.Value.Select(y => y.Name).ToList());
64      }
[5695]65    }
[5686]66    #endregion
67
[6803]68    private bool suppressEvents;
[12422]69    protected readonly Dictionary<string, ISymbol> symbols;
70    protected readonly Dictionary<string, Tuple<int, int>> symbolSubtreeCount;
71    protected readonly Dictionary<string, List<string>> allowedChildSymbols;
72    protected readonly Dictionary<Tuple<string, int>, List<string>> allowedChildSymbolsPerIndex;
[5686]73
[5688]74    public override bool CanChangeName {
75      get { return false; }
76    }
77    public override bool CanChangeDescription {
78      get { return false; }
79    }
80
[5686]81    [StorableConstructor]
[16565]82    protected SymbolicExpressionGrammarBase(StorableConstructorFlag _) : base(_) {
[6803]83
[12422]84      symbols = new Dictionary<string, ISymbol>();
85      symbolSubtreeCount = new Dictionary<string, Tuple<int, int>>();
86      allowedChildSymbols = new Dictionary<string, List<string>>();
87      allowedChildSymbolsPerIndex = new Dictionary<Tuple<string, int>, List<string>>();
88
[6803]89      suppressEvents = false;
[5686]90    }
[6233]91
[5686]92    protected SymbolicExpressionGrammarBase(SymbolicExpressionGrammarBase original, Cloner cloner)
93      : base(original, cloner) {
94
[6814]95      symbols = original.symbols.ToDictionary(x => x.Key, y => cloner.Clone(y.Value));
[5695]96      symbolSubtreeCount = new Dictionary<string, Tuple<int, int>>(original.symbolSubtreeCount);
[5686]97
98      allowedChildSymbols = new Dictionary<string, List<string>>();
99      foreach (var element in original.allowedChildSymbols)
100        allowedChildSymbols.Add(element.Key, new List<string>(element.Value));
101
102      allowedChildSymbolsPerIndex = new Dictionary<Tuple<string, int>, List<string>>();
103      foreach (var element in original.allowedChildSymbolsPerIndex)
104        allowedChildSymbolsPerIndex.Add(element.Key, new List<string>(element.Value));
[6803]105
106      suppressEvents = false;
[5686]107    }
108
[5688]109    protected SymbolicExpressionGrammarBase(string name, string description)
110      : base(name, description) {
[5686]111      symbols = new Dictionary<string, ISymbol>();
[5695]112      symbolSubtreeCount = new Dictionary<string, Tuple<int, int>>();
[5686]113      allowedChildSymbols = new Dictionary<string, List<string>>();
114      allowedChildSymbolsPerIndex = new Dictionary<Tuple<string, int>, List<string>>();
[6803]115
116      suppressEvents = false;
[5686]117    }
118
119    #region protected grammar manipulation methods
[12422]120    public virtual void AddSymbol(ISymbol symbol) {
[5686]121      if (ContainsSymbol(symbol)) throw new ArgumentException("Symbol " + symbol + " is already defined.");
[6803]122      foreach (var s in symbol.Flatten()) {
123        symbols.Add(s.Name, s);
[7001]124        int maxSubTreeCount = Math.Min(s.MinimumArity + 1, s.MaximumArity);
125        symbolSubtreeCount.Add(s.Name, Tuple.Create(s.MinimumArity, maxSubTreeCount));
[6803]126      }
[5686]127      ClearCaches();
128    }
129
[12422]130    public virtual void RemoveSymbol(ISymbol symbol) {
[6803]131      foreach (var s in symbol.Flatten()) {
132        symbols.Remove(s.Name);
133        allowedChildSymbols.Remove(s.Name);
134        for (int i = 0; i < GetMaximumSubtreeCount(s); i++)
135          allowedChildSymbolsPerIndex.Remove(Tuple.Create(s.Name, i));
136        symbolSubtreeCount.Remove(s.Name);
[5686]137
[6803]138        foreach (var parent in Symbols) {
139          List<string> allowedChilds;
140          if (allowedChildSymbols.TryGetValue(parent.Name, out allowedChilds))
141            allowedChilds.Remove(s.Name);
[5686]142
[6803]143          for (int i = 0; i < GetMaximumSubtreeCount(parent); i++) {
144            if (allowedChildSymbolsPerIndex.TryGetValue(Tuple.Create(parent.Name, i), out allowedChilds))
145              allowedChilds.Remove(s.Name);
146          }
[5686]147        }
[6803]148        suppressEvents = true;
149        foreach (var groupSymbol in Symbols.OfType<GroupSymbol>())
150          groupSymbol.SymbolsCollection.Remove(symbol);
151        suppressEvents = false;
[5686]152      }
153      ClearCaches();
154    }
155
156    public virtual ISymbol GetSymbol(string symbolName) {
157      ISymbol symbol;
158      if (symbols.TryGetValue(symbolName, out symbol)) return symbol;
159      return null;
160    }
161
[12422]162    public virtual void AddAllowedChildSymbol(ISymbol parent, ISymbol child) {
[6803]163      bool changed = false;
164
165      foreach (ISymbol p in parent.Flatten().Where(p => !(p is GroupSymbol)))
166        changed |= AddAllowedChildSymbolToDictionaries(p, child);
167
168      if (changed) {
169        ClearCaches();
170        OnChanged();
171      }
172    }
173
174    private bool AddAllowedChildSymbolToDictionaries(ISymbol parent, ISymbol child) {
[5686]175      List<string> childSymbols;
176      if (!allowedChildSymbols.TryGetValue(parent.Name, out childSymbols)) {
177        childSymbols = new List<string>();
178        allowedChildSymbols.Add(parent.Name, childSymbols);
179      }
[6803]180      if (childSymbols.Contains(child.Name)) return false;
181
182      suppressEvents = true;
183      for (int argumentIndex = 0; argumentIndex < GetMaximumSubtreeCount(parent); argumentIndex++)
184        RemoveAllowedChildSymbol(parent, child, argumentIndex);
185      suppressEvents = false;
186
[5686]187      childSymbols.Add(child.Name);
[6803]188      return true;
[5686]189    }
190
[12422]191    public virtual void AddAllowedChildSymbol(ISymbol parent, ISymbol child, int argumentIndex) {
[6803]192      bool changed = false;
193
194      foreach (ISymbol p in parent.Flatten().Where(p => !(p is GroupSymbol)))
195        changed |= AddAllowedChildSymbolToDictionaries(p, child, argumentIndex);
196
197      if (changed) {
198        ClearCaches();
199        OnChanged();
200      }
201    }
202
203
204    private bool AddAllowedChildSymbolToDictionaries(ISymbol parent, ISymbol child, int argumentIndex) {
205      List<string> childSymbols;
206      if (!allowedChildSymbols.TryGetValue(parent.Name, out childSymbols)) {
207        childSymbols = new List<string>();
208        allowedChildSymbols.Add(parent.Name, childSymbols);
209      }
210      if (childSymbols.Contains(child.Name)) return false;
211
212
[5686]213      var key = Tuple.Create(parent.Name, argumentIndex);
214      if (!allowedChildSymbolsPerIndex.TryGetValue(key, out childSymbols)) {
215        childSymbols = new List<string>();
216        allowedChildSymbolsPerIndex.Add(key, childSymbols);
217      }
218
[6803]219      if (childSymbols.Contains(child.Name)) return false;
220
[5686]221      childSymbols.Add(child.Name);
[6803]222      return true;
[5686]223    }
224
[12422]225    public virtual void RemoveAllowedChildSymbol(ISymbol parent, ISymbol child) {
[6803]226      bool changed = false;
[5792]227      List<string> childSymbols;
228      if (allowedChildSymbols.TryGetValue(child.Name, out childSymbols)) {
[6803]229        changed |= childSymbols.Remove(child.Name);
[5792]230      }
[6803]231
232      for (int argumentIndex = 0; argumentIndex < GetMaximumSubtreeCount(parent); argumentIndex++) {
233        var key = Tuple.Create(parent.Name, argumentIndex);
234        if (allowedChildSymbolsPerIndex.TryGetValue(key, out childSymbols))
235          changed |= childSymbols.Remove(child.Name);
236      }
237
238      if (changed) {
239        ClearCaches();
240        OnChanged();
241      }
[5686]242    }
243
[12422]244    public virtual void RemoveAllowedChildSymbol(ISymbol parent, ISymbol child, int argumentIndex) {
[6803]245      bool changed = false;
246
247      suppressEvents = true;
[5792]248      List<string> childSymbols;
[6803]249      if (allowedChildSymbols.TryGetValue(parent.Name, out childSymbols)) {
250        if (childSymbols.Remove(child.Name)) {
251          for (int i = 0; i < GetMaximumSubtreeCount(parent); i++) {
252            if (i != argumentIndex) AddAllowedChildSymbol(parent, child, i);
253          }
254          changed = true;
255        }
[5792]256      }
[6803]257      suppressEvents = false;
258
259      var key = Tuple.Create(parent.Name, argumentIndex);
260      if (allowedChildSymbolsPerIndex.TryGetValue(key, out childSymbols))
261        changed |= childSymbols.Remove(child.Name);
262
263      if (changed) {
264        ClearCaches();
265        OnChanged();
266      }
[5686]267    }
268
[12422]269    public virtual void SetSubtreeCount(ISymbol symbol, int minimumSubtreeCount, int maximumSubtreeCount) {
[6803]270      var symbols = symbol.Flatten().Where(s => !(s is GroupSymbol));
271      if (symbols.Any(s => s.MinimumArity > minimumSubtreeCount)) throw new ArgumentException("Invalid minimum subtree count " + minimumSubtreeCount + " for " + symbol);
272      if (symbols.Any(s => s.MaximumArity < maximumSubtreeCount)) throw new ArgumentException("Invalid maximum subtree count " + maximumSubtreeCount + " for " + symbol);
273
274      foreach (ISymbol s in symbols)
275        SetSubTreeCountInDictionaries(s, minimumSubtreeCount, maximumSubtreeCount);
276
277      ClearCaches();
278      OnChanged();
279    }
280
281    private void SetSubTreeCountInDictionaries(ISymbol symbol, int minimumSubtreeCount, int maximumSubtreeCount) {
282      for (int i = maximumSubtreeCount; i < GetMaximumSubtreeCount(symbol); i++) {
[5686]283        var key = Tuple.Create(symbol.Name, i);
284        allowedChildSymbolsPerIndex.Remove(key);
285      }
286
[5695]287      symbolSubtreeCount[symbol.Name] = Tuple.Create(minimumSubtreeCount, maximumSubtreeCount);
[5686]288    }
289    #endregion
290
291    public virtual IEnumerable<ISymbol> Symbols {
292      get { return symbols.Values; }
293    }
294    public virtual IEnumerable<ISymbol> AllowedSymbols {
[12422]295      get { return Symbols.Where(s => s.Enabled); }
[5686]296    }
297    public virtual bool ContainsSymbol(ISymbol symbol) {
298      return symbols.ContainsKey(symbol.Name);
299    }
300
[14342]301    private readonly Dictionary<Tuple<string, string>, bool> cachedIsAllowedChildSymbol = new Dictionary<Tuple<string, string>, bool>();
[5686]302    public virtual bool IsAllowedChildSymbol(ISymbol parent, ISymbol child) {
[7656]303      if (allowedChildSymbols.Count == 0) return false;
[6803]304      if (!child.Enabled) return false;
305
[6814]306      bool result;
[7656]307      var key = Tuple.Create(parent.Name, child.Name);
308      if (cachedIsAllowedChildSymbol.TryGetValue(key, out result)) return result;
309
[10985]310      // value has to be calculated and cached make sure this is done in only one thread
311      lock (cachedIsAllowedChildSymbol) {
312        // in case the value has been calculated on another thread in the meanwhile
313        if (cachedIsAllowedChildSymbol.TryGetValue(key, out result)) return result;
314
315        List<string> temp;
316        if (allowedChildSymbols.TryGetValue(parent.Name, out temp)) {
[12509]317          for (int i = 0; i < temp.Count; i++) {
318            var symbol = GetSymbol(temp[i]);
319            foreach (var s in symbol.Flatten())
320              if (s.Name == child.Name) {
321                cachedIsAllowedChildSymbol.Add(key, true);
322                return true;
323              }
[10985]324          }
[6814]325        }
[10985]326        cachedIsAllowedChildSymbol.Add(key, false);
327        return false;
[6803]328      }
[5686]329    }
330
[14342]331    private readonly Dictionary<Tuple<string, string, int>, bool> cachedIsAllowedChildSymbolIndex = new Dictionary<Tuple<string, string, int>, bool>();
[5686]332    public virtual bool IsAllowedChildSymbol(ISymbol parent, ISymbol child, int argumentIndex) {
[6803]333      if (!child.Enabled) return false;
[7660]334      if (IsAllowedChildSymbol(parent, child)) return true;
[7656]335      if (allowedChildSymbolsPerIndex.Count == 0) return false;
[6803]336
[6814]337      bool result;
[7656]338      var key = Tuple.Create(parent.Name, child.Name, argumentIndex);
339      if (cachedIsAllowedChildSymbolIndex.TryGetValue(key, out result)) return result;
340
[10985]341      // value has to be calculated and cached make sure this is done in only one thread
342      lock (cachedIsAllowedChildSymbolIndex) {
343        // in case the value has been calculated on another thread in the meanwhile
344        if (cachedIsAllowedChildSymbolIndex.TryGetValue(key, out result)) return result;
345
346        List<string> temp;
347        if (allowedChildSymbolsPerIndex.TryGetValue(Tuple.Create(parent.Name, argumentIndex), out temp)) {
[12509]348          for (int i = 0; i < temp.Count; i++) {
349            var symbol = GetSymbol(temp[i]);
350            foreach (var s in symbol.Flatten())
351              if (s.Name == child.Name) {
352                cachedIsAllowedChildSymbolIndex.Add(key, true);
353                return true;
354              }
[10985]355          }
[6814]356        }
[10985]357        cachedIsAllowedChildSymbolIndex.Add(key, false);
358        return false;
[6803]359      }
[5686]360    }
361
[6911]362    public IEnumerable<ISymbol> GetAllowedChildSymbols(ISymbol parent) {
[7656]363      foreach (ISymbol child in AllowedSymbols) {
364        if (IsAllowedChildSymbol(parent, child)) yield return child;
365      }
[5686]366    }
367
[6911]368    public IEnumerable<ISymbol> GetAllowedChildSymbols(ISymbol parent, int argumentIndex) {
[7656]369      foreach (ISymbol child in AllowedSymbols) {
370        if (IsAllowedChildSymbol(parent, child, argumentIndex)) yield return child;
371      }
[5686]372    }
373
374    public virtual int GetMinimumSubtreeCount(ISymbol symbol) {
[5695]375      return symbolSubtreeCount[symbol.Name].Item1;
[5686]376    }
377    public virtual int GetMaximumSubtreeCount(ISymbol symbol) {
[5695]378      return symbolSubtreeCount[symbol.Name].Item2;
[5686]379    }
380
[6443]381    protected void ClearCaches() {
[5686]382      cachedMinExpressionLength.Clear();
383      cachedMaxExpressionLength.Clear();
384      cachedMinExpressionDepth.Clear();
[7076]385      cachedMaxExpressionDepth.Clear();
[6814]386
387      cachedIsAllowedChildSymbol.Clear();
388      cachedIsAllowedChildSymbolIndex.Clear();
[5686]389    }
390
[14342]391    private readonly Dictionary<string, int> cachedMinExpressionLength = new Dictionary<string, int>();
[5686]392    public int GetMinimumExpressionLength(ISymbol symbol) {
[9402]393      int res;
394      if (cachedMinExpressionLength.TryGetValue(symbol.Name, out res))
395        return res;
396
[10985]397      // value has to be calculated and cached make sure this is done in only one thread
398      lock (cachedMinExpressionLength) {
399        // in case the value has been calculated on another thread in the meanwhile
400        if (cachedMinExpressionLength.TryGetValue(symbol.Name, out res)) return res;
401
[14342]402        GrammarUtils.CalculateMinimumExpressionLengths(this, cachedMinExpressionLength);
403        return cachedMinExpressionLength[symbol.Name];
[9402]404      }
405    }
406
[5686]407
[14342]408    private readonly Dictionary<Tuple<string, int>, int> cachedMaxExpressionLength = new Dictionary<Tuple<string, int>, int>();
[6911]409    public int GetMaximumExpressionLength(ISymbol symbol, int maxDepth) {
[5686]410      int temp;
[6911]411      var key = Tuple.Create(symbol.Name, maxDepth);
[10985]412      if (cachedMaxExpressionLength.TryGetValue(key, out temp)) return temp;
413      // value has to be calculated and cached make sure this is done in only one thread
414      lock (cachedMaxExpressionLength) {
415        // in case the value has been calculated on another thread in the meanwhile
416        if (cachedMaxExpressionLength.TryGetValue(key, out temp)) return temp;
417
[6911]418        cachedMaxExpressionLength[key] = int.MaxValue; // prevent infinite recursion
[5686]419        long sumOfMaxTrees = 1 + (from argIndex in Enumerable.Range(0, GetMaximumSubtreeCount(symbol))
[6803]420                                  let maxForSlot = (long)(from s in GetAllowedChildSymbols(symbol, argIndex)
[6911]421                                                          where s.InitialFrequency > 0.0
422                                                          where GetMinimumExpressionDepth(s) < maxDepth
423                                                          select GetMaximumExpressionLength(s, maxDepth - 1)).DefaultIfEmpty(0).Max()
[5686]424                                  select maxForSlot).DefaultIfEmpty(0).Sum();
[6911]425        cachedMaxExpressionLength[key] = (int)Math.Min(sumOfMaxTrees, int.MaxValue);
426        return cachedMaxExpressionLength[key];
[5686]427      }
428    }
429
[14342]430    private readonly Dictionary<string, int> cachedMinExpressionDepth = new Dictionary<string, int>();
[5686]431    public int GetMinimumExpressionDepth(ISymbol symbol) {
[9402]432      int res;
433      if (cachedMinExpressionDepth.TryGetValue(symbol.Name, out res))
434        return res;
435
[10985]436      // value has to be calculated and cached make sure this is done in only one thread
437      lock (cachedMinExpressionDepth) {
438        // in case the value has been calculated on another thread in the meanwhile
439        if (cachedMinExpressionDepth.TryGetValue(symbol.Name, out res)) return res;
440
[14342]441        GrammarUtils.CalculateMinimumExpressionDepth(this, cachedMinExpressionDepth);
[5686]442        return cachedMinExpressionDepth[symbol.Name];
443      }
444    }
[6803]445
[14342]446    private readonly Dictionary<string, int> cachedMaxExpressionDepth = new Dictionary<string, int>();
[7076]447    public int GetMaximumExpressionDepth(ISymbol symbol) {
448      int temp;
[10985]449      if (cachedMaxExpressionDepth.TryGetValue(symbol.Name, out temp)) return temp;
450      // value has to be calculated and cached make sure this is done in only one thread
451      lock (cachedMaxExpressionDepth) {
452        // in case the value has been calculated on another thread in the meanwhile
453        if (cachedMaxExpressionDepth.TryGetValue(symbol.Name, out temp)) return temp;
454
[7076]455        cachedMaxExpressionDepth[symbol.Name] = int.MaxValue;
456        long maxDepth = 1 + (from argIndex in Enumerable.Range(0, GetMaximumSubtreeCount(symbol))
457                             let maxForSlot = (long)(from s in GetAllowedChildSymbols(symbol, argIndex)
458                                                     where s.InitialFrequency > 0.0
459                                                     select GetMaximumExpressionDepth(s)).DefaultIfEmpty(0).Max()
460                             select maxForSlot).DefaultIfEmpty(0).Max();
461        cachedMaxExpressionDepth[symbol.Name] = (int)Math.Min(maxDepth, int.MaxValue);
462        return cachedMaxExpressionDepth[symbol.Name];
463      }
464    }
465
[6803]466    public event EventHandler Changed;
467    protected virtual void OnChanged() {
468      if (suppressEvents) return;
469      var handler = Changed;
[12422]470      if (handler != null) handler(this, EventArgs.Empty);
[6803]471    }
[5686]472  }
473}
Note: See TracBrowser for help on using the repository browser.