21 


22  using System;


23  using System.Collections.Generic;


24  using System.Linq;


25  using AutoDiff;


26  using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;


27 


28  namespace HeuristicLab.Problems.DataAnalysis.Symbolic {


29  public class TreeToAutoDiffTermTransformator {


30  public delegate double ParametricFunction(double[] vars, double[] @params);


31  public delegate Tuple<double[], double> ParametricFunctionGradient(double[] vars, double[] @params);


32 


33  #region helper class


34  public class DataForVariable {


35  public readonly string variableName;


36  public readonly string variableValue; // for factor vars


37  public readonly int lag;


38 


39  public DataForVariable(string varName, string varValue, int lag) {


40  this.variableName = varName;


41  this.variableValue = varValue;


42  this.lag = lag;


43  }


44 


45  public override bool Equals(object obj) {


46  var other = obj as DataForVariable;


47  if (other == null) return false;


48  return other.variableName.Equals(this.variableName) &&


49  other.variableValue.Equals(this.variableValue) &&


50  other.lag == this.lag;


51  }


52 


53  public override int GetHashCode() {


54  return variableName.GetHashCode() ^ variableValue.GetHashCode() ^ lag;


55  }


56  }


57  #endregion


58 


59  #region derivations of functions


60  // create function factory for arctangent


61  private static readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(


62  eval: Math.Atan,


63  diff: x => 1 / (1 + x * x));


64  private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(


65  eval: Math.Sin,


66  diff: Math.Cos);


67  private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(


68  eval: Math.Cos,


69  diff: x => Math.Sin(x));


70  private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(


71  eval: Math.Tan,


72  diff: x => 1 + Math.Tan(x) * Math.Tan(x));


73  private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(


74  eval: alglib.errorfunction,


75  diff: x => 2.0 * Math.Exp((x * x)) / Math.Sqrt(Math.PI));


76  private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(


77  eval: alglib.normaldistribution,


78  diff: x => (Math.Exp((x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));


79 


80  #endregion


81 


82  public static bool TryTransformToAutoDiff(ISymbolicExpressionTree tree, bool makeVariableWeightsVariable,


83  out List<DataForVariable> parameters, out double[] initialConstants,


84  out ParametricFunction func,


85  out ParametricFunctionGradient func_grad) {


86 


87  // use a transformator object which holds the state (variable list, parameter list, ...) for recursive transformation of the tree


88  var transformator = new TreeToAutoDiffTermTransformator(makeVariableWeightsVariable);


89  AutoDiff.Term term;


90  var success = transformator.TryTransformToAutoDiff(tree.Root.GetSubtree(0), out term);


91  if (success) {


92  var parameterEntries = transformator.parameters.ToArray(); // guarantee same order for keys and values


93  var compiledTerm = term.Compile(transformator.variables.ToArray(), parameterEntries.Select(kvp => kvp.Value).ToArray());


94  parameters = new List<DataForVariable>(parameterEntries.Select(kvp => kvp.Key));


95  initialConstants = transformator.initialConstants.ToArray();


96  func = (vars, @params) => compiledTerm.Evaluate(vars, @params);


97  func_grad = (vars, @params) => compiledTerm.Differentiate(vars, @params);


98  } else {


99  func = null;


100  func_grad = null;


101  parameters = null;


102  initialConstants = null;


103  }


104  return success;


105  }


106 


107  // state for recursive transformation of trees


108  private readonly List<string> variableNames;


109  private readonly List<int> lags;


110  private readonly List<double> initialConstants;


111  private readonly Dictionary<DataForVariable, AutoDiff.Variable> parameters;


112  private readonly List<AutoDiff.Variable> variables;


113  private readonly bool makeVariableWeightsVariable;


114 


115  private TreeToAutoDiffTermTransformator(bool makeVariableWeightsVariable) {


116  this.makeVariableWeightsVariable = makeVariableWeightsVariable;


117  this.variableNames = new List<string>();


118  this.lags = new List<int>();


119  this.initialConstants = new List<double>();


120  this.parameters = new Dictionary<DataForVariable, AutoDiff.Variable>();


121  this.variables = new List<AutoDiff.Variable>();


122  }


123 


124  private bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, out AutoDiff.Term term) {


125  if (node.Symbol is Constant) {


126  initialConstants.Add(((ConstantTreeNode)node).Value);


127  var var = new AutoDiff.Variable();


128  variables.Add(var);


129  term = var;


130  return true;


131  }


132  if (node.Symbol is Variable  node.Symbol is BinaryFactorVariable) {


133  var varNode = node as VariableTreeNode;


134  var factorVarNode = node as BinaryFactorVariableTreeNode;


135  // factor variable values are only 0 or 1 and set in x accordingly


136  var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;


137  var par = FindOrCreateParameter(parameters, varNode.VariableName, varValue);


138 


139  if (makeVariableWeightsVariable) {


140  initialConstants.Add(varNode.Weight);


141  var w = new AutoDiff.Variable();


142  variables.Add(w);


143  term = AutoDiff.TermBuilder.Product(w, par);


144  } else {


145  term = varNode.Weight * par;


146  }


147  return true;


148  }


149  if (node.Symbol is FactorVariable) {


150  var factorVarNode = node as FactorVariableTreeNode;


151  var products = new List<Term>();


152  foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {


153  var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);


154 


155  initialConstants.Add(factorVarNode.GetValue(variableValue));


156  var wVar = new AutoDiff.Variable();


157  variables.Add(wVar);


158 


159  products.Add(AutoDiff.TermBuilder.Product(wVar, par));


160  }


161  term = AutoDiff.TermBuilder.Sum(products);


162  return true;


163  }


164  if (node.Symbol is LaggedVariable) {


165  var varNode = node as LaggedVariableTreeNode;


166  var par = FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag);


167 


168  if (makeVariableWeightsVariable) {


169  initialConstants.Add(varNode.Weight);


170  var w = new AutoDiff.Variable();


171  variables.Add(w);


172  term = AutoDiff.TermBuilder.Product(w, par);


173  } else {


174  term = varNode.Weight * par;


175  }


176  return true;


177  }


178  if (node.Symbol is Addition) {


179  List<AutoDiff.Term> terms = new List<Term>();


180  foreach (var subTree in node.Subtrees) {


181  AutoDiff.Term t;


182  if (!TryTransformToAutoDiff(subTree, out t)) {


183  term = null;


184  return false;


185  }


186  terms.Add(t);


187  }


188  term = AutoDiff.TermBuilder.Sum(terms);


189  return true;


190  }


191  if (node.Symbol is Subtraction) {


192  List<AutoDiff.Term> terms = new List<Term>();


193  for (int i = 0; i < node.SubtreeCount; i++) {


194  AutoDiff.Term t;


195  if (!TryTransformToAutoDiff(node.GetSubtree(i), out t)) {


196  term = null;


197  return false;


198  }


199  if (i > 0) t = t;


200  terms.Add(t);


201  }


202  if (terms.Count == 1) term = terms[0];


203  else term = AutoDiff.TermBuilder.Sum(terms);


204  return true;


205  }


206  if (node.Symbol is Multiplication) {


207  List<AutoDiff.Term> terms = new List<Term>();


208  foreach (var subTree in node.Subtrees) {


209  AutoDiff.Term t;


210  if (!TryTransformToAutoDiff(subTree, out t)) {


211  term = null;


212  return false;


213  }


214  terms.Add(t);


215  }


216  if (terms.Count == 1) term = terms[0];


217  else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, b));


218  return true;


219 


220  }


221  if (node.Symbol is Division) {


222  List<AutoDiff.Term> terms = new List<Term>();


223  foreach (var subTree in node.Subtrees) {


224  AutoDiff.Term t;


225  if (!TryTransformToAutoDiff(subTree, out t)) {


226  term = null;


227  return false;


228  }


229  terms.Add(t);


230  }


231  if (terms.Count == 1) term = 1.0 / terms[0];


232  else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b));


233  return true;


234  }


235  if (node.Symbol is Logarithm) {


236  AutoDiff.Term t;


237  if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {


238  term = null;


239  return false;


240  } else {


241  term = AutoDiff.TermBuilder.Log(t);


242  return true;


243  }


244  }


245  if (node.Symbol is Exponential) {


246  AutoDiff.Term t;


247  if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {


248  term = null;


249  return false;


250  } else {


251  term = AutoDiff.TermBuilder.Exp(t);


252  return true;


253  }


254  }


255  if (node.Symbol is Square) {


256  AutoDiff.Term t;


257  if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {


258  term = null;


259  return false;


260  } else {


261  term = AutoDiff.TermBuilder.Power(t, 2.0);


262  return true;


263  }


264  }


265  if (node.Symbol is SquareRoot) {


266  AutoDiff.Term t;


267  if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {


268  term = null;


269  return false;


270  } else {


271  term = AutoDiff.TermBuilder.Power(t, 0.5);


272  return true;


273  }


274  }


275  if (node.Symbol is Sine) {


276  AutoDiff.Term t;


277  if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {


278  term = null;


279  return false;


280  } else {


281  term = sin(t);


282  return true;


283  }


284  }


285  if (node.Symbol is Cosine) {


286  AutoDiff.Term t;


287  if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {


288  term = null;


289  return false;


290  } else {


291  term = cos(t);


292  return true;


293  }


294  }


295  if (node.Symbol is Tangent) {


296  AutoDiff.Term t;


297  if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {


298  term = null;


299  return false;


300  } else {


301  term = tan(t);


302  return true;


303  }


304  }


305  if (node.Symbol is Erf) {


306  AutoDiff.Term t;


307  if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {


308  term = null;


309  return false;


310  } else {


311  term = erf(t);


312  return true;


313  }


314  }


315  if (node.Symbol is Norm) {


316  AutoDiff.Term t;


317  if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {


318  term = null;


319  return false;


320  } else {


321  term = norm(t);


322  return true;


323  }


324  }


325  if (node.Symbol is StartSymbol) {


326  var alpha = new AutoDiff.Variable();


327  var beta = new AutoDiff.Variable();


328  variables.Add(beta);


329  variables.Add(alpha);


330  AutoDiff.Term branchTerm;


331  if (TryTransformToAutoDiff(node.GetSubtree(0), out branchTerm)) {


332  term = branchTerm * alpha + beta;


333  return true;


334  } else {


335  term = null;


336  return false;


337  }


338  }


339  term = null;


340  return false;


341  }


342 


343 


344  // for each factor variable value we need a parameter which represents a binary indicator for that variable & value combination


345  // each binary indicator is only necessary once. So we only create a parameter if this combination is not yet available


346  private static Term FindOrCreateParameter(Dictionary<DataForVariable, AutoDiff.Variable> parameters,


347  string varName, string varValue = "", int lag = 0) {


348  var data = new DataForVariable(varName, varValue, lag);


349 


350  AutoDiff.Variable par = null;


351  if (!parameters.TryGetValue(data, out par)) {


352  // not found > create new parameter and entries in names and values lists


353  par = new AutoDiff.Variable();


354  parameters.Add(data, par);


355  }


356  return par;


357  }


358 


359  public static bool IsCompatible(ISymbolicExpressionTree tree) {


360  var containsUnknownSymbol = (


361  from n in tree.Root.GetSubtree(0).IterateNodesPrefix()


362  where


363  !(n.Symbol is Variable) &&


364  !(n.Symbol is LaggedVariable) &&


365  !(n.Symbol is Constant) &&


366  !(n.Symbol is Addition) &&


367  !(n.Symbol is Subtraction) &&


368  !(n.Symbol is Multiplication) &&


369  !(n.Symbol is Division) &&


370  !(n.Symbol is Logarithm) &&


371  !(n.Symbol is Exponential) &&


372  !(n.Symbol is SquareRoot) &&


373  !(n.Symbol is Square) &&


374  !(n.Symbol is Sine) &&


375  !(n.Symbol is Cosine) &&


376  !(n.Symbol is Tangent) &&


377  !(n.Symbol is Erf) &&


378  !(n.Symbol is Norm) &&


379  !(n.Symbol is StartSymbol)


380  select n).Any();


381  return !containsUnknownSymbol;


382  }


383  }


384  }

