#region License Information
/* HeuristicLab
* Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections.Generic;
using System.Linq;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
namespace HeuristicLab.Encodings.SymbolicExpressionTreeEncoding {
///
/// The default symbolic expression grammar stores symbols and syntactic constraints for symbols.
/// Symbols are treated as equvivalent if they have the same name.
/// Syntactic constraints limit the number of allowed sub trees for a node with a symbol and which symbols are allowed
/// in the sub-trees of a symbol (can be specified for each sub-tree index separately).
///
[StorableClass]
public abstract class SymbolicExpressionGrammarBase : NamedItem, ISymbolicExpressionGrammarBase {
#region properties for separation between implementation and persistence
[Storable(Name = "Symbols")]
private IEnumerable StorableSymbols {
get { return symbols.Values.ToArray(); }
set { symbols = value.ToDictionary(sym => sym.Name); }
}
[Storable(Name = "SymbolSubtreeCount")]
private IEnumerable>> StorableSymbolSubtreeCount {
get { return symbolSubtreeCount.Select(x => new KeyValuePair>(GetSymbol(x.Key), x.Value)).ToArray(); }
set { symbolSubtreeCount = value.ToDictionary(x => x.Key.Name, x => x.Value); }
}
[Storable(Name = "AllowedChildSymbols")]
private IEnumerable>> StorableAllowedChildSymbols {
get { return allowedChildSymbols.Select(x => new KeyValuePair>(GetSymbol(x.Key), x.Value.Select(y => GetSymbol(y)).ToArray())).ToArray(); }
set { allowedChildSymbols = value.ToDictionary(x => x.Key.Name, x => x.Value.Select(y => y.Name).ToList()); }
}
[Storable(Name = "AllowedChildSymbolsPerIndex")]
private IEnumerable, IEnumerable>> StorableAllowedChildSymbolsPerIndex {
get { return allowedChildSymbolsPerIndex.Select(x => new KeyValuePair, IEnumerable>(Tuple.Create(GetSymbol(x.Key.Item1), x.Key.Item2), x.Value.Select(y => GetSymbol(y)).ToArray())).ToArray(); }
set { allowedChildSymbolsPerIndex = value.ToDictionary(x => Tuple.Create(x.Key.Item1.Name, x.Key.Item2), x => x.Value.Select(y => y.Name).ToList()); }
}
#endregion
protected Dictionary symbols;
protected Dictionary> symbolSubtreeCount;
protected Dictionary> allowedChildSymbols;
protected Dictionary, List> allowedChildSymbolsPerIndex;
public override bool CanChangeName {
get { return false; }
}
public override bool CanChangeDescription {
get { return false; }
}
[StorableConstructor]
protected SymbolicExpressionGrammarBase(bool deserializing)
: base(deserializing) {
cachedMinExpressionLength = new Dictionary();
cachedMaxExpressionLength = new Dictionary();
cachedMinExpressionDepth = new Dictionary();
}
[StorableHook(HookType.AfterDeserialization)]
private void AfterDeserialization() {
foreach (ISymbol symbol in symbols.Values)
RegisterSymbolEvents(symbol);
}
protected SymbolicExpressionGrammarBase(SymbolicExpressionGrammarBase original, Cloner cloner)
: base(original, cloner) {
cachedMinExpressionLength = new Dictionary();
cachedMaxExpressionLength = new Dictionary();
cachedMinExpressionDepth = new Dictionary();
symbols = original.symbols.ToDictionary(x => x.Key, y => (ISymbol)cloner.Clone(y.Value));
symbolSubtreeCount = new Dictionary>(original.symbolSubtreeCount);
allowedChildSymbols = new Dictionary>();
foreach (var element in original.allowedChildSymbols)
allowedChildSymbols.Add(element.Key, new List(element.Value));
allowedChildSymbolsPerIndex = new Dictionary, List>();
foreach (var element in original.allowedChildSymbolsPerIndex)
allowedChildSymbolsPerIndex.Add(element.Key, new List(element.Value));
foreach (ISymbol symbol in symbols.Values)
RegisterSymbolEvents(symbol);
}
protected SymbolicExpressionGrammarBase(string name, string description)
: base(name, description) {
cachedMinExpressionLength = new Dictionary();
cachedMaxExpressionLength = new Dictionary();
cachedMinExpressionDepth = new Dictionary();
symbols = new Dictionary();
symbolSubtreeCount = new Dictionary>();
allowedChildSymbols = new Dictionary>();
allowedChildSymbolsPerIndex = new Dictionary, List>();
}
#region protected grammar manipulation methods
protected void AddSymbol(ISymbol symbol) {
if (ContainsSymbol(symbol)) throw new ArgumentException("Symbol " + symbol + " is already defined.");
RegisterSymbolEvents(symbol);
symbols.Add(symbol.Name, symbol);
symbolSubtreeCount.Add(symbol.Name, Tuple.Create(0, 0));
ClearCaches();
}
private void RegisterSymbolEvents(ISymbol symbol) {
symbol.NameChanging += new EventHandler>(Symbol_NameChanging);
symbol.NameChanged += new EventHandler(Symbol_NameChanged);
}
private void DeregisterSymbolEvents(ISymbol symbol) {
symbol.NameChanging -= new EventHandler>(Symbol_NameChanging);
symbol.NameChanged -= new EventHandler(Symbol_NameChanged);
}
private void Symbol_NameChanging(object sender, CancelEventArgs e) {
if (symbols.ContainsKey(e.Value)) e.Cancel = true;
}
private void Symbol_NameChanged(object sender, EventArgs e) {
ISymbol symbol = (ISymbol)sender;
string oldName = symbols.Where(x => x.Value == symbol).First().Key;
string newName = symbol.Name;
symbols.Remove(oldName);
symbols.Add(newName, symbol);
var subtreeCount = symbolSubtreeCount[oldName];
symbolSubtreeCount.Remove(oldName);
symbolSubtreeCount.Add(newName, subtreeCount);
List allowedChilds;
if (allowedChildSymbols.TryGetValue(oldName, out allowedChilds)) {
allowedChildSymbols.Remove(oldName);
allowedChildSymbols.Add(newName, allowedChilds);
}
for (int i = 0; i < GetMaximumSubtreeCount(symbol); i++) {
if (allowedChildSymbolsPerIndex.TryGetValue(Tuple.Create(oldName, i), out allowedChilds)) {
allowedChildSymbolsPerIndex.Remove(Tuple.Create(oldName, i));
allowedChildSymbolsPerIndex.Add(Tuple.Create(newName, i), allowedChilds);
}
}
foreach (var parent in Symbols) {
if (allowedChildSymbols.TryGetValue(parent.Name, out allowedChilds))
if (allowedChilds.Remove(oldName))
allowedChilds.Add(newName);
for (int i = 0; i < GetMaximumSubtreeCount(parent); i++) {
if (allowedChildSymbolsPerIndex.TryGetValue(Tuple.Create(parent.Name, i), out allowedChilds))
if (allowedChilds.Remove(oldName)) allowedChilds.Add(newName);
}
}
ClearCaches();
}
protected void RemoveSymbol(ISymbol symbol) {
symbols.Remove(symbol.Name);
allowedChildSymbols.Remove(symbol.Name);
for (int i = 0; i < GetMaximumSubtreeCount(symbol); i++)
allowedChildSymbolsPerIndex.Remove(Tuple.Create(symbol.Name, i));
symbolSubtreeCount.Remove(symbol.Name);
foreach (var parent in Symbols) {
List allowedChilds;
if (allowedChildSymbols.TryGetValue(parent.Name, out allowedChilds))
allowedChilds.Remove(symbol.Name);
for (int i = 0; i < GetMaximumSubtreeCount(parent); i++) {
if (allowedChildSymbolsPerIndex.TryGetValue(Tuple.Create(parent.Name, i), out allowedChilds))
allowedChilds.Remove(symbol.Name);
}
}
DeregisterSymbolEvents(symbol);
ClearCaches();
}
public virtual ISymbol GetSymbol(string symbolName) {
ISymbol symbol;
if (symbols.TryGetValue(symbolName, out symbol)) return symbol;
return null;
}
protected void AddAllowedChildSymbol(ISymbol parent, ISymbol child) {
List childSymbols;
if (!allowedChildSymbols.TryGetValue(parent.Name, out childSymbols)) {
childSymbols = new List();
allowedChildSymbols.Add(parent.Name, childSymbols);
}
if (childSymbols.Contains(child.Name)) throw new ArgumentException();
childSymbols.Add(child.Name);
ClearCaches();
}
protected void AddAllowedChildSymbol(ISymbol parent, ISymbol child, int argumentIndex) {
var key = Tuple.Create(parent.Name, argumentIndex);
List childSymbols;
if (!allowedChildSymbolsPerIndex.TryGetValue(key, out childSymbols)) {
childSymbols = new List();
allowedChildSymbolsPerIndex.Add(key, childSymbols);
}
if (IsAllowedChildSymbol(parent, child)) throw new ArgumentException();
if (childSymbols.Contains(child.Name)) throw new ArgumentException();
childSymbols.Add(child.Name);
ClearCaches();
}
protected void RemoveAllowedChildSymbol(ISymbol parent, ISymbol child) {
List childSymbols;
if (allowedChildSymbols.TryGetValue(child.Name, out childSymbols)) {
if (allowedChildSymbols[parent.Name].Remove(child.Name))
ClearCaches();
}
}
protected void RemoveAllowedChildSymbol(ISymbol parent, ISymbol child, int argumentIndex) {
var key = Tuple.Create(parent.Name, argumentIndex);
List childSymbols;
if (allowedChildSymbolsPerIndex.TryGetValue(key, out childSymbols)) {
if (allowedChildSymbolsPerIndex[key].Remove(child.Name))
ClearCaches();
}
}
protected void SetSubtreeCount(ISymbol symbol, int minimumSubtreeCount, int maximumSubtreeCount) {
for (int i = GetMaximumSubtreeCount(symbol) - 1; i >= maximumSubtreeCount; i--) {
var key = Tuple.Create(symbol.Name, i);
allowedChildSymbolsPerIndex.Remove(key);
}
symbolSubtreeCount[symbol.Name] = Tuple.Create(minimumSubtreeCount, maximumSubtreeCount);
ClearCaches();
}
#endregion
#region ISymbolicExpressionGrammarBase Members
public virtual IEnumerable Symbols {
get { return symbols.Values; }
}
public virtual IEnumerable AllowedSymbols {
get { return Symbols.Where(s => !s.InitialFrequency.IsAlmost(0.0)); }
}
public virtual bool ContainsSymbol(ISymbol symbol) {
return symbols.ContainsKey(symbol.Name);
}
public virtual bool IsAllowedChildSymbol(ISymbol parent, ISymbol child) {
List temp;
if (allowedChildSymbols.TryGetValue(parent.Name, out temp))
if (temp.Contains(child.Name)) return true;
return false;
}
public virtual bool IsAllowedChildSymbol(ISymbol parent, ISymbol child, int argumentIndex) {
List temp;
if (allowedChildSymbols.TryGetValue(parent.Name, out temp))
if (temp.Contains(child.Name)) return true;
var key = Tuple.Create(parent.Name, argumentIndex);
if (allowedChildSymbolsPerIndex.TryGetValue(key, out temp))
return temp.Contains(child.Name);
return false;
}
public virtual IEnumerable GetAllowedChildSymbols(ISymbol parent) {
return from s in AllowedSymbols where IsAllowedChildSymbol(parent, s) select s;
}
public virtual IEnumerable GetAllowedChildSymbols(ISymbol parent, int argumentIndex) {
var result = Enumerable.Empty();
List temp;
if (allowedChildSymbols.TryGetValue(parent.Name, out temp))
result = result.Union(temp);
var key = Tuple.Create(parent.Name, argumentIndex);
if (allowedChildSymbolsPerIndex.TryGetValue(key, out temp))
result = result.Union(temp);
return result.Select(x => GetSymbol(x));
}
public virtual int GetMinimumSubtreeCount(ISymbol symbol) {
return symbolSubtreeCount[symbol.Name].Item1;
}
public virtual int GetMaximumSubtreeCount(ISymbol symbol) {
return symbolSubtreeCount[symbol.Name].Item2;
}
private void ClearCaches() {
cachedMinExpressionLength.Clear();
cachedMaxExpressionLength.Clear();
cachedMinExpressionDepth.Clear();
}
private Dictionary cachedMinExpressionLength;
public int GetMinimumExpressionLength(ISymbol symbol) {
int temp;
if (!cachedMinExpressionLength.TryGetValue(symbol.Name, out temp)) {
cachedMinExpressionLength[symbol.Name] = int.MaxValue; // prevent infinite recursion
long sumOfMinExpressionLengths = 1 + (from argIndex in Enumerable.Range(0, GetMinimumSubtreeCount(symbol))
let minForSlot = (long)(from s in AllowedSymbols
where IsAllowedChildSymbol(symbol, s, argIndex)
select GetMinimumExpressionLength(s)).DefaultIfEmpty(0).Min()
select minForSlot).DefaultIfEmpty(0).Sum();
cachedMinExpressionLength[symbol.Name] = (int)Math.Min(sumOfMinExpressionLengths, int.MaxValue);
return cachedMinExpressionLength[symbol.Name];
}
return temp;
}
private Dictionary cachedMaxExpressionLength;
public int GetMaximumExpressionLength(ISymbol symbol) {
int temp;
if (!cachedMaxExpressionLength.TryGetValue(symbol.Name, out temp)) {
cachedMaxExpressionLength[symbol.Name] = int.MaxValue; // prevent infinite recursion
long sumOfMaxTrees = 1 + (from argIndex in Enumerable.Range(0, GetMaximumSubtreeCount(symbol))
let maxForSlot = (long)(from s in AllowedSymbols
where IsAllowedChildSymbol(symbol, s, argIndex)
select GetMaximumExpressionLength(s)).DefaultIfEmpty(0).Max()
select maxForSlot).DefaultIfEmpty(0).Sum();
cachedMaxExpressionLength[symbol.Name] = (int)Math.Min(sumOfMaxTrees, int.MaxValue);
return cachedMaxExpressionLength[symbol.Name];
}
return temp;
}
private Dictionary cachedMinExpressionDepth;
public int GetMinimumExpressionDepth(ISymbol symbol) {
int temp;
if (!cachedMinExpressionDepth.TryGetValue(symbol.Name, out temp)) {
cachedMinExpressionDepth[symbol.Name] = int.MaxValue; // prevent infinite recursion
long minDepth = 1 + (from argIndex in Enumerable.Range(0, GetMinimumSubtreeCount(symbol))
let minForSlot = (long)(from s in AllowedSymbols
where IsAllowedChildSymbol(symbol, s, argIndex)
select GetMinimumExpressionDepth(s)).DefaultIfEmpty(0).Min()
select minForSlot).DefaultIfEmpty(0).Max();
cachedMinExpressionDepth[symbol.Name] = (int)Math.Min(minDepth, int.MaxValue);
return cachedMinExpressionDepth[symbol.Name];
}
return temp;
}
#endregion
}
}