#region License Information
/* HeuristicLab
* Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Text;
using HEAL.Attic;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
[Item("Python String Formatter", "String formatter for string representations of symbolic data analysis expressions in Python syntax.")]
[StorableType("37C1E1DD-437F-414B-AA96-9C6A0F6FEE46")]
public sealed class SymbolicDataAnalysisExpressionPythonFormatter : NamedItem, ISymbolicExpressionTreeStringFormatter {
[StorableConstructor]
private SymbolicDataAnalysisExpressionPythonFormatter(StorableConstructorFlag _) : base(_) { }
private SymbolicDataAnalysisExpressionPythonFormatter(SymbolicDataAnalysisExpressionPythonFormatter original, Cloner cloner) : base(original, cloner) { }
public SymbolicDataAnalysisExpressionPythonFormatter()
: base() {
Name = ItemName;
Description = ItemDescription;
}
public override IDeepCloneable Clone(Cloner cloner) => new SymbolicDataAnalysisExpressionPythonFormatter(this, cloner);
public string Format(ISymbolicExpressionTree symbolicExpressionTree) {
StringBuilder strBuilderModel = new StringBuilder();
var header = GenerateHeader(symbolicExpressionTree);
FormatRecursively(symbolicExpressionTree.Root, strBuilderModel);
return $"{header}{strBuilderModel}";
}
public static string FormatTree(ISymbolicExpressionTree symbolicExpressionTree) {
var formatter = new SymbolicDataAnalysisExpressionPythonFormatter();
return formatter.Format(symbolicExpressionTree);
}
private static string GenerateHeader(ISymbolicExpressionTree symbolicExpressionTree) {
StringBuilder strBuilder = new StringBuilder();
ISet variables = new HashSet();
int mathLibCounter = 0;
int statisticLibCounter = 0;
int evaluateIfCounter = 0;
// iterate tree and search for necessary imports and variable names
foreach (var node in symbolicExpressionTree.IterateNodesPostfix()) {
ISymbol symbol = node.Symbol;
if (symbol is Average) statisticLibCounter++;
else if (symbol is IfThenElse) evaluateIfCounter++;
else if (symbol is Cosine) mathLibCounter++;
else if (symbol is Exponential) mathLibCounter++;
else if (symbol is Logarithm) mathLibCounter++;
else if (symbol is Sine) mathLibCounter++;
else if (symbol is Tangent) mathLibCounter++;
else if (symbol is HyperbolicTangent) mathLibCounter++;
else if (symbol is SquareRoot) mathLibCounter++;
else if (symbol is Power) mathLibCounter++;
else if (symbol is AnalyticQuotient) mathLibCounter++;
else if (node is VariableTreeNode) {
var varNode = node as VariableTreeNode;
var formattedVariable = VariableName2Identifier(varNode.VariableName);
variables.Add(formattedVariable);
}
}
// generate import section (if necessary)
var importSection = GenerateNecessaryImports(mathLibCounter, statisticLibCounter);
strBuilder.Append(importSection);
// generate if-then-else helper construct (if necessary)
var ifThenElseSourceSection = GenerateIfThenElseSource(evaluateIfCounter);
strBuilder.Append(ifThenElseSourceSection);
// generate model evaluation function
var modelEvaluationFunctionSection = GenerateModelEvaluationFunction(variables);
strBuilder.Append(modelEvaluationFunctionSection);
return strBuilder.ToString();
}
private static string GenerateNecessaryImports(int mathLibCounter, int statisticLibCounter) {
StringBuilder strBuilder = new StringBuilder();
if (mathLibCounter > 0 || statisticLibCounter > 0) {
strBuilder.AppendLine("# imports");
if (mathLibCounter > 0)
strBuilder.AppendLine("import math");
if (statisticLibCounter > 0)
strBuilder.AppendLine("import statistics");
strBuilder.AppendLine();
}
return strBuilder.ToString();
}
private static string GenerateIfThenElseSource(int evaluateIfCounter) {
StringBuilder strBuilder = new StringBuilder();
if (evaluateIfCounter > 0) {
strBuilder.AppendLine("# condition helper function");
strBuilder.AppendLine("def evaluate_if(condition, then_path, else_path): ");
strBuilder.AppendLine("\tif condition:");
strBuilder.AppendLine("\t\treturn then_path");
strBuilder.AppendLine("\telse:");
strBuilder.AppendLine("\t\treturn else_path");
}
return strBuilder.ToString();
}
private static string GenerateModelEvaluationFunction(ISet variables) {
StringBuilder strBuilder = new StringBuilder();
strBuilder.Append("def evaluate(");
var orderedVariables = variables.OrderBy(n => n, new NaturalStringComparer());
foreach (var variable in orderedVariables) {
strBuilder.Append($"{variable}");
if (variable != orderedVariables.Last())
strBuilder.Append(", ");
}
strBuilder.AppendLine("):");
strBuilder.Append("\treturn ");
return strBuilder.ToString();
}
private static void FormatRecursively(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
ISymbol symbol = node.Symbol;
if (symbol is ProgramRootSymbol)
FormatRecursively(node.GetSubtree(0), strBuilder);
else if (symbol is StartSymbol)
FormatRecursively(node.GetSubtree(0), strBuilder);
else if (symbol is Absolute)
FormatNode(node, strBuilder, "abs");
else if (symbol is Addition)
FormatNode(node, strBuilder, infixSymbol: " + ");
else if (symbol is Subtraction)
FormatSubtraction(node, strBuilder);
else if (symbol is Multiplication)
FormatNode(node, strBuilder, infixSymbol: " * ");
else if (symbol is Division)
FormatDivision(node, strBuilder);
else if (symbol is Average)
FormatNode(node, strBuilder, prefixSymbol: "statistics.mean", openingSymbol: "([", closingSymbol: "])");
else if (symbol is Sine)
FormatNode(node, strBuilder, "math.sin");
else if (symbol is Cosine)
FormatNode(node, strBuilder, "math.cos");
else if (symbol is Tangent)
FormatNode(node, strBuilder, "math.tan");
else if (symbol is HyperbolicTangent)
FormatNode(node, strBuilder, "math.tanh");
else if (symbol is Exponential)
FormatNode(node, strBuilder, "math.exp");
else if (symbol is Logarithm)
FormatNode(node, strBuilder, "math.log");
else if (symbol is Power)
FormatNode(node, strBuilder, "math.pow");
else if (symbol is Root)
FormatRoot(node, strBuilder);
else if (symbol is Square)
FormatPower(node, strBuilder, "2");
else if (symbol is SquareRoot)
FormatNode(node, strBuilder, "math.sqrt");
else if (symbol is Cube)
FormatPower(node, strBuilder, "3");
else if (symbol is CubeRoot)
FormatNode(node, strBuilder, closingSymbol: " ** (1. / 3))");
else if (symbol is AnalyticQuotient)
FormatAnalyticQuotient(node, strBuilder);
else if (symbol is And)
FormatNode(node, strBuilder, infixSymbol: " and ");
else if (symbol is Or)
FormatNode(node, strBuilder, infixSymbol: " or ");
else if (symbol is Xor)
FormatNode(node, strBuilder, infixSymbol: " ^ ");
else if (symbol is Not)
FormatNode(node, strBuilder, "not");
else if (symbol is IfThenElse)
FormatNode(node, strBuilder, "evaluate_if");
else if (symbol is GreaterThan)
FormatNode(node, strBuilder, infixSymbol: " > ");
else if (symbol is LessThan)
FormatNode(node, strBuilder, infixSymbol: " < ");
else if (node is VariableTreeNode)
FormatVariableTreeNode(node, strBuilder);
else if (node is ConstantTreeNode)
FormatConstantTreeNode(node, strBuilder);
else
throw new NotSupportedException("Formatting of symbol: " + symbol + " not supported for Python symbolic expression tree formatter.");
}
private static string VariableName2Identifier(string variableName) => variableName.Replace(" ", "_");
private static void FormatNode(ISymbolicExpressionTreeNode node, StringBuilder strBuilder, string prefixSymbol = "", string openingSymbol = "(", string closingSymbol = ")", string infixSymbol = ",") {
strBuilder.Append($"{prefixSymbol}{openingSymbol}");
foreach (var child in node.Subtrees) {
FormatRecursively(child, strBuilder);
if (child != node.Subtrees.Last())
strBuilder.Append(infixSymbol);
}
strBuilder.Append(closingSymbol);
}
private static void FormatVariableTreeNode(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
var varNode = node as VariableTreeNode;
var formattedVariable = VariableName2Identifier(varNode.VariableName);
var variableWeight = varNode.Weight.ToString("g17", CultureInfo.InvariantCulture);
strBuilder.Append($"{formattedVariable} * {variableWeight}");
}
private static void FormatConstantTreeNode(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
var constNode = node as ConstantTreeNode;
strBuilder.Append(constNode.Value.ToString("g17", CultureInfo.InvariantCulture));
}
private static void FormatPower(ISymbolicExpressionTreeNode node, StringBuilder strBuilder, string exponent) {
strBuilder.Append("math.pow(");
FormatRecursively(node.GetSubtree(0), strBuilder);
strBuilder.Append($", {exponent})");
}
private static void FormatRoot(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
strBuilder.Append("math.pow(");
FormatRecursively(node.GetSubtree(0), strBuilder);
strBuilder.Append(", 1.0 / (");
FormatRecursively(node.GetSubtree(1), strBuilder);
strBuilder.Append("))");
}
private static void FormatSubtraction(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
if (node.SubtreeCount == 1) {
strBuilder.Append("-");
FormatRecursively(node.GetSubtree(0), strBuilder);
return;
}
//Default case: more than 1 child
FormatNode(node, strBuilder, infixSymbol: " - ");
}
private static void FormatDivision(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
strBuilder.Append("(");
if (node.SubtreeCount == 1) {
strBuilder.Append("1.0 / ");
FormatRecursively(node.GetSubtree(0), strBuilder);
} else {
FormatRecursively(node.GetSubtree(0), strBuilder);
strBuilder.Append(" / (");
for (int i = 1; i < node.SubtreeCount; i++) {
if (i > 1) strBuilder.Append(" * ");
FormatRecursively(node.GetSubtree(i), strBuilder);
}
strBuilder.Append(")");
}
strBuilder.Append(")");
}
private static void FormatAnalyticQuotient(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
strBuilder.Append("(");
FormatRecursively(node.GetSubtree(0), strBuilder);
strBuilder.Append(" / math.sqrt(1 + math.pow(");
FormatRecursively(node.GetSubtree(1), strBuilder);
strBuilder.Append(" , 2) ) )");
}
}
}