#region License Information
/* HeuristicLab
* Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using HeuristicLab.Common;
using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Symbols;
using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
///
/// Simplistic simplifier for arithmetic expressions
///
public class SymbolicSimplifier {
private Addition addSymbol = new Addition();
private Subtraction subSymbol = new Subtraction();
private Multiplication mulSymbol = new Multiplication();
private Division divSymbol = new Division();
private Constant constSymbol = new Constant();
private Variable varSymbol = new Variable();
private Logarithm logSymbol = new Logarithm();
private Exponential expSymbol = new Exponential();
private Root rootSymbol = new Root();
private Power powSymbol = new Power();
private Sine sineSymbol = new Sine();
private Cosine cosineSymbol = new Cosine();
private Tangent tanSymbol = new Tangent();
private IfThenElse ifThenElseSymbol = new IfThenElse();
private And andSymbol = new And();
private Or orSymbol = new Or();
private Not notSymbol = new Not();
private GreaterThan gtSymbol = new GreaterThan();
private LessThan ltSymbol = new LessThan();
public SymbolicExpressionTree Simplify(SymbolicExpressionTree originalTree) {
var clone = (SymbolicExpressionTreeNode)originalTree.Root.Clone();
// macro expand (initially no argument trees)
var macroExpandedTree = MacroExpand(clone, clone.SubTrees[0], new List());
SymbolicExpressionTreeNode rootNode = (new ProgramRootSymbol()).CreateTreeNode();
rootNode.AddSubTree(GetSimplifiedTree(macroExpandedTree));
return new SymbolicExpressionTree(rootNode);
}
// the argumentTrees list contains already expanded trees used as arguments for invocations
private SymbolicExpressionTreeNode MacroExpand(SymbolicExpressionTreeNode root, SymbolicExpressionTreeNode node, IList argumentTrees) {
List subtrees = new List(node.SubTrees);
while (node.SubTrees.Count > 0) node.RemoveSubTree(0);
if (node.Symbol is InvokeFunction) {
var invokeSym = node.Symbol as InvokeFunction;
var defunNode = FindFunctionDefinition(root, invokeSym.FunctionName);
var macroExpandedArguments = new List();
foreach (var subtree in subtrees) {
macroExpandedArguments.Add(MacroExpand(root, subtree, argumentTrees));
}
return MacroExpand(root, defunNode, macroExpandedArguments);
} else if (node.Symbol is Argument) {
var argSym = node.Symbol as Argument;
// return the correct argument sub-tree (already macro-expanded)
return (SymbolicExpressionTreeNode)argumentTrees[argSym.ArgumentIndex].Clone();
} else {
// recursive application
foreach (var subtree in subtrees) {
node.AddSubTree(MacroExpand(root, subtree, argumentTrees));
}
return node;
}
}
private SymbolicExpressionTreeNode FindFunctionDefinition(SymbolicExpressionTreeNode root, string functionName) {
foreach (var subtree in root.SubTrees.OfType()) {
if (subtree.FunctionName == functionName) return subtree.SubTrees[0];
}
throw new ArgumentException("Definition of function " + functionName + " not found.");
}
#region symbol predicates
// arithmetic
private bool IsDivision(SymbolicExpressionTreeNode node) {
return node.Symbol is Division;
}
private bool IsMultiplication(SymbolicExpressionTreeNode node) {
return node.Symbol is Multiplication;
}
private bool IsSubtraction(SymbolicExpressionTreeNode node) {
return node.Symbol is Subtraction;
}
private bool IsAddition(SymbolicExpressionTreeNode node) {
return node.Symbol is Addition;
}
private bool IsAverage(SymbolicExpressionTreeNode node) {
return node.Symbol is Average;
}
// exponential
private bool IsLog(SymbolicExpressionTreeNode node) {
return node.Symbol is Logarithm;
}
private bool IsExp(SymbolicExpressionTreeNode node) {
return node.Symbol is Exponential;
}
private bool IsRoot(SymbolicExpressionTreeNode node) {
return node.Symbol is Root;
}
private bool IsPower(SymbolicExpressionTreeNode node) {
return node.Symbol is Power;
}
// trigonometric
private bool IsSine(SymbolicExpressionTreeNode node) {
return node.Symbol is Sine;
}
private bool IsCosine(SymbolicExpressionTreeNode node) {
return node.Symbol is Cosine;
}
private bool IsTangent(SymbolicExpressionTreeNode node) {
return node.Symbol is Tangent;
}
// boolean
private bool IsIfThenElse(SymbolicExpressionTreeNode node) {
return node.Symbol is IfThenElse;
}
private bool IsAnd(SymbolicExpressionTreeNode node) {
return node.Symbol is And;
}
private bool IsOr(SymbolicExpressionTreeNode node) {
return node.Symbol is Or;
}
private bool IsNot(SymbolicExpressionTreeNode node) {
return node.Symbol is Not;
}
// comparison
private bool IsGreaterThan(SymbolicExpressionTreeNode node) {
return node.Symbol is GreaterThan;
}
private bool IsLessThan(SymbolicExpressionTreeNode node) {
return node.Symbol is LessThan;
}
private bool IsBoolean(SymbolicExpressionTreeNode node) {
return
node.Symbol is GreaterThan ||
node.Symbol is LessThan ||
node.Symbol is And ||
node.Symbol is Or;
}
// terminals
private bool IsVariable(SymbolicExpressionTreeNode node) {
return node.Symbol is Variable;
}
private bool IsConstant(SymbolicExpressionTreeNode node) {
return node.Symbol is Constant;
}
#endregion
///
/// Creates a new simplified tree
///
///
///
public SymbolicExpressionTreeNode GetSimplifiedTree(SymbolicExpressionTreeNode original) {
if (IsConstant(original) || IsVariable(original)) {
return (SymbolicExpressionTreeNode)original.Clone();
} else if (IsAddition(original)) {
return SimplifyAddition(original);
} else if (IsSubtraction(original)) {
return SimplifySubtraction(original);
} else if (IsMultiplication(original)) {
return SimplifyMultiplication(original);
} else if (IsDivision(original)) {
return SimplifyDivision(original);
} else if (IsAverage(original)) {
return SimplifyAverage(original);
} else if (IsLog(original)) {
return SimplifyLog(original);
} else if (IsExp(original)) {
return SimplifyExp(original);
} else if (IsRoot(original)) {
return SimplifyRoot(original);
} else if (IsPower(original)) {
return SimplifyPower(original);
} else if (IsSine(original)) {
return SimplifySine(original);
} else if (IsCosine(original)) {
return SimplifyCosine(original);
} else if (IsTangent(original)) {
return SimplifyTangent(original);
} else if (IsIfThenElse(original)) {
return SimplifyIfThenElse(original);
} else if (IsGreaterThan(original)) {
return SimplifyGreaterThan(original);
} else if (IsLessThan(original)) {
return SimplifyLessThan(original);
} else if (IsAnd(original)) {
return SimplifyAnd(original);
} else if (IsOr(original)) {
return SimplifyOr(original);
} else if (IsNot(original)) {
return SimplifyNot(original);
} else {
return SimplifyAny(original);
}
}
#region specific simplification routines
private SymbolicExpressionTreeNode SimplifyAny(SymbolicExpressionTreeNode original) {
// can't simplify this function but simplify all subtrees
List subTrees = new List(original.SubTrees);
while (original.SubTrees.Count > 0) original.RemoveSubTree(0);
var clone = (SymbolicExpressionTreeNode)original.Clone();
List simplifiedSubTrees = new List();
foreach (var subTree in subTrees) {
simplifiedSubTrees.Add(GetSimplifiedTree(subTree));
original.AddSubTree(subTree);
}
foreach (var simplifiedSubtree in simplifiedSubTrees) {
clone.AddSubTree(simplifiedSubtree);
}
if (simplifiedSubTrees.TrueForAll(t => IsConstant(t))) {
SimplifyConstantExpression(clone);
}
return clone;
}
private SymbolicExpressionTreeNode SimplifyConstantExpression(SymbolicExpressionTreeNode original) {
// not yet implemented
return original;
}
private SymbolicExpressionTreeNode SimplifyAverage(SymbolicExpressionTreeNode original) {
if (original.SubTrees.Count == 1) {
return GetSimplifiedTree(original.SubTrees[0]);
} else {
// simplify expressions x0..xn
// make sum(x0..xn) / n
Trace.Assert(original.SubTrees.Count > 1);
var sum = original.SubTrees
.Select(x => GetSimplifiedTree(x))
.Aggregate((a, b) => MakeSum(a, b));
return MakeFraction(sum, MakeConstant(original.SubTrees.Count));
}
}
private SymbolicExpressionTreeNode SimplifyDivision(SymbolicExpressionTreeNode original) {
if (original.SubTrees.Count == 1) {
return Invert(GetSimplifiedTree(original.SubTrees[0]));
} else {
// simplify expressions x0..xn
// make multiplication (x0 * 1/(x1 * x1 * .. * xn))
Trace.Assert(original.SubTrees.Count > 1);
var simplifiedTrees = original.SubTrees.Select(x => GetSimplifiedTree(x));
return
MakeProduct(simplifiedTrees.First(), Invert(simplifiedTrees.Skip(1).Aggregate((a, b) => MakeProduct(a, b))));
}
}
private SymbolicExpressionTreeNode SimplifyMultiplication(SymbolicExpressionTreeNode original) {
if (original.SubTrees.Count == 1) {
return GetSimplifiedTree(original.SubTrees[0]);
} else {
Trace.Assert(original.SubTrees.Count > 1);
return original.SubTrees
.Select(x => GetSimplifiedTree(x))
.Aggregate((a, b) => MakeProduct(a, b));
}
}
private SymbolicExpressionTreeNode SimplifySubtraction(SymbolicExpressionTreeNode original) {
if (original.SubTrees.Count == 1) {
return Negate(GetSimplifiedTree(original.SubTrees[0]));
} else {
// simplify expressions x0..xn
// make addition (x0,-x1..-xn)
Trace.Assert(original.SubTrees.Count > 1);
var simplifiedTrees = original.SubTrees.Select(x => GetSimplifiedTree(x));
return simplifiedTrees.Take(1)
.Concat(simplifiedTrees.Skip(1).Select(x => Negate(x)))
.Aggregate((a, b) => MakeSum(a, b));
}
}
private SymbolicExpressionTreeNode SimplifyAddition(SymbolicExpressionTreeNode original) {
if (original.SubTrees.Count == 1) {
return GetSimplifiedTree(original.SubTrees[0]);
} else {
// simplify expression x0..xn
// make addition (x0..xn)
Trace.Assert(original.SubTrees.Count > 1);
return original.SubTrees
.Select(x => GetSimplifiedTree(x))
.Aggregate((a, b) => MakeSum(a, b));
}
}
private SymbolicExpressionTreeNode SimplifyNot(SymbolicExpressionTreeNode original) {
return MakeNot(GetSimplifiedTree(original.SubTrees[0]));
}
private SymbolicExpressionTreeNode SimplifyOr(SymbolicExpressionTreeNode original) {
return original.SubTrees
.Select(x => GetSimplifiedTree(x))
.Aggregate((a, b) => MakeOr(a, b));
}
private SymbolicExpressionTreeNode SimplifyAnd(SymbolicExpressionTreeNode original) {
return original.SubTrees
.Select(x => GetSimplifiedTree(x))
.Aggregate((a, b) => MakeAnd(a, b));
}
private SymbolicExpressionTreeNode SimplifyLessThan(SymbolicExpressionTreeNode original) {
return MakeLessThan(GetSimplifiedTree(original.SubTrees[0]), GetSimplifiedTree(original.SubTrees[1]));
}
private SymbolicExpressionTreeNode SimplifyGreaterThan(SymbolicExpressionTreeNode original) {
return MakeGreaterThan(GetSimplifiedTree(original.SubTrees[0]), GetSimplifiedTree(original.SubTrees[1]));
}
private SymbolicExpressionTreeNode SimplifyIfThenElse(SymbolicExpressionTreeNode original) {
return MakeIfThenElse(GetSimplifiedTree(original.SubTrees[0]), GetSimplifiedTree(original.SubTrees[1]), GetSimplifiedTree(original.SubTrees[2]));
}
private SymbolicExpressionTreeNode SimplifyTangent(SymbolicExpressionTreeNode original) {
return MakeTangent(GetSimplifiedTree(original.SubTrees[0]));
}
private SymbolicExpressionTreeNode SimplifyCosine(SymbolicExpressionTreeNode original) {
return MakeCosine(GetSimplifiedTree(original.SubTrees[0]));
}
private SymbolicExpressionTreeNode SimplifySine(SymbolicExpressionTreeNode original) {
return MakeSine(GetSimplifiedTree(original.SubTrees[0]));
}
private SymbolicExpressionTreeNode SimplifyExp(SymbolicExpressionTreeNode original) {
return MakeExp(GetSimplifiedTree(original.SubTrees[0]));
}
private SymbolicExpressionTreeNode SimplifyLog(SymbolicExpressionTreeNode original) {
return MakeLog(GetSimplifiedTree(original.SubTrees[0]));
}
private SymbolicExpressionTreeNode SimplifyRoot(SymbolicExpressionTreeNode original) {
return MakeRoot(GetSimplifiedTree(original.SubTrees[0]), GetSimplifiedTree(original.SubTrees[1]));
}
private SymbolicExpressionTreeNode SimplifyPower(SymbolicExpressionTreeNode original) {
return MakePower(GetSimplifiedTree(original.SubTrees[0]), GetSimplifiedTree(original.SubTrees[1]));
}
#endregion
#region low level tree restructuring
private SymbolicExpressionTreeNode MakeNot(SymbolicExpressionTreeNode t) {
return MakeProduct(t, MakeConstant(-1.0));
}
private SymbolicExpressionTreeNode MakeOr(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
if (IsConstant(a) && IsConstant(b)) {
var constA = a as ConstantTreeNode;
var constB = b as ConstantTreeNode;
if (constA.Value > 0.0 || constB.Value > 0.0) {
return MakeConstant(1.0);
} else {
return MakeConstant(-1.0);
}
} else if (IsConstant(a)) {
return MakeOr(b, a);
} else if (IsConstant(b)) {
var constT = b as ConstantTreeNode;
if (constT.Value > 0.0) {
// boolean expression is necessarily true
return MakeConstant(1.0);
} else {
// the constant value has no effect on the result of the boolean condition so we can drop the constant term
var orNode = orSymbol.CreateTreeNode();
orNode.AddSubTree(a);
return orNode;
}
} else {
var orNode = orSymbol.CreateTreeNode();
orNode.AddSubTree(a);
orNode.AddSubTree(b);
return orNode;
}
}
private SymbolicExpressionTreeNode MakeAnd(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
if (IsConstant(a) && IsConstant(b)) {
var constA = a as ConstantTreeNode;
var constB = b as ConstantTreeNode;
if (constA.Value > 0.0 && constB.Value > 0.0) {
return MakeConstant(1.0);
} else {
return MakeConstant(-1.0);
}
} else if (IsConstant(a)) {
return MakeAnd(b, a);
} else if (IsConstant(b)) {
var constB = b as ConstantTreeNode;
if (constB.Value > 0.0) {
// the constant value has no effect on the result of the boolean condition so we can drop the constant term
var andNode = andSymbol.CreateTreeNode();
andNode.AddSubTree(a);
return andNode;
} else {
// boolean expression is necessarily false
return MakeConstant(-1.0);
}
} else {
var andNode = andSymbol.CreateTreeNode();
andNode.AddSubTree(a);
andNode.AddSubTree(b);
return andNode;
}
}
private SymbolicExpressionTreeNode MakeLessThan(SymbolicExpressionTreeNode leftSide, SymbolicExpressionTreeNode rightSide) {
if (IsConstant(leftSide) && IsConstant(rightSide)) {
var lsConst = leftSide as ConstantTreeNode;
var rsConst = rightSide as ConstantTreeNode;
if (lsConst.Value < rsConst.Value) return MakeConstant(1.0);
else return MakeConstant(-1.0);
} else {
var ltNode = ltSymbol.CreateTreeNode();
ltNode.AddSubTree(leftSide);
ltNode.AddSubTree(rightSide);
return ltNode;
}
}
private SymbolicExpressionTreeNode MakeGreaterThan(SymbolicExpressionTreeNode leftSide, SymbolicExpressionTreeNode rightSide) {
if (IsConstant(leftSide) && IsConstant(rightSide)) {
var lsConst = leftSide as ConstantTreeNode;
var rsConst = rightSide as ConstantTreeNode;
if (lsConst.Value > rsConst.Value) return MakeConstant(1.0);
else return MakeConstant(-1.0);
} else {
var gtNode = gtSymbol.CreateTreeNode();
gtNode.AddSubTree(leftSide);
gtNode.AddSubTree(rightSide);
return gtNode;
}
}
private SymbolicExpressionTreeNode MakeIfThenElse(SymbolicExpressionTreeNode condition, SymbolicExpressionTreeNode trueBranch, SymbolicExpressionTreeNode falseBranch) {
if (IsConstant(condition)) {
var constT = condition as ConstantTreeNode;
if (constT.Value > 0.0) return trueBranch;
else return falseBranch;
} else {
var ifNode = ifThenElseSymbol.CreateTreeNode();
if (IsBoolean(condition)) {
ifNode.AddSubTree(condition);
} else {
var gtNode = gtSymbol.CreateTreeNode();
gtNode.AddSubTree(condition); gtNode.AddSubTree(MakeConstant(0.0));
ifNode.AddSubTree(gtNode);
}
ifNode.AddSubTree(trueBranch);
ifNode.AddSubTree(falseBranch);
return ifNode;
}
}
private SymbolicExpressionTreeNode MakeSine(SymbolicExpressionTreeNode node) {
if (IsConstant(node)) {
var constT = node as ConstantTreeNode;
return MakeConstant(Math.Sin(constT.Value));
} else {
var sineNode = sineSymbol.CreateTreeNode();
sineNode.AddSubTree(node);
return sineNode;
}
}
private SymbolicExpressionTreeNode MakeTangent(SymbolicExpressionTreeNode node) {
if (IsConstant(node)) {
var constT = node as ConstantTreeNode;
return MakeConstant(Math.Tan(constT.Value));
} else {
var tanNode = tanSymbol.CreateTreeNode();
tanNode.AddSubTree(node);
return tanNode;
}
}
private SymbolicExpressionTreeNode MakeCosine(SymbolicExpressionTreeNode node) {
if (IsConstant(node)) {
var constT = node as ConstantTreeNode;
return MakeConstant(Math.Cos(constT.Value));
} else {
var cosNode = cosineSymbol.CreateTreeNode();
cosNode.AddSubTree(node);
return cosNode;
}
}
private SymbolicExpressionTreeNode MakeExp(SymbolicExpressionTreeNode node) {
if (IsConstant(node)) {
var constT = node as ConstantTreeNode;
return MakeConstant(Math.Exp(constT.Value));
} else if (IsLog(node)) {
return node.SubTrees[0];
} else if (IsAddition(node)) {
return node.SubTrees.Select(s => MakeExp(s)).Aggregate((s, t) => MakeProduct(s, t));
} else if (IsSubtraction(node)) {
return node.SubTrees.Select(s => MakeExp(s)).Aggregate((s, t) => MakeProduct(s, Negate(t)));
} else {
var expNode = expSymbol.CreateTreeNode();
expNode.AddSubTree(node);
return expNode;
}
}
private SymbolicExpressionTreeNode MakeLog(SymbolicExpressionTreeNode node) {
if (IsConstant(node)) {
var constT = node as ConstantTreeNode;
return MakeConstant(Math.Log(constT.Value));
} else if (IsExp(node)) {
return node.SubTrees[0];
} else if (IsMultiplication(node)) {
return node.SubTrees.Select(s => MakeLog(s)).Aggregate((x, y) => MakeSum(x, y));
} else if (IsDivision(node)) {
var subtractionNode = subSymbol.CreateTreeNode();
foreach (var subTree in node.SubTrees) {
subtractionNode.AddSubTree(MakeLog(subTree));
}
return subtractionNode;
} else {
var logNode = logSymbol.CreateTreeNode();
logNode.AddSubTree(node);
return logNode;
}
}
private SymbolicExpressionTreeNode MakeRoot(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
if (IsConstant(a) && IsConstant(b)) {
var constA = a as ConstantTreeNode;
var constB = b as ConstantTreeNode;
return MakeConstant(Math.Pow(constA.Value, 1.0 / Math.Round(constB.Value)));
} else if (IsConstant(b)) {
var constB = b as ConstantTreeNode;
var constBValue = Math.Round(constB.Value);
if (constBValue.IsAlmost(1.0)) {
return a;
} else if (constBValue.IsAlmost(0.0)) {
return MakeConstant(1.0);
} else if (constBValue.IsAlmost(-1.0)) {
return MakeFraction(MakeConstant(1.0), a);
} else if (constBValue < 0) {
var rootNode = rootSymbol.CreateTreeNode();
rootNode.AddSubTree(a);
rootNode.AddSubTree(MakeConstant(-1.0 * constBValue));
return MakeFraction(MakeConstant(1.0), rootNode);
} else {
var rootNode = rootSymbol.CreateTreeNode();
rootNode.AddSubTree(a);
rootNode.AddSubTree(MakeConstant(constBValue));
return rootNode;
}
} else {
var rootNode = rootSymbol.CreateTreeNode();
rootNode.AddSubTree(a);
rootNode.AddSubTree(b);
return rootNode;
}
}
private SymbolicExpressionTreeNode MakePower(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
if (IsConstant(a) && IsConstant(b)) {
var constA = a as ConstantTreeNode;
var constB = b as ConstantTreeNode;
return MakeConstant(Math.Pow(constA.Value, Math.Round(constB.Value)));
} else if (IsConstant(b)) {
var constB = b as ConstantTreeNode;
double exponent = Math.Round(constB.Value);
if (exponent.IsAlmost(0.0)) {
return MakeConstant(1.0);
} else if (exponent.IsAlmost(1.0)) {
return a;
} else if (exponent.IsAlmost(-1.0)) {
return MakeFraction(MakeConstant(1.0), a);
} else if (exponent < 0) {
var powNode = powSymbol.CreateTreeNode();
powNode.AddSubTree(a);
powNode.AddSubTree(MakeConstant(-1.0 * exponent));
return MakeFraction(MakeConstant(1.0), powNode);
} else {
var powNode = powSymbol.CreateTreeNode();
powNode.AddSubTree(a);
powNode.AddSubTree(MakeConstant(exponent));
return powNode;
}
} else {
var powNode = powSymbol.CreateTreeNode();
powNode.AddSubTree(a);
powNode.AddSubTree(b);
return powNode;
}
}
// MakeFraction, MakeProduct and MakeSum take two already simplified trees and create a new simplified tree
private SymbolicExpressionTreeNode MakeFraction(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
if (IsConstant(a) && IsConstant(b)) {
// fold constants
return MakeConstant(((ConstantTreeNode)a).Value / ((ConstantTreeNode)b).Value);
} if (IsConstant(a) && !((ConstantTreeNode)a).Value.IsAlmost(1.0)) {
return MakeFraction(MakeConstant(1.0), MakeProduct(b, Invert(a)));
} else if (IsVariable(a) && IsConstant(b)) {
// merge constant values into variable weights
var constB = ((ConstantTreeNode)b).Value;
((VariableTreeNode)a).Weight /= constB;
return a;
} else if (IsVariable(a) && IsVariable(b) && AreSameVariable(a, b)) {
// cancel variables
var aVar = a as VariableTreeNode;
var bVar = b as VariableTreeNode;
return MakeConstant(aVar.Weight / bVar.Weight);
} else if (IsAddition(a) && IsConstant(b)) {
return a.SubTrees
.Select(x => GetSimplifiedTree(x))
.Select(x => MakeFraction(x, b))
.Aggregate((c, d) => MakeSum(c, d));
} else if (IsMultiplication(a) && IsConstant(b)) {
return MakeProduct(a, Invert(b));
} else if (IsDivision(a) && IsConstant(b)) {
// (a1 / a2) / c => (a1 / (a2 * c))
Trace.Assert(a.SubTrees.Count == 2);
return MakeFraction(a.SubTrees[0], MakeProduct(a.SubTrees[1], b));
} else if (IsDivision(a) && IsDivision(b)) {
// (a1 / a2) / (b1 / b2) =>
Trace.Assert(a.SubTrees.Count == 2);
Trace.Assert(b.SubTrees.Count == 2);
return MakeFraction(MakeProduct(a.SubTrees[0], b.SubTrees[1]), MakeProduct(a.SubTrees[1], b.SubTrees[0]));
} else if (IsDivision(a)) {
// (a1 / a2) / b => (a1 / (a2 * b))
Trace.Assert(a.SubTrees.Count == 2);
return MakeFraction(a.SubTrees[0], MakeProduct(a.SubTrees[1], b));
} else if (IsDivision(b)) {
// a / (b1 / b2) => (a * b2) / b1
Trace.Assert(b.SubTrees.Count == 2);
return MakeFraction(MakeProduct(a, b.SubTrees[1]), b.SubTrees[0]);
} else {
var div = divSymbol.CreateTreeNode();
div.AddSubTree(a);
div.AddSubTree(b);
return div;
}
}
private SymbolicExpressionTreeNode MakeSum(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
if (IsConstant(a) && IsConstant(b)) {
// fold constants
((ConstantTreeNode)a).Value += ((ConstantTreeNode)b).Value;
return a;
} else if (IsConstant(a)) {
// c + x => x + c
// b is not constant => make sure constant is on the right
return MakeSum(b, a);
} else if (IsConstant(b) && ((ConstantTreeNode)b).Value.IsAlmost(0.0)) {
// x + 0 => x
return a;
} else if (IsAddition(a) && IsAddition(b)) {
// merge additions
var add = addSymbol.CreateTreeNode();
for (int i = 0; i < a.SubTrees.Count - 1; i++) add.AddSubTree(a.SubTrees[i]);
for (int i = 0; i < b.SubTrees.Count - 1; i++) add.AddSubTree(b.SubTrees[i]);
if (IsConstant(a.SubTrees.Last()) && IsConstant(b.SubTrees.Last())) {
add.AddSubTree(MakeSum(a.SubTrees.Last(), b.SubTrees.Last()));
} else if (IsConstant(a.SubTrees.Last())) {
add.AddSubTree(b.SubTrees.Last());
add.AddSubTree(a.SubTrees.Last());
} else {
add.AddSubTree(a.SubTrees.Last());
add.AddSubTree(b.SubTrees.Last());
}
MergeVariablesInSum(add);
if (add.SubTrees.Count == 1) {
return add.SubTrees[0];
} else {
return add;
}
} else if (IsAddition(b)) {
return MakeSum(b, a);
} else if (IsAddition(a) && IsConstant(b)) {
// a is an addition and b is a constant => append b to a and make sure the constants are merged
var add = addSymbol.CreateTreeNode();
for (int i = 0; i < a.SubTrees.Count - 1; i++) add.AddSubTree(a.SubTrees[i]);
if (IsConstant(a.SubTrees.Last()))
add.AddSubTree(MakeSum(a.SubTrees.Last(), b));
else {
add.AddSubTree(a.SubTrees.Last());
add.AddSubTree(b);
}
return add;
} else if (IsAddition(a)) {
// a is already an addition => append b
var add = addSymbol.CreateTreeNode();
add.AddSubTree(b);
foreach (var subTree in a.SubTrees) {
add.AddSubTree(subTree);
}
MergeVariablesInSum(add);
if (add.SubTrees.Count == 1) {
return add.SubTrees[0];
} else {
return add;
}
} else {
var add = addSymbol.CreateTreeNode();
add.AddSubTree(a);
add.AddSubTree(b);
MergeVariablesInSum(add);
if (add.SubTrees.Count == 1) {
return add.SubTrees[0];
} else {
return add;
}
}
}
// makes sure variable symbols in sums are combined
// possible improvement: combine sums of products where the products only reference the same variable
private void MergeVariablesInSum(SymbolicExpressionTreeNode sum) {
var subtrees = new List(sum.SubTrees);
while (sum.SubTrees.Count > 0) sum.RemoveSubTree(0);
var groupedVarNodes = from node in subtrees.OfType()
let lag = (node is LaggedVariableTreeNode) ? ((LaggedVariableTreeNode)node).Lag : 0
group node by node.VariableName + lag into g
select g;
var unchangedSubTrees = subtrees.Where(t => !(t is VariableTreeNode));
foreach (var variableNodeGroup in groupedVarNodes) {
var weightSum = variableNodeGroup.Select(t => t.Weight).Sum();
var representative = variableNodeGroup.First();
representative.Weight = weightSum;
sum.AddSubTree(representative);
}
foreach (var unchangedSubtree in unchangedSubTrees)
sum.AddSubTree(unchangedSubtree);
}
private SymbolicExpressionTreeNode MakeProduct(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
if (IsConstant(a) && IsConstant(b)) {
// fold constants
((ConstantTreeNode)a).Value *= ((ConstantTreeNode)b).Value;
return a;
} else if (IsConstant(a)) {
// a * $ => $ * a
return MakeProduct(b, a);
} else if (IsConstant(b) && ((ConstantTreeNode)b).Value.IsAlmost(1.0)) {
// $ * 1.0 => $
return a;
} else if (IsConstant(b) && IsVariable(a)) {
// multiply constants into variables weights
((VariableTreeNode)a).Weight *= ((ConstantTreeNode)b).Value;
return a;
} else if (IsConstant(b) && IsAddition(a)) {
// multiply constants into additions
return a.SubTrees.Select(x => MakeProduct(x, b)).Aggregate((c, d) => MakeSum(c, d));
} else if (IsDivision(a) && IsDivision(b)) {
// (a1 / a2) * (b1 / b2) => (a1 * b1) / (a2 * b2)
Trace.Assert(a.SubTrees.Count == 2);
Trace.Assert(b.SubTrees.Count == 2);
return MakeFraction(MakeProduct(a.SubTrees[0], b.SubTrees[0]), MakeProduct(a.SubTrees[1], b.SubTrees[1]));
} else if (IsDivision(a)) {
// (a1 / a2) * b => (a1 * b) / a2
Trace.Assert(a.SubTrees.Count == 2);
return MakeFraction(MakeProduct(a.SubTrees[0], b), a.SubTrees[1]);
} else if (IsDivision(b)) {
// a * (b1 / b2) => (b1 * a) / b2
Trace.Assert(b.SubTrees.Count == 2);
return MakeFraction(MakeProduct(b.SubTrees[0], a), b.SubTrees[1]);
} else if (IsMultiplication(a) && IsMultiplication(b)) {
// merge multiplications (make sure constants are merged)
var mul = mulSymbol.CreateTreeNode();
for (int i = 0; i < a.SubTrees.Count; i++) mul.AddSubTree(a.SubTrees[i]);
for (int i = 0; i < b.SubTrees.Count; i++) mul.AddSubTree(b.SubTrees[i]);
MergeVariablesAndConstantsInProduct(mul);
return mul;
} else if (IsMultiplication(b)) {
return MakeProduct(b, a);
} else if (IsMultiplication(a)) {
// a is already an multiplication => append b
a.AddSubTree(b);
MergeVariablesAndConstantsInProduct(a);
return a;
} else {
var mul = mulSymbol.CreateTreeNode();
mul.SubTrees.Add(a);
mul.SubTrees.Add(b);
MergeVariablesAndConstantsInProduct(mul);
return mul;
}
}
#endregion
#region helper functions
private bool AreSameVariable(SymbolicExpressionTreeNode a, SymbolicExpressionTreeNode b) {
var aLaggedVar = a as LaggedVariableTreeNode;
var bLaggedVar = b as LaggedVariableTreeNode;
if (aLaggedVar != null && bLaggedVar != null) {
return aLaggedVar.VariableName == bLaggedVar.VariableName &&
aLaggedVar.Lag == bLaggedVar.Lag;
}
var aVar = a as VariableTreeNode;
var bVar = b as VariableTreeNode;
if (aVar != null && bVar != null) {
return aVar.VariableName == bVar.VariableName;
}
return false;
}
// helper to combine the constant factors in products and to combine variables (powers of 2, 3...)
private void MergeVariablesAndConstantsInProduct(SymbolicExpressionTreeNode prod) {
var subtrees = new List(prod.SubTrees);
while (prod.SubTrees.Count > 0) prod.RemoveSubTree(0);
var groupedVarNodes = from node in subtrees.OfType()
let lag = (node is LaggedVariableTreeNode) ? ((LaggedVariableTreeNode)node).Lag : 0
group node by node.VariableName + lag into g
orderby g.Count()
select g;
var constantProduct = (from node in subtrees.OfType()
select node.Weight)
.Concat(from node in subtrees.OfType()
select node.Value)
.DefaultIfEmpty(1.0)
.Aggregate((c1, c2) => c1 * c2);
var unchangedSubTrees = from tree in subtrees
where !(tree is VariableTreeNode)
where !(tree is ConstantTreeNode)
select tree;
foreach (var variableNodeGroup in groupedVarNodes) {
var representative = variableNodeGroup.First();
representative.Weight = 1.0;
if (variableNodeGroup.Count() > 1) {
var poly = mulSymbol.CreateTreeNode();
for (int p = 0; p < variableNodeGroup.Count(); p++) {
poly.AddSubTree((SymbolicExpressionTreeNode)representative.Clone());
}
prod.AddSubTree(poly);
} else {
prod.AddSubTree(representative);
}
}
foreach (var unchangedSubtree in unchangedSubTrees)
prod.AddSubTree(unchangedSubtree);
if (!constantProduct.IsAlmost(1.0)) {
prod.AddSubTree(MakeConstant(constantProduct));
}
}
///
/// x => x * -1
/// Doesn't create new trees and manipulates x
///
///
/// -x
private SymbolicExpressionTreeNode Negate(SymbolicExpressionTreeNode x) {
if (IsConstant(x)) {
((ConstantTreeNode)x).Value *= -1;
} else if (IsVariable(x)) {
var variableTree = (VariableTreeNode)x;
variableTree.Weight *= -1.0;
} else if (IsAddition(x)) {
// (x0 + x1 + .. + xn) * -1 => (-x0 + -x1 + .. + -xn)
for (int i = 0; i < x.SubTrees.Count; i++)
x.SubTrees[i] = Negate(x.SubTrees[i]);
} else if (IsMultiplication(x) || IsDivision(x)) {
// x0 * x1 * .. * xn * -1 => x0 * x1 * .. * -xn
x.SubTrees[x.SubTrees.Count - 1] = Negate(x.SubTrees.Last()); // last is maybe a constant, prefer to negate the constant
} else {
// any other function
return MakeProduct(x, MakeConstant(-1));
}
return x;
}
///
/// x => 1/x
/// Doesn't create new trees and manipulates x
///
///
///
private SymbolicExpressionTreeNode Invert(SymbolicExpressionTreeNode x) {
if (IsConstant(x)) {
return MakeConstant(1.0 / ((ConstantTreeNode)x).Value);
} else if (IsDivision(x)) {
Trace.Assert(x.SubTrees.Count == 2);
return MakeFraction(x.SubTrees[1], x.SubTrees[0]);
} else {
// any other function
return MakeFraction(MakeConstant(1), x);
}
}
private SymbolicExpressionTreeNode MakeConstant(double value) {
ConstantTreeNode constantTreeNode = (ConstantTreeNode)(constSymbol.CreateTreeNode());
constantTreeNode.Value = value;
return (SymbolicExpressionTreeNode)constantTreeNode;
}
private SymbolicExpressionTreeNode MakeVariable(double weight, string name) {
var tree = (VariableTreeNode)varSymbol.CreateTreeNode();
tree.Weight = weight;
tree.VariableName = name;
return tree;
}
#endregion
}
}