Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Formatters/SymbolicDataAnalysisExpressionPythonFormatter.cs

Last change on this file was 18220, checked in by gkronber, 2 years ago

#3136: reintegrated structure-template GP branch into trunk

File size: 12.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Globalization;
25using System.Linq;
26using System.Text;
27using HEAL.Attic;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
33  [Item("Python String Formatter", "String formatter for string representations of symbolic data analysis expressions in Python syntax.")]
34  [StorableType("37C1E1DD-437F-414B-AA96-9C6A0F6FEE46")]
35  public sealed class SymbolicDataAnalysisExpressionPythonFormatter : NamedItem, ISymbolicExpressionTreeStringFormatter {
36
37    [StorableConstructor]
38    private SymbolicDataAnalysisExpressionPythonFormatter(StorableConstructorFlag _) : base(_) { }
39    private SymbolicDataAnalysisExpressionPythonFormatter(SymbolicDataAnalysisExpressionPythonFormatter original, Cloner cloner) : base(original, cloner) { }
40    public SymbolicDataAnalysisExpressionPythonFormatter()
41      : base() {
42      Name = ItemName;
43      Description = ItemDescription;
44    }
45
46    public override IDeepCloneable Clone(Cloner cloner) => new SymbolicDataAnalysisExpressionPythonFormatter(this, cloner);
47
48    public string Format(ISymbolicExpressionTree symbolicExpressionTree) {
49      StringBuilder strBuilderModel = new StringBuilder();
50      var header = GenerateHeader(symbolicExpressionTree);
51      FormatRecursively(symbolicExpressionTree.Root, strBuilderModel);
52      return $"{header}{strBuilderModel}";
53    }
54
55    public static string FormatTree(ISymbolicExpressionTree symbolicExpressionTree) {
56      var formatter = new SymbolicDataAnalysisExpressionPythonFormatter();
57      return formatter.Format(symbolicExpressionTree);
58    }
59
60    private static string GenerateHeader(ISymbolicExpressionTree symbolicExpressionTree) {
61      StringBuilder strBuilder = new StringBuilder();
62
63      ISet<string> variables = new HashSet<string>();
64      int mathLibCounter = 0;
65      int statisticLibCounter = 0;
66      int evaluateIfCounter = 0;
67
68      // iterate tree and search for necessary imports and variable names
69      foreach (var node in symbolicExpressionTree.IterateNodesPostfix()) {
70        ISymbol symbol = node.Symbol;
71        if (symbol is Average) statisticLibCounter++;
72        else if (symbol is IfThenElse) evaluateIfCounter++;
73        else if (symbol is Cosine) mathLibCounter++;
74        else if (symbol is Exponential) mathLibCounter++;
75        else if (symbol is Logarithm) mathLibCounter++;
76        else if (symbol is Sine) mathLibCounter++;
77        else if (symbol is Tangent) mathLibCounter++;
78        else if (symbol is HyperbolicTangent) mathLibCounter++;
79        else if (symbol is SquareRoot) mathLibCounter++;
80        else if (symbol is Power) mathLibCounter++;
81        else if (symbol is AnalyticQuotient) mathLibCounter++;
82        else if (node is VariableTreeNode) {
83          var varNode = node as VariableTreeNode;
84          var formattedVariable = VariableName2Identifier(varNode.VariableName);
85          variables.Add(formattedVariable);
86        }
87      }
88
89      // generate import section (if necessary)
90      var importSection = GenerateNecessaryImports(mathLibCounter, statisticLibCounter);
91      strBuilder.Append(importSection);
92
93      // generate if-then-else helper construct (if necessary)
94      var ifThenElseSourceSection = GenerateIfThenElseSource(evaluateIfCounter);
95      strBuilder.Append(ifThenElseSourceSection);
96
97      // generate model evaluation function
98      var modelEvaluationFunctionSection = GenerateModelEvaluationFunction(variables);
99      strBuilder.Append(modelEvaluationFunctionSection);
100
101      return strBuilder.ToString();
102    }
103
104    private static string GenerateNecessaryImports(int mathLibCounter, int statisticLibCounter) {
105      StringBuilder strBuilder = new StringBuilder();
106      if (mathLibCounter > 0 || statisticLibCounter > 0) {
107        strBuilder.AppendLine("# imports");
108        if (mathLibCounter > 0)
109          strBuilder.AppendLine("import math");
110        if (statisticLibCounter > 0)
111          strBuilder.AppendLine("import statistics");
112        strBuilder.AppendLine();
113      }
114      return strBuilder.ToString();
115    }
116
117    private static string GenerateIfThenElseSource(int evaluateIfCounter) {
118      StringBuilder strBuilder = new StringBuilder();
119      if (evaluateIfCounter > 0) {
120        strBuilder.AppendLine("# condition helper function");
121        strBuilder.AppendLine("def evaluate_if(condition, then_path, else_path): ");
122        strBuilder.AppendLine("\tif condition:");
123        strBuilder.AppendLine("\t\treturn then_path");
124        strBuilder.AppendLine("\telse:");
125        strBuilder.AppendLine("\t\treturn else_path");
126      }
127      return strBuilder.ToString();
128    }
129
130    private static string GenerateModelEvaluationFunction(ISet<string> variables) {
131      StringBuilder strBuilder = new StringBuilder();
132      strBuilder.Append("def evaluate(");
133      var orderedVariables = variables.OrderBy(n => n, new NaturalStringComparer());
134      foreach (var variable in orderedVariables) {
135        strBuilder.Append($"{variable}");
136        if (variable != orderedVariables.Last())
137          strBuilder.Append(", ");
138      }
139      strBuilder.AppendLine("):");
140      strBuilder.Append("\treturn ");
141      return strBuilder.ToString();
142    }
143
144    private static void FormatRecursively(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
145      ISymbol symbol = node.Symbol;
146      if (symbol is ProgramRootSymbol)
147        FormatRecursively(node.GetSubtree(0), strBuilder);
148      else if (symbol is StartSymbol)
149        FormatRecursively(node.GetSubtree(0), strBuilder);
150      else if (symbol is Absolute)
151        FormatNode(node, strBuilder, "abs");
152      else if (symbol is Addition)
153        FormatNode(node, strBuilder, infixSymbol: " + ");
154      else if (symbol is Subtraction)
155        FormatSubtraction(node, strBuilder);
156      else if (symbol is Multiplication)
157        FormatNode(node, strBuilder, infixSymbol: " * ");
158      else if (symbol is Division)
159        FormatDivision(node, strBuilder);
160      else if (symbol is Average)
161        FormatNode(node, strBuilder, prefixSymbol: "statistics.mean", openingSymbol: "([", closingSymbol: "])");
162      else if (symbol is Sine)
163        FormatNode(node, strBuilder, "math.sin");
164      else if (symbol is Cosine)
165        FormatNode(node, strBuilder, "math.cos");
166      else if (symbol is Tangent)
167        FormatNode(node, strBuilder, "math.tan");
168      else if (symbol is HyperbolicTangent)
169        FormatNode(node, strBuilder, "math.tanh");
170      else if (symbol is Exponential)
171        FormatNode(node, strBuilder, "math.exp");
172      else if (symbol is Logarithm)
173        FormatNode(node, strBuilder, "math.log");
174      else if (symbol is Power)
175        FormatNode(node, strBuilder, "math.pow");
176      else if (symbol is Root)
177        FormatRoot(node, strBuilder);
178      else if (symbol is Square)
179        FormatPower(node, strBuilder, "2");
180      else if (symbol is SquareRoot)
181        FormatNode(node, strBuilder, "math.sqrt");
182      else if (symbol is Cube)
183        FormatPower(node, strBuilder, "3");
184      else if (symbol is CubeRoot)
185        FormatNode(node, strBuilder, closingSymbol: " ** (1. / 3))");
186      else if (symbol is AnalyticQuotient)
187        FormatAnalyticQuotient(node, strBuilder);
188      else if (symbol is And)
189        FormatNode(node, strBuilder, infixSymbol: " and ");
190      else if (symbol is Or)
191        FormatNode(node, strBuilder, infixSymbol: " or ");
192      else if (symbol is Xor)
193        FormatNode(node, strBuilder, infixSymbol: " ^ ");
194      else if (symbol is Not)
195        FormatNode(node, strBuilder, "not");
196      else if (symbol is IfThenElse)
197        FormatNode(node, strBuilder, "evaluate_if");
198      else if (symbol is GreaterThan)
199        FormatNode(node, strBuilder, infixSymbol: " > ");
200      else if (symbol is LessThan)
201        FormatNode(node, strBuilder, infixSymbol: " < ");
202      else if (node is VariableTreeNode)
203        FormatVariableTreeNode(node, strBuilder);
204      else if (node is INumericTreeNode)
205        FormatNumericTreeNode(node, strBuilder);
206      else if (symbol is SubFunctionSymbol)
207        FormatRecursively(node.GetSubtree(0), strBuilder);
208      else
209        throw new NotSupportedException("Formatting of symbol: " + symbol + " not supported for Python symbolic expression tree formatter.");
210    }
211
212    private static string VariableName2Identifier(string variableName) => variableName.Replace(" ", "_");
213
214    private static void FormatNode(ISymbolicExpressionTreeNode node, StringBuilder strBuilder, string prefixSymbol = "", string openingSymbol = "(", string closingSymbol = ")", string infixSymbol = ",") {
215      strBuilder.Append($"{prefixSymbol}{openingSymbol}");
216      foreach (var child in node.Subtrees) {
217        FormatRecursively(child, strBuilder);
218        if (child != node.Subtrees.Last())
219          strBuilder.Append(infixSymbol);
220      }
221      strBuilder.Append(closingSymbol);
222    }
223
224    private static void FormatVariableTreeNode(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
225      var varNode = node as VariableTreeNode;
226      var formattedVariable = VariableName2Identifier(varNode.VariableName);
227      var variableWeight = varNode.Weight.ToString("g17", CultureInfo.InvariantCulture);
228      strBuilder.Append($"{formattedVariable} * {variableWeight}");
229    }
230
231    private static void FormatNumericTreeNode(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
232      var numNode = node as INumericTreeNode;
233      strBuilder.Append(numNode.Value.ToString("g17", CultureInfo.InvariantCulture));
234    }
235
236    private static void FormatPower(ISymbolicExpressionTreeNode node, StringBuilder strBuilder, string exponent) {
237      strBuilder.Append("math.pow(");
238      FormatRecursively(node.GetSubtree(0), strBuilder);
239      strBuilder.Append($", {exponent})");
240    }
241
242    private static void FormatRoot(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
243      strBuilder.Append("math.pow(");
244      FormatRecursively(node.GetSubtree(0), strBuilder);
245      strBuilder.Append(", 1.0 / (");
246      FormatRecursively(node.GetSubtree(1), strBuilder);
247      strBuilder.Append("))");
248    }
249
250    private static void FormatSubtraction(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
251      if (node.SubtreeCount == 1) {
252        strBuilder.Append("-");
253        FormatRecursively(node.GetSubtree(0), strBuilder);
254        return;
255      }
256      //Default case: more than 1 child
257      FormatNode(node, strBuilder, infixSymbol: " - ");
258    }
259
260    private static void FormatDivision(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
261      strBuilder.Append("(");
262      if (node.SubtreeCount == 1) {
263        strBuilder.Append("1.0 / ");
264        FormatRecursively(node.GetSubtree(0), strBuilder);
265      } else {
266        FormatRecursively(node.GetSubtree(0), strBuilder);
267        strBuilder.Append(" / (");
268        for (int i = 1; i < node.SubtreeCount; i++) {
269          if (i > 1) strBuilder.Append(" * ");
270          FormatRecursively(node.GetSubtree(i), strBuilder);
271        }
272        strBuilder.Append(")");
273      }
274      strBuilder.Append(")");
275    }
276
277    private static void FormatAnalyticQuotient(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
278      strBuilder.Append("(");
279      FormatRecursively(node.GetSubtree(0), strBuilder);
280      strBuilder.Append(" / math.sqrt(1 + math.pow(");
281      FormatRecursively(node.GetSubtree(1), strBuilder);
282      strBuilder.Append(" , 2) ) )");
283    }
284  }
285}
Note: See TracBrowser for help on using the repository browser.