#region License Information
/* HeuristicLab
* Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections.Generic;
using System.Linq;
using HeuristicLab.Common;
using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
///
/// Simplifier for symbolic expressions
///
public class TreeSimplifier {
private static readonly Addition addSymbol = new Addition();
private static readonly Multiplication mulSymbol = new Multiplication();
private static readonly Division divSymbol = new Division();
private static readonly Constant constSymbol = new Constant();
private static readonly Absolute absSymbol = new Absolute();
private static readonly Logarithm logSymbol = new Logarithm();
private static readonly Exponential expSymbol = new Exponential();
private static readonly Root rootSymbol = new Root();
private static readonly Square sqrSymbol = new Square();
private static readonly SquareRoot sqrtSymbol = new SquareRoot();
private static readonly AnalyticQuotient aqSymbol = new AnalyticQuotient();
private static readonly Cube cubeSymbol = new Cube();
private static readonly CubeRoot cubeRootSymbol = new CubeRoot();
private static readonly Power powSymbol = new Power();
private static readonly Sine sineSymbol = new Sine();
private static readonly Cosine cosineSymbol = new Cosine();
private static readonly Tangent tanSymbol = new Tangent();
private static readonly IfThenElse ifThenElseSymbol = new IfThenElse();
private static readonly And andSymbol = new And();
private static readonly Or orSymbol = new Or();
private static readonly Not notSymbol = new Not();
private static readonly GreaterThan gtSymbol = new GreaterThan();
private static readonly LessThan ltSymbol = new LessThan();
private static readonly Integral integralSymbol = new Integral();
private static readonly LaggedVariable laggedVariableSymbol = new LaggedVariable();
private static readonly TimeLag timeLagSymbol = new TimeLag();
[Obsolete("Use static method TreeSimplifier.Simplify instead")]
public TreeSimplifier() { }
public static ISymbolicExpressionTree Simplify(ISymbolicExpressionTree originalTree) {
var clone = (ISymbolicExpressionTreeNode)originalTree.Root.Clone();
// macro expand (initially no argument trees)
var macroExpandedTree = MacroExpand(clone, clone.GetSubtree(0), new List());
ISymbolicExpressionTreeNode rootNode = (new ProgramRootSymbol()).CreateTreeNode();
rootNode.AddSubtree(GetSimplifiedTree(macroExpandedTree));
#if DEBUG
// check that each node is only referenced once
var nodes = rootNode.IterateNodesPrefix().ToArray();
foreach (var n in nodes) if (nodes.Count(ni => ni == n) > 1) throw new InvalidOperationException();
#endif
return new SymbolicExpressionTree(rootNode);
}
// the argumentTrees list contains already expanded trees used as arguments for invocations
private static ISymbolicExpressionTreeNode MacroExpand(ISymbolicExpressionTreeNode root, ISymbolicExpressionTreeNode node,
IList argumentTrees) {
List subtrees = new List(node.Subtrees);
while (node.SubtreeCount > 0) node.RemoveSubtree(0);
if (node.Symbol is InvokeFunction) {
var invokeSym = node.Symbol as InvokeFunction;
var defunNode = FindFunctionDefinition(root, invokeSym.FunctionName);
var macroExpandedArguments = new List();
foreach (var subtree in subtrees) {
macroExpandedArguments.Add(MacroExpand(root, subtree, argumentTrees));
}
return MacroExpand(root, defunNode, macroExpandedArguments);
} else if (node.Symbol is Argument) {
var argSym = node.Symbol as Argument;
// return the correct argument sub-tree (already macro-expanded)
return (SymbolicExpressionTreeNode)argumentTrees[argSym.ArgumentIndex].Clone();
} else {
// recursive application
foreach (var subtree in subtrees) {
node.AddSubtree(MacroExpand(root, subtree, argumentTrees));
}
return node;
}
}
private static ISymbolicExpressionTreeNode FindFunctionDefinition(ISymbolicExpressionTreeNode root, string functionName) {
foreach (var subtree in root.Subtrees.OfType()) {
if (subtree.FunctionName == functionName) return subtree.GetSubtree(0);
}
throw new ArgumentException("Definition of function " + functionName + " not found.");
}
#region symbol predicates
// arithmetic
private static bool IsDivision(ISymbolicExpressionTreeNode node) {
return node.Symbol is Division;
}
private static bool IsMultiplication(ISymbolicExpressionTreeNode node) {
return node.Symbol is Multiplication;
}
private static bool IsSubtraction(ISymbolicExpressionTreeNode node) {
return node.Symbol is Subtraction;
}
private static bool IsAddition(ISymbolicExpressionTreeNode node) {
return node.Symbol is Addition;
}
private static bool IsAverage(ISymbolicExpressionTreeNode node) {
return node.Symbol is Average;
}
private static bool IsAbsolute(ISymbolicExpressionTreeNode node) {
return node.Symbol is Absolute;
}
// exponential
private static bool IsLog(ISymbolicExpressionTreeNode node) {
return node.Symbol is Logarithm;
}
private static bool IsExp(ISymbolicExpressionTreeNode node) {
return node.Symbol is Exponential;
}
private static bool IsRoot(ISymbolicExpressionTreeNode node) {
return node.Symbol is Root;
}
private static bool IsSquare(ISymbolicExpressionTreeNode node) {
return node.Symbol is Square;
}
private static bool IsSquareRoot(ISymbolicExpressionTreeNode node) {
return node.Symbol is SquareRoot;
}
private static bool IsCube(ISymbolicExpressionTreeNode node) {
return node.Symbol is Cube;
}
private static bool IsCubeRoot(ISymbolicExpressionTreeNode node) {
return node.Symbol is CubeRoot;
}
private static bool IsPower(ISymbolicExpressionTreeNode node) {
return node.Symbol is Power;
}
// trigonometric
private static bool IsSine(ISymbolicExpressionTreeNode node) {
return node.Symbol is Sine;
}
private static bool IsCosine(ISymbolicExpressionTreeNode node) {
return node.Symbol is Cosine;
}
private static bool IsTangent(ISymbolicExpressionTreeNode node) {
return node.Symbol is Tangent;
}
private static bool IsAnalyticalQuotient(ISymbolicExpressionTreeNode node) {
return node.Symbol is AnalyticQuotient;
}
// boolean
private static bool IsIfThenElse(ISymbolicExpressionTreeNode node) {
return node.Symbol is IfThenElse;
}
private static bool IsAnd(ISymbolicExpressionTreeNode node) {
return node.Symbol is And;
}
private static bool IsOr(ISymbolicExpressionTreeNode node) {
return node.Symbol is Or;
}
private static bool IsNot(ISymbolicExpressionTreeNode node) {
return node.Symbol is Not;
}
// comparison
private static bool IsGreaterThan(ISymbolicExpressionTreeNode node) {
return node.Symbol is GreaterThan;
}
private static bool IsLessThan(ISymbolicExpressionTreeNode node) {
return node.Symbol is LessThan;
}
private static bool IsBoolean(ISymbolicExpressionTreeNode node) {
return
node.Symbol is GreaterThan ||
node.Symbol is LessThan ||
node.Symbol is And ||
node.Symbol is Or;
}
// terminals
private static bool IsVariable(ISymbolicExpressionTreeNode node) {
return node.Symbol is Variable;
}
private static bool IsVariableBase(ISymbolicExpressionTreeNode node) {
return node is VariableTreeNodeBase;
}
private static bool IsFactor(ISymbolicExpressionTreeNode node) {
return node is FactorVariableTreeNode;
}
private static bool IsBinFactor(ISymbolicExpressionTreeNode node) {
return node is BinaryFactorVariableTreeNode;
}
private static bool IsConstant(ISymbolicExpressionTreeNode node) {
return node.Symbol is Constant;
}
// dynamic
private static bool IsTimeLag(ISymbolicExpressionTreeNode node) {
return node.Symbol is TimeLag;
}
private static bool IsIntegral(ISymbolicExpressionTreeNode node) {
return node.Symbol is Integral;
}
#endregion
///
/// Creates a new simplified tree
///
///
///
public static ISymbolicExpressionTreeNode GetSimplifiedTree(ISymbolicExpressionTreeNode original) {
if (IsConstant(original) || IsVariableBase(original)) {
return (ISymbolicExpressionTreeNode)original.Clone();
} else if (IsAbsolute(original)) {
return SimplifyAbsolute(original);
} else if (IsAddition(original)) {
return SimplifyAddition(original);
} else if (IsSubtraction(original)) {
return SimplifySubtraction(original);
} else if (IsMultiplication(original)) {
return SimplifyMultiplication(original);
} else if (IsDivision(original)) {
return SimplifyDivision(original);
} else if (IsAverage(original)) {
return SimplifyAverage(original);
} else if (IsLog(original)) {
return SimplifyLog(original);
} else if (IsExp(original)) {
return SimplifyExp(original);
} else if (IsSquare(original)) {
return SimplifySquare(original);
} else if (IsSquareRoot(original)) {
return SimplifySquareRoot(original);
} else if (IsCube(original)) {
return SimplifyCube(original);
} else if (IsCubeRoot(original)) {
return SimplifyCubeRoot(original);
} else if (IsPower(original)) {
return SimplifyPower(original);
} else if (IsRoot(original)) {
return SimplifyRoot(original);
} else if (IsSine(original)) {
return SimplifySine(original);
} else if (IsCosine(original)) {
return SimplifyCosine(original);
} else if (IsTangent(original)) {
return SimplifyTangent(original);
} else if (IsAnalyticalQuotient(original)) {
return SimplifyAnalyticalQuotient(original);
} else if (IsIfThenElse(original)) {
return SimplifyIfThenElse(original);
} else if (IsGreaterThan(original)) {
return SimplifyGreaterThan(original);
} else if (IsLessThan(original)) {
return SimplifyLessThan(original);
} else if (IsAnd(original)) {
return SimplifyAnd(original);
} else if (IsOr(original)) {
return SimplifyOr(original);
} else if (IsNot(original)) {
return SimplifyNot(original);
} else if (IsTimeLag(original)) {
return SimplifyTimeLag(original);
} else if (IsIntegral(original)) {
return SimplifyIntegral(original);
} else {
return SimplifyAny(original);
}
}
#region specific simplification routines
private static ISymbolicExpressionTreeNode SimplifyAny(ISymbolicExpressionTreeNode original) {
// can't simplify this function but simplify all subtrees
List subtrees = new List(original.Subtrees);
while (original.Subtrees.Count() > 0) original.RemoveSubtree(0);
var clone = (SymbolicExpressionTreeNode)original.Clone();
List simplifiedSubtrees = new List();
foreach (var subtree in subtrees) {
simplifiedSubtrees.Add(GetSimplifiedTree(subtree));
original.AddSubtree(subtree);
}
foreach (var simplifiedSubtree in simplifiedSubtrees) {
clone.AddSubtree(simplifiedSubtree);
}
if (simplifiedSubtrees.TrueForAll(t => IsConstant(t))) {
SimplifyConstantExpression(clone);
}
return clone;
}
private static ISymbolicExpressionTreeNode SimplifyConstantExpression(ISymbolicExpressionTreeNode original) {
// not yet implemented
return original;
}
private static ISymbolicExpressionTreeNode SimplifyAverage(ISymbolicExpressionTreeNode original) {
if (original.Subtrees.Count() == 1) {
return GetSimplifiedTree(original.GetSubtree(0));
} else {
// simplify expressions x0..xn
// make sum(x0..xn) / n
var sum = original.Subtrees
.Select(GetSimplifiedTree)
.Aggregate(MakeSum);
return MakeFraction(sum, MakeConstant(original.Subtrees.Count()));
}
}
private static ISymbolicExpressionTreeNode SimplifyDivision(ISymbolicExpressionTreeNode original) {
if (original.Subtrees.Count() == 1) {
return Invert(GetSimplifiedTree(original.GetSubtree(0)));
} else {
// simplify expressions x0..xn
// make multiplication (x0 * 1/(x1 * x1 * .. * xn))
var first = original.GetSubtree(0);
var second = original.GetSubtree(1);
var remaining = original.Subtrees.Skip(2);
return
MakeProduct(GetSimplifiedTree(first),
Invert(remaining.Aggregate(GetSimplifiedTree(second), (a, b) => MakeProduct(a, GetSimplifiedTree(b)))));
}
}
private static ISymbolicExpressionTreeNode SimplifyMultiplication(ISymbolicExpressionTreeNode original) {
if (original.Subtrees.Count() == 1) {
return GetSimplifiedTree(original.GetSubtree(0));
} else {
return original.Subtrees
.Select(GetSimplifiedTree)
.Aggregate(MakeProduct);
}
}
private static ISymbolicExpressionTreeNode SimplifySubtraction(ISymbolicExpressionTreeNode original) {
if (original.Subtrees.Count() == 1) {
return Negate(GetSimplifiedTree(original.GetSubtree(0)));
} else {
// simplify expressions x0..xn
// make addition (x0,-x1..-xn)
var first = original.Subtrees.First();
var remaining = original.Subtrees.Skip(1);
return remaining.Aggregate(GetSimplifiedTree(first), (a, b) => MakeSum(a, Negate(GetSimplifiedTree(b))));
}
}
private static ISymbolicExpressionTreeNode SimplifyAddition(ISymbolicExpressionTreeNode original) {
if (original.Subtrees.Count() == 1) {
return GetSimplifiedTree(original.GetSubtree(0));
} else {
// simplify expression x0..xn
// make addition (x0..xn)
return original.Subtrees
.Select(GetSimplifiedTree)
.Aggregate(MakeSum);
}
}
private static ISymbolicExpressionTreeNode SimplifyAbsolute(ISymbolicExpressionTreeNode original) {
return MakeAbs(GetSimplifiedTree(original.GetSubtree(0)));
}
private static ISymbolicExpressionTreeNode SimplifyNot(ISymbolicExpressionTreeNode original) {
return MakeNot(GetSimplifiedTree(original.GetSubtree(0)));
}
private static ISymbolicExpressionTreeNode SimplifyOr(ISymbolicExpressionTreeNode original) {
return original.Subtrees
.Select(GetSimplifiedTree)
.Aggregate(MakeOr);
}
private static ISymbolicExpressionTreeNode SimplifyAnd(ISymbolicExpressionTreeNode original) {
return original.Subtrees
.Select(GetSimplifiedTree)
.Aggregate(MakeAnd);
}
private static ISymbolicExpressionTreeNode SimplifyLessThan(ISymbolicExpressionTreeNode original) {
return MakeLessThan(GetSimplifiedTree(original.GetSubtree(0)), GetSimplifiedTree(original.GetSubtree(1)));
}
private static ISymbolicExpressionTreeNode SimplifyGreaterThan(ISymbolicExpressionTreeNode original) {
return MakeGreaterThan(GetSimplifiedTree(original.GetSubtree(0)), GetSimplifiedTree(original.GetSubtree(1)));
}
private static ISymbolicExpressionTreeNode SimplifyIfThenElse(ISymbolicExpressionTreeNode original) {
return MakeIfThenElse(GetSimplifiedTree(original.GetSubtree(0)), GetSimplifiedTree(original.GetSubtree(1)),
GetSimplifiedTree(original.GetSubtree(2)));
}
private static ISymbolicExpressionTreeNode SimplifyTangent(ISymbolicExpressionTreeNode original) {
return MakeTangent(GetSimplifiedTree(original.GetSubtree(0)));
}
private static ISymbolicExpressionTreeNode SimplifyCosine(ISymbolicExpressionTreeNode original) {
return MakeCosine(GetSimplifiedTree(original.GetSubtree(0)));
}
private static ISymbolicExpressionTreeNode SimplifySine(ISymbolicExpressionTreeNode original) {
return MakeSine(GetSimplifiedTree(original.GetSubtree(0)));
}
private static ISymbolicExpressionTreeNode SimplifyExp(ISymbolicExpressionTreeNode original) {
return MakeExp(GetSimplifiedTree(original.GetSubtree(0)));
}
private static ISymbolicExpressionTreeNode SimplifySquare(ISymbolicExpressionTreeNode original) {
return MakeSquare(GetSimplifiedTree(original.GetSubtree(0)));
}
private static ISymbolicExpressionTreeNode SimplifySquareRoot(ISymbolicExpressionTreeNode original) {
return MakeSquareRoot(GetSimplifiedTree(original.GetSubtree(0)));
}
private static ISymbolicExpressionTreeNode SimplifyCube(ISymbolicExpressionTreeNode original) {
return MakeCube(GetSimplifiedTree(original.GetSubtree(0)));
}
private static ISymbolicExpressionTreeNode SimplifyCubeRoot(ISymbolicExpressionTreeNode original) {
return MakeCubeRoot(GetSimplifiedTree(original.GetSubtree(0)));
}
private static ISymbolicExpressionTreeNode SimplifyLog(ISymbolicExpressionTreeNode original) {
return MakeLog(GetSimplifiedTree(original.GetSubtree(0)));
}
private static ISymbolicExpressionTreeNode SimplifyRoot(ISymbolicExpressionTreeNode original) {
return MakeRoot(GetSimplifiedTree(original.GetSubtree(0)), GetSimplifiedTree(original.GetSubtree(1)));
}
private static ISymbolicExpressionTreeNode SimplifyPower(ISymbolicExpressionTreeNode original) {
return MakePower(GetSimplifiedTree(original.GetSubtree(0)), GetSimplifiedTree(original.GetSubtree(1)));
}
private static ISymbolicExpressionTreeNode SimplifyAnalyticalQuotient(ISymbolicExpressionTreeNode original) {
return MakeAnalyticalQuotient(GetSimplifiedTree(original.GetSubtree(0)), GetSimplifiedTree(original.GetSubtree(1)));
}
private static ISymbolicExpressionTreeNode SimplifyTimeLag(ISymbolicExpressionTreeNode original) {
var laggedTreeNode = original as ILaggedTreeNode;
var simplifiedSubtree = GetSimplifiedTree(original.GetSubtree(0));
if (!ContainsVariableCondition(simplifiedSubtree)) {
return AddLagToDynamicNodes(simplifiedSubtree, laggedTreeNode.Lag);
} else {
return MakeTimeLag(simplifiedSubtree, laggedTreeNode.Lag);
}
}
private static ISymbolicExpressionTreeNode SimplifyIntegral(ISymbolicExpressionTreeNode original) {
var laggedTreeNode = original as ILaggedTreeNode;
var simplifiedSubtree = GetSimplifiedTree(original.GetSubtree(0));
if (IsConstant(simplifiedSubtree)) {
return GetSimplifiedTree(MakeProduct(simplifiedSubtree, MakeConstant(-laggedTreeNode.Lag)));
} else {
return MakeIntegral(simplifiedSubtree, laggedTreeNode.Lag);
}
}
#endregion
#region low level tree restructuring
private static ISymbolicExpressionTreeNode MakeTimeLag(ISymbolicExpressionTreeNode subtree, int lag) {
if (lag == 0) return subtree;
if (IsConstant(subtree)) return subtree;
var lagNode = (LaggedTreeNode)timeLagSymbol.CreateTreeNode();
lagNode.Lag = lag;
lagNode.AddSubtree(subtree);
return lagNode;
}
private static ISymbolicExpressionTreeNode MakeIntegral(ISymbolicExpressionTreeNode subtree, int lag) {
if (lag == 0) return subtree;
else if (lag == -1 || lag == 1) {
return MakeSum(subtree, AddLagToDynamicNodes((ISymbolicExpressionTreeNode)subtree.Clone(), lag));
} else {
var node = (LaggedTreeNode)integralSymbol.CreateTreeNode();
node.Lag = lag;
node.AddSubtree(subtree);
return node;
}
}
private static ISymbolicExpressionTreeNode MakeNot(ISymbolicExpressionTreeNode t) {
if (IsConstant(t)) {
var constNode = t as ConstantTreeNode;
if (constNode.Value > 0) return MakeConstant(-1.0);
else return MakeConstant(1.0);
} else if (IsNot(t)) {
return t.GetSubtree(0);
} else if (!IsBoolean(t)) {
var gtNode = gtSymbol.CreateTreeNode();
gtNode.AddSubtree(t);
gtNode.AddSubtree(MakeConstant(0.0));
var notNode = notSymbol.CreateTreeNode();
notNode.AddSubtree(gtNode);
return notNode;
} else {
var notNode = notSymbol.CreateTreeNode();
notNode.AddSubtree(t);
return notNode;
}
}
private static ISymbolicExpressionTreeNode MakeOr(ISymbolicExpressionTreeNode a, ISymbolicExpressionTreeNode b) {
if (IsConstant(a) && IsConstant(b)) {
var constA = a as ConstantTreeNode;
var constB = b as ConstantTreeNode;
if (constA.Value > 0.0 || constB.Value > 0.0) {
return MakeConstant(1.0);
} else {
return MakeConstant(-1.0);
}
} else if (IsConstant(a)) {
return MakeOr(b, a);
} else if (IsConstant(b)) {
var constT = b as ConstantTreeNode;
if (constT.Value > 0.0) {
// boolean expression is necessarily true
return MakeConstant(1.0);
} else {
// the constant value has no effect on the result of the boolean condition so we can drop the constant term
var orNode = orSymbol.CreateTreeNode();
orNode.AddSubtree(a);
return orNode;
}
} else {
var orNode = orSymbol.CreateTreeNode();
orNode.AddSubtree(a);
orNode.AddSubtree(b);
return orNode;
}
}
private static ISymbolicExpressionTreeNode MakeAnd(ISymbolicExpressionTreeNode a, ISymbolicExpressionTreeNode b) {
if (IsConstant(a) && IsConstant(b)) {
var constA = a as ConstantTreeNode;
var constB = b as ConstantTreeNode;
if (constA.Value > 0.0 && constB.Value > 0.0) {
return MakeConstant(1.0);
} else {
return MakeConstant(-1.0);
}
} else if (IsConstant(a)) {
return MakeAnd(b, a);
} else if (IsConstant(b)) {
var constB = b as ConstantTreeNode;
if (constB.Value > 0.0) {
// the constant value has no effect on the result of the boolean condition so we can drop the constant term
var andNode = andSymbol.CreateTreeNode();
andNode.AddSubtree(a);
return andNode;
} else {
// boolean expression is necessarily false
return MakeConstant(-1.0);
}
} else {
var andNode = andSymbol.CreateTreeNode();
andNode.AddSubtree(a);
andNode.AddSubtree(b);
return andNode;
}
}
private static ISymbolicExpressionTreeNode MakeLessThan(ISymbolicExpressionTreeNode leftSide,
ISymbolicExpressionTreeNode rightSide) {
if (IsConstant(leftSide) && IsConstant(rightSide)) {
var lsConst = leftSide as ConstantTreeNode;
var rsConst = rightSide as ConstantTreeNode;
if (lsConst.Value < rsConst.Value) return MakeConstant(1.0);
else return MakeConstant(-1.0);
} else {
var ltNode = ltSymbol.CreateTreeNode();
ltNode.AddSubtree(leftSide);
ltNode.AddSubtree(rightSide);
return ltNode;
}
}
private static ISymbolicExpressionTreeNode MakeGreaterThan(ISymbolicExpressionTreeNode leftSide,
ISymbolicExpressionTreeNode rightSide) {
if (IsConstant(leftSide) && IsConstant(rightSide)) {
var lsConst = leftSide as ConstantTreeNode;
var rsConst = rightSide as ConstantTreeNode;
if (lsConst.Value > rsConst.Value) return MakeConstant(1.0);
else return MakeConstant(-1.0);
} else {
var gtNode = gtSymbol.CreateTreeNode();
gtNode.AddSubtree(leftSide);
gtNode.AddSubtree(rightSide);
return gtNode;
}
}
private static ISymbolicExpressionTreeNode MakeIfThenElse(ISymbolicExpressionTreeNode condition,
ISymbolicExpressionTreeNode trueBranch, ISymbolicExpressionTreeNode falseBranch) {
if (IsConstant(condition)) {
var constT = condition as ConstantTreeNode;
if (constT.Value > 0.0) return trueBranch;
else return falseBranch;
} else {
var ifNode = ifThenElseSymbol.CreateTreeNode();
if (IsBoolean(condition)) {
ifNode.AddSubtree(condition);
} else {
var gtNode = gtSymbol.CreateTreeNode();
gtNode.AddSubtree(condition);
gtNode.AddSubtree(MakeConstant(0.0));
ifNode.AddSubtree(gtNode);
}
ifNode.AddSubtree(trueBranch);
ifNode.AddSubtree(falseBranch);
return ifNode;
}
}
private static ISymbolicExpressionTreeNode MakeSine(ISymbolicExpressionTreeNode node) {
if (IsConstant(node)) {
var constT = node as ConstantTreeNode;
return MakeConstant(Math.Sin(constT.Value));
} else if (IsFactor(node)) {
var factor = node as FactorVariableTreeNode;
return MakeFactor(factor.Symbol, factor.VariableName, factor.Weights.Select(Math.Sin));
} else if (IsBinFactor(node)) {
var binFactor = node as BinaryFactorVariableTreeNode;
return MakeBinFactor(binFactor.Symbol, binFactor.VariableName, binFactor.VariableValue, Math.Sin(binFactor.Weight));
} else {
var sineNode = sineSymbol.CreateTreeNode();
sineNode.AddSubtree(node);
return sineNode;
}
}
private static ISymbolicExpressionTreeNode MakeTangent(ISymbolicExpressionTreeNode node) {
if (IsConstant(node)) {
var constT = node as ConstantTreeNode;
return MakeConstant(Math.Tan(constT.Value));
} else if (IsFactor(node)) {
var factor = node as FactorVariableTreeNode;
return MakeFactor(factor.Symbol, factor.VariableName, factor.Weights.Select(Math.Tan));
} else if (IsBinFactor(node)) {
var binFactor = node as BinaryFactorVariableTreeNode;
return MakeBinFactor(binFactor.Symbol, binFactor.VariableName, binFactor.VariableValue, Math.Tan(binFactor.Weight));
} else {
var tanNode = tanSymbol.CreateTreeNode();
tanNode.AddSubtree(node);
return tanNode;
}
}
private static ISymbolicExpressionTreeNode MakeCosine(ISymbolicExpressionTreeNode node) {
if (IsConstant(node)) {
var constT = node as ConstantTreeNode;
return MakeConstant(Math.Cos(constT.Value));
} else if (IsFactor(node)) {
var factor = node as FactorVariableTreeNode;
return MakeFactor(factor.Symbol, factor.VariableName, factor.Weights.Select(Math.Cos));
} else if (IsBinFactor(node)) {
var binFactor = node as BinaryFactorVariableTreeNode;
// cos(0) = 1 see similar case for Exp(binfactor)
return MakeSum(MakeBinFactor(binFactor.Symbol, binFactor.VariableName, binFactor.VariableValue, Math.Cos(binFactor.Weight) - 1),
MakeConstant(1.0));
} else {
var cosNode = cosineSymbol.CreateTreeNode();
cosNode.AddSubtree(node);
return cosNode;
}
}
private static ISymbolicExpressionTreeNode MakeExp(ISymbolicExpressionTreeNode node) {
if (IsConstant(node)) {
var constT = node as ConstantTreeNode;
return MakeConstant(Math.Exp(constT.Value));
} else if (IsFactor(node)) {
var factNode = node as FactorVariableTreeNode;
return MakeFactor(factNode.Symbol, factNode.VariableName, factNode.Weights.Select(w => Math.Exp(w)));
} else if (IsBinFactor(node)) {
// exp( binfactor w val=a) = if(val=a) exp(w) else exp(0) = binfactor( (exp(w) - 1) val a) + 1
var binFactor = node as BinaryFactorVariableTreeNode;
return
MakeSum(MakeBinFactor(binFactor.Symbol, binFactor.VariableName, binFactor.VariableValue, Math.Exp(binFactor.Weight) - 1), MakeConstant(1.0));
} else if (IsLog(node)) {
return node.GetSubtree(0);
} else if (IsAddition(node)) {
return node.Subtrees.Select(s => MakeExp(s)).Aggregate((s, t) => MakeProduct(s, t));
} else if (IsSubtraction(node)) {
return node.Subtrees.Select(s => MakeExp(s)).Aggregate((s, t) => MakeProduct(s, Negate(t)));
} else {
var expNode = expSymbol.CreateTreeNode();
expNode.AddSubtree(node);
return expNode;
}
}
private static ISymbolicExpressionTreeNode MakeLog(ISymbolicExpressionTreeNode node) {
if (IsConstant(node)) {
var constT = node as ConstantTreeNode;
return MakeConstant(Math.Log(constT.Value));
} else if (IsFactor(node)) {
var factNode = node as FactorVariableTreeNode;
return MakeFactor(factNode.Symbol, factNode.VariableName, factNode.Weights.Select(w => Math.Log(w)));
} else if (IsExp(node)) {
return node.GetSubtree(0);
} else if (IsSquareRoot(node)) {
return MakeFraction(MakeLog(node.GetSubtree(0)), MakeConstant(2.0));
} else {
var logNode = logSymbol.CreateTreeNode();
logNode.AddSubtree(node);
return logNode;
}
}
private static ISymbolicExpressionTreeNode MakeSquare(ISymbolicExpressionTreeNode node) {
if (IsConstant(node)) {
var constT = node as ConstantTreeNode;
return MakeConstant(constT.Value * constT.Value);
} else if (IsFactor(node)) {
var factNode = node as FactorVariableTreeNode;
return MakeFactor(factNode.Symbol, factNode.VariableName, factNode.Weights.Select(w => w * w));
} else if (IsBinFactor(node)) {
var binFactor = node as BinaryFactorVariableTreeNode;
return MakeBinFactor(binFactor.Symbol, binFactor.VariableName, binFactor.VariableValue, binFactor.Weight * binFactor.Weight);
} else if (IsSquareRoot(node)) {
return node.GetSubtree(0);
} else if (IsMultiplication(node)) {
// sqr( x * y ) = sqr(x) * sqr(y)
var mulNode = mulSymbol.CreateTreeNode();
foreach (var subtree in node.Subtrees) {
mulNode.AddSubtree(MakeSquare(subtree));
}
return mulNode;
} else if (IsAbsolute(node)) {
return MakeSquare(node.GetSubtree(0)); // sqr(abs(x)) = sqr(x)
} else if (IsExp(node)) {
return MakeExp(MakeProduct(node.GetSubtree(0), MakeConstant(2.0))); // sqr(exp(x)) = exp(2x)
} else if (IsSquare(node)) {
return MakePower(node.GetSubtree(0), MakeConstant(4));
} else if (IsCube(node)) {
return MakePower(node.GetSubtree(0), MakeConstant(6));
} else {
var sqrNode = sqrSymbol.CreateTreeNode();
sqrNode.AddSubtree(node);
return sqrNode;
}
}
private static ISymbolicExpressionTreeNode MakeCube(ISymbolicExpressionTreeNode node) {
if (IsConstant(node)) {
var constT = node as ConstantTreeNode;
return MakeConstant(constT.Value * constT.Value * constT.Value);
} else if (IsFactor(node)) {
var factNode = node as FactorVariableTreeNode;
return MakeFactor(factNode.Symbol, factNode.VariableName, factNode.Weights.Select(w => w * w * w));
} else if (IsBinFactor(node)) {
var binFactor = node as BinaryFactorVariableTreeNode;
return MakeBinFactor(binFactor.Symbol, binFactor.VariableName, binFactor.VariableValue, binFactor.Weight * binFactor.Weight * binFactor.Weight);
} else if (IsCubeRoot(node)) {
return node.GetSubtree(0); // NOTE: not really accurate because cuberoot(x) with negative x is evaluated to NaN and after this simplification we evaluate as x
} else if (IsExp(node)) {
return MakeExp(MakeProduct(node.GetSubtree(0), MakeConstant(3)));
} else if (IsSquare(node)) {
return MakePower(node.GetSubtree(0), MakeConstant(6));
} else if (IsCube(node)) {
return MakePower(node.GetSubtree(0), MakeConstant(9));
} else {
var cubeNode = cubeSymbol.CreateTreeNode();
cubeNode.AddSubtree(node);
return cubeNode;
}
}
private static ISymbolicExpressionTreeNode MakeAbs(ISymbolicExpressionTreeNode node) {
if (IsConstant(node)) {
var constT = node as ConstantTreeNode;
return MakeConstant(Math.Abs(constT.Value));
} else if (IsFactor(node)) {
var factNode = node as FactorVariableTreeNode;
return MakeFactor(factNode.Symbol, factNode.VariableName, factNode.Weights.Select(w => Math.Abs(w)));
} else if (IsBinFactor(node)) {
var binFactor = node as BinaryFactorVariableTreeNode;
return MakeBinFactor(binFactor.Symbol, binFactor.VariableName, binFactor.VariableValue, Math.Abs(binFactor.Weight));
} else if (IsSquare(node) || IsExp(node) || IsSquareRoot(node) || IsCubeRoot(node)) {
return node; // abs(sqr(x)) = sqr(x), abs(exp(x)) = exp(x) ...
} else if (IsMultiplication(node)) {
var mul = mulSymbol.CreateTreeNode();
foreach (var st in node.Subtrees) {
mul.AddSubtree(MakeAbs(st));
}
return mul;
} else if (IsDivision(node)) {
var div = divSymbol.CreateTreeNode();
foreach (var st in node.Subtrees) {
div.AddSubtree(MakeAbs(st));
}
return div;
} else {
var absNode = absSymbol.CreateTreeNode();
absNode.AddSubtree(node);
return absNode;
}
}
// constant folding only
private static ISymbolicExpressionTreeNode MakeAnalyticalQuotient(ISymbolicExpressionTreeNode a, ISymbolicExpressionTreeNode b) {
if (IsConstant(b)) {
var c = b as ConstantTreeNode;
return MakeFraction(a, MakeConstant(Math.Sqrt(1.0 + c.Value * c.Value)));
} else if (IsFactor(b)) {
var factNode = b as FactorVariableTreeNode;
return MakeFraction(a, MakeFactor(factNode.Symbol, factNode.VariableName, factNode.Weights.Select(w => Math.Sqrt(1.0 + w * w))));
} else if (IsBinFactor(b)) {
var binFactor = b as BinaryFactorVariableTreeNode;
return MakeFraction(a, MakeBinFactor(binFactor.Symbol, binFactor.VariableName, binFactor.VariableValue, Math.Sqrt(1.0 + binFactor.Weight * binFactor.Weight)));
} else {
var aqNode = aqSymbol.CreateTreeNode();
aqNode.AddSubtree(a);
aqNode.AddSubtree(b);
return aqNode;
}
}
private static ISymbolicExpressionTreeNode MakeSquareRoot(ISymbolicExpressionTreeNode node) {
if (IsConstant(node)) {
var constT = node as ConstantTreeNode;
return MakeConstant(Math.Sqrt(constT.Value));
} else if (IsFactor(node)) {
var factNode = node as FactorVariableTreeNode;
return MakeFactor(factNode.Symbol, factNode.VariableName, factNode.Weights.Select(w => Math.Sqrt(w)));
} else if (IsBinFactor(node)) {
var binFactor = node as BinaryFactorVariableTreeNode;
return MakeBinFactor(binFactor.Symbol, binFactor.VariableName, binFactor.VariableValue, Math.Sqrt(binFactor.Weight));
} else if (IsSquare(node)) {
return node.GetSubtree(0); // NOTE: not really accurate because sqrt(x) with negative x is evaluated to NaN and after this simplification we evaluate as x
} else {
var sqrtNode = sqrtSymbol.CreateTreeNode();
sqrtNode.AddSubtree(node);
return sqrtNode;
}
}
private static ISymbolicExpressionTreeNode MakeCubeRoot(ISymbolicExpressionTreeNode node) {
if (IsConstant(node)) {
var constT = node as ConstantTreeNode;
return MakeConstant(Math.Pow(constT.Value, 1.0 / 3.0));
} else if (IsFactor(node)) {
var factNode = node as FactorVariableTreeNode;
return MakeFactor(factNode.Symbol, factNode.VariableName, factNode.Weights.Select(w => Math.Pow(w, 1.0 / 3.0)));
} else if (IsBinFactor(node)) {
var binFactor = node as BinaryFactorVariableTreeNode;
return MakeBinFactor(binFactor.Symbol, binFactor.VariableName, binFactor.VariableValue, Math.Sqrt(Math.Pow(binFactor.Weight, 1.0 / 3.0)));
} else if (IsCube(node)) {
return node.GetSubtree(0);
} else {
var cubeRootNode = cubeRootSymbol.CreateTreeNode();
cubeRootNode.AddSubtree(node);
return cubeRootNode;
}
}
private static ISymbolicExpressionTreeNode MakeRoot(ISymbolicExpressionTreeNode a, ISymbolicExpressionTreeNode b) {
if (IsConstant(a) && IsConstant(b)) {
var constA = a as ConstantTreeNode;
var constB = b as ConstantTreeNode;
return MakeConstant(Math.Pow(constA.Value, 1.0 / Math.Round(constB.Value)));
} else if (IsFactor(a) && IsConstant(b)) {
var factNode = a as FactorVariableTreeNode;
var constNode = b as ConstantTreeNode;
return MakeFactor(factNode.Symbol, factNode.VariableName,
factNode.Weights.Select(w => Math.Pow(w, 1.0 / Math.Round(constNode.Value))));
} else if (IsBinFactor(a) && IsConstant(b)) {
var binFactor = a as BinaryFactorVariableTreeNode;
var constNode = b as ConstantTreeNode;
return MakeBinFactor(binFactor.Symbol, binFactor.VariableName, binFactor.VariableValue, Math.Pow(binFactor.Weight, 1.0 / Math.Round(constNode.Value)));
} else if (IsConstant(a) && IsFactor(b)) {
var constNode = a as ConstantTreeNode;
var factNode = b as FactorVariableTreeNode;
return MakeFactor(factNode.Symbol, factNode.VariableName, factNode.Weights.Select(w => Math.Pow(constNode.Value, 1.0 / Math.Round(w))));
} else if (IsConstant(a) && IsBinFactor(b)) {
var constNode = a as ConstantTreeNode;
var factNode = b as BinaryFactorVariableTreeNode;
return MakeBinFactor(factNode.Symbol, factNode.VariableName, factNode.VariableValue, Math.Pow(constNode.Value, 1.0 / Math.Round(factNode.Weight)));
} else if (IsFactor(a) && IsFactor(b) && AreSameTypeAndVariable(a, b)) {
var node0 = a as FactorVariableTreeNode;
var node1 = b as FactorVariableTreeNode;
return MakeFactor(node0.Symbol, node0.VariableName, node0.Weights.Zip(node1.Weights, (u, v) => Math.Pow(u, 1.0 / Math.Round(v))));
} else if (IsConstant(b)) {
var constB = b as ConstantTreeNode;
var constBValue = Math.Round(constB.Value);
if (constBValue == 1.0) {
// root(a, 1) => a
return a;
} else if (constBValue == 0.0) {
// root(a, 0) is not defined
//return MakeConstant(1.0);
return MakeConstant(double.NaN);
} else if (constBValue == -1.0) {
// root(a, -1) => a^(-1/1) => 1/a
return MakeFraction(MakeConstant(1.0), a);
} else if (constBValue < 0) {
// root(a, -b) => a^(-1/b) => (1/a)^(1/b) => root(1, b) / root(a, b) => 1 / root(a, b)
var rootNode = rootSymbol.CreateTreeNode();
rootNode.AddSubtree(a);
rootNode.AddSubtree(MakeConstant(-1.0 * constBValue));
return MakeFraction(MakeConstant(1.0), rootNode);
} else {
var rootNode = rootSymbol.CreateTreeNode();
rootNode.AddSubtree(a);
rootNode.AddSubtree(MakeConstant(constBValue));
return rootNode;
}
} else {
var rootNode = rootSymbol.CreateTreeNode();
rootNode.AddSubtree(a);
rootNode.AddSubtree(b);
return rootNode;
}
}
private static ISymbolicExpressionTreeNode MakePower(ISymbolicExpressionTreeNode a, ISymbolicExpressionTreeNode b) {
if (IsConstant(a) && IsConstant(b)) {
var constA = a as ConstantTreeNode;
var constB = b as ConstantTreeNode;
return MakeConstant(Math.Pow(constA.Value, Math.Round(constB.Value)));
} else if (IsFactor(a) && IsConstant(b)) {
var factNode = a as FactorVariableTreeNode;
var constNode = b as ConstantTreeNode;
return MakeFactor(factNode.Symbol, factNode.VariableName, factNode.Weights.Select(w => Math.Pow(w, Math.Round(constNode.Value))));
} else if (IsBinFactor(a) && IsConstant(b)) {
var binFactor = a as BinaryFactorVariableTreeNode;
var constNode = b as ConstantTreeNode;
return MakeBinFactor(binFactor.Symbol, binFactor.VariableName, binFactor.VariableValue, Math.Pow(binFactor.Weight, Math.Round(constNode.Value)));
} else if (IsConstant(a) && IsFactor(b)) {
var constNode = a as ConstantTreeNode;
var factNode = b as FactorVariableTreeNode;
return MakeFactor(factNode.Symbol, factNode.VariableName, factNode.Weights.Select(w => Math.Pow(constNode.Value, Math.Round(w))));
} else if (IsConstant(a) && IsBinFactor(b)) {
var constNode = a as ConstantTreeNode;
var factNode = b as BinaryFactorVariableTreeNode;
return MakeBinFactor(factNode.Symbol, factNode.VariableName, factNode.VariableValue, Math.Pow(constNode.Value, Math.Round(factNode.Weight)));
} else if (IsFactor(a) && IsFactor(b) && AreSameTypeAndVariable(a, b)) {
var node0 = a as FactorVariableTreeNode;
var node1 = b as FactorVariableTreeNode;
return MakeFactor(node0.Symbol, node0.VariableName, node0.Weights.Zip(node1.Weights, (u, v) => Math.Pow(u, Math.Round(v))));
} else if (IsConstant(b)) {
var constB = b as ConstantTreeNode;
double exponent = Math.Round(constB.Value);
if (exponent == 0.0) {
// a^0 => 1
return MakeConstant(1.0);
} else if (exponent == 1.0) {
// a^1 => a
return a;
} else if (exponent == -1.0) {
// a^-1 => 1/a
return MakeFraction(MakeConstant(1.0), a);
} else if (exponent < 0) {
// a^-b => (1/a)^b => 1/(a^b)
var powNode = powSymbol.CreateTreeNode();
powNode.AddSubtree(a);
powNode.AddSubtree(MakeConstant(-1.0 * exponent));
return MakeFraction(MakeConstant(1.0), powNode);
} else {
var powNode = powSymbol.CreateTreeNode();
powNode.AddSubtree(a);
powNode.AddSubtree(MakeConstant(exponent));
return powNode;
}
} else {
var powNode = powSymbol.CreateTreeNode();
powNode.AddSubtree(a);
powNode.AddSubtree(b);
return powNode;
}
}
// MakeFraction, MakeProduct and MakeSum take two already simplified trees and create a new simplified tree
private static ISymbolicExpressionTreeNode MakeFraction(ISymbolicExpressionTreeNode a, ISymbolicExpressionTreeNode b) {
if (IsConstant(a) && IsConstant(b)) {
// fold constants
return MakeConstant(((ConstantTreeNode)a).Value / ((ConstantTreeNode)b).Value);
} else if ((IsConstant(a) && ((ConstantTreeNode)a).Value != 1.0)) {
// a / x => (a * 1/a) / (x * 1/a) => 1 / (x * 1/a)
return MakeFraction(MakeConstant(1.0), MakeProduct(b, Invert(a)));
} else if (IsVariableBase(a) && IsConstant(b)) {
// merge constant values into variable weights
var constB = ((ConstantTreeNode)b).Value;
((VariableTreeNodeBase)a).Weight /= constB;
return a;
} else if (IsFactor(a) && IsConstant(b)) {
var factNode = a as FactorVariableTreeNode;
var constNode = b as ConstantTreeNode;
return MakeFactor(factNode.Symbol, factNode.VariableName, factNode.Weights.Select(w => w / constNode.Value));
} else if (IsBinFactor(a) && IsConstant(b)) {
var factNode = a as BinaryFactorVariableTreeNode;
var constNode = b as ConstantTreeNode;
return MakeBinFactor(factNode.Symbol, factNode.VariableName, factNode.VariableValue, factNode.Weight / constNode.Value);
} else if (IsFactor(a) && IsFactor(b) && AreSameTypeAndVariable(a, b)) {
var node0 = a as FactorVariableTreeNode;
var node1 = b as FactorVariableTreeNode;
return MakeFactor(node0.Symbol, node0.VariableName, node0.Weights.Zip(node1.Weights, (u, v) => u / v));
} else if (IsFactor(a) && IsBinFactor(b) && ((IVariableTreeNode)a).VariableName == ((IVariableTreeNode)b).VariableName) {
var node0 = a as FactorVariableTreeNode;
var node1 = b as BinaryFactorVariableTreeNode;
var varValues = node0.Symbol.GetVariableValues(node0.VariableName).ToArray();
var wi = Array.IndexOf(varValues, node1.VariableValue);
if (wi < 0) throw new ArgumentException();
var newWeighs = new double[varValues.Length];
node0.Weights.CopyTo(newWeighs, 0);
for (int i = 0; i < newWeighs.Length; i++)
if (wi == i) newWeighs[i] /= node1.Weight;
else newWeighs[i] /= 0.0;
return MakeFactor(node0.Symbol, node0.VariableName, newWeighs);
} else if (IsFactor(a)) {
return MakeFraction(MakeConstant(1.0), MakeProduct(b, Invert(a)));
} else if (IsVariableBase(a) && IsVariableBase(b) && AreSameTypeAndVariable(a, b) && !IsBinFactor(b)) {
// cancel variables (not allowed for bin factors because of division by zero)
var aVar = a as VariableTreeNode;
var bVar = b as VariableTreeNode;
return MakeConstant(aVar.Weight / bVar.Weight);
} else if (IsAddition(a) && IsConstant(b)) {
return a.Subtrees
.Select(x => GetSimplifiedTree(x))
.Select(x => MakeFraction(x, GetSimplifiedTree(b)))
.Aggregate((c, d) => MakeSum(c, d));
} else if (IsMultiplication(a) && IsConstant(b)) {
return MakeProduct(a, Invert(b));
} else if (IsDivision(a) && IsConstant(b)) {
// (a1 / a2) / c => (a1 / (a2 * c))
return MakeFraction(a.GetSubtree(0), MakeProduct(a.GetSubtree(1), b));
} else if (IsDivision(a) && IsDivision(b)) {
// (a1 / a2) / (b1 / b2) =>
return MakeFraction(MakeProduct(a.GetSubtree(0), b.GetSubtree(1)), MakeProduct(a.GetSubtree(1), b.GetSubtree(0)));
} else if (IsDivision(a)) {
// (a1 / a2) / b => (a1 / (a2 * b))
return MakeFraction(a.GetSubtree(0), MakeProduct(a.GetSubtree(1), b));
} else if (IsDivision(b)) {
// a / (b1 / b2) => (a * b2) / b1
return MakeFraction(MakeProduct(a, b.GetSubtree(1)), b.GetSubtree(0));
} else if (IsAnalyticalQuotient(a)) {
return MakeAnalyticalQuotient(a.GetSubtree(0), MakeProduct(a.GetSubtree(1), b));
} else {
var div = divSymbol.CreateTreeNode();
div.AddSubtree(a);
div.AddSubtree(b);
return div;
}
}
private static ISymbolicExpressionTreeNode MakeSum(ISymbolicExpressionTreeNode a, ISymbolicExpressionTreeNode b) {
if (IsConstant(a) && IsConstant(b)) {
// fold constants
((ConstantTreeNode)a).Value += ((ConstantTreeNode)b).Value;
return a;
} else if (IsConstant(a)) {
// c + x => x + c
// b is not constant => make sure constant is on the right
return MakeSum(b, a);
} else if (IsConstant(b) && ((ConstantTreeNode)b).Value == 0.0) {
// x + 0 => x
return a;
} else if (IsFactor(a) && IsConstant(b)) {
var factNode = a as FactorVariableTreeNode;
var constNode = b as ConstantTreeNode;
return MakeFactor(factNode.Symbol, factNode.VariableName, factNode.Weights.Select((w) => w + constNode.Value));
} else if (IsFactor(a) && IsFactor(b) && AreSameTypeAndVariable(a, b)) {
var node0 = a as FactorVariableTreeNode;
var node1 = b as FactorVariableTreeNode;
return MakeFactor(node0.Symbol, node0.VariableName, node0.Weights.Zip(node1.Weights, (u, v) => u + v));
} else if (IsBinFactor(a) && IsFactor(b)) {
return MakeSum(b, a);
} else if (IsFactor(a) && IsBinFactor(b) &&
((IVariableTreeNode)a).VariableName == ((IVariableTreeNode)b).VariableName) {
var node0 = a as FactorVariableTreeNode;
var node1 = b as BinaryFactorVariableTreeNode;
var varValues = node0.Symbol.GetVariableValues(node0.VariableName).ToArray();
var wi = Array.IndexOf(varValues, node1.VariableValue);
if (wi < 0) throw new ArgumentException();
var newWeighs = new double[varValues.Length];
node0.Weights.CopyTo(newWeighs, 0);
newWeighs[wi] += node1.Weight;
return MakeFactor(node0.Symbol, node0.VariableName, newWeighs);
} else if (IsAddition(a) && IsAddition(b)) {
// merge additions
var add = addSymbol.CreateTreeNode();
// add all sub trees except for the last
for (int i = 0; i < a.Subtrees.Count() - 1; i++) add.AddSubtree(a.GetSubtree(i));
for (int i = 0; i < b.Subtrees.Count() - 1; i++) add.AddSubtree(b.GetSubtree(i));
if (IsConstant(a.Subtrees.Last()) && IsConstant(b.Subtrees.Last())) {
add.AddSubtree(MakeSum(a.Subtrees.Last(), b.Subtrees.Last()));
} else if (IsConstant(a.Subtrees.Last())) {
add.AddSubtree(b.Subtrees.Last());
add.AddSubtree(a.Subtrees.Last());
} else {
add.AddSubtree(a.Subtrees.Last());
add.AddSubtree(b.Subtrees.Last());
}
MergeVariablesInSum(add);
if (add.Subtrees.Count() == 1) {
return add.GetSubtree(0);
} else {
return add;
}
} else if (IsAddition(b)) {
return MakeSum(b, a);
} else if (IsAddition(a) && IsConstant(b)) {
// a is an addition and b is a constant => append b to a and make sure the constants are merged
var add = addSymbol.CreateTreeNode();
// add all sub trees except for the last
for (int i = 0; i < a.Subtrees.Count() - 1; i++) add.AddSubtree(a.GetSubtree(i));
if (IsConstant(a.Subtrees.Last()))
add.AddSubtree(MakeSum(a.Subtrees.Last(), b));
else {
add.AddSubtree(a.Subtrees.Last());
add.AddSubtree(b);
}
return add;
} else if (IsAddition(a)) {
// a is already an addition => append b
var add = addSymbol.CreateTreeNode();
add.AddSubtree(b);
foreach (var subtree in a.Subtrees) {
add.AddSubtree(subtree);
}
MergeVariablesInSum(add);
if (add.Subtrees.Count() == 1) {
return add.GetSubtree(0);
} else {
return add;
}
} else {
var add = addSymbol.CreateTreeNode();
add.AddSubtree(a);
add.AddSubtree(b);
MergeVariablesInSum(add);
if (add.Subtrees.Count() == 1) {
return add.GetSubtree(0);
} else {
return add;
}
}
}
// makes sure variable symbols in sums are combined
private static void MergeVariablesInSum(ISymbolicExpressionTreeNode sum) {
var subtrees = new List(sum.Subtrees);
while (sum.Subtrees.Any()) sum.RemoveSubtree(0);
var groupedVarNodes = from node in subtrees.OfType()
where node.SubtreeCount == 0
group node by GroupId(node) into g
select g;
var constant = (from node in subtrees.OfType()
select node.Value).DefaultIfEmpty(0.0).Sum();
var unchangedSubtrees = subtrees.Where(t => t.SubtreeCount > 0 || !(t is IVariableTreeNode) && !(t is ConstantTreeNode));
foreach (var variableNodeGroup in groupedVarNodes) {
var firstNode = variableNodeGroup.First();
if (firstNode is VariableTreeNodeBase) {
var representative = firstNode as VariableTreeNodeBase;
var weightSum = variableNodeGroup.Cast().Select(t => t.Weight).Sum();
representative.Weight = weightSum;
sum.AddSubtree(representative);
} else if (firstNode is FactorVariableTreeNode) {
var representative = firstNode as FactorVariableTreeNode;
foreach (var node in variableNodeGroup.Skip(1).Cast()) {
for (int j = 0; j < representative.Weights.Length; j++) {
representative.Weights[j] += node.Weights[j];
}
}
sum.AddSubtree(representative);
}
}
foreach (var unchangedSubtree in unchangedSubtrees)
sum.AddSubtree(unchangedSubtree);
if (constant != 0.0) {
sum.AddSubtree(MakeConstant(constant));
}
}
// nodes referencing variables can be grouped if they have
private static string GroupId(IVariableTreeNode node) {
var binaryFactorNode = node as BinaryFactorVariableTreeNode;
var factorNode = node as FactorVariableTreeNode;
var variableNode = node as VariableTreeNode;
var laggedVarNode = node as LaggedVariableTreeNode;
if (variableNode != null) {
return "var " + variableNode.VariableName;
} else if (binaryFactorNode != null) {
return "binfactor " + binaryFactorNode.VariableName + " " + binaryFactorNode.VariableValue;
} else if (factorNode != null) {
return "factor " + factorNode.VariableName;
} else if (laggedVarNode != null) {
return "lagged " + laggedVarNode.VariableName + " " + laggedVarNode.Lag;
} else {
throw new NotSupportedException();
}
}
private static ISymbolicExpressionTreeNode MakeProduct(ISymbolicExpressionTreeNode a, ISymbolicExpressionTreeNode b) {
if (IsConstant(a) && IsConstant(b)) {
// fold constants
return MakeConstant(((ConstantTreeNode)a).Value * ((ConstantTreeNode)b).Value);
} else if (IsConstant(a)) {
// a * $ => $ * a
return MakeProduct(b, a);
} else if (IsFactor(a) && IsFactor(b) && AreSameTypeAndVariable(a, b)) {
var node0 = a as FactorVariableTreeNode;
var node1 = b as FactorVariableTreeNode;
return MakeFactor(node0.Symbol, node0.VariableName, node0.Weights.Zip(node1.Weights, (u, v) => u * v));
} else if (IsBinFactor(a) && IsBinFactor(b) && AreSameTypeAndVariable(a, b)) {
var node0 = a as BinaryFactorVariableTreeNode;
var node1 = b as BinaryFactorVariableTreeNode;
return MakeBinFactor(node0.Symbol, node0.VariableName, node0.VariableValue, node0.Weight * node1.Weight);
} else if (IsFactor(a) && IsConstant(b)) {
var node0 = a as FactorVariableTreeNode;
var node1 = b as ConstantTreeNode;
return MakeFactor(node0.Symbol, node0.VariableName, node0.Weights.Select(w => w * node1.Value));
} else if (IsBinFactor(a) && IsConstant(b)) {
var node0 = a as BinaryFactorVariableTreeNode;
var node1 = b as ConstantTreeNode;
return MakeBinFactor(node0.Symbol, node0.VariableName, node0.VariableValue, node0.Weight * node1.Value);
} else if (IsBinFactor(a) && IsFactor(b)) {
return MakeProduct(b, a);
} else if (IsFactor(a) && IsBinFactor(b) &&
((IVariableTreeNode)a).VariableName == ((IVariableTreeNode)b).VariableName) {
var node0 = a as FactorVariableTreeNode;
var node1 = b as BinaryFactorVariableTreeNode;
var varValues = node0.Symbol.GetVariableValues(node0.VariableName).ToArray();
var wi = Array.IndexOf(varValues, node1.VariableValue);
if (wi < 0) throw new ArgumentException();
return MakeBinFactor(node1.Symbol, node1.VariableName, node1.VariableValue, node1.Weight * node0.Weights[wi]);
} else if (IsConstant(b) && ((ConstantTreeNode)b).Value == 1.0) {
// $ * 1.0 => $
return a;
} else if (IsConstant(b) && ((ConstantTreeNode)b).Value == 0.0) {
return MakeConstant(0);
} else if (IsConstant(b) && IsVariableBase(a)) {
// multiply constants into variables weights
((VariableTreeNodeBase)a).Weight *= ((ConstantTreeNode)b).Value;
return a;
} else if (IsConstant(b) && IsAddition(a) ||
IsFactor(b) && IsAddition(a) ||
IsBinFactor(b) && IsAddition(a)) {
// multiply constants into additions
return a.Subtrees.Select(x => MakeProduct(GetSimplifiedTree(x), GetSimplifiedTree(b))).Aggregate((c, d) => MakeSum(c, d));
} else if (IsDivision(a) && IsDivision(b)) {
// (a1 / a2) * (b1 / b2) => (a1 * b1) / (a2 * b2)
return MakeFraction(MakeProduct(a.GetSubtree(0), b.GetSubtree(0)), MakeProduct(a.GetSubtree(1), b.GetSubtree(1)));
} else if (IsDivision(a)) {
// (a1 / a2) * b => (a1 * b) / a2
return MakeFraction(MakeProduct(a.GetSubtree(0), b), a.GetSubtree(1));
} else if (IsDivision(b)) {
// a * (b1 / b2) => (b1 * a) / b2
return MakeFraction(MakeProduct(b.GetSubtree(0), a), b.GetSubtree(1));
} else if (IsMultiplication(a) && IsMultiplication(b)) {
// merge multiplications (make sure constants are merged)
var mul = mulSymbol.CreateTreeNode();
for (int i = 0; i < a.Subtrees.Count(); i++) mul.AddSubtree(a.GetSubtree(i));
for (int i = 0; i < b.Subtrees.Count(); i++) mul.AddSubtree(b.GetSubtree(i));
MergeVariablesAndConstantsInProduct(mul);
return mul;
} else if (IsMultiplication(b)) {
return MakeProduct(b, a);
} else if (IsMultiplication(a)) {
// a is already an multiplication => append b
a.AddSubtree(GetSimplifiedTree(b));
MergeVariablesAndConstantsInProduct(a);
return a;
} else if (IsAbsolute(a) && IsAbsolute(b)) {
return MakeAbs(MakeProduct(a.GetSubtree(0), b.GetSubtree(0)));
} else if (IsAbsolute(a) && IsConstant(b)) {
var constNode = b as ConstantTreeNode;
var posF = Math.Abs(constNode.Value);
if (constNode.Value > 0) {
return MakeAbs(MakeProduct(a.GetSubtree(0), MakeConstant(posF)));
} else {
var mul = mulSymbol.CreateTreeNode();
mul.AddSubtree(MakeAbs(MakeProduct(a.GetSubtree(0), MakeConstant(posF))));
mul.AddSubtree(MakeConstant(-1.0));
return mul;
}
} else if (IsAnalyticalQuotient(a)) {
return MakeAnalyticalQuotient(MakeProduct(a.GetSubtree(0), b), a.GetSubtree(1));
} else {
var mul = mulSymbol.CreateTreeNode();
mul.AddSubtree(a);
mul.AddSubtree(b);
MergeVariablesAndConstantsInProduct(mul);
return mul;
}
}
#endregion
#region helper functions
private static bool ContainsVariableCondition(ISymbolicExpressionTreeNode node) {
if (node.Symbol is VariableCondition) return true;
foreach (var subtree in node.Subtrees)
if (ContainsVariableCondition(subtree)) return true;
return false;
}
private static ISymbolicExpressionTreeNode AddLagToDynamicNodes(ISymbolicExpressionTreeNode node, int lag) {
var laggedTreeNode = node as ILaggedTreeNode;
var variableNode = node as VariableTreeNode;
var variableConditionNode = node as VariableConditionTreeNode;
if (laggedTreeNode != null)
laggedTreeNode.Lag += lag;
else if (variableNode != null) {
var laggedVariableNode = (LaggedVariableTreeNode)laggedVariableSymbol.CreateTreeNode();
laggedVariableNode.Lag = lag;
laggedVariableNode.VariableName = variableNode.VariableName;
return laggedVariableNode;
} else if (variableConditionNode != null) {
throw new NotSupportedException("Removal of time lags around variable condition symbols is not allowed.");
}
var subtrees = new List(node.Subtrees);
while (node.SubtreeCount > 0) node.RemoveSubtree(0);
foreach (var subtree in subtrees) {
node.AddSubtree(AddLagToDynamicNodes(subtree, lag));
}
return node;
}
private static bool AreSameTypeAndVariable(ISymbolicExpressionTreeNode a, ISymbolicExpressionTreeNode b) {
return GroupId((IVariableTreeNode)a) == GroupId((IVariableTreeNode)b);
}
// helper to combine the constant factors in products and to combine variables (powers of 2, 3...)
private static void MergeVariablesAndConstantsInProduct(ISymbolicExpressionTreeNode prod) {
var subtrees = new List(prod.Subtrees);
while (prod.Subtrees.Any()) prod.RemoveSubtree(0);
var groupedVarNodes = from node in subtrees.OfType()
where node.SubtreeCount == 0
group node by GroupId(node) into g
orderby g.Count()
select g;
var constantProduct = (from node in subtrees.OfType()
select node.Weight)
.Concat(from node in subtrees.OfType()
select node.Value)
.DefaultIfEmpty(1.0)
.Aggregate((c1, c2) => c1 * c2);
var unchangedSubtrees = from tree in subtrees
where tree.SubtreeCount > 0 || !(tree is IVariableTreeNode) && !(tree is ConstantTreeNode)
select tree;
foreach (var variableNodeGroup in groupedVarNodes) {
var firstNode = variableNodeGroup.First();
if (firstNode is VariableTreeNodeBase) {
var representative = (VariableTreeNodeBase)firstNode;
representative.Weight = 1.0;
if (variableNodeGroup.Count() > 1) {
var poly = mulSymbol.CreateTreeNode();
for (int p = 0; p < variableNodeGroup.Count(); p++) {
poly.AddSubtree((ISymbolicExpressionTreeNode)representative.Clone());
}
prod.AddSubtree(poly);
} else {
prod.AddSubtree(representative);
}
} else if (firstNode is FactorVariableTreeNode) {
var representative = (FactorVariableTreeNode)firstNode;
foreach (var node in variableNodeGroup.Skip(1).Cast()) {
for (int j = 0; j < representative.Weights.Length; j++) {
representative.Weights[j] *= node.Weights[j];
}
}
for (int j = 0; j < representative.Weights.Length; j++) {
representative.Weights[j] *= constantProduct;
}
constantProduct = 1.0;
// if the product already contains a factor it is not necessary to multiply a constant below
prod.AddSubtree(representative);
}
}
foreach (var unchangedSubtree in unchangedSubtrees)
prod.AddSubtree(unchangedSubtree);
if (constantProduct != 1.0) {
prod.AddSubtree(MakeConstant(constantProduct));
}
}
///
/// x => x * -1
/// Is only used in cases where it is not necessary to create new tree nodes. Manipulates x directly.
///
///
/// -x
private static ISymbolicExpressionTreeNode Negate(ISymbolicExpressionTreeNode x) {
if (IsConstant(x)) {
((ConstantTreeNode)x).Value *= -1;
} else if (IsVariableBase(x)) {
var variableTree = (VariableTreeNodeBase)x;
variableTree.Weight *= -1.0;
} else if (IsFactor(x)) {
var factorNode = (FactorVariableTreeNode)x;
for (int i = 0; i < factorNode.Weights.Length; i++) factorNode.Weights[i] *= -1;
} else if (IsBinFactor(x)) {
var factorNode = (BinaryFactorVariableTreeNode)x;
factorNode.Weight *= -1;
} else if (IsAddition(x)) {
// (x0 + x1 + .. + xn) * -1 => (-x0 + -x1 + .. + -xn)
var subtrees = new List(x.Subtrees);
while (x.Subtrees.Any()) x.RemoveSubtree(0);
foreach (var subtree in subtrees) {
x.AddSubtree(Negate(subtree));
}
} else if (IsMultiplication(x) || IsDivision(x)) {
// x0 * x1 * .. * xn * -1 => x0 * x1 * .. * -xn
var lastSubTree = x.Subtrees.Last();
x.RemoveSubtree(x.SubtreeCount - 1);
x.AddSubtree(Negate(lastSubTree)); // last is maybe a constant, prefer to negate the constant
} else {
// any other function
return MakeProduct(x, MakeConstant(-1));
}
return x;
}
///
/// x => 1/x
/// Must create new tree nodes
///
///
///
private static ISymbolicExpressionTreeNode Invert(ISymbolicExpressionTreeNode x) {
if (IsConstant(x)) {
return MakeConstant(1.0 / ((ConstantTreeNode)x).Value);
} else if (IsFactor(x)) {
var factorNode = (FactorVariableTreeNode)x;
return MakeFactor(factorNode.Symbol, factorNode.VariableName, factorNode.Weights.Select(w => 1.0 / w));
} else if (IsDivision(x)) {
return MakeFraction(x.GetSubtree(1), x.GetSubtree(0));
} else {
// any other function
return MakeFraction(MakeConstant(1), x);
}
}
private static ISymbolicExpressionTreeNode MakeConstant(double value) {
ConstantTreeNode constantTreeNode = (ConstantTreeNode)(constSymbol.CreateTreeNode());
constantTreeNode.Value = value;
return constantTreeNode;
}
private static ISymbolicExpressionTreeNode MakeFactor(FactorVariable sy, string variableName, IEnumerable weights) {
var tree = (FactorVariableTreeNode)sy.CreateTreeNode();
tree.VariableName = variableName;
tree.Weights = weights.ToArray();
return tree;
}
private static ISymbolicExpressionTreeNode MakeBinFactor(BinaryFactorVariable sy, string variableName, string variableValue, double weight) {
var tree = (BinaryFactorVariableTreeNode)sy.CreateTreeNode();
tree.VariableName = variableName;
tree.VariableValue = variableValue;
tree.Weight = weight;
return tree;
}
#endregion
}
}