#region License Information
/* HeuristicLab
* Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System.Globalization;
using System.Linq;
using System.Text;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
[Item("MATLAB Function Formatter", "String formatter for string representations of symbolic data analysis expressions in MATLAB function syntax.")]
[StorableClass]
public sealed class SymbolicDataAnalysisExpressionMATLABFunctionFormatter : NamedItem, ISymbolicExpressionTreeStringFormatter {
private int currentLag;
[StorableConstructor]
private SymbolicDataAnalysisExpressionMATLABFunctionFormatter(bool deserializing) : base(deserializing) { }
private SymbolicDataAnalysisExpressionMATLABFunctionFormatter(SymbolicDataAnalysisExpressionMATLABFunctionFormatter original, Cloner cloner) : base(original, cloner) { }
public SymbolicDataAnalysisExpressionMATLABFunctionFormatter()
: base() {
Name = ItemName;
Description = ItemDescription;
}
public override IDeepCloneable Clone(Cloner cloner) {
return new SymbolicDataAnalysisExpressionMATLABFunctionFormatter(this, cloner);
}
private int currentIndexNumber;
public string CurrentIndexVariable {
get {
return "i" + currentIndexNumber;
}
}
private void ReleaseIndexVariable() {
currentIndexNumber--;
}
private string AllocateIndexVariable() {
currentIndexNumber++;
return CurrentIndexVariable;
}
private string GetVariableNames(ISymbolicExpressionTree symbolicExpressionTree) {
var variableNames = symbolicExpressionTree.IterateNodesPostfix()
.Where(x => x.Symbol is IVariableSymbol)
.Select(x => (x as IVariableTreeNode).VariableName)
.Distinct().ToList();
variableNames.Sort();
return string.Join(", ", variableNames);
}
public string Format(ISymbolicExpressionTree symbolicExpressionTree) {
currentLag = 0;
currentIndexNumber = 0;
var stringBuilder = new StringBuilder();
stringBuilder.AppendLine("%%");
stringBuilder.Append("function retval = fct(");
stringBuilder.Append(GetVariableNames(symbolicExpressionTree));
stringBuilder.AppendLine(")");
stringBuilder.AppendLine(" retval = " + FormatRecursively(symbolicExpressionTree.Root.GetSubtree(0)) + ";");
stringBuilder.AppendLine("end");
stringBuilder.AppendLine();
stringBuilder.AppendLine("%%");
stringBuilder.AppendLine("function y = log_(x)");
stringBuilder.AppendLine(" if(x <= 0) y = NaN;");
stringBuilder.AppendLine(" else y = log(x);");
stringBuilder.AppendLine(" end");
stringBuilder.AppendLine("end");
stringBuilder.AppendLine();
stringBuilder.AppendLine("function y = fivePoint(f0, f1, f3, f4)");
stringBuilder.AppendLine(" y = (f0 + 2*f1 - 2*f3 - f4) / 8;");
stringBuilder.AppendLine("end");
var factorVariableNames =
symbolicExpressionTree.IterateNodesPostfix()
.OfType()
.Select(n => n.VariableName)
.Distinct();
foreach (var factorVarName in factorVariableNames) {
var factorSymb = symbolicExpressionTree.IterateNodesPostfix()
.OfType()
.First(n => n.VariableName == factorVarName)
.Symbol;
stringBuilder.AppendFormat("function y = switch_{0}(val, v)", factorVarName).AppendLine();
var values = factorSymb.GetVariableValues(factorVarName).ToArray();
stringBuilder.AppendLine("switch val");
for (int i = 0; i < values.Length; i++) {
stringBuilder.AppendFormat(CultureInfo.InvariantCulture, " case \"{0}\" y = v({1})", values[i], i).AppendLine();
}
stringBuilder.AppendLine("end");
stringBuilder.AppendLine();
}
return stringBuilder.ToString();
}
private string FormatRecursively(ISymbolicExpressionTreeNode node) {
ISymbol symbol = node.Symbol;
StringBuilder stringBuilder = new StringBuilder();
if (symbol is ProgramRootSymbol) {
stringBuilder.AppendLine(FormatRecursively(node.GetSubtree(0)));
} else if (symbol is StartSymbol)
return FormatRecursively(node.GetSubtree(0));
else if (symbol is Addition) {
stringBuilder.Append("(");
for (int i = 0; i < node.SubtreeCount; i++) {
if (i > 0) stringBuilder.Append("+");
stringBuilder.Append(FormatRecursively(node.GetSubtree(i)));
}
stringBuilder.Append(")");
} else if (symbol is And) {
stringBuilder.Append("((");
for (int i = 0; i < node.SubtreeCount; i++) {
if (i > 0) stringBuilder.Append("&");
stringBuilder.Append("((");
stringBuilder.Append(FormatRecursively(node.GetSubtree(i)));
stringBuilder.Append(")>0)");
}
stringBuilder.Append(")-0.5)*2");
// MATLAB maps false and true to 0 and 1, resp., we map this result to -1.0 and +1.0, resp.
} else if (symbol is Average) {
stringBuilder.Append("(1/");
stringBuilder.Append(node.SubtreeCount);
stringBuilder.Append(")*(");
for (int i = 0; i < node.SubtreeCount; i++) {
if (i > 0) stringBuilder.Append("+");
stringBuilder.Append("(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(i)));
stringBuilder.Append(")");
}
stringBuilder.Append(")");
} else if (symbol is Constant) {
ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
stringBuilder.Append(constantTreeNode.Value.ToString(CultureInfo.InvariantCulture));
} else if (symbol is Cosine) {
stringBuilder.Append("cos(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(")");
} else if (symbol is Division) {
if (node.SubtreeCount == 1) {
stringBuilder.Append("1/");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
} else {
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append("/(");
for (int i = 1; i < node.SubtreeCount; i++) {
if (i > 1) stringBuilder.Append("*");
stringBuilder.Append(FormatRecursively(node.GetSubtree(i)));
}
stringBuilder.Append(")");
}
} else if (symbol is Exponential) {
stringBuilder.Append("exp(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(")");
} else if (symbol is Square) {
stringBuilder.Append("(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(").^2");
} else if (symbol is SquareRoot) {
stringBuilder.Append("sqrt(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(")");
} else if (symbol is GreaterThan) {
stringBuilder.Append("((");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(">");
stringBuilder.Append(FormatRecursively(node.GetSubtree(1)));
stringBuilder.Append(")-0.5)*2");
// MATLAB maps false and true to 0 and 1, resp., we map this result to -1.0 and +1.0, resp.
} else if (symbol is IfThenElse) {
stringBuilder.Append("(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(">0)*");
stringBuilder.Append(FormatRecursively(node.GetSubtree(1)));
stringBuilder.Append("+");
stringBuilder.Append("(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append("<=0)*");
stringBuilder.Append(FormatRecursively(node.GetSubtree(2)));
} else if (symbol is LaggedVariable) {
// this if must be checked before if(symbol is LaggedVariable)
LaggedVariableTreeNode laggedVariableTreeNode = node as LaggedVariableTreeNode;
stringBuilder.Append(laggedVariableTreeNode.Weight.ToString(CultureInfo.InvariantCulture));
stringBuilder.Append("*");
stringBuilder.Append(laggedVariableTreeNode.VariableName +
LagToString(currentLag + laggedVariableTreeNode.Lag));
} else if (symbol is LessThan) {
stringBuilder.Append("((");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append("<");
stringBuilder.Append(FormatRecursively(node.GetSubtree(1)));
stringBuilder.Append(")-0.5)*2");
// MATLAB maps false and true to 0 and 1, resp., we map this result to -1.0 and +1.0, resp.
} else if (symbol is Logarithm) {
stringBuilder.Append("log_(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(")");
} else if (symbol is Multiplication) {
for (int i = 0; i < node.SubtreeCount; i++) {
if (i > 0) stringBuilder.Append("*");
stringBuilder.Append(FormatRecursively(node.GetSubtree(i)));
}
} else if (symbol is Not) {
stringBuilder.Append("~(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(" > 0 )");
} else if (symbol is Or) {
stringBuilder.Append("((");
for (int i = 0; i < node.SubtreeCount; i++) {
if (i > 0) stringBuilder.Append("|");
stringBuilder.Append("((");
stringBuilder.Append(FormatRecursively(node.GetSubtree(i)));
stringBuilder.Append(")>0)");
}
stringBuilder.Append(")-0.5)*2");
// MATLAB maps false and true to 0 and 1, resp., we map this result to -1.0 and +1.0, resp.
} else if (symbol is Sine) {
stringBuilder.Append("sin(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(")");
} else if (symbol is Subtraction) {
stringBuilder.Append("(");
if (node.SubtreeCount == 1) {
stringBuilder.Append("-");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
} else {
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
for (int i = 1; i < node.SubtreeCount; i++) {
stringBuilder.Append("-");
stringBuilder.Append(FormatRecursively(node.GetSubtree(i)));
}
}
stringBuilder.Append(")");
} else if (symbol is Tangent) {
stringBuilder.Append("tan(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(")");
} else if (node.Symbol is AiryA) {
stringBuilder.Append("airy(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(")");
} else if (node.Symbol is AiryB) {
stringBuilder.Append("airy(2, ");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(")");
} else if (node.Symbol is Bessel) {
stringBuilder.Append("besseli(0.0,");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(")");
} else if (node.Symbol is CosineIntegral) {
stringBuilder.Append("cosint(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(")");
} else if (node.Symbol is Dawson) {
stringBuilder.Append("dawson(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(")");
} else if (node.Symbol is Erf) {
stringBuilder.Append("erf(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(")");
} else if (node.Symbol is ExponentialIntegralEi) {
stringBuilder.Append("expint(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(")");
} else if (node.Symbol is FresnelCosineIntegral) {
stringBuilder.Append("FresnelC(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(")");
} else if (node.Symbol is FresnelSineIntegral) {
stringBuilder.Append("FresnelS(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(")");
} else if (node.Symbol is Gamma) {
stringBuilder.Append("gamma(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(")");
} else if (node.Symbol is HyperbolicCosineIntegral) {
stringBuilder.Append("Chi(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(")");
} else if (node.Symbol is HyperbolicSineIntegral) {
stringBuilder.Append("Shi(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(")");
} else if (node.Symbol is Norm) {
stringBuilder.Append("normpdf(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(")");
} else if (node.Symbol is Psi) {
stringBuilder.Append("psi(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(")");
} else if (node.Symbol is SineIntegral) {
stringBuilder.Append("sinint(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(")");
} else if (symbol is HeuristicLab.Problems.DataAnalysis.Symbolic.Variable) {
VariableTreeNode variableTreeNode = node as VariableTreeNode;
stringBuilder.Append(variableTreeNode.Weight.ToString(CultureInfo.InvariantCulture));
stringBuilder.Append("*");
stringBuilder.Append(variableTreeNode.VariableName + LagToString(currentLag));
} else if (symbol is HeuristicLab.Problems.DataAnalysis.Symbolic.FactorVariable) {
var factorNode = node as FactorVariableTreeNode;
var weights = string.Join(" ", factorNode.Weights.Select(w => w.ToString("G17", CultureInfo.InvariantCulture)));
stringBuilder.AppendFormat("switch_{0}(\"{1}\",[{2}])",
factorNode.VariableName, factorNode.VariableName, weights)
.AppendLine();
} else if (symbol is HeuristicLab.Problems.DataAnalysis.Symbolic.BinaryFactorVariable) {
var factorNode = node as BinaryFactorVariableTreeNode;
stringBuilder.AppendFormat(CultureInfo.InvariantCulture,
"((strcmp({0},\"{1}\")==1) * {2:G17})", factorNode.VariableName, factorNode.VariableValue, factorNode.Weight);
} else if (symbol is Power) {
stringBuilder.Append("(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(")^round(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(1)));
stringBuilder.Append(")");
} else if (symbol is Root) {
stringBuilder.Append("(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(")^(1 / round(");
stringBuilder.Append(FormatRecursively(node.GetSubtree(1)));
stringBuilder.Append("))");
} else if (symbol is Derivative) {
stringBuilder.Append("fivePoint(");
// f0
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(", ");
// f1
currentLag--;
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(", ");
// f3
currentLag -= 2;
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(", ");
currentLag--;
// f4
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
stringBuilder.Append(")");
currentLag += 4;
} else if (symbol is Integral) {
var laggedNode = node as LaggedTreeNode;
string prevCounterVariable = CurrentIndexVariable;
string counterVariable = AllocateIndexVariable();
stringBuilder.AppendLine(" sum (map(@(" + counterVariable + ") " + FormatRecursively(node.GetSubtree(0)) +
", (" + prevCounterVariable + "+" + laggedNode.Lag + "):" + prevCounterVariable +
"))");
ReleaseIndexVariable();
} else if (symbol is TimeLag) {
var laggedNode = node as LaggedTreeNode;
currentLag += laggedNode.Lag;
stringBuilder.Append(FormatRecursively(node.GetSubtree(0)));
currentLag -= laggedNode.Lag;
} else if (symbol is VariableCondition) {
stringBuilder.AppendLine(FormatRandomForestRecursively(node, 1));
//stringBuilder.Append("if (");
//stringBuilder.Append(node.ToString());
//stringBuilder.AppendLine(") then ");
//stringBuilder.AppendLine(FormatRecursively(node.GetSubtree(0)));
//stringBuilder.AppendLine(" else ");
//stringBuilder.AppendLine(FormatRecursively(node.GetSubtree(1)));
//stringBuilder.AppendLine(" end ");
} else {
stringBuilder.Append("ERROR");
}
return stringBuilder.ToString();
}
private string FormatRandomForestRecursively(ISymbolicExpressionTreeNode node, int indent) {
ISymbol symbol = node.Symbol;
StringBuilder stringBuilder = new StringBuilder();
string indentStr = GetIndent(indent);
if (symbol is VariableCondition) {
stringBuilder.AppendLine();
stringBuilder.Append(indentStr);
stringBuilder.Append("if (");
stringBuilder.Append(node.ToString());
stringBuilder.Append(") then");
stringBuilder.Append(FormatRandomForestRecursively(node.GetSubtree(0), indent + 1));
stringBuilder.AppendLine();
stringBuilder.Append(indentStr);
stringBuilder.Append("else");
stringBuilder.Append(FormatRandomForestRecursively(node.GetSubtree(1), indent + 1));
stringBuilder.AppendLine();
stringBuilder.Append(indentStr);
stringBuilder.Append("end");
} else if (symbol is Constant) {
ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
stringBuilder.AppendLine();
stringBuilder.Append(indentStr);
stringBuilder.Append("retval = ");
stringBuilder.Append(constantTreeNode.Value.ToString(CultureInfo.InvariantCulture));
stringBuilder.Append(";");
} else {
}
return stringBuilder.ToString();
}
private string GetIndent(int indent) {
StringBuilder stringBuilder = new StringBuilder();
for (int i = 0; i < indent; i++) {
stringBuilder.Append(" ");
}
return stringBuilder.ToString();
}
///
/// Returns the suffix for a lagged variable.
///
///
///
private string LagToString(int lag) {
if (lag < 0) {
return "(" + CurrentIndexVariable + "" + lag + ")";
} else if (lag > 0) {
return "(" + CurrentIndexVariable + "+" + lag + ")";
} else {
return string.Empty;
}
}
}
}