21 


22  #define EXPLICIT_SHAPE


23 


24  using System;


25  using System.Collections.Generic;


26  using System.Linq;


27  using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;


28  using NumSharp;


29  using Tensorflow;


30  using static Tensorflow.Binding;


31 


32  namespace HeuristicLab.Problems.DataAnalysis.Symbolic {


33  public class TreeToTensorConverter {


34 


35  private static readonly TF_DataType DataType = tf.float32;


36 


37  public static bool TryConvert(ISymbolicExpressionTree tree, int numRows, Dictionary<string, int> variableLengths,


38  bool makeVariableWeightsVariable, bool addLinearScalingTerms,


39  out Tensor graph, out Dictionary<Tensor, string> parameters, out List<Tensor> variables) {


40 


41  try {


42  var converter = new TreeToTensorConverter(numRows, variableLengths, makeVariableWeightsVariable, addLinearScalingTerms);


43  graph = converter.ConvertNode(tree.Root.GetSubtree(0));


44 


45  parameters = converter.parameters;


46  variables = converter.variables;


47  return true;


48  } catch (NotSupportedException) {


49  graph = null;


50  parameters = null;


51  variables = null;


52  return false;


53  }


54  }


55 


56  private readonly int numRows;


57  private readonly Dictionary<string, int> variableLengths;


58  private readonly bool makeVariableWeightsVariable;


59  private readonly bool addLinearScalingTerms;


60 


61  private readonly Dictionary<Tensor, string> parameters = new Dictionary<Tensor, string>();


62  private readonly List<Tensor> variables = new List<Tensor>();


63 


64  private TreeToTensorConverter(int numRows, Dictionary<string, int> variableLengths, bool makeVariableWeightsVariable, bool addLinearScalingTerms) {


65  this.numRows = numRows;


66  this.variableLengths = variableLengths;


67  this.makeVariableWeightsVariable = makeVariableWeightsVariable;


68  this.addLinearScalingTerms = addLinearScalingTerms;


69  }


70 


71 


72  private Tensor ConvertNode(ISymbolicExpressionTreeNode node) {


73  if (node.Symbol is Constant) {


74  var value = (float)((ConstantTreeNode)node).Value;


75  var value_arr = np.array(value).reshape(1, 1);


76  var var = tf.Variable(value_arr, name: $"c_{variables.Count}", dtype: DataType);


77  variables.Add(var);


78  return var;


79  }


80 


81  if (node.Symbol is Variable/*  node.Symbol is BinaryFactorVariable*/) {


82  var varNode = node as VariableTreeNodeBase;


83  //var factorVarNode = node as BinaryFactorVariableTreeNode;


84  // factor variable values are only 0 or 1 and set in x accordingly


85  //var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;


86  //var par = FindOrCreateParameter(parameters, varNode.VariableName, varValue);


87  var par = tf.placeholder(DataType, new TensorShape(numRows, variableLengths[varNode.VariableName]), name: varNode.VariableName);


88  parameters.Add(par, varNode.VariableName);


89 


90  if (makeVariableWeightsVariable) {


91  var w_arr = np.array((float)varNode.Weight).reshape(1, 1);


92  var w = tf.Variable(w_arr, name: $"w_{varNode.VariableName}", dtype: DataType);


93  variables.Add(w);


94  return w * par;


95  } else {


96  return varNode.Weight * par;


97  }


98  }


99 


100  //if (node.Symbol is FactorVariable) {


101  // var factorVarNode = node as FactorVariableTreeNode;


102  // var products = new List<Tensor>();


103  // foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {


104  // //var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);


105  // var par = tf.placeholder(DataType, new TensorShape(numRows, 1), name: factorVarNode.VariableName);


106  // parameters.Add(par, factorVarNode.VariableName);


107 


108  // var value = factorVarNode.GetValue(variableValue);


109  // //initialConstants.Add(value);


110  // var wVar = (RefVariable)tf.VariableV1(value, name: $"f_{factorVarNode.VariableName}_{variables.Count}", dtype: DataType, shape: new[] { 1, 1 });


111  // //var wVar = tf.Variable(value, name: $"f_{factorVarNode.VariableName}_{variables.Count}"/*, shape: new[] { 1, 1 }*/);


112  // variables.Add(wVar);


113 


114  // products.add(wVar * par);


115  // }


116 


117  // return products.Aggregate((a, b) => a + b);


118  //}


119 


120  if (node.Symbol is Addition) {


121  var terms = node.Subtrees.Select(ConvertNode).ToList();


122  if (terms.Count == 1) return terms[0];


123  return terms.Aggregate((a, b) => a + b);


124  }


125 


126  if (node.Symbol is Subtraction) {


127  var terms = node.Subtrees.Select(ConvertNode).ToList();


128  if (terms.Count == 1) return terms[0];


129  return terms.Aggregate((a, b) => a  b);


130  }


131 


132  if (node.Symbol is Multiplication) {


133  var terms = node.Subtrees.Select(ConvertNode).ToList();


134  if (terms.Count == 1) return terms[0];


135  return terms.Aggregate((a, b) => a * b);


136  }


137 


138  if (node.Symbol is Division) {


139  var terms = node.Subtrees.Select(ConvertNode).ToList();


140  if (terms.Count == 1) return 1.0f / terms[0];


141  return terms.Aggregate((a, b) => a / b);


142  }


143 


144  if (node.Symbol is Absolute) {


145  var x1 = ConvertNode(node.GetSubtree(0));


146  return tf.abs(x1);


147  }


148 


149  if (node.Symbol is AnalyticQuotient) {


150  var x1 = ConvertNode(node.GetSubtree(0));


151  var x2 = ConvertNode(node.GetSubtree(1));


152  return x1 / tf.pow(1.0f + x2 * x2, 0.5f);


153  }


154 


155  if (node.Symbol is Logarithm) {


156  return tf.log(


157  ConvertNode(node.GetSubtree(0)));


158  }


159 


160  if (node.Symbol is Exponential) {


161  return tf.pow(


162  (float)Math.E,


163  ConvertNode(node.GetSubtree(0)));


164  }


165 


166  if (node.Symbol is Square) {


167  return tf.square(


168  ConvertNode(node.GetSubtree(0)));


169  }


170 


171  if (node.Symbol is SquareRoot) {


172  return tf.sqrt(


173  ConvertNode(node.GetSubtree(0)));


174  }


175 


176  if (node.Symbol is Cube) {


177  return tf.pow(


178  ConvertNode(node.GetSubtree(0)), 3.0f);


179  }


180 


181  if (node.Symbol is CubeRoot) {


182  return tf.pow(


183  ConvertNode(node.GetSubtree(0)), 1.0f / 3.0f);


184  // TODO


185  // f: x < 0 ? Math.Pow(x, 1.0 / 3) : Math.Pow(x, 1.0 / 3),


186  // g: { var cbrt_x = x < 0 ? Math.Pow(x, 1.0 / 3) : Math.Pow(x, 1.0 / 3); return 1.0 / (3 * cbrt_x * cbrt_x); }


187  }


188 


189  if (node.Symbol is Sine) {


190  return tf.sin(


191  ConvertNode(node.GetSubtree(0)));


192  }


193 


194  if (node.Symbol is Cosine) {


195  return tf.cos(


196  ConvertNode(node.GetSubtree(0)));


197  }


198 


199  if (node.Symbol is Tangent) {


200  return tf.tan(


201  ConvertNode(node.GetSubtree(0)));


202  }


203 


204  if (node.Symbol is Mean) {


205  return tf.reduce_mean(


206  ConvertNode(node.GetSubtree(0)),


207  axis: new[] { 1 },


208  keepdims: true);


209  }


210 


211  if (node.Symbol is StandardDeviation) {


212  return reduce_std(


213  ConvertNode(node.GetSubtree(0)),


214  axis: new[] { 1 },


215  keepdims: true


216  );


217  }


218 


219  if (node.Symbol is Variance) {


220  return reduce_var(


221  ConvertNode(node.GetSubtree(0)),


222  axis: new[] { 1 } ,


223  keepdims: true


224  );


225  }


226 


227  if (node.Symbol is Sum) {


228  return tf.reduce_sum(


229  ConvertNode(node.GetSubtree(0)),


230  axis: new[] { 1 },


231  keepdims: true);


232  }


233 


234  if (node.Symbol is SubVector) {


235  var tensor = ConvertNode(node.GetSubtree(0));


236  int rows = tensor.shape[0], vectorLength = tensor.shape[1];


237  var windowedNode = (IWindowedSymbolTreeNode)node;


238  int startIdx = SymbolicDataAnalysisExpressionTreeVectorInterpreter.ToVectorIdx(windowedNode.Offset, vectorLength);


239  int endIdx = SymbolicDataAnalysisExpressionTreeVectorInterpreter.ToVectorIdx(windowedNode.Length, vectorLength);


240  var slices = SymbolicDataAnalysisExpressionTreeVectorInterpreter.GetVectorSlices(startIdx, endIdx, vectorLength);


241 


242  var segments = new List<Tensor>();


243  foreach (var (start, count) in slices) {


244  segments.Add(tensor[new Slice(), new Slice(start, start + count)]);


245  }


246  return tf.concat(segments.ToArray(), axis: 1);


247  }


248 


249 


250  if (node.Symbol is StartSymbol) {


251  Tensor prediction;


252  if (addLinearScalingTerms) {


253  // scaling variables α, β are given at the beginning of the parameter vector


254  var alpha_arr = np.array(1.0f).reshape(1, 1);


255  var alpha = tf.Variable(alpha_arr, name: "alpha", dtype: DataType);


256  var beta_arr = np.array(0.0f).reshape(1, 1);


257  var beta = tf.Variable(beta_arr, name: "beta", dtype: DataType);


258  variables.Add(beta);


259  variables.Add(alpha);


260  var t = ConvertNode(node.GetSubtree(0));


261  prediction = t * alpha + beta;


262  } else {


263  prediction = ConvertNode(node.GetSubtree(0));


264  }


265 


266  return tf.reshape(prediction, shape: new[] { 1 });


267  }


268 


269  throw new NotSupportedException($"Node symbol {node.Symbol} is not supported.");


270  }


271 


272  private static Tensor reduce_var(Tensor input_tensor, int[] axis = null, bool keepdims = false) {


273  var means = tf.reduce_mean(input_tensor, axis, true);


274  var squared_deviation = tf.square(input_tensor  means);


275  return tf.reduce_mean(squared_deviation, axis, keepdims);


276  }


277  private static Tensor reduce_std(Tensor input_tensor, int[] axis = null, bool keepdims = false) {


278  return tf.sqrt(reduce_var(input_tensor, axis, keepdims));


279  }


280 


281  public static bool IsCompatible(ISymbolicExpressionTree tree) {


282  var containsUnknownSymbol = (


283  from n in tree.Root.GetSubtree(0).IterateNodesPrefix()


284  where


285  !(n.Symbol is Variable) &&


286  //!(n.Symbol is BinaryFactorVariable) &&


287  //!(n.Symbol is FactorVariable) &&


288  !(n.Symbol is Constant) &&


289  !(n.Symbol is Addition) &&


290  !(n.Symbol is Subtraction) &&


291  !(n.Symbol is Multiplication) &&


292  !(n.Symbol is Division) &&


293  !(n.Symbol is Logarithm) &&


294  !(n.Symbol is Exponential) &&


295  !(n.Symbol is SquareRoot) &&


296  !(n.Symbol is Square) &&


297  !(n.Symbol is Sine) &&


298  !(n.Symbol is Cosine) &&


299  !(n.Symbol is Tangent) &&


300  !(n.Symbol is HyperbolicTangent) &&


301  !(n.Symbol is Erf) &&


302  !(n.Symbol is Norm) &&


303  !(n.Symbol is StartSymbol) &&


304  !(n.Symbol is Absolute) &&


305  !(n.Symbol is AnalyticQuotient) &&


306  !(n.Symbol is Cube) &&


307  !(n.Symbol is CubeRoot) &&


308  !(n.Symbol is Mean) &&


309  !(n.Symbol is StandardDeviation) &&


310  !(n.Symbol is Variance) &&


311  !(n.Symbol is Sum) &&


312  !(n.Symbol is SubVector)


313  select n).Any();


314  return !containsUnknownSymbol;


315  }


316  }


317  }

