#region License Information /* HeuristicLab * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Linq; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; namespace HeuristicLab.Encodings.SymbolicExpressionTreeEncoding { /// /// The default symbolic expression grammar stores symbols and syntactic constraints for symbols. /// Symbols are treated as equvivalent if they have the same name. /// Syntactic constraints limit the number of allowed sub trees for a node with a symbol and which symbols are allowed /// in the sub-trees of a symbol (can be specified for each sub-tree index separately). /// [StorableClass] public abstract class SymbolicExpressionGrammarBase : NamedItem, ISymbolicExpressionGrammarBase { #region properties for separation between implementation and persistence [Storable(Name = "Symbols")] private IEnumerable StorableSymbols { get { return symbols.Values.ToArray(); } set { symbols = value.ToDictionary(sym => sym.Name); } } [Storable(Name = "SymbolSubtreeCount")] private IEnumerable>> StorableSymbolSubtreeCount { get { return symbolSubtreeCount.Select(x => new KeyValuePair>(GetSymbol(x.Key), x.Value)).ToArray(); } set { symbolSubtreeCount = value.ToDictionary(x => x.Key.Name, x => x.Value); } } [Storable(Name = "AllowedChildSymbols")] private IEnumerable>> StorableAllowedChildSymbols { get { return allowedChildSymbols.Select(x => new KeyValuePair>(GetSymbol(x.Key), x.Value.Select(y => GetSymbol(y)).ToArray())).ToArray(); } set { allowedChildSymbols = value.ToDictionary(x => x.Key.Name, x => x.Value.Select(y => y.Name).ToList()); } } [Storable(Name = "AllowedChildSymbolsPerIndex")] private IEnumerable, IEnumerable>> StorableAllowedChildSymbolsPerIndex { get { return allowedChildSymbolsPerIndex.Select(x => new KeyValuePair, IEnumerable>(Tuple.Create(GetSymbol(x.Key.Item1), x.Key.Item2), x.Value.Select(y => GetSymbol(y)).ToArray())).ToArray(); } set { allowedChildSymbolsPerIndex = value.ToDictionary(x => Tuple.Create(x.Key.Item1.Name, x.Key.Item2), x => x.Value.Select(y => y.Name).ToList()); } } #endregion protected Dictionary symbols; protected Dictionary> symbolSubtreeCount; protected Dictionary> allowedChildSymbols; protected Dictionary, List> allowedChildSymbolsPerIndex; public override bool CanChangeName { get { return false; } } public override bool CanChangeDescription { get { return false; } } [StorableConstructor] protected SymbolicExpressionGrammarBase(bool deserializing) : base(deserializing) { cachedMinExpressionLength = new Dictionary(); cachedMaxExpressionLength = new Dictionary(); cachedMinExpressionDepth = new Dictionary(); } [StorableHook(HookType.AfterDeserialization)] private void AfterDeserialization() { foreach (ISymbol symbol in symbols.Values) RegisterSymbolEvents(symbol); } protected SymbolicExpressionGrammarBase(SymbolicExpressionGrammarBase original, Cloner cloner) : base(original, cloner) { cachedMinExpressionLength = new Dictionary(); cachedMaxExpressionLength = new Dictionary(); cachedMinExpressionDepth = new Dictionary(); symbols = original.symbols.ToDictionary(x => x.Key, y => (ISymbol)cloner.Clone(y.Value)); symbolSubtreeCount = new Dictionary>(original.symbolSubtreeCount); allowedChildSymbols = new Dictionary>(); foreach (var element in original.allowedChildSymbols) allowedChildSymbols.Add(element.Key, new List(element.Value)); allowedChildSymbolsPerIndex = new Dictionary, List>(); foreach (var element in original.allowedChildSymbolsPerIndex) allowedChildSymbolsPerIndex.Add(element.Key, new List(element.Value)); foreach (ISymbol symbol in symbols.Values) RegisterSymbolEvents(symbol); } protected SymbolicExpressionGrammarBase(string name, string description) : base(name, description) { cachedMinExpressionLength = new Dictionary(); cachedMaxExpressionLength = new Dictionary(); cachedMinExpressionDepth = new Dictionary(); symbols = new Dictionary(); symbolSubtreeCount = new Dictionary>(); allowedChildSymbols = new Dictionary>(); allowedChildSymbolsPerIndex = new Dictionary, List>(); } #region protected grammar manipulation methods protected void AddSymbol(ISymbol symbol) { if (ContainsSymbol(symbol)) throw new ArgumentException("Symbol " + symbol + " is already defined."); RegisterSymbolEvents(symbol); symbols.Add(symbol.Name, symbol); symbolSubtreeCount.Add(symbol.Name, Tuple.Create(0, 0)); ClearCaches(); } private void RegisterSymbolEvents(ISymbol symbol) { symbol.NameChanging += new EventHandler>(Symbol_NameChanging); symbol.NameChanged += new EventHandler(Symbol_NameChanged); } private void DeregisterSymbolEvents(ISymbol symbol) { symbol.NameChanging -= new EventHandler>(Symbol_NameChanging); symbol.NameChanged -= new EventHandler(Symbol_NameChanged); } private void Symbol_NameChanging(object sender, CancelEventArgs e) { if (symbols.ContainsKey(e.Value)) e.Cancel = true; } private void Symbol_NameChanged(object sender, EventArgs e) { ISymbol symbol = (ISymbol)sender; string oldName = symbols.Where(x => x.Value == symbol).First().Key; string newName = symbol.Name; symbols.Remove(oldName); symbols.Add(newName, symbol); var subtreeCount = symbolSubtreeCount[oldName]; symbolSubtreeCount.Remove(oldName); symbolSubtreeCount.Add(newName, subtreeCount); List allowedChilds; if (allowedChildSymbols.TryGetValue(oldName, out allowedChilds)) { allowedChildSymbols.Remove(oldName); allowedChildSymbols.Add(newName, allowedChilds); } for (int i = 0; i < GetMaximumSubtreeCount(symbol); i++) { if (allowedChildSymbolsPerIndex.TryGetValue(Tuple.Create(oldName, i), out allowedChilds)) { allowedChildSymbolsPerIndex.Remove(Tuple.Create(oldName, i)); allowedChildSymbolsPerIndex.Add(Tuple.Create(newName, i), allowedChilds); } } foreach (var parent in Symbols) { if (allowedChildSymbols.TryGetValue(parent.Name, out allowedChilds)) if (allowedChilds.Remove(oldName)) allowedChilds.Add(newName); for (int i = 0; i < GetMaximumSubtreeCount(parent); i++) { if (allowedChildSymbolsPerIndex.TryGetValue(Tuple.Create(parent.Name, i), out allowedChilds)) if (allowedChilds.Remove(oldName)) allowedChilds.Add(newName); } } ClearCaches(); } protected void RemoveSymbol(ISymbol symbol) { symbols.Remove(symbol.Name); allowedChildSymbols.Remove(symbol.Name); for (int i = 0; i < GetMaximumSubtreeCount(symbol); i++) allowedChildSymbolsPerIndex.Remove(Tuple.Create(symbol.Name, i)); symbolSubtreeCount.Remove(symbol.Name); foreach (var parent in Symbols) { List allowedChilds; if (allowedChildSymbols.TryGetValue(parent.Name, out allowedChilds)) allowedChilds.Remove(symbol.Name); for (int i = 0; i < GetMaximumSubtreeCount(parent); i++) { if (allowedChildSymbolsPerIndex.TryGetValue(Tuple.Create(parent.Name, i), out allowedChilds)) allowedChilds.Remove(symbol.Name); } } DeregisterSymbolEvents(symbol); ClearCaches(); } public virtual ISymbol GetSymbol(string symbolName) { ISymbol symbol; if (symbols.TryGetValue(symbolName, out symbol)) return symbol; return null; } protected void AddAllowedChildSymbol(ISymbol parent, ISymbol child) { List childSymbols; if (!allowedChildSymbols.TryGetValue(parent.Name, out childSymbols)) { childSymbols = new List(); allowedChildSymbols.Add(parent.Name, childSymbols); } if (childSymbols.Contains(child.Name)) throw new ArgumentException(); childSymbols.Add(child.Name); ClearCaches(); } protected void AddAllowedChildSymbol(ISymbol parent, ISymbol child, int argumentIndex) { var key = Tuple.Create(parent.Name, argumentIndex); List childSymbols; if (!allowedChildSymbolsPerIndex.TryGetValue(key, out childSymbols)) { childSymbols = new List(); allowedChildSymbolsPerIndex.Add(key, childSymbols); } if (IsAllowedChildSymbol(parent, child)) throw new ArgumentException(); if (childSymbols.Contains(child.Name)) throw new ArgumentException(); childSymbols.Add(child.Name); ClearCaches(); } protected void RemoveAllowedChildSymbol(ISymbol parent, ISymbol child) { List childSymbols; if (allowedChildSymbols.TryGetValue(child.Name, out childSymbols)) { if (allowedChildSymbols[parent.Name].Remove(child.Name)) ClearCaches(); } } protected void RemoveAllowedChildSymbol(ISymbol parent, ISymbol child, int argumentIndex) { var key = Tuple.Create(parent.Name, argumentIndex); List childSymbols; if (allowedChildSymbolsPerIndex.TryGetValue(key, out childSymbols)) { if (allowedChildSymbolsPerIndex[key].Remove(child.Name)) ClearCaches(); } } protected void SetSubtreeCount(ISymbol symbol, int minimumSubtreeCount, int maximumSubtreeCount) { for (int i = GetMaximumSubtreeCount(symbol) - 1; i >= maximumSubtreeCount; i--) { var key = Tuple.Create(symbol.Name, i); allowedChildSymbolsPerIndex.Remove(key); } symbolSubtreeCount[symbol.Name] = Tuple.Create(minimumSubtreeCount, maximumSubtreeCount); ClearCaches(); } #endregion #region ISymbolicExpressionGrammarBase Members public virtual IEnumerable Symbols { get { return symbols.Values; } } public virtual IEnumerable AllowedSymbols { get { return Symbols.Where(s => !s.InitialFrequency.IsAlmost(0.0)); } } public virtual bool ContainsSymbol(ISymbol symbol) { return symbols.ContainsKey(symbol.Name); } public virtual bool IsAllowedChildSymbol(ISymbol parent, ISymbol child) { List temp; if (allowedChildSymbols.TryGetValue(parent.Name, out temp)) if (temp.Contains(child.Name)) return true; return false; } public virtual bool IsAllowedChildSymbol(ISymbol parent, ISymbol child, int argumentIndex) { List temp; if (allowedChildSymbols.TryGetValue(parent.Name, out temp)) if (temp.Contains(child.Name)) return true; var key = Tuple.Create(parent.Name, argumentIndex); if (allowedChildSymbolsPerIndex.TryGetValue(key, out temp)) return temp.Contains(child.Name); return false; } public virtual IEnumerable GetAllowedChildSymbols(ISymbol parent) { return from s in AllowedSymbols where IsAllowedChildSymbol(parent, s) select s; } public virtual IEnumerable GetAllowedChildSymbols(ISymbol parent, int argumentIndex) { var result = Enumerable.Empty(); List temp; if (allowedChildSymbols.TryGetValue(parent.Name, out temp)) result = result.Union(temp); var key = Tuple.Create(parent.Name, argumentIndex); if (allowedChildSymbolsPerIndex.TryGetValue(key, out temp)) result = result.Union(temp); return result.Select(x => GetSymbol(x)); } public virtual int GetMinimumSubtreeCount(ISymbol symbol) { return symbolSubtreeCount[symbol.Name].Item1; } public virtual int GetMaximumSubtreeCount(ISymbol symbol) { return symbolSubtreeCount[symbol.Name].Item2; } private void ClearCaches() { cachedMinExpressionLength.Clear(); cachedMaxExpressionLength.Clear(); cachedMinExpressionDepth.Clear(); } private Dictionary cachedMinExpressionLength; public int GetMinimumExpressionLength(ISymbol symbol) { int temp; if (!cachedMinExpressionLength.TryGetValue(symbol.Name, out temp)) { cachedMinExpressionLength[symbol.Name] = int.MaxValue; // prevent infinite recursion long sumOfMinExpressionLengths = 1 + (from argIndex in Enumerable.Range(0, GetMinimumSubtreeCount(symbol)) let minForSlot = (long)(from s in AllowedSymbols where IsAllowedChildSymbol(symbol, s, argIndex) select GetMinimumExpressionLength(s)).DefaultIfEmpty(0).Min() select minForSlot).DefaultIfEmpty(0).Sum(); cachedMinExpressionLength[symbol.Name] = (int)Math.Min(sumOfMinExpressionLengths, int.MaxValue); return cachedMinExpressionLength[symbol.Name]; } return temp; } private Dictionary cachedMaxExpressionLength; public int GetMaximumExpressionLength(ISymbol symbol) { int temp; if (!cachedMaxExpressionLength.TryGetValue(symbol.Name, out temp)) { cachedMaxExpressionLength[symbol.Name] = int.MaxValue; // prevent infinite recursion long sumOfMaxTrees = 1 + (from argIndex in Enumerable.Range(0, GetMaximumSubtreeCount(symbol)) let maxForSlot = (long)(from s in AllowedSymbols where IsAllowedChildSymbol(symbol, s, argIndex) select GetMaximumExpressionLength(s)).DefaultIfEmpty(0).Max() select maxForSlot).DefaultIfEmpty(0).Sum(); cachedMaxExpressionLength[symbol.Name] = (int)Math.Min(sumOfMaxTrees, int.MaxValue); return cachedMaxExpressionLength[symbol.Name]; } return temp; } private Dictionary cachedMinExpressionDepth; public int GetMinimumExpressionDepth(ISymbol symbol) { int temp; if (!cachedMinExpressionDepth.TryGetValue(symbol.Name, out temp)) { cachedMinExpressionDepth[symbol.Name] = int.MaxValue; // prevent infinite recursion long minDepth = 1 + (from argIndex in Enumerable.Range(0, GetMinimumSubtreeCount(symbol)) let minForSlot = (long)(from s in AllowedSymbols where IsAllowedChildSymbol(symbol, s, argIndex) select GetMinimumExpressionDepth(s)).DefaultIfEmpty(0).Min() select minForSlot).DefaultIfEmpty(0).Max(); cachedMinExpressionDepth[symbol.Name] = (int)Math.Min(minDepth, int.MaxValue); return cachedMinExpressionDepth[symbol.Name]; } return temp; } #endregion } }