Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
03/11/20 15:02:34 (4 years ago)
Author:
pfleck
Message:

#3040 Started working on the TF constant opt evaluator.

File:
1 copied

Legend:

Unmodified
Added
Removed
  • branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/TreeToTensorConverter.cs

    r17455 r17474  
    2626using AutoDiff;
    2727using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
     28using Tensorflow;
     29using static Tensorflow.Binding;
    2830
    2931namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
    30   public class TreeToAutoDiffTermConverter {
    31     public delegate double ParametricFunction(double[] vars, double[] @params);
    32 
    33     public delegate Tuple<double[], double> ParametricFunctionGradient(double[] vars, double[] @params);
     32  public class TreeToTensorConverter {
    3433
    3534    #region helper class
     
    3736      public readonly string variableName;
    3837      public readonly string variableValue; // for factor vars
    39       public readonly int lag;
    40 
    41       public DataForVariable(string varName, string varValue, int lag) {
     38
     39      public DataForVariable(string varName, string varValue) {
    4240        this.variableName = varName;
    4341        this.variableValue = varValue;
    44         this.lag = lag;
    4542      }
    4643
     
    4946        if (other == null) return false;
    5047        return other.variableName.Equals(this.variableName) &&
    51                other.variableValue.Equals(this.variableValue) &&
    52                other.lag == this.lag;
     48               other.variableValue.Equals(this.variableValue);
    5349      }
    5450
    5551      public override int GetHashCode() {
    56         return variableName.GetHashCode() ^ variableValue.GetHashCode() ^ lag;
     52        return variableName.GetHashCode() ^ variableValue.GetHashCode();
    5753      }
    5854    }
    5955    #endregion
    6056
    61     #region derivations of functions
    62     // create function factory for arctangent
    63     private static readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
    64       eval: Math.Atan,
    65       diff: x => 1 / (1 + x * x));
    66 
    67     private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
    68       eval: Math.Sin,
    69       diff: Math.Cos);
    70 
    71     private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
    72       eval: Math.Cos,
    73       diff: x => -Math.Sin(x));
    74 
    75     private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
    76       eval: Math.Tan,
    77       diff: x => 1 + Math.Tan(x) * Math.Tan(x));
    78     private static readonly Func<Term, UnaryFunc> tanh = UnaryFunc.Factory(
    79       eval: Math.Tanh,
    80       diff: x => 1 - Math.Tanh(x) * Math.Tanh(x));
    81     private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
    82       eval: alglib.errorfunction,
    83       diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
    84 
    85     private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
    86       eval: alglib.normaldistribution,
    87       diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
    88 
    89     private static readonly Func<Term, UnaryFunc> abs = UnaryFunc.Factory(
    90       eval: Math.Abs,
    91       diff: x => Math.Sign(x)
    92       );
    93 
    94     private static readonly Func<Term, UnaryFunc> cbrt = UnaryFunc.Factory(
    95       eval: x => x < 0 ? -Math.Pow(-x, 1.0 / 3) : Math.Pow(x, 1.0 / 3),
    96       diff: x => { var cbrt_x = x < 0 ? -Math.Pow(-x, 1.0 / 3) : Math.Pow(x, 1.0 / 3); return 1.0 / (3 * cbrt_x * cbrt_x); }
    97       );
    98 
    99 
    100 
    101     #endregion
    102 
    103     public static bool TryConvertToAutoDiff(ISymbolicExpressionTree tree, bool makeVariableWeightsVariable, bool addLinearScalingTerms,
    104       out List<DataForVariable> parameters, out double[] initialConstants,
    105       out ParametricFunction func,
    106       out ParametricFunctionGradient func_grad) {
    107 
    108       // use a transformator object which holds the state (variable list, parameter list, ...) for recursive transformation of the tree
    109       var transformator = new TreeToAutoDiffTermConverter(makeVariableWeightsVariable, addLinearScalingTerms);
    110       AutoDiff.Term term;
     57    public static bool TryConvert(ISymbolicExpressionTree tree, bool makeVariableWeightsVariable, bool addLinearScalingTerms,
     58      out Tensor graph, out Dictionary<DataForVariable, Tensor> variables/*, out double[] initialConstants*/) {
     59
    11160      try {
    112         term = transformator.ConvertToAutoDiff(tree.Root.GetSubtree(0));
    113         var parameterEntries = transformator.parameters.ToArray(); // guarantee same order for keys and values
    114         var compiledTerm = term.Compile(transformator.variables.ToArray(),
    115           parameterEntries.Select(kvp => kvp.Value).ToArray());
    116         parameters = new List<DataForVariable>(parameterEntries.Select(kvp => kvp.Key));
    117         initialConstants = transformator.initialConstants.ToArray();
    118         func = (vars, @params) => compiledTerm.Evaluate(vars, @params);
    119         func_grad = (vars, @params) => compiledTerm.Differentiate(vars, @params);
     61        var converter = new TreeToTensorConverter(makeVariableWeightsVariable, addLinearScalingTerms);
     62        graph = converter.ConvertNode(tree.Root.GetSubtree(0));
     63
     64        //var parametersEntries = converter.parameters.ToList(); // guarantee same order for keys and values
     65        variables = converter.parameters; // parametersEntries.Select(kvp => kvp.Value).ToList();
     66        //initialConstants = converter.initialConstants.ToArray();
    12067        return true;
    121       } catch (ConversionException) {
    122         func = null;
    123         func_grad = null;
    124         parameters = null;
    125         initialConstants = null;
    126       }
    127       return false;
    128     }
    129 
    130     // state for recursive transformation of trees
    131     private readonly
    132     List<double> initialConstants;
    133     private readonly Dictionary<DataForVariable, AutoDiff.Variable> parameters;
    134     private readonly List<AutoDiff.Variable> variables;
     68      } catch (NotSupportedException) {
     69        graph = null;
     70        variables = null;
     71        //initialConstants = null;
     72        return false;
     73      }
     74    }
     75
    13576    private readonly bool makeVariableWeightsVariable;
    13677    private readonly bool addLinearScalingTerms;
    13778
    138     private TreeToAutoDiffTermConverter(bool makeVariableWeightsVariable, bool addLinearScalingTerms) {
     79    //private readonly List<double> initialConstants = new List<double>();
     80    private readonly Dictionary<DataForVariable, Tensor> parameters = new Dictionary<DataForVariable, Tensor>();
     81    private readonly List<Tensor> variables = new List<Tensor>();
     82
     83    private TreeToTensorConverter(bool makeVariableWeightsVariable, bool addLinearScalingTerms) {
    13984      this.makeVariableWeightsVariable = makeVariableWeightsVariable;
    14085      this.addLinearScalingTerms = addLinearScalingTerms;
    141       this.initialConstants = new List<double>();
    142       this.parameters = new Dictionary<DataForVariable, AutoDiff.Variable>();
    143       this.variables = new List<AutoDiff.Variable>();
    144     }
    145 
    146     private AutoDiff.Term ConvertToAutoDiff(ISymbolicExpressionTreeNode node) {
     86    }
     87
     88
     89    private Tensor ConvertNode(ISymbolicExpressionTreeNode node) {
    14790      if (node.Symbol is Constant) {
    148         initialConstants.Add(((ConstantTreeNode)node).Value);
    149         var var = new AutoDiff.Variable();
     91        var value = ((ConstantTreeNode)node).Value;
     92        //initialConstants.Add(value);
     93        var var = tf.Variable(value);
    15094        variables.Add(var);
    15195        return var;
    15296      }
     97
    15398      if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) {
    15499        var varNode = node as VariableTreeNodeBase;
     
    159104
    160105        if (makeVariableWeightsVariable) {
    161           initialConstants.Add(varNode.Weight);
    162           var w = new AutoDiff.Variable();
     106          //initialConstants.Add(varNode.Weight);
     107          var w = tf.Variable(varNode.Weight);
    163108          variables.Add(w);
    164           return AutoDiff.TermBuilder.Product(w, par);
     109          return w * par;
    165110        } else {
    166111          return varNode.Weight * par;
    167112        }
    168113      }
     114
    169115      if (node.Symbol is FactorVariable) {
    170116        var factorVarNode = node as FactorVariableTreeNode;
    171         var products = new List<Term>();
     117        var products = new List<Tensor>();
    172118        foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
    173119          var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);
    174120
    175           initialConstants.Add(factorVarNode.GetValue(variableValue));
    176           var wVar = new AutoDiff.Variable();
     121          var value = factorVarNode.GetValue(variableValue);
     122          //initialConstants.Add(value);
     123          var wVar = tf.Variable(value);
    177124          variables.Add(wVar);
    178125
    179           products.Add(AutoDiff.TermBuilder.Product(wVar, par));
    180         }
    181         return AutoDiff.TermBuilder.Sum(products);
    182       }
    183       if (node.Symbol is LaggedVariable) {
    184         var varNode = node as LaggedVariableTreeNode;
    185         var par = FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag);
    186 
    187         if (makeVariableWeightsVariable) {
    188           initialConstants.Add(varNode.Weight);
    189           var w = new AutoDiff.Variable();
    190           variables.Add(w);
    191           return AutoDiff.TermBuilder.Product(w, par);
    192         } else {
    193           return varNode.Weight * par;
    194         }
    195       }
     126          products.add(wVar * par);
     127        }
     128
     129        return tf.add_n(products.ToArray());
     130      }
     131
    196132      if (node.Symbol is Addition) {
    197         List<AutoDiff.Term> terms = new List<Term>();
     133        var terms = new List<Tensor>();
    198134        foreach (var subTree in node.Subtrees) {
    199           terms.Add(ConvertToAutoDiff(subTree));
    200         }
    201         return AutoDiff.TermBuilder.Sum(terms);
    202       }
     135          terms.Add(ConvertNode(subTree));
     136        }
     137
     138        return tf.add_n(terms.ToArray());
     139      }
     140
    203141      if (node.Symbol is Subtraction) {
    204         List<AutoDiff.Term> terms = new List<Term>();
     142        var terms = new List<Tensor>();
    205143        for (int i = 0; i < node.SubtreeCount; i++) {
    206           AutoDiff.Term t = ConvertToAutoDiff(node.GetSubtree(i));
     144          var t = ConvertNode(node.GetSubtree(i));
    207145          if (i > 0) t = -t;
    208146          terms.Add(t);
    209147        }
     148
    210149        if (terms.Count == 1) return -terms[0];
    211         else return AutoDiff.TermBuilder.Sum(terms);
    212       }
     150        else return tf.add_n(terms.ToArray());
     151      }
     152
    213153      if (node.Symbol is Multiplication) {
    214         List<AutoDiff.Term> terms = new List<Term>();
     154        var terms = new List<Tensor>();
    215155        foreach (var subTree in node.Subtrees) {
    216           terms.Add(ConvertToAutoDiff(subTree));
    217         }
     156          terms.Add(ConvertNode(subTree));
     157        }
     158
    218159        if (terms.Count == 1) return terms[0];
    219         else return terms.Aggregate((a, b) => new AutoDiff.Product(a, b));
    220       }
     160        else return terms.Aggregate((a, b) => a * b);
     161      }
     162
    221163      if (node.Symbol is Division) {
    222         List<AutoDiff.Term> terms = new List<Term>();
     164        var terms = new List<Tensor>();
    223165        foreach (var subTree in node.Subtrees) {
    224           terms.Add(ConvertToAutoDiff(subTree));
    225         }
     166          terms.Add(ConvertNode(subTree));
     167        }
     168
    226169        if (terms.Count == 1) return 1.0 / terms[0];
    227         else return terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b));
    228       }
     170        else return terms.Aggregate((a, b) => a * (1.0 / b));
     171      }
     172
    229173      if (node.Symbol is Absolute) {
    230         var x1 = ConvertToAutoDiff(node.GetSubtree(0));
    231         return abs(x1);
    232       }
     174        var x1 = ConvertNode(node.GetSubtree(0));
     175        return tf.abs(x1);
     176      }
     177
    233178      if (node.Symbol is AnalyticQuotient) {
    234         var x1 = ConvertToAutoDiff(node.GetSubtree(0));
    235         var x2 = ConvertToAutoDiff(node.GetSubtree(1));
    236         return x1 / (TermBuilder.Power(1 + x2 * x2, 0.5));
    237       }
     179        var x1 = ConvertNode(node.GetSubtree(0));
     180        var x2 = ConvertNode(node.GetSubtree(1));
     181        return x1 / tf.pow(1 + x2 * x2, 0.5);
     182      }
     183
    238184      if (node.Symbol is Logarithm) {
    239         return AutoDiff.TermBuilder.Log(
    240           ConvertToAutoDiff(node.GetSubtree(0)));
    241       }
     185        return math_ops.log(
     186          ConvertNode(node.GetSubtree(0)));
     187      }
     188
    242189      if (node.Symbol is Exponential) {
    243         return AutoDiff.TermBuilder.Exp(
    244           ConvertToAutoDiff(node.GetSubtree(0)));
    245       }
     190        return math_ops.pow(
     191          Math.E,
     192          ConvertNode(node.GetSubtree(0)));
     193      }
     194
    246195      if (node.Symbol is Square) {
    247         return AutoDiff.TermBuilder.Power(
    248           ConvertToAutoDiff(node.GetSubtree(0)), 2.0);
    249       }
     196        return tf.square(
     197          ConvertNode(node.GetSubtree(0)));
     198      }
     199
    250200      if (node.Symbol is SquareRoot) {
    251         return AutoDiff.TermBuilder.Power(
    252           ConvertToAutoDiff(node.GetSubtree(0)), 0.5);
    253       }
     201        return math_ops.sqrt(
     202          ConvertNode(node.GetSubtree(0)));
     203      }
     204
    254205      if (node.Symbol is Cube) {
    255         return AutoDiff.TermBuilder.Power(
    256           ConvertToAutoDiff(node.GetSubtree(0)), 3.0);
    257       }
     206        return math_ops.pow(
     207          ConvertNode(node.GetSubtree(0)), 3.0);
     208      }
     209
    258210      if (node.Symbol is CubeRoot) {
    259         return cbrt(ConvertToAutoDiff(node.GetSubtree(0)));
    260       }
     211        return math_ops.pow(
     212          ConvertNode(node.GetSubtree(0)), 1.0 / 3.0);
     213        // TODO
     214        // f: x < 0 ? -Math.Pow(-x, 1.0 / 3) : Math.Pow(x, 1.0 / 3),
     215        // g:  { var cbrt_x = x < 0 ? -Math.Pow(-x, 1.0 / 3) : Math.Pow(x, 1.0 / 3); return 1.0 / (3 * cbrt_x * cbrt_x); }
     216      }
     217
    261218      if (node.Symbol is Sine) {
    262         return sin(
    263           ConvertToAutoDiff(node.GetSubtree(0)));
    264       }
     219        return tf.sin(
     220          ConvertNode(node.GetSubtree(0)));
     221      }
     222
    265223      if (node.Symbol is Cosine) {
    266         return cos(
    267           ConvertToAutoDiff(node.GetSubtree(0)));
    268       }
     224        return tf.cos(
     225          ConvertNode(node.GetSubtree(0)));
     226      }
     227
    269228      if (node.Symbol is Tangent) {
    270         return tan(
    271           ConvertToAutoDiff(node.GetSubtree(0)));
    272       }
    273       if (node.Symbol is HyperbolicTangent) {
    274         return tanh(
    275           ConvertToAutoDiff(node.GetSubtree(0)));
    276       }
    277       if (node.Symbol is Erf) {
    278         return erf(
    279           ConvertToAutoDiff(node.GetSubtree(0)));
    280       }
    281       if (node.Symbol is Norm) {
    282         return norm(
    283           ConvertToAutoDiff(node.GetSubtree(0)));
    284       }
     229        return tf.tan(
     230          ConvertNode(node.GetSubtree(0)));
     231      }
     232
     233      if (node.Symbol is Mean) {
     234        return tf.reduce_mean(
     235          ConvertNode(node.GetSubtree(0)));
     236      }
     237
     238      //if (node.Symbol is StandardDeviation) {
     239      //  return tf.reduce_std(
     240      //    ConvertNode(node.GetSubtree(0)));
     241      //}
     242
     243      if (node.Symbol is Sum) {
     244        return tf.reduce_sum(
     245          ConvertNode(node.GetSubtree(0)));
     246      }
     247
    285248      if (node.Symbol is StartSymbol) {
    286249        if (addLinearScalingTerms) {
    287250          // scaling variables α, β are given at the beginning of the parameter vector
    288           var alpha = new AutoDiff.Variable();
    289           var beta = new AutoDiff.Variable();
     251          var alpha = tf.Variable(1.0);
     252          var beta = tf.Variable(0.0);
    290253          variables.Add(beta);
    291254          variables.Add(alpha);
    292           var t = ConvertToAutoDiff(node.GetSubtree(0));
     255          var t = ConvertNode(node.GetSubtree(0));
    293256          return t * alpha + beta;
    294         } else return ConvertToAutoDiff(node.GetSubtree(0));
    295       }
    296       throw new ConversionException();
    297     }
    298 
     257        } else return ConvertNode(node.GetSubtree(0));
     258      }
     259
     260      throw new NotSupportedException($"Node symbol {node.Symbol} is not supported.");
     261    }
    299262
    300263    // for each factor variable value we need a parameter which represents a binary indicator for that variable & value combination
    301264    // each binary indicator is only necessary once. So we only create a parameter if this combination is not yet available
    302     private static Term FindOrCreateParameter(Dictionary<DataForVariable, AutoDiff.Variable> parameters,
    303       string varName, string varValue = "", int lag = 0) {
    304       var data = new DataForVariable(varName, varValue, lag);
    305 
    306       AutoDiff.Variable par = null;
    307       if (!parameters.TryGetValue(data, out par)) {
     265    private static Tensor FindOrCreateParameter(Dictionary<DataForVariable, Tensor> parameters, string varName, string varValue = "") {
     266      var data = new DataForVariable(varName, varValue);
     267
     268      if (!parameters.TryGetValue(data, out var par)) {
    308269        // not found -> create new parameter and entries in names and values lists
    309         par = new AutoDiff.Variable();
     270        par = tf.placeholder(tf.float64, name: varName);
    310271        parameters.Add(data, par);
    311272      }
     
    320281          !(n.Symbol is BinaryFactorVariable) &&
    321282          !(n.Symbol is FactorVariable) &&
    322           !(n.Symbol is LaggedVariable) &&
    323283          !(n.Symbol is Constant) &&
    324284          !(n.Symbol is Addition) &&
     
    340300          !(n.Symbol is AnalyticQuotient) &&
    341301          !(n.Symbol is Cube) &&
    342           !(n.Symbol is CubeRoot)
     302          !(n.Symbol is CubeRoot) &&
     303          !(n.Symbol is Mean) &&
     304          //!(n.Symbol is StandardDeviation) &&
     305          !(n.Symbol is Sum)
    343306        select n).Any();
    344307      return !containsUnknownSymbol;
    345308    }
    346     #region exception class
    347     [Serializable]
    348     public class ConversionException : Exception {
    349 
    350       public ConversionException() {
    351       }
    352 
    353       public ConversionException(string message) : base(message) {
    354       }
    355 
    356       public ConversionException(string message, Exception inner) : base(message, inner) {
    357       }
    358 
    359       protected ConversionException(
    360         SerializationInfo info,
    361         StreamingContext context) : base(info, context) {
    362       }
    363     }
    364     #endregion
    365309  }
    366310}
Note: See TracChangeset for help on using the changeset viewer.