Changeset 17919


Ignore:
Timestamp:
03/29/21 17:49:47 (2 weeks ago)
Author:
dpiringe
Message:

#3105

  • added a static FormatTree method
  • reordered the if-else construct and some helper methods
  • refactored the header generation
  • removed all private properties
  • added a helper method for AnalyticQuotient
File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/3105_PythonFormatter/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Formatters/SymbolicDataAnalysisExpressionPythonFormatter.cs

    r17860 r17919  
    3535  public sealed class SymbolicDataAnalysisExpressionPythonFormatter : NamedItem, ISymbolicExpressionTreeStringFormatter {
    3636
    37     private int VariableCounter { get; set; } = 0;
    38     private IDictionary<string, string> VariableMap { get; } = new Dictionary<string, string>();
    39     private int MathLibCounter { get; set; } = 0;
    40     private int StatisticLibCounter { get; set; } = 0;
    41     private int EvaluateIfCounter { get; set; } = 0;
    42 
    4337    [StorableConstructor]
    4438    private SymbolicDataAnalysisExpressionPythonFormatter(StorableConstructorFlag _) : base(_) { }
     
    5448    public string Format(ISymbolicExpressionTree symbolicExpressionTree) {
    5549      StringBuilder strBuilderModel = new StringBuilder();
     50      var header = GenerateHeader(symbolicExpressionTree);
    5651      FormatRecursively(symbolicExpressionTree.Root, strBuilderModel);
    57       return $"{GenerateHeader()}{strBuilderModel}";
    58     }
    59 
    60     private string GenerateHeader() {
    61       StringBuilder strBuilder = new StringBuilder();
    62       GenerateImports(strBuilder);
    63       GenerateIfThenElseSource(strBuilder);
    64       GenerateModelComment(strBuilder);
    65       GenerateModelEvaluationFunction(strBuilder);
    66       return strBuilder.ToString();
    67     }
    68 
    69     private void GenerateImports(StringBuilder strBuilder) {
    70       if(MathLibCounter > 0 || StatisticLibCounter > 0)
     52      return $"{header}{strBuilderModel}";
     53    }
     54
     55    public static string FormatTree(ISymbolicExpressionTree symbolicExpressionTree) {
     56      var formatter = new SymbolicDataAnalysisExpressionPythonFormatter();
     57      return formatter.Format(symbolicExpressionTree);
     58    }
     59
     60    private string GenerateHeader(ISymbolicExpressionTree symbolicExpressionTree) {
     61      StringBuilder strBuilder = new StringBuilder();
     62
     63      ISet<string> variables = new HashSet<string>();
     64      int mathLibCounter = 0;
     65      int statisticLibCounter = 0;
     66      int evaluateIfCounter = 0;
     67
     68      // iterate tree and search for necessary imports and variable names
     69      foreach (var node in symbolicExpressionTree.IterateNodesPostfix()) {
     70        ISymbol symbol = node.Symbol;
     71        if (symbol is Average) statisticLibCounter++;
     72        else if (symbol is IfThenElse) evaluateIfCounter++;
     73        else if (symbol is Cosine) mathLibCounter++;
     74        else if (symbol is Exponential) mathLibCounter++;
     75        else if (symbol is Logarithm) mathLibCounter++;
     76        else if (symbol is Sine) mathLibCounter++;
     77        else if (symbol is Tangent) mathLibCounter++;
     78        else if (symbol is HyperbolicTangent) mathLibCounter++;
     79        else if (symbol is SquareRoot) mathLibCounter++;
     80        else if (symbol is Power) mathLibCounter++;
     81        else if (symbol is AnalyticQuotient) mathLibCounter++;
     82        else if (node is VariableTreeNode) {
     83          var varNode = node as VariableTreeNode;
     84          var formattedVariable = VariableName2Identifier(varNode.VariableName);
     85          variables.Add(formattedVariable);
     86        }
     87      }
     88
     89      // generate import section (if necessary)
     90      var importSection = GenerateNecessaryImports(mathLibCounter, statisticLibCounter);
     91      strBuilder.Append(importSection);
     92
     93      // generate if-then-else helper construct (if necessary)
     94      var ifThenElseSourceSection = GenerateIfThenElseSource(evaluateIfCounter);
     95      strBuilder.Append(ifThenElseSourceSection);
     96
     97      // generate model evaluation function
     98      var modelEvaluationFunctionSection = GenerateModelEvaluationFunction(variables);
     99      strBuilder.Append(modelEvaluationFunctionSection);
     100
     101      return strBuilder.ToString();
     102    }
     103
     104    private string GenerateNecessaryImports(int mathLibCounter, int statisticLibCounter) {
     105      StringBuilder strBuilder = new StringBuilder();
     106      if (mathLibCounter > 0 || statisticLibCounter > 0) {
    71107        strBuilder.AppendLine("# imports");
    72       if(MathLibCounter > 0)
    73         strBuilder.AppendLine("import math");
    74       if(StatisticLibCounter > 0)
    75         strBuilder.AppendLine("import statistics");
    76     }
    77 
    78     private void GenerateIfThenElseSource(StringBuilder strBuilder) {
    79       if(EvaluateIfCounter > 0) {
     108        if (mathLibCounter > 0)
     109          strBuilder.AppendLine("import math");
     110        if (statisticLibCounter > 0)
     111          strBuilder.AppendLine("import statistics");
     112        strBuilder.AppendLine();
     113      }
     114      return strBuilder.ToString();
     115    }
     116
     117    private string GenerateIfThenElseSource(int evaluateIfCounter) {
     118      StringBuilder strBuilder = new StringBuilder();
     119      if (evaluateIfCounter > 0) {
    80120        strBuilder.AppendLine("# condition helper function");
    81121        strBuilder.AppendLine("def evaluate_if(condition, then_path, else_path): ");
     
    85125        strBuilder.AppendLine("\t\treturn else_path");
    86126      }
    87     }
    88 
    89     private void GenerateModelComment(StringBuilder strBuilder) {
    90       strBuilder.AppendLine("# model");
    91       strBuilder.AppendLine("\"\"\"");
    92       foreach (var kvp in VariableMap) {
    93         strBuilder.AppendLine($"{kvp.Key} = {kvp.Value}");
    94       }
    95       strBuilder.AppendLine("\"\"\"");
    96     }
    97 
    98     private void GenerateModelEvaluationFunction(StringBuilder strBuilder) {
     127      return strBuilder.ToString();
     128    }
     129
     130    private string GenerateModelEvaluationFunction(ISet<string> variables) {
     131      StringBuilder strBuilder = new StringBuilder();
    99132      strBuilder.Append("def evaluate(");
    100       foreach (var kvp in VariableMap) {
    101         strBuilder.Append($"{kvp.Value}");
    102         if (kvp.Key != VariableMap.Last().Key)
    103           strBuilder.Append(",");
     133      var orderedVariables = variables.OrderBy(n => n, new NaturalStringComparer());
     134      foreach (var variable in orderedVariables) {
     135        strBuilder.Append($"{variable}");
     136        if (variable != orderedVariables.Last())
     137          strBuilder.Append(", ");
    104138      }
    105139      strBuilder.AppendLine("):");
    106140      strBuilder.Append("\treturn ");
     141      return strBuilder.ToString();
    107142    }
    108143
    109144    private void FormatRecursively(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
    110145      ISymbol symbol = node.Symbol;
    111       if (symbol is ProgramRootSymbol) {
    112         FormatRecursively(node.GetSubtree(0), strBuilder);
    113       } else if (symbol is StartSymbol) {
    114         FormatRecursively(node.GetSubtree(0), strBuilder);
    115       } else if (symbol is Addition) {
     146      if (symbol is ProgramRootSymbol)
     147        FormatRecursively(node.GetSubtree(0), strBuilder);
     148      else if (symbol is StartSymbol)
     149        FormatRecursively(node.GetSubtree(0), strBuilder);
     150      else if (symbol is Absolute)
     151        FormatNode(node, strBuilder, "abs");
     152      else if (symbol is Addition)
    116153        FormatNode(node, strBuilder, infixSymbol: " + ");
    117       } else if (symbol is And) {
     154      else if (symbol is Subtraction)
     155        FormatSubtraction(node, strBuilder);
     156      else if (symbol is Multiplication)
     157        FormatNode(node, strBuilder, infixSymbol: " * ");
     158      else if (symbol is Division)
     159        FormatDivision(node, strBuilder);
     160      else if (symbol is Average)
     161        FormatNode(node, strBuilder, prefixSymbol: "statistics.mean", openingSymbol: "([", closingSymbol: "])");
     162      else if (symbol is Sine)
     163        FormatNode(node, strBuilder, "math.sin");
     164      else if (symbol is Cosine)
     165        FormatNode(node, strBuilder, "math.cos");
     166      else if (symbol is Tangent)
     167        FormatNode(node, strBuilder, "math.tan");
     168      else if (symbol is HyperbolicTangent)
     169        FormatNode(node, strBuilder, "math.tanh");
     170      else if (symbol is Exponential)
     171        FormatNode(node, strBuilder, "math.exp");
     172      else if (symbol is Logarithm)
     173        FormatNode(node, strBuilder, "math.log");
     174      else if (symbol is Power)
     175        FormatNode(node, strBuilder, "math.pow");
     176      else if (symbol is Root)
     177        FormatRoot(node, strBuilder);
     178      else if (symbol is Square)
     179        FormatPower(node, strBuilder, "2");
     180      else if (symbol is SquareRoot)
     181        FormatNode(node, strBuilder, "math.sqrt");
     182      else if (symbol is Cube)
     183        FormatPower(node, strBuilder, "3");
     184      else if (symbol is CubeRoot)
     185        FormatNode(node, strBuilder, closingSymbol: " ** (1. / 3))");
     186      else if (symbol is AnalyticQuotient)
     187        FormatAnalyticQuotient(node, strBuilder);
     188      else if (symbol is And)
    118189        FormatNode(node, strBuilder, infixSymbol: " and ");
    119       } else if (symbol is Average) {
    120         StatisticLibCounter++;
    121         FormatNode(node, strBuilder, prefixSymbol: "statistics.mean", openingSymbol: "([", closingSymbol: "])");
    122       } else if (symbol is Cosine) {
    123         MathLibCounter++;
    124         FormatNode(node, strBuilder, "math.cos");
    125       } else if (symbol is Division) {
    126         FormatDivision(node, strBuilder);
    127       } else if (symbol is Exponential) {
    128         MathLibCounter++;
    129         FormatNode(node, strBuilder, "math.exp");
    130       } else if (symbol is GreaterThan) {
     190      else if (symbol is Or)
     191        FormatNode(node, strBuilder, infixSymbol: " or ");
     192      else if (symbol is Xor)
     193        FormatNode(node, strBuilder, infixSymbol: " ^ ");
     194      else if (symbol is Not)
     195        FormatNode(node, strBuilder, "not");
     196      else if (symbol is IfThenElse)
     197        FormatNode(node, strBuilder, "evaluate_if");
     198      else if (symbol is GreaterThan)
    131199        FormatNode(node, strBuilder, infixSymbol: " > ");
    132       } else if (symbol is IfThenElse) {
    133         EvaluateIfCounter++;
    134         FormatNode(node, strBuilder, "evaluate_if");
    135       } else if (symbol is LessThan) {
     200      else if (symbol is LessThan)
    136201        FormatNode(node, strBuilder, infixSymbol: " < ");
    137       } else if (symbol is Logarithm) {
    138         MathLibCounter++;
    139         FormatNode(node, strBuilder, "math.log");
    140       } else if (symbol is Multiplication) {
    141         FormatNode(node, strBuilder, infixSymbol: " * ");
    142       } else if (symbol is Not) {
    143         FormatNode(node, strBuilder, "not");
    144       } else if (symbol is Or) {
    145         FormatNode(node, strBuilder, infixSymbol: " or ");
    146       } else if (symbol is Xor) {
    147         FormatNode(node, strBuilder, infixSymbol: " ^ ");
    148       } else if (symbol is Sine) {
    149         MathLibCounter++;
    150         FormatNode(node, strBuilder, "math.sin");
    151       } else if (symbol is Subtraction) {
    152         FormatSubtraction(node, strBuilder);
    153       } else if (symbol is Tangent) {
    154         MathLibCounter++;
    155         FormatNode(node, strBuilder, "math.tan");
    156       } else if (symbol is HyperbolicTangent) {
    157         MathLibCounter++;
    158         FormatNode(node, strBuilder, "math.tanh");
    159       } else if (symbol is Square) {
    160         FormatPower(node, strBuilder, "2");
    161       } else if (symbol is SquareRoot) {
    162         MathLibCounter++;
    163         FormatNode(node, strBuilder, "math.sqrt");
    164       } else if (symbol is Cube) {
    165         FormatPower(node, strBuilder, "3");
    166       } else if (symbol is CubeRoot) {
    167         FormatNode(node, strBuilder, closingSymbol: " ** (1. / 3))");
    168       } else if (symbol is Power) {
    169         MathLibCounter++;
    170         FormatNode(node, strBuilder, "math.pow");
    171       } else if (symbol is Root) {
    172         FormatRoot(node, strBuilder);
    173       } else if (symbol is Absolute) {
    174         FormatNode(node, strBuilder, "abs");
    175       } else if (symbol is AnalyticQuotient) {
    176         MathLibCounter++;
    177         strBuilder.Append("(");
    178         FormatRecursively(node.GetSubtree(0), strBuilder);
    179         strBuilder.Append(" / math.sqrt(1 + math.pow(");
    180         FormatRecursively(node.GetSubtree(1), strBuilder);
    181         strBuilder.Append(" , 2) ) )");
    182       } else {
    183         if (node is VariableTreeNode) {
    184           FormatVariableTreeNode(node, strBuilder);
    185         } else if (node is ConstantTreeNode) {
    186           FormatConstantTreeNode(node, strBuilder);
    187         } else {
    188           throw new NotSupportedException("Formatting of symbol: " + symbol + " not supported for Python symbolic expression tree formatter.");
    189         } 
    190       }
    191     }
    192 
    193     private void FormatVariableTreeNode(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
    194       var varNode = node as VariableTreeNode;
    195       string variable;
    196       if (!VariableMap.TryGetValue(varNode.VariableName, out variable)) {
    197         variable = $"var{VariableCounter++}";
    198         VariableMap.Add(varNode.VariableName, variable);
    199       }
    200       strBuilder.AppendFormat("{0} * {1}", variable, varNode.Weight.ToString("g17", CultureInfo.InvariantCulture));
    201     }
    202 
    203     private void FormatConstantTreeNode(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
    204       var constNode = node as ConstantTreeNode;
    205       strBuilder.Append(constNode.Value.ToString("g17", CultureInfo.InvariantCulture));
    206     }
    207 
    208     private void FormatPower(ISymbolicExpressionTreeNode node, StringBuilder strBuilder, string exponent) {
    209       MathLibCounter++;
    210       strBuilder.Append("math.pow(");
    211       FormatRecursively(node.GetSubtree(0), strBuilder);
    212       strBuilder.Append($", {exponent})");
    213     }
    214 
    215     private void FormatRoot(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
    216       MathLibCounter++;
    217       strBuilder.Append("math.pow(");
    218       FormatRecursively(node.GetSubtree(0), strBuilder);
    219       strBuilder.Append(", 1.0 / (");
    220       FormatRecursively(node.GetSubtree(1), strBuilder);
    221       strBuilder.Append("))");
    222     }
     202      else if (node is VariableTreeNode)
     203        FormatVariableTreeNode(node, strBuilder);
     204      else if (node is ConstantTreeNode)
     205        FormatConstantTreeNode(node, strBuilder);
     206      else
     207        throw new NotSupportedException("Formatting of symbol: " + symbol + " not supported for Python symbolic expression tree formatter.");
     208    }
     209
     210    private string VariableName2Identifier(string variableName) => variableName.Replace(" ", "_");
    223211
    224212    private void FormatNode(ISymbolicExpressionTreeNode node, StringBuilder strBuilder, string prefixSymbol = "", string openingSymbol = "(", string closingSymbol = ")", string infixSymbol = ",") {
     
    232220    }
    233221
     222    private void FormatVariableTreeNode(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
     223      var varNode = node as VariableTreeNode;
     224      var formattedVariable = VariableName2Identifier(varNode.VariableName);
     225      var variableWeight = varNode.Weight.ToString("g17", CultureInfo.InvariantCulture);
     226      strBuilder.Append($"{formattedVariable} * {variableWeight}");
     227    }
     228
     229    private void FormatConstantTreeNode(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
     230      var constNode = node as ConstantTreeNode;
     231      strBuilder.Append(constNode.Value.ToString("g17", CultureInfo.InvariantCulture));
     232    }
     233
     234    private void FormatPower(ISymbolicExpressionTreeNode node, StringBuilder strBuilder, string exponent) {
     235      strBuilder.Append("math.pow(");
     236      FormatRecursively(node.GetSubtree(0), strBuilder);
     237      strBuilder.Append($", {exponent})");
     238    }
     239
     240    private void FormatRoot(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
     241      strBuilder.Append("math.pow(");
     242      FormatRecursively(node.GetSubtree(0), strBuilder);
     243      strBuilder.Append(", 1.0 / (");
     244      FormatRecursively(node.GetSubtree(1), strBuilder);
     245      strBuilder.Append("))");
     246    }
     247
     248    private void FormatSubtraction(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
     249      if (node.SubtreeCount == 1) {
     250        strBuilder.Append("-");
     251        FormatRecursively(node.GetSubtree(0), strBuilder);
     252        return;
     253      }
     254      //Default case: more than 1 child
     255      FormatNode(node, strBuilder, infixSymbol: " - ");
     256    }
     257
    234258    private void FormatDivision(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
    235259      strBuilder.Append("(");
     
    239263      } else {
    240264        FormatRecursively(node.GetSubtree(0), strBuilder);
    241         strBuilder.Append("/ (");
     265        strBuilder.Append(" / (");
    242266        for (int i = 1; i < node.SubtreeCount; i++) {
    243267          if (i > 1) strBuilder.Append(" * ");
     
    249273    }
    250274
    251     private void FormatSubtraction(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
    252       if (node.SubtreeCount == 1) {
    253         strBuilder.Append("-");
    254         FormatRecursively(node.GetSubtree(0), strBuilder);
    255         return;
    256       }
    257       //Default case: more than 1 child
    258       FormatNode(node, strBuilder, infixSymbol: " - ");
     275    private void FormatAnalyticQuotient(ISymbolicExpressionTreeNode node, StringBuilder strBuilder) {
     276      strBuilder.Append("(");
     277      FormatRecursively(node.GetSubtree(0), strBuilder);
     278      strBuilder.Append(" / math.sqrt(1 + math.pow(");
     279      FormatRecursively(node.GetSubtree(1), strBuilder);
     280      strBuilder.Append(" , 2) ) )");
    259281    }
    260282  }
    261 
    262283}
Note: See TracChangeset for help on using the changeset viewer.