Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Formatters/SymbolicDataAnalysisExpressionPythonFormatter.cs @ 17929

Last change on this file since 17929 was 17929, checked in by mkommend, 3 years ago

#3105: Added static keyword to private methods to indicate an (most likely) side-effect free method.

File size: 12.3 KB
RevLine 
[17855]1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Globalization;
25using System.Linq;
26using System.Text;
27using HEAL.Attic;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
33  [Item("Python String Formatter", "String formatter for string representations of symbolic data analysis expressions in Python syntax.")]
34  [StorableType("37C1E1DD-437F-414B-AA96-9C6A0F6FEE46")]
35  public sealed class SymbolicDataAnalysisExpressionPythonFormatter : NamedItem, ISymbolicExpressionTreeStringFormatter {
36
37    [StorableConstructor]
38    private SymbolicDataAnalysisExpressionPythonFormatter(StorableConstructorFlag _) : base(_) { }
39    private SymbolicDataAnalysisExpressionPythonFormatter(SymbolicDataAnalysisExpressionPythonFormatter original, Cloner cloner) : base(original, cloner) { }
40    public SymbolicDataAnalysisExpressionPythonFormatter()
41      : base() {
42      Name = ItemName;
43      Description = ItemDescription;
44    }
45
46    public override IDeepCloneable Clone(Cloner cloner) => new SymbolicDataAnalysisExpressionPythonFormatter(this, cloner);
47
48    public string Format(ISymbolicExpressionTree symbolicExpressionTree) {
49      StringBuilder strBuilderModel = new StringBuilder();
[17919]50      var header = GenerateHeader(symbolicExpressionTree);
[17855]51      FormatRecursively(symbolicExpressionTree.Root, strBuilderModel);
[17919]52      return $"{header}{strBuilderModel}";
[17855]53    }
54
[17919]55    public static string FormatTree(ISymbolicExpressionTree symbolicExpressionTree) {
56      var formatter = new SymbolicDataAnalysisExpressionPythonFormatter();
57      return formatter.Format(symbolicExpressionTree);
58    }
59
[17929]60    private static string GenerateHeader(ISymbolicExpressionTree symbolicExpressionTree) {
[17855]61      StringBuilder strBuilder = new StringBuilder();
[17919]62
63      ISet<string> variables = new HashSet<string>();
64      int mathLibCounter = 0;
65      int statisticLibCounter = 0;
66      int evaluateIfCounter = 0;
67
68      // iterate tree and search for necessary imports and variable names
69      foreach (var node in symbolicExpressionTree.IterateNodesPostfix()) {
70        ISymbol symbol = node.Symbol;
71        if (symbol is Average) statisticLibCounter++;
72        else if (symbol is IfThenElse) evaluateIfCounter++;
73        else if (symbol is Cosine) mathLibCounter++;
74        else if (symbol is Exponential) mathLibCounter++;
75        else if (symbol is Logarithm) mathLibCounter++;
76        else if (symbol is Sine) mathLibCounter++;
77        else if (symbol is Tangent) mathLibCounter++;
78        else if (symbol is HyperbolicTangent) mathLibCounter++;
79        else if (symbol is SquareRoot) mathLibCounter++;
80        else if (symbol is Power) mathLibCounter++;
81        else if (symbol is AnalyticQuotient) mathLibCounter++;
82        else if (node is VariableTreeNode) {
83          var varNode = node as VariableTreeNode;
84          var formattedVariable = VariableName2Identifier(varNode.VariableName);
85          variables.Add(formattedVariable);
86        }
87      }
88
89      // generate import section (if necessary)
90      var importSection = GenerateNecessaryImports(mathLibCounter, statisticLibCounter);
91      strBuilder.Append(importSection);
92
93      // generate if-then-else helper construct (if necessary)
94      var ifThenElseSourceSection = GenerateIfThenElseSource(evaluateIfCounter);
95      strBuilder.Append(ifThenElseSourceSection);
96
97      // generate model evaluation function
98      var modelEvaluationFunctionSection = GenerateModelEvaluationFunction(variables);
99      strBuilder.Append(modelEvaluationFunctionSection);
100
[17855]101      return strBuilder.ToString();
102    }
103
[17929]104    private static string GenerateNecessaryImports(int mathLibCounter, int statisticLibCounter) {
[17919]105      StringBuilder strBuilder = new StringBuilder();
106      if (mathLibCounter > 0 || statisticLibCounter > 0) {
[17860]107        strBuilder.AppendLine("# imports");
[17919]108        if (mathLibCounter > 0)
109          strBuilder.AppendLine("import math");
110        if (statisticLibCounter > 0)
111          strBuilder.AppendLine("import statistics");
112        strBuilder.AppendLine();
113      }
114      return strBuilder.ToString();
[17855]115    }
116
[17929]117    private static string GenerateIfThenElseSource(int evaluateIfCounter) {
[17919]118      StringBuilder strBuilder = new StringBuilder();
119      if (evaluateIfCounter > 0) {
[17860]120        strBuilder.AppendLine("# condition helper function");
121        strBuilder.AppendLine("def evaluate_if(condition, then_path, else_path): ");
122        strBuilder.AppendLine("\tif condition:");
123        strBuilder.AppendLine("\t\treturn then_path");
124        strBuilder.AppendLine("\telse:");
125        strBuilder.AppendLine("\t\treturn else_path");
126      }
[17919]127      return strBuilder.ToString();
[17855]128    }
129
[17929]130    private static string GenerateModelEvaluationFunction(ISet<string> variables) {
[17919]131      StringBuilder strBuilder = new StringBuilder();
[17855]132      strBuilder.Append("def evaluate(");
[17919]133      var orderedVariables = variables.OrderBy(n => n, new NaturalStringComparer());
134      foreach (var variable in orderedVariables) {
135        strBuilder.Append($"{variable}");
136        if (variable != orderedVariables.Last())
137          strBuilder.Append(", ");
[17855]138      }
139      strBuilder.AppendLine("):");
140      strBuilder.Append("\treturn ");
[17919]141      return strBuilder.ToString();
[17855]142    }
143
[17929]144    private static void FormatRecursively(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
[17855]145      ISymbol symbol = node.Symbol;
[17919]146      if (symbol is ProgramRootSymbol)
[17855]147        FormatRecursively(node.GetSubtree(0), strBuilder);
[17919]148      else if (symbol is StartSymbol)
[17855]149        FormatRecursively(node.GetSubtree(0), strBuilder);
[17919]150      else if (symbol is Absolute)
151        FormatNode(node, strBuilder, "abs");
152      else if (symbol is Addition)
[17855]153        FormatNode(node, strBuilder, infixSymbol: " + ");
[17919]154      else if (symbol is Subtraction)
155        FormatSubtraction(node, strBuilder);
156      else if (symbol is Multiplication)
157        FormatNode(node, strBuilder, infixSymbol: " * ");
158      else if (symbol is Division)
159        FormatDivision(node, strBuilder);
160      else if (symbol is Average)
[17855]161        FormatNode(node, strBuilder, prefixSymbol: "statistics.mean", openingSymbol: "([", closingSymbol: "])");
[17919]162      else if (symbol is Sine)
163        FormatNode(node, strBuilder, "math.sin");
164      else if (symbol is Cosine)
[17855]165        FormatNode(node, strBuilder, "math.cos");
[17919]166      else if (symbol is Tangent)
167        FormatNode(node, strBuilder, "math.tan");
168      else if (symbol is HyperbolicTangent)
169        FormatNode(node, strBuilder, "math.tanh");
170      else if (symbol is Exponential)
[17855]171        FormatNode(node, strBuilder, "math.exp");
[17919]172      else if (symbol is Logarithm)
[17855]173        FormatNode(node, strBuilder, "math.log");
[17919]174      else if (symbol is Power)
175        FormatNode(node, strBuilder, "math.pow");
176      else if (symbol is Root)
177        FormatRoot(node, strBuilder);
178      else if (symbol is Square)
[17855]179        FormatPower(node, strBuilder, "2");
[17919]180      else if (symbol is SquareRoot)
[17855]181        FormatNode(node, strBuilder, "math.sqrt");
[17919]182      else if (symbol is Cube)
[17855]183        FormatPower(node, strBuilder, "3");
[17919]184      else if (symbol is CubeRoot)
[17855]185        FormatNode(node, strBuilder, closingSymbol: " ** (1. / 3))");
[17919]186      else if (symbol is AnalyticQuotient)
187        FormatAnalyticQuotient(node, strBuilder);
188      else if (symbol is And)
189        FormatNode(node, strBuilder, infixSymbol: " and ");
190      else if (symbol is Or)
191        FormatNode(node, strBuilder, infixSymbol: " or ");
192      else if (symbol is Xor)
193        FormatNode(node, strBuilder, infixSymbol: " ^ ");
194      else if (symbol is Not)
195        FormatNode(node, strBuilder, "not");
196      else if (symbol is IfThenElse)
197        FormatNode(node, strBuilder, "evaluate_if");
198      else if (symbol is GreaterThan)
199        FormatNode(node, strBuilder, infixSymbol: " > ");
200      else if (symbol is LessThan)
201        FormatNode(node, strBuilder, infixSymbol: " < ");
202      else if (node is VariableTreeNode)
203        FormatVariableTreeNode(node, strBuilder);
204      else if (node is ConstantTreeNode)
205        FormatConstantTreeNode(node, strBuilder);
206      else
207        throw new NotSupportedException("Formatting of symbol: " + symbol + " not supported for Python symbolic expression tree formatter.");
208    }
209
[17929]210    private static string VariableName2Identifier(string variableName) => variableName.Replace(" ", "_");
[17919]211
[17929]212    private static void FormatNode(ISymbolicExpressionTreeNode node, StringBuilder strBuilder, string prefixSymbol = "", string openingSymbol = "(", string closingSymbol = ")", string infixSymbol = ",") {
[17919]213      strBuilder.Append($"{prefixSymbol}{openingSymbol}");
214      foreach (var child in node.Subtrees) {
215        FormatRecursively(child, strBuilder);
216        if (child != node.Subtrees.Last())
217          strBuilder.Append(infixSymbol);
[17855]218      }
[17919]219      strBuilder.Append(closingSymbol);
[17855]220    }
221
[17929]222    private static void FormatVariableTreeNode(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
[17855]223      var varNode = node as VariableTreeNode;
[17919]224      var formattedVariable = VariableName2Identifier(varNode.VariableName);
225      var variableWeight = varNode.Weight.ToString("g17", CultureInfo.InvariantCulture);
226      strBuilder.Append($"{formattedVariable} * {variableWeight}");
[17855]227    }
228
[17929]229    private static void FormatConstantTreeNode(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
[17855]230      var constNode = node as ConstantTreeNode;
231      strBuilder.Append(constNode.Value.ToString("g17", CultureInfo.InvariantCulture));
232    }
233
[17929]234    private static void FormatPower(ISymbolicExpressionTreeNode node, StringBuilder strBuilder, string exponent) {
[17855]235      strBuilder.Append("math.pow(");
236      FormatRecursively(node.GetSubtree(0), strBuilder);
237      strBuilder.Append($", {exponent})");
238    }
239
[17929]240    private static void FormatRoot(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
[17855]241      strBuilder.Append("math.pow(");
242      FormatRecursively(node.GetSubtree(0), strBuilder);
243      strBuilder.Append(", 1.0 / (");
244      FormatRecursively(node.GetSubtree(1), strBuilder);
245      strBuilder.Append("))");
246    }
247
[17929]248    private static void FormatSubtraction(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
[17919]249      if (node.SubtreeCount == 1) {
250        strBuilder.Append("-");
251        FormatRecursively(node.GetSubtree(0), strBuilder);
252        return;
[17855]253      }
[17919]254      //Default case: more than 1 child
255      FormatNode(node, strBuilder, infixSymbol: " - ");
[17855]256    }
257
[17929]258    private static void FormatDivision(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
[17855]259      strBuilder.Append("(");
260      if (node.SubtreeCount == 1) {
261        strBuilder.Append("1.0 / ");
262        FormatRecursively(node.GetSubtree(0), strBuilder);
263      } else {
264        FormatRecursively(node.GetSubtree(0), strBuilder);
[17919]265        strBuilder.Append(" / (");
[17855]266        for (int i = 1; i < node.SubtreeCount; i++) {
267          if (i > 1) strBuilder.Append(" * ");
268          FormatRecursively(node.GetSubtree(i), strBuilder);
269        }
270        strBuilder.Append(")");
271      }
272      strBuilder.Append(")");
273    }
274
[17929]275    private static void FormatAnalyticQuotient(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
[17919]276      strBuilder.Append("(");
277      FormatRecursively(node.GetSubtree(0), strBuilder);
278      strBuilder.Append(" / math.sqrt(1 + math.pow(");
279      FormatRecursively(node.GetSubtree(1), strBuilder);
280      strBuilder.Append(" , 2) ) )");
[17855]281    }
282  }
283}
Note: See TracBrowser for help on using the repository browser.