Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2990_VariableImpactBasedFeatureSelection/HeuristicLab.Encodings.SymbolicExpressionTreeEncoding/3.4/Grammars/SymbolicExpressionGrammarBase.cs @ 17607

Last change on this file since 17607 was 16565, checked in by gkronber, 6 years ago

#2520: merged changes from PersistenceOverhaul branch (r16451:16564) into trunk

File size: 20.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HEAL.Attic;
28
29namespace HeuristicLab.Encodings.SymbolicExpressionTreeEncoding {
30  /// <summary>
31  /// The default symbolic expression grammar stores symbols and syntactic constraints for symbols.
32  /// Symbols are treated as equvivalent if they have the same name.
33  /// Syntactic constraints limit the number of allowed sub trees for a node with a symbol and which symbols are allowed
34  /// in the sub-trees of a symbol (can be specified for each sub-tree index separately).
35  /// </summary>
36  [StorableType("E76C087C-4E10-488A-86D0-295A4265DA53")]
37  public abstract class SymbolicExpressionGrammarBase : NamedItem, ISymbolicExpressionGrammarBase {
38
39    #region properties for separation between implementation and persistence
40    [Storable(Name = "Symbols")]
41    private IEnumerable<ISymbol> StorableSymbols {
42      get { return symbols.Values.ToArray(); }
43      set { foreach (var s in value) symbols.Add(s.Name, s); }
44    }
45
46    [Storable(Name = "SymbolSubtreeCount")]
47    private IEnumerable<KeyValuePair<ISymbol, Tuple<int, int>>> StorableSymbolSubtreeCount {
48      get { return symbolSubtreeCount.Select(x => new KeyValuePair<ISymbol, Tuple<int, int>>(GetSymbol(x.Key), x.Value)).ToArray(); }
49      set { foreach (var pair in value) symbolSubtreeCount.Add(pair.Key.Name, pair.Value); }
50    }
51
52    [Storable(Name = "AllowedChildSymbols")]
53    private IEnumerable<KeyValuePair<ISymbol, IEnumerable<ISymbol>>> StorableAllowedChildSymbols {
54      get { return allowedChildSymbols.Select(x => new KeyValuePair<ISymbol, IEnumerable<ISymbol>>(GetSymbol(x.Key), x.Value.Select(GetSymbol).ToArray())).ToArray(); }
55      set { foreach (var pair in value) allowedChildSymbols.Add(pair.Key.Name, pair.Value.Select(y => y.Name).ToList()); }
56    }
57
58    [Storable(Name = "AllowedChildSymbolsPerIndex")]
59    private IEnumerable<KeyValuePair<Tuple<ISymbol, int>, IEnumerable<ISymbol>>> StorableAllowedChildSymbolsPerIndex {
60      get { return allowedChildSymbolsPerIndex.Select(x => new KeyValuePair<Tuple<ISymbol, int>, IEnumerable<ISymbol>>(Tuple.Create(GetSymbol(x.Key.Item1), x.Key.Item2), x.Value.Select(GetSymbol).ToArray())).ToArray(); }
61      set {
62        foreach (var pair in value)
63          allowedChildSymbolsPerIndex.Add(Tuple.Create(pair.Key.Item1.Name, pair.Key.Item2), pair.Value.Select(y => y.Name).ToList());
64      }
65    }
66    #endregion
67
68    private bool suppressEvents;
69    protected readonly Dictionary<string, ISymbol> symbols;
70    protected readonly Dictionary<string, Tuple<int, int>> symbolSubtreeCount;
71    protected readonly Dictionary<string, List<string>> allowedChildSymbols;
72    protected readonly Dictionary<Tuple<string, int>, List<string>> allowedChildSymbolsPerIndex;
73
74    public override bool CanChangeName {
75      get { return false; }
76    }
77    public override bool CanChangeDescription {
78      get { return false; }
79    }
80
81    [StorableConstructor]
82    protected SymbolicExpressionGrammarBase(StorableConstructorFlag _) : base(_) {
83
84      symbols = new Dictionary<string, ISymbol>();
85      symbolSubtreeCount = new Dictionary<string, Tuple<int, int>>();
86      allowedChildSymbols = new Dictionary<string, List<string>>();
87      allowedChildSymbolsPerIndex = new Dictionary<Tuple<string, int>, List<string>>();
88
89      suppressEvents = false;
90    }
91
92    protected SymbolicExpressionGrammarBase(SymbolicExpressionGrammarBase original, Cloner cloner)
93      : base(original, cloner) {
94
95      symbols = original.symbols.ToDictionary(x => x.Key, y => cloner.Clone(y.Value));
96      symbolSubtreeCount = new Dictionary<string, Tuple<int, int>>(original.symbolSubtreeCount);
97
98      allowedChildSymbols = new Dictionary<string, List<string>>();
99      foreach (var element in original.allowedChildSymbols)
100        allowedChildSymbols.Add(element.Key, new List<string>(element.Value));
101
102      allowedChildSymbolsPerIndex = new Dictionary<Tuple<string, int>, List<string>>();
103      foreach (var element in original.allowedChildSymbolsPerIndex)
104        allowedChildSymbolsPerIndex.Add(element.Key, new List<string>(element.Value));
105
106      suppressEvents = false;
107    }
108
109    protected SymbolicExpressionGrammarBase(string name, string description)
110      : base(name, description) {
111      symbols = new Dictionary<string, ISymbol>();
112      symbolSubtreeCount = new Dictionary<string, Tuple<int, int>>();
113      allowedChildSymbols = new Dictionary<string, List<string>>();
114      allowedChildSymbolsPerIndex = new Dictionary<Tuple<string, int>, List<string>>();
115
116      suppressEvents = false;
117    }
118
119    #region protected grammar manipulation methods
120    public virtual void AddSymbol(ISymbol symbol) {
121      if (ContainsSymbol(symbol)) throw new ArgumentException("Symbol " + symbol + " is already defined.");
122      foreach (var s in symbol.Flatten()) {
123        symbols.Add(s.Name, s);
124        int maxSubTreeCount = Math.Min(s.MinimumArity + 1, s.MaximumArity);
125        symbolSubtreeCount.Add(s.Name, Tuple.Create(s.MinimumArity, maxSubTreeCount));
126      }
127      ClearCaches();
128    }
129
130    public virtual void RemoveSymbol(ISymbol symbol) {
131      foreach (var s in symbol.Flatten()) {
132        symbols.Remove(s.Name);
133        allowedChildSymbols.Remove(s.Name);
134        for (int i = 0; i < GetMaximumSubtreeCount(s); i++)
135          allowedChildSymbolsPerIndex.Remove(Tuple.Create(s.Name, i));
136        symbolSubtreeCount.Remove(s.Name);
137
138        foreach (var parent in Symbols) {
139          List<string> allowedChilds;
140          if (allowedChildSymbols.TryGetValue(parent.Name, out allowedChilds))
141            allowedChilds.Remove(s.Name);
142
143          for (int i = 0; i < GetMaximumSubtreeCount(parent); i++) {
144            if (allowedChildSymbolsPerIndex.TryGetValue(Tuple.Create(parent.Name, i), out allowedChilds))
145              allowedChilds.Remove(s.Name);
146          }
147        }
148        suppressEvents = true;
149        foreach (var groupSymbol in Symbols.OfType<GroupSymbol>())
150          groupSymbol.SymbolsCollection.Remove(symbol);
151        suppressEvents = false;
152      }
153      ClearCaches();
154    }
155
156    public virtual ISymbol GetSymbol(string symbolName) {
157      ISymbol symbol;
158      if (symbols.TryGetValue(symbolName, out symbol)) return symbol;
159      return null;
160    }
161
162    public virtual void AddAllowedChildSymbol(ISymbol parent, ISymbol child) {
163      bool changed = false;
164
165      foreach (ISymbol p in parent.Flatten().Where(p => !(p is GroupSymbol)))
166        changed |= AddAllowedChildSymbolToDictionaries(p, child);
167
168      if (changed) {
169        ClearCaches();
170        OnChanged();
171      }
172    }
173
174    private bool AddAllowedChildSymbolToDictionaries(ISymbol parent, ISymbol child) {
175      List<string> childSymbols;
176      if (!allowedChildSymbols.TryGetValue(parent.Name, out childSymbols)) {
177        childSymbols = new List<string>();
178        allowedChildSymbols.Add(parent.Name, childSymbols);
179      }
180      if (childSymbols.Contains(child.Name)) return false;
181
182      suppressEvents = true;
183      for (int argumentIndex = 0; argumentIndex < GetMaximumSubtreeCount(parent); argumentIndex++)
184        RemoveAllowedChildSymbol(parent, child, argumentIndex);
185      suppressEvents = false;
186
187      childSymbols.Add(child.Name);
188      return true;
189    }
190
191    public virtual void AddAllowedChildSymbol(ISymbol parent, ISymbol child, int argumentIndex) {
192      bool changed = false;
193
194      foreach (ISymbol p in parent.Flatten().Where(p => !(p is GroupSymbol)))
195        changed |= AddAllowedChildSymbolToDictionaries(p, child, argumentIndex);
196
197      if (changed) {
198        ClearCaches();
199        OnChanged();
200      }
201    }
202
203
204    private bool AddAllowedChildSymbolToDictionaries(ISymbol parent, ISymbol child, int argumentIndex) {
205      List<string> childSymbols;
206      if (!allowedChildSymbols.TryGetValue(parent.Name, out childSymbols)) {
207        childSymbols = new List<string>();
208        allowedChildSymbols.Add(parent.Name, childSymbols);
209      }
210      if (childSymbols.Contains(child.Name)) return false;
211
212
213      var key = Tuple.Create(parent.Name, argumentIndex);
214      if (!allowedChildSymbolsPerIndex.TryGetValue(key, out childSymbols)) {
215        childSymbols = new List<string>();
216        allowedChildSymbolsPerIndex.Add(key, childSymbols);
217      }
218
219      if (childSymbols.Contains(child.Name)) return false;
220
221      childSymbols.Add(child.Name);
222      return true;
223    }
224
225    public virtual void RemoveAllowedChildSymbol(ISymbol parent, ISymbol child) {
226      bool changed = false;
227      List<string> childSymbols;
228      if (allowedChildSymbols.TryGetValue(child.Name, out childSymbols)) {
229        changed |= childSymbols.Remove(child.Name);
230      }
231
232      for (int argumentIndex = 0; argumentIndex < GetMaximumSubtreeCount(parent); argumentIndex++) {
233        var key = Tuple.Create(parent.Name, argumentIndex);
234        if (allowedChildSymbolsPerIndex.TryGetValue(key, out childSymbols))
235          changed |= childSymbols.Remove(child.Name);
236      }
237
238      if (changed) {
239        ClearCaches();
240        OnChanged();
241      }
242    }
243
244    public virtual void RemoveAllowedChildSymbol(ISymbol parent, ISymbol child, int argumentIndex) {
245      bool changed = false;
246
247      suppressEvents = true;
248      List<string> childSymbols;
249      if (allowedChildSymbols.TryGetValue(parent.Name, out childSymbols)) {
250        if (childSymbols.Remove(child.Name)) {
251          for (int i = 0; i < GetMaximumSubtreeCount(parent); i++) {
252            if (i != argumentIndex) AddAllowedChildSymbol(parent, child, i);
253          }
254          changed = true;
255        }
256      }
257      suppressEvents = false;
258
259      var key = Tuple.Create(parent.Name, argumentIndex);
260      if (allowedChildSymbolsPerIndex.TryGetValue(key, out childSymbols))
261        changed |= childSymbols.Remove(child.Name);
262
263      if (changed) {
264        ClearCaches();
265        OnChanged();
266      }
267    }
268
269    public virtual void SetSubtreeCount(ISymbol symbol, int minimumSubtreeCount, int maximumSubtreeCount) {
270      var symbols = symbol.Flatten().Where(s => !(s is GroupSymbol));
271      if (symbols.Any(s => s.MinimumArity > minimumSubtreeCount)) throw new ArgumentException("Invalid minimum subtree count " + minimumSubtreeCount + " for " + symbol);
272      if (symbols.Any(s => s.MaximumArity < maximumSubtreeCount)) throw new ArgumentException("Invalid maximum subtree count " + maximumSubtreeCount + " for " + symbol);
273
274      foreach (ISymbol s in symbols)
275        SetSubTreeCountInDictionaries(s, minimumSubtreeCount, maximumSubtreeCount);
276
277      ClearCaches();
278      OnChanged();
279    }
280
281    private void SetSubTreeCountInDictionaries(ISymbol symbol, int minimumSubtreeCount, int maximumSubtreeCount) {
282      for (int i = maximumSubtreeCount; i < GetMaximumSubtreeCount(symbol); i++) {
283        var key = Tuple.Create(symbol.Name, i);
284        allowedChildSymbolsPerIndex.Remove(key);
285      }
286
287      symbolSubtreeCount[symbol.Name] = Tuple.Create(minimumSubtreeCount, maximumSubtreeCount);
288    }
289    #endregion
290
291    public virtual IEnumerable<ISymbol> Symbols {
292      get { return symbols.Values; }
293    }
294    public virtual IEnumerable<ISymbol> AllowedSymbols {
295      get { return Symbols.Where(s => s.Enabled); }
296    }
297    public virtual bool ContainsSymbol(ISymbol symbol) {
298      return symbols.ContainsKey(symbol.Name);
299    }
300
301    private readonly Dictionary<Tuple<string, string>, bool> cachedIsAllowedChildSymbol = new Dictionary<Tuple<string, string>, bool>();
302    public virtual bool IsAllowedChildSymbol(ISymbol parent, ISymbol child) {
303      if (allowedChildSymbols.Count == 0) return false;
304      if (!child.Enabled) return false;
305
306      bool result;
307      var key = Tuple.Create(parent.Name, child.Name);
308      if (cachedIsAllowedChildSymbol.TryGetValue(key, out result)) return result;
309
310      // value has to be calculated and cached make sure this is done in only one thread
311      lock (cachedIsAllowedChildSymbol) {
312        // in case the value has been calculated on another thread in the meanwhile
313        if (cachedIsAllowedChildSymbol.TryGetValue(key, out result)) return result;
314
315        List<string> temp;
316        if (allowedChildSymbols.TryGetValue(parent.Name, out temp)) {
317          for (int i = 0; i < temp.Count; i++) {
318            var symbol = GetSymbol(temp[i]);
319            foreach (var s in symbol.Flatten())
320              if (s.Name == child.Name) {
321                cachedIsAllowedChildSymbol.Add(key, true);
322                return true;
323              }
324          }
325        }
326        cachedIsAllowedChildSymbol.Add(key, false);
327        return false;
328      }
329    }
330
331    private readonly Dictionary<Tuple<string, string, int>, bool> cachedIsAllowedChildSymbolIndex = new Dictionary<Tuple<string, string, int>, bool>();
332    public virtual bool IsAllowedChildSymbol(ISymbol parent, ISymbol child, int argumentIndex) {
333      if (!child.Enabled) return false;
334      if (IsAllowedChildSymbol(parent, child)) return true;
335      if (allowedChildSymbolsPerIndex.Count == 0) return false;
336
337      bool result;
338      var key = Tuple.Create(parent.Name, child.Name, argumentIndex);
339      if (cachedIsAllowedChildSymbolIndex.TryGetValue(key, out result)) return result;
340
341      // value has to be calculated and cached make sure this is done in only one thread
342      lock (cachedIsAllowedChildSymbolIndex) {
343        // in case the value has been calculated on another thread in the meanwhile
344        if (cachedIsAllowedChildSymbolIndex.TryGetValue(key, out result)) return result;
345
346        List<string> temp;
347        if (allowedChildSymbolsPerIndex.TryGetValue(Tuple.Create(parent.Name, argumentIndex), out temp)) {
348          for (int i = 0; i < temp.Count; i++) {
349            var symbol = GetSymbol(temp[i]);
350            foreach (var s in symbol.Flatten())
351              if (s.Name == child.Name) {
352                cachedIsAllowedChildSymbolIndex.Add(key, true);
353                return true;
354              }
355          }
356        }
357        cachedIsAllowedChildSymbolIndex.Add(key, false);
358        return false;
359      }
360    }
361
362    public IEnumerable<ISymbol> GetAllowedChildSymbols(ISymbol parent) {
363      foreach (ISymbol child in AllowedSymbols) {
364        if (IsAllowedChildSymbol(parent, child)) yield return child;
365      }
366    }
367
368    public IEnumerable<ISymbol> GetAllowedChildSymbols(ISymbol parent, int argumentIndex) {
369      foreach (ISymbol child in AllowedSymbols) {
370        if (IsAllowedChildSymbol(parent, child, argumentIndex)) yield return child;
371      }
372    }
373
374    public virtual int GetMinimumSubtreeCount(ISymbol symbol) {
375      return symbolSubtreeCount[symbol.Name].Item1;
376    }
377    public virtual int GetMaximumSubtreeCount(ISymbol symbol) {
378      return symbolSubtreeCount[symbol.Name].Item2;
379    }
380
381    protected void ClearCaches() {
382      cachedMinExpressionLength.Clear();
383      cachedMaxExpressionLength.Clear();
384      cachedMinExpressionDepth.Clear();
385      cachedMaxExpressionDepth.Clear();
386
387      cachedIsAllowedChildSymbol.Clear();
388      cachedIsAllowedChildSymbolIndex.Clear();
389    }
390
391    private readonly Dictionary<string, int> cachedMinExpressionLength = new Dictionary<string, int>();
392    public int GetMinimumExpressionLength(ISymbol symbol) {
393      int res;
394      if (cachedMinExpressionLength.TryGetValue(symbol.Name, out res))
395        return res;
396
397      // value has to be calculated and cached make sure this is done in only one thread
398      lock (cachedMinExpressionLength) {
399        // in case the value has been calculated on another thread in the meanwhile
400        if (cachedMinExpressionLength.TryGetValue(symbol.Name, out res)) return res;
401
402        GrammarUtils.CalculateMinimumExpressionLengths(this, cachedMinExpressionLength);
403        return cachedMinExpressionLength[symbol.Name];
404      }
405    }
406
407
408    private readonly Dictionary<Tuple<string, int>, int> cachedMaxExpressionLength = new Dictionary<Tuple<string, int>, int>();
409    public int GetMaximumExpressionLength(ISymbol symbol, int maxDepth) {
410      int temp;
411      var key = Tuple.Create(symbol.Name, maxDepth);
412      if (cachedMaxExpressionLength.TryGetValue(key, out temp)) return temp;
413      // value has to be calculated and cached make sure this is done in only one thread
414      lock (cachedMaxExpressionLength) {
415        // in case the value has been calculated on another thread in the meanwhile
416        if (cachedMaxExpressionLength.TryGetValue(key, out temp)) return temp;
417
418        cachedMaxExpressionLength[key] = int.MaxValue; // prevent infinite recursion
419        long sumOfMaxTrees = 1 + (from argIndex in Enumerable.Range(0, GetMaximumSubtreeCount(symbol))
420                                  let maxForSlot = (long)(from s in GetAllowedChildSymbols(symbol, argIndex)
421                                                          where s.InitialFrequency > 0.0
422                                                          where GetMinimumExpressionDepth(s) < maxDepth
423                                                          select GetMaximumExpressionLength(s, maxDepth - 1)).DefaultIfEmpty(0).Max()
424                                  select maxForSlot).DefaultIfEmpty(0).Sum();
425        cachedMaxExpressionLength[key] = (int)Math.Min(sumOfMaxTrees, int.MaxValue);
426        return cachedMaxExpressionLength[key];
427      }
428    }
429
430    private readonly Dictionary<string, int> cachedMinExpressionDepth = new Dictionary<string, int>();
431    public int GetMinimumExpressionDepth(ISymbol symbol) {
432      int res;
433      if (cachedMinExpressionDepth.TryGetValue(symbol.Name, out res))
434        return res;
435
436      // value has to be calculated and cached make sure this is done in only one thread
437      lock (cachedMinExpressionDepth) {
438        // in case the value has been calculated on another thread in the meanwhile
439        if (cachedMinExpressionDepth.TryGetValue(symbol.Name, out res)) return res;
440
441        GrammarUtils.CalculateMinimumExpressionDepth(this, cachedMinExpressionDepth);
442        return cachedMinExpressionDepth[symbol.Name];
443      }
444    }
445
446    private readonly Dictionary<string, int> cachedMaxExpressionDepth = new Dictionary<string, int>();
447    public int GetMaximumExpressionDepth(ISymbol symbol) {
448      int temp;
449      if (cachedMaxExpressionDepth.TryGetValue(symbol.Name, out temp)) return temp;
450      // value has to be calculated and cached make sure this is done in only one thread
451      lock (cachedMaxExpressionDepth) {
452        // in case the value has been calculated on another thread in the meanwhile
453        if (cachedMaxExpressionDepth.TryGetValue(symbol.Name, out temp)) return temp;
454
455        cachedMaxExpressionDepth[symbol.Name] = int.MaxValue;
456        long maxDepth = 1 + (from argIndex in Enumerable.Range(0, GetMaximumSubtreeCount(symbol))
457                             let maxForSlot = (long)(from s in GetAllowedChildSymbols(symbol, argIndex)
458                                                     where s.InitialFrequency > 0.0
459                                                     select GetMaximumExpressionDepth(s)).DefaultIfEmpty(0).Max()
460                             select maxForSlot).DefaultIfEmpty(0).Max();
461        cachedMaxExpressionDepth[symbol.Name] = (int)Math.Min(maxDepth, int.MaxValue);
462        return cachedMaxExpressionDepth[symbol.Name];
463      }
464    }
465
466    public event EventHandler Changed;
467    protected virtual void OnChanged() {
468      if (suppressEvents) return;
469      var handler = Changed;
470      if (handler != null) handler(this, EventArgs.Empty);
471    }
472  }
473}
Note: See TracBrowser for help on using the repository browser.