Free cookie consent management tool by TermsFeed Policy Generator

source: stable/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Formatters/SymbolicDataAnalysisExpressionPythonFormatter.cs @ 18095

Last change on this file since 18095 was 18016, checked in by gkronber, 3 years ago

#3105: merged r17922, r17929 from trunk to stable

File size: 12.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Globalization;
25using System.Linq;
26using System.Text;
27using HEAL.Attic;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
33  [Item("Python String Formatter", "String formatter for string representations of symbolic data analysis expressions in Python syntax.")]
34  [StorableType("37C1E1DD-437F-414B-AA96-9C6A0F6FEE46")]
35  public sealed class SymbolicDataAnalysisExpressionPythonFormatter : NamedItem, ISymbolicExpressionTreeStringFormatter {
36
37    [StorableConstructor]
38    private SymbolicDataAnalysisExpressionPythonFormatter(StorableConstructorFlag _) : base(_) { }
39    private SymbolicDataAnalysisExpressionPythonFormatter(SymbolicDataAnalysisExpressionPythonFormatter original, Cloner cloner) : base(original, cloner) { }
40    public SymbolicDataAnalysisExpressionPythonFormatter()
41      : base() {
42      Name = ItemName;
43      Description = ItemDescription;
44    }
45
46    public override IDeepCloneable Clone(Cloner cloner) => new SymbolicDataAnalysisExpressionPythonFormatter(this, cloner);
47
48    public string Format(ISymbolicExpressionTree symbolicExpressionTree) {
49      StringBuilder strBuilderModel = new StringBuilder();
50      var header = GenerateHeader(symbolicExpressionTree);
51      FormatRecursively(symbolicExpressionTree.Root, strBuilderModel);
52      return $"{header}{strBuilderModel}";
53    }
54
55    public static string FormatTree(ISymbolicExpressionTree symbolicExpressionTree) {
56      var formatter = new SymbolicDataAnalysisExpressionPythonFormatter();
57      return formatter.Format(symbolicExpressionTree);
58    }
59
60    private static string GenerateHeader(ISymbolicExpressionTree symbolicExpressionTree) {
61      StringBuilder strBuilder = new StringBuilder();
62
63      ISet<string> variables = new HashSet<string>();
64      int mathLibCounter = 0;
65      int statisticLibCounter = 0;
66      int evaluateIfCounter = 0;
67
68      // iterate tree and search for necessary imports and variable names
69      foreach (var node in symbolicExpressionTree.IterateNodesPostfix()) {
70        ISymbol symbol = node.Symbol;
71        if (symbol is Average) statisticLibCounter++;
72        else if (symbol is IfThenElse) evaluateIfCounter++;
73        else if (symbol is Cosine) mathLibCounter++;
74        else if (symbol is Exponential) mathLibCounter++;
75        else if (symbol is Logarithm) mathLibCounter++;
76        else if (symbol is Sine) mathLibCounter++;
77        else if (symbol is Tangent) mathLibCounter++;
78        else if (symbol is HyperbolicTangent) mathLibCounter++;
79        else if (symbol is SquareRoot) mathLibCounter++;
80        else if (symbol is Power) mathLibCounter++;
81        else if (symbol is AnalyticQuotient) mathLibCounter++;
82        else if (node is VariableTreeNode) {
83          var varNode = node as VariableTreeNode;
84          var formattedVariable = VariableName2Identifier(varNode.VariableName);
85          variables.Add(formattedVariable);
86        }
87      }
88
89      // generate import section (if necessary)
90      var importSection = GenerateNecessaryImports(mathLibCounter, statisticLibCounter);
91      strBuilder.Append(importSection);
92
93      // generate if-then-else helper construct (if necessary)
94      var ifThenElseSourceSection = GenerateIfThenElseSource(evaluateIfCounter);
95      strBuilder.Append(ifThenElseSourceSection);
96
97      // generate model evaluation function
98      var modelEvaluationFunctionSection = GenerateModelEvaluationFunction(variables);
99      strBuilder.Append(modelEvaluationFunctionSection);
100
101      return strBuilder.ToString();
102    }
103
104    private static string GenerateNecessaryImports(int mathLibCounter, int statisticLibCounter) {
105      StringBuilder strBuilder = new StringBuilder();
106      if (mathLibCounter > 0 || statisticLibCounter > 0) {
107        strBuilder.AppendLine("# imports");
108        if (mathLibCounter > 0)
109          strBuilder.AppendLine("import math");
110        if (statisticLibCounter > 0)
111          strBuilder.AppendLine("import statistics");
112        strBuilder.AppendLine();
113      }
114      return strBuilder.ToString();
115    }
116
117    private static string GenerateIfThenElseSource(int evaluateIfCounter) {
118      StringBuilder strBuilder = new StringBuilder();
119      if (evaluateIfCounter > 0) {
120        strBuilder.AppendLine("# condition helper function");
121        strBuilder.AppendLine("def evaluate_if(condition, then_path, else_path): ");
122        strBuilder.AppendLine("\tif condition:");
123        strBuilder.AppendLine("\t\treturn then_path");
124        strBuilder.AppendLine("\telse:");
125        strBuilder.AppendLine("\t\treturn else_path");
126      }
127      return strBuilder.ToString();
128    }
129
130    private static string GenerateModelEvaluationFunction(ISet<string> variables) {
131      StringBuilder strBuilder = new StringBuilder();
132      strBuilder.Append("def evaluate(");
133      var orderedVariables = variables.OrderBy(n => n, new NaturalStringComparer());
134      foreach (var variable in orderedVariables) {
135        strBuilder.Append($"{variable}");
136        if (variable != orderedVariables.Last())
137          strBuilder.Append(", ");
138      }
139      strBuilder.AppendLine("):");
140      strBuilder.Append("\treturn ");
141      return strBuilder.ToString();
142    }
143
144    private static void FormatRecursively(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
145      ISymbol symbol = node.Symbol;
146      if (symbol is ProgramRootSymbol)
147        FormatRecursively(node.GetSubtree(0), strBuilder);
148      else if (symbol is StartSymbol)
149        FormatRecursively(node.GetSubtree(0), strBuilder);
150      else if (symbol is Absolute)
151        FormatNode(node, strBuilder, "abs");
152      else if (symbol is Addition)
153        FormatNode(node, strBuilder, infixSymbol: " + ");
154      else if (symbol is Subtraction)
155        FormatSubtraction(node, strBuilder);
156      else if (symbol is Multiplication)
157        FormatNode(node, strBuilder, infixSymbol: " * ");
158      else if (symbol is Division)
159        FormatDivision(node, strBuilder);
160      else if (symbol is Average)
161        FormatNode(node, strBuilder, prefixSymbol: "statistics.mean", openingSymbol: "([", closingSymbol: "])");
162      else if (symbol is Sine)
163        FormatNode(node, strBuilder, "math.sin");
164      else if (symbol is Cosine)
165        FormatNode(node, strBuilder, "math.cos");
166      else if (symbol is Tangent)
167        FormatNode(node, strBuilder, "math.tan");
168      else if (symbol is HyperbolicTangent)
169        FormatNode(node, strBuilder, "math.tanh");
170      else if (symbol is Exponential)
171        FormatNode(node, strBuilder, "math.exp");
172      else if (symbol is Logarithm)
173        FormatNode(node, strBuilder, "math.log");
174      else if (symbol is Power)
175        FormatNode(node, strBuilder, "math.pow");
176      else if (symbol is Root)
177        FormatRoot(node, strBuilder);
178      else if (symbol is Square)
179        FormatPower(node, strBuilder, "2");
180      else if (symbol is SquareRoot)
181        FormatNode(node, strBuilder, "math.sqrt");
182      else if (symbol is Cube)
183        FormatPower(node, strBuilder, "3");
184      else if (symbol is CubeRoot)
185        FormatNode(node, strBuilder, closingSymbol: " ** (1. / 3))");
186      else if (symbol is AnalyticQuotient)
187        FormatAnalyticQuotient(node, strBuilder);
188      else if (symbol is And)
189        FormatNode(node, strBuilder, infixSymbol: " and ");
190      else if (symbol is Or)
191        FormatNode(node, strBuilder, infixSymbol: " or ");
192      else if (symbol is Xor)
193        FormatNode(node, strBuilder, infixSymbol: " ^ ");
194      else if (symbol is Not)
195        FormatNode(node, strBuilder, "not");
196      else if (symbol is IfThenElse)
197        FormatNode(node, strBuilder, "evaluate_if");
198      else if (symbol is GreaterThan)
199        FormatNode(node, strBuilder, infixSymbol: " > ");
200      else if (symbol is LessThan)
201        FormatNode(node, strBuilder, infixSymbol: " < ");
202      else if (node is VariableTreeNode)
203        FormatVariableTreeNode(node, strBuilder);
204      else if (node is ConstantTreeNode)
205        FormatConstantTreeNode(node, strBuilder);
206      else
207        throw new NotSupportedException("Formatting of symbol: " + symbol + " not supported for Python symbolic expression tree formatter.");
208    }
209
210    private static string VariableName2Identifier(string variableName) => variableName.Replace(" ", "_");
211
212    private static void FormatNode(ISymbolicExpressionTreeNode node, StringBuilder strBuilder, string prefixSymbol = "", string openingSymbol = "(", string closingSymbol = ")", string infixSymbol = ",") {
213      strBuilder.Append($"{prefixSymbol}{openingSymbol}");
214      foreach (var child in node.Subtrees) {
215        FormatRecursively(child, strBuilder);
216        if (child != node.Subtrees.Last())
217          strBuilder.Append(infixSymbol);
218      }
219      strBuilder.Append(closingSymbol);
220    }
221
222    private static void FormatVariableTreeNode(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
223      var varNode = node as VariableTreeNode;
224      var formattedVariable = VariableName2Identifier(varNode.VariableName);
225      var variableWeight = varNode.Weight.ToString("g17", CultureInfo.InvariantCulture);
226      strBuilder.Append($"{formattedVariable} * {variableWeight}");
227    }
228
229    private static void FormatConstantTreeNode(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
230      var constNode = node as ConstantTreeNode;
231      strBuilder.Append(constNode.Value.ToString("g17", CultureInfo.InvariantCulture));
232    }
233
234    private static void FormatPower(ISymbolicExpressionTreeNode node, StringBuilder strBuilder, string exponent) {
235      strBuilder.Append("math.pow(");
236      FormatRecursively(node.GetSubtree(0), strBuilder);
237      strBuilder.Append($", {exponent})");
238    }
239
240    private static void FormatRoot(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
241      strBuilder.Append("math.pow(");
242      FormatRecursively(node.GetSubtree(0), strBuilder);
243      strBuilder.Append(", 1.0 / (");
244      FormatRecursively(node.GetSubtree(1), strBuilder);
245      strBuilder.Append("))");
246    }
247
248    private static void FormatSubtraction(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
249      if (node.SubtreeCount == 1) {
250        strBuilder.Append("-");
251        FormatRecursively(node.GetSubtree(0), strBuilder);
252        return;
253      }
254      //Default case: more than 1 child
255      FormatNode(node, strBuilder, infixSymbol: " - ");
256    }
257
258    private static void FormatDivision(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
259      strBuilder.Append("(");
260      if (node.SubtreeCount == 1) {
261        strBuilder.Append("1.0 / ");
262        FormatRecursively(node.GetSubtree(0), strBuilder);
263      } else {
264        FormatRecursively(node.GetSubtree(0), strBuilder);
265        strBuilder.Append(" / (");
266        for (int i = 1; i < node.SubtreeCount; i++) {
267          if (i > 1) strBuilder.Append(" * ");
268          FormatRecursively(node.GetSubtree(i), strBuilder);
269        }
270        strBuilder.Append(")");
271      }
272      strBuilder.Append(")");
273    }
274
275    private static void FormatAnalyticQuotient(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
276      strBuilder.Append("(");
277      FormatRecursively(node.GetSubtree(0), strBuilder);
278      strBuilder.Append(" / math.sqrt(1 + math.pow(");
279      FormatRecursively(node.GetSubtree(1), strBuilder);
280      strBuilder.Append(" , 2) ) )");
281    }
282  }
283}
Note: See TracBrowser for help on using the repository browser.