1  #region License Information


2  /* HeuristicLab


3  * Copyright (C) 20022016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)


4  *


5  * This file is part of HeuristicLab.


6  *


7  * HeuristicLab is free software: you can redistribute it and/or modify


8  * it under the terms of the GNU General Public License as published by


9  * the Free Software Foundation, either version 3 of the License, or


10  * (at your option) any later version.


11  *


12  * HeuristicLab is distributed in the hope that it will be useful,


13  * but WITHOUT ANY WARRANTY; without even the implied warranty of


14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the


15  * GNU General Public License for more details.


16  *


17  * You should have received a copy of the GNU General Public License


18  * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.


19  */


20  #endregion


21 


22  using System;


23  using System.Collections.Generic;


24  using System.Linq;


25  using AutoDiff;


26  using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;


27 


28  namespace HeuristicLab.Problems.DataAnalysis.Symbolic {


29  public class TreeToAutoDiffTermConverter {


30  public delegate double ParametricFunction(double[] vars, double[] @params);


31  public delegate Tuple<double[], double> ParametricFunctionGradient(double[] vars, double[] @params);


32 


33  #region derivations of functions


34  // create function factory for arctangent


35  private static readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(


36  eval: Math.Atan,


37  diff: x => 1 / (1 + x * x));


38  private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(


39  eval: Math.Sin,


40  diff: Math.Cos);


41  private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(


42  eval: Math.Cos,


43  diff: x => Math.Sin(x));


44  private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(


45  eval: Math.Tan,


46  diff: x => 1 + Math.Tan(x) * Math.Tan(x));


47  private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(


48  eval: alglib.errorfunction,


49  diff: x => 2.0 * Math.Exp((x * x)) / Math.Sqrt(Math.PI));


50  private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(


51  eval: alglib.normaldistribution,


52  diff: x => (Math.Exp((x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));


53 


54  #endregion


55 


56  public static bool TryTransformToAutoDiff(ISymbolicExpressionTree tree, bool makeVariableWeightsVariable,


57  out string[] variableNames, out int[] lags, out double[] initialConstants,


58  out ParametricFunction func,


59  out ParametricFunctionGradient func_grad) {


60 


61  // use a transformator object which holds the state (variable list, parameter list, ...) for recursive transformation of the tree


62  var transformator = new TreeToAutoDiffTermConverter(makeVariableWeightsVariable);


63  AutoDiff.Term term;


64  var success = transformator.TryTransformToAutoDiff(tree.Root.GetSubtree(0), out term);


65  if (success) {


66  var compiledTerm = term.Compile(transformator.variables.ToArray(), transformator.parameters.ToArray());


67  variableNames = transformator.variableNames.ToArray();


68  lags = transformator.lags.ToArray();


69  initialConstants = transformator.initialConstants.ToArray();


70  func = (vars, @params) => compiledTerm.Evaluate(vars, @params);


71  func_grad = (vars, @params) => compiledTerm.Differentiate(vars, @params);


72  } else {


73  func = null;


74  func_grad = null;


75  variableNames = null;


76  lags = null;


77  initialConstants = null;


78  }


79  return success;


80  }


81 


82  // state for recursive transformation of trees


83  private readonly List<string> variableNames;


84  private readonly List<int> lags;


85  private readonly List<double> initialConstants;


86  private readonly List<AutoDiff.Variable> parameters;


87  private readonly List<AutoDiff.Variable> variables;


88  private readonly bool makeVariableWeightsVariable;


89 


90  private TreeToAutoDiffTermConverter(bool makeVariableWeightsVariable) {


91  this.makeVariableWeightsVariable = makeVariableWeightsVariable;


92  this.variableNames = new List<string>();


93  this.lags = new List<int>();


94  this.initialConstants = new List<double>();


95  this.parameters = new List<AutoDiff.Variable>();


96  this.variables = new List<AutoDiff.Variable>();


97  }


98 


99  private bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, out AutoDiff.Term term) {


100  if (node.Symbol is Constant) {


101  initialConstants.Add(((ConstantTreeNode)node).Value);


102  var var = new AutoDiff.Variable();


103  variables.Add(var);


104  term = var;


105  return true;


106  }


107  if (node.Symbol is Variable) {


108  var varNode = node as VariableTreeNode;


109  var par = new AutoDiff.Variable();


110  parameters.Add(par);


111  variableNames.Add(varNode.VariableName);


112  lags.Add(0);


113 


114  if (makeVariableWeightsVariable) {


115  initialConstants.Add(varNode.Weight);


116  var w = new AutoDiff.Variable();


117  variables.Add(w);


118  term = AutoDiff.TermBuilder.Product(w, par);


119  } else {


120  term = varNode.Weight * par;


121  }


122  return true;


123  }


124  if (node.Symbol is LaggedVariable) {


125  var varNode = node as LaggedVariableTreeNode;


126  var par = new AutoDiff.Variable();


127  parameters.Add(par);


128  variableNames.Add(varNode.VariableName);


129  lags.Add(varNode.Lag);


130 


131  if (makeVariableWeightsVariable) {


132  initialConstants.Add(varNode.Weight);


133  var w = new AutoDiff.Variable();


134  variables.Add(w);


135  term = AutoDiff.TermBuilder.Product(w, par);


136  } else {


137  term = varNode.Weight * par;


138  }


139  return true;


140  }


141  if (node.Symbol is Addition) {


142  List<AutoDiff.Term> terms = new List<Term>();


143  foreach (var subTree in node.Subtrees) {


144  AutoDiff.Term t;


145  if (!TryTransformToAutoDiff(subTree, out t)) {


146  term = null;


147  return false;


148  }


149  terms.Add(t);


150  }


151  term = AutoDiff.TermBuilder.Sum(terms);


152  return true;


153  }


154  if (node.Symbol is Subtraction) {


155  List<AutoDiff.Term> terms = new List<Term>();


156  for (int i = 0; i < node.SubtreeCount; i++) {


157  AutoDiff.Term t;


158  if (!TryTransformToAutoDiff(node.GetSubtree(i), out t)) {


159  term = null;


160  return false;


161  }


162  if (i > 0) t = t;


163  terms.Add(t);


164  }


165  if (terms.Count == 1) term = terms[0];


166  else term = AutoDiff.TermBuilder.Sum(terms);


167  return true;


168  }


169  if (node.Symbol is Multiplication) {


170  List<AutoDiff.Term> terms = new List<Term>();


171  foreach (var subTree in node.Subtrees) {


172  AutoDiff.Term t;


173  if (!TryTransformToAutoDiff(subTree, out t)) {


174  term = null;


175  return false;


176  }


177  terms.Add(t);


178  }


179  if (terms.Count == 1) term = terms[0];


180  else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, b));


181  return true;


182 


183  }


184  if (node.Symbol is Division) {


185  List<AutoDiff.Term> terms = new List<Term>();


186  foreach (var subTree in node.Subtrees) {


187  AutoDiff.Term t;


188  if (!TryTransformToAutoDiff(subTree, out t)) {


189  term = null;


190  return false;


191  }


192  terms.Add(t);


193  }


194  if (terms.Count == 1) term = 1.0 / terms[0];


195  else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b));


196  return true;


197  }


198  if (node.Symbol is Logarithm) {


199  AutoDiff.Term t;


200  if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {


201  term = null;


202  return false;


203  } else {


204  term = AutoDiff.TermBuilder.Log(t);


205  return true;


206  }


207  }


208  if (node.Symbol is Exponential) {


209  AutoDiff.Term t;


210  if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {


211  term = null;


212  return false;


213  } else {


214  term = AutoDiff.TermBuilder.Exp(t);


215  return true;


216  }


217  }


218  if (node.Symbol is Square) {


219  AutoDiff.Term t;


220  if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {


221  term = null;


222  return false;


223  } else {


224  term = AutoDiff.TermBuilder.Power(t, 2.0);


225  return true;


226  }


227  }


228  if (node.Symbol is SquareRoot) {


229  AutoDiff.Term t;


230  if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {


231  term = null;


232  return false;


233  } else {


234  term = AutoDiff.TermBuilder.Power(t, 0.5);


235  return true;


236  }


237  }


238  if (node.Symbol is Sine) {


239  AutoDiff.Term t;


240  if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {


241  term = null;


242  return false;


243  } else {


244  term = sin(t);


245  return true;


246  }


247  }


248  if (node.Symbol is Cosine) {


249  AutoDiff.Term t;


250  if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {


251  term = null;


252  return false;


253  } else {


254  term = cos(t);


255  return true;


256  }


257  }


258  if (node.Symbol is Tangent) {


259  AutoDiff.Term t;


260  if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {


261  term = null;


262  return false;


263  } else {


264  term = tan(t);


265  return true;


266  }


267  }


268  if (node.Symbol is Erf) {


269  AutoDiff.Term t;


270  if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {


271  term = null;


272  return false;


273  } else {


274  term = erf(t);


275  return true;


276  }


277  }


278  if (node.Symbol is Norm) {


279  AutoDiff.Term t;


280  if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {


281  term = null;


282  return false;


283  } else {


284  term = norm(t);


285  return true;


286  }


287  }


288  if (node.Symbol is StartSymbol) {


289  var alpha = new AutoDiff.Variable(); // TODO


290  var beta = new AutoDiff.Variable();


291  variables.Add(beta);


292  variables.Add(alpha);


293  AutoDiff.Term branchTerm;


294  if (TryTransformToAutoDiff(node.GetSubtree(0), out branchTerm)) {


295  term = branchTerm * alpha + beta;


296  return true;


297  } else {


298  term = null;


299  return false;


300  }


301  }


302  term = null;


303  return false;


304  }


305 


306 


307  public static bool IsCompatible(ISymbolicExpressionTree tree) {


308  var containsUnknownSymbol = (


309  from n in tree.Root.GetSubtree(0).IterateNodesPrefix()


310  where


311  !(n.Symbol is Variable) &&


312  !(n.Symbol is LaggedVariable) &&


313  !(n.Symbol is Constant) &&


314  !(n.Symbol is Addition) &&


315  !(n.Symbol is Subtraction) &&


316  !(n.Symbol is Multiplication) &&


317  !(n.Symbol is Division) &&


318  !(n.Symbol is Logarithm) &&


319  !(n.Symbol is Exponential) &&


320  !(n.Symbol is SquareRoot) &&


321  !(n.Symbol is Square) &&


322  !(n.Symbol is Sine) &&


323  !(n.Symbol is Cosine) &&


324  !(n.Symbol is Tangent) &&


325  !(n.Symbol is Erf) &&


326  !(n.Symbol is Norm) &&


327  !(n.Symbol is StartSymbol)


328  select n).Any();


329  return !containsUnknownSymbol;


330  }


331  }


332  }

