Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/TreeToTensorConverter.cs @ 17502

Last change on this file since 17502 was 17502, checked in by pfleck, 4 years ago

#3040

  • Switched whole TF-graph to float (Adam optimizer won't work with double).
  • Added progress and cancellation support for TF-const opt.
  • Added optional logging with console and/or file for later plotting.
File size: 10.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22#define EXPLICIT_SHAPE
23
24using System;
25using System.Collections.Generic;
26using System.Linq;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using NumSharp;
29using Tensorflow;
30using static Tensorflow.Binding;
31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
33  public class TreeToTensorConverter {
34
35    private static readonly TF_DataType DataType = tf.float32;
36
37    public static bool TryConvert(ISymbolicExpressionTree tree, int numRows, Dictionary<string, int> variableLengths,
38      bool makeVariableWeightsVariable, bool addLinearScalingTerms,
39      out Tensor graph, out Dictionary<Tensor, string> parameters, out List<Tensor> variables) {
40
41      try {
42        var converter = new TreeToTensorConverter(numRows, variableLengths, makeVariableWeightsVariable, addLinearScalingTerms);
43        graph = converter.ConvertNode(tree.Root.GetSubtree(0));
44
45        parameters = converter.parameters;
46        variables = converter.variables;
47        return true;
48      } catch (NotSupportedException) {
49        graph = null;
50        parameters = null;
51        variables = null;
52        return false;
53      }
54    }
55
56    private readonly int numRows;
57    private readonly Dictionary<string, int> variableLengths;
58    private readonly bool makeVariableWeightsVariable;
59    private readonly bool addLinearScalingTerms;
60
61    private readonly Dictionary<Tensor, string> parameters = new Dictionary<Tensor, string>();
62    private readonly List<Tensor> variables = new List<Tensor>();
63
64    private TreeToTensorConverter(int numRows, Dictionary<string, int> variableLengths, bool makeVariableWeightsVariable, bool addLinearScalingTerms) {
65      this.numRows = numRows;
66      this.variableLengths = variableLengths;
67      this.makeVariableWeightsVariable = makeVariableWeightsVariable;
68      this.addLinearScalingTerms = addLinearScalingTerms;
69    }
70
71
72    private Tensor ConvertNode(ISymbolicExpressionTreeNode node) {
73      if (node.Symbol is Constant) {
74        var value = (float)((ConstantTreeNode)node).Value;
75        var value_arr = np.array(value).reshape(1, 1);
76        var var = tf.Variable(value_arr, name: $"c_{variables.Count}", dtype: DataType);
77        variables.Add(var);
78        return var;
79      }
80
81      if (node.Symbol is Variable/* || node.Symbol is BinaryFactorVariable*/) {
82        var varNode = node as VariableTreeNodeBase;
83        //var factorVarNode = node as BinaryFactorVariableTreeNode;
84        // factor variable values are only 0 or 1 and set in x accordingly
85        //var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;
86        //var par = FindOrCreateParameter(parameters, varNode.VariableName, varValue);
87        var par = tf.placeholder(DataType, new TensorShape(numRows, variableLengths[varNode.VariableName]), name: varNode.VariableName);
88        parameters.Add(par, varNode.VariableName);
89
90        if (makeVariableWeightsVariable) {
91          var w_arr = np.array((float)varNode.Weight).reshape(1, 1);
92          var w = tf.Variable(w_arr, name: $"w_{varNode.VariableName}", dtype: DataType);
93          variables.Add(w);
94          return w * par;
95        } else {
96          return varNode.Weight * par;
97        }
98      }
99
100      //if (node.Symbol is FactorVariable) {
101      //  var factorVarNode = node as FactorVariableTreeNode;
102      //  var products = new List<Tensor>();
103      //  foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
104      //    //var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);
105      //    var par = tf.placeholder(DataType, new TensorShape(numRows, 1), name: factorVarNode.VariableName);
106      //    parameters.Add(par, factorVarNode.VariableName);
107
108      //    var value = factorVarNode.GetValue(variableValue);
109      //    //initialConstants.Add(value);
110      //    var wVar = (RefVariable)tf.VariableV1(value, name: $"f_{factorVarNode.VariableName}_{variables.Count}", dtype: DataType, shape: new[] { 1, 1 });
111      //    //var wVar = tf.Variable(value, name: $"f_{factorVarNode.VariableName}_{variables.Count}"/*, shape: new[] { 1, 1 }*/);
112      //    variables.Add(wVar);
113
114      //    products.add(wVar * par);
115      //  }
116
117      //  return products.Aggregate((a, b) => a + b);
118      //}
119
120      if (node.Symbol is Addition) {
121        var terms = node.Subtrees.Select(ConvertNode).ToList();
122        return terms.Aggregate((a, b) => a + b);
123      }
124
125      if (node.Symbol is Subtraction) {
126        var terms = node.Subtrees.Select(ConvertNode).ToList();
127        if (terms.Count == 1) return -terms[0];
128        return terms.Aggregate((a, b) => a - b);
129      }
130
131      if (node.Symbol is Multiplication) {
132        var terms = node.Subtrees.Select(ConvertNode).ToList();
133        return terms.Aggregate((a, b) => a * b);
134      }
135
136      if (node.Symbol is Division) {
137        var terms = node.Subtrees.Select(ConvertNode).ToList();
138        if (terms.Count == 1) return 1.0f / terms[0];
139        return terms.Aggregate((a, b) => a / b);
140      }
141
142      if (node.Symbol is Absolute) {
143        var x1 = ConvertNode(node.GetSubtree(0));
144        return tf.abs(x1);
145      }
146
147      if (node.Symbol is AnalyticQuotient) {
148        var x1 = ConvertNode(node.GetSubtree(0));
149        var x2 = ConvertNode(node.GetSubtree(1));
150        return x1 / tf.pow(1.0f + x2 * x2, 0.5f);
151      }
152
153      if (node.Symbol is Logarithm) {
154        return math_ops.log(
155          ConvertNode(node.GetSubtree(0)));
156      }
157
158      if (node.Symbol is Exponential) {
159        return math_ops.pow(
160          Math.E,
161          ConvertNode(node.GetSubtree(0)));
162      }
163
164      if (node.Symbol is Square) {
165        return tf.square(
166          ConvertNode(node.GetSubtree(0)));
167      }
168
169      if (node.Symbol is SquareRoot) {
170        return math_ops.sqrt(
171          ConvertNode(node.GetSubtree(0)));
172      }
173
174      if (node.Symbol is Cube) {
175        return math_ops.pow(
176          ConvertNode(node.GetSubtree(0)), 3.0f);
177      }
178
179      if (node.Symbol is CubeRoot) {
180        return math_ops.pow(
181          ConvertNode(node.GetSubtree(0)), 1.0f / 3.0f);
182        // TODO
183        // f: x < 0 ? -Math.Pow(-x, 1.0 / 3) : Math.Pow(x, 1.0 / 3),
184        // g:  { var cbrt_x = x < 0 ? -Math.Pow(-x, 1.0 / 3) : Math.Pow(x, 1.0 / 3); return 1.0 / (3 * cbrt_x * cbrt_x); }
185      }
186
187      if (node.Symbol is Sine) {
188        return tf.sin(
189          ConvertNode(node.GetSubtree(0)));
190      }
191
192      if (node.Symbol is Cosine) {
193        return tf.cos(
194          ConvertNode(node.GetSubtree(0)));
195      }
196
197      if (node.Symbol is Tangent) {
198        return tf.tan(
199          ConvertNode(node.GetSubtree(0)));
200      }
201
202      if (node.Symbol is Mean) {
203        return tf.reduce_mean(
204          ConvertNode(node.GetSubtree(0)),
205          axis: new[] { 1 },
206          keepdims: true);
207      }
208
209      //if (node.Symbol is StandardDeviation) {
210      //  return tf.reduce_std(
211      //    ConvertNode(node.GetSubtree(0)),
212      //    axis: new [] { 1 }
213      // );
214      //}
215
216      if (node.Symbol is Sum) {
217        return tf.reduce_sum(
218          ConvertNode(node.GetSubtree(0)),
219          axis: new[] { 1 },
220          keepdims: true);
221      }
222
223      if (node.Symbol is StartSymbol) {
224        Tensor prediction;
225        if (addLinearScalingTerms) {
226          // scaling variables α, β are given at the beginning of the parameter vector
227          var alpha_arr = np.array(1.0f).reshape(1, 1);
228          var alpha = tf.Variable(alpha_arr, name: "alpha", dtype: DataType);
229          var beta_arr = np.array(0.0f).reshape(1, 1);
230          var beta = tf.Variable(beta_arr, name: "beta", dtype: DataType);
231          variables.Add(alpha);
232          variables.Add(beta);
233          var t = ConvertNode(node.GetSubtree(0));
234          prediction = t * alpha + beta;
235        } else {
236          prediction = ConvertNode(node.GetSubtree(0));
237        }
238
239        return tf.reduce_sum(prediction, axis: new[] { 1 });
240      }
241
242      throw new NotSupportedException($"Node symbol {node.Symbol} is not supported.");
243    }
244
245    public static bool IsCompatible(ISymbolicExpressionTree tree) {
246      var containsUnknownSymbol = (
247        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
248        where
249          !(n.Symbol is Variable) &&
250          //!(n.Symbol is BinaryFactorVariable) &&
251          //!(n.Symbol is FactorVariable) &&
252          !(n.Symbol is Constant) &&
253          !(n.Symbol is Addition) &&
254          !(n.Symbol is Subtraction) &&
255          !(n.Symbol is Multiplication) &&
256          !(n.Symbol is Division) &&
257          !(n.Symbol is Logarithm) &&
258          !(n.Symbol is Exponential) &&
259          !(n.Symbol is SquareRoot) &&
260          !(n.Symbol is Square) &&
261          !(n.Symbol is Sine) &&
262          !(n.Symbol is Cosine) &&
263          !(n.Symbol is Tangent) &&
264          !(n.Symbol is HyperbolicTangent) &&
265          !(n.Symbol is Erf) &&
266          !(n.Symbol is Norm) &&
267          !(n.Symbol is StartSymbol) &&
268          !(n.Symbol is Absolute) &&
269          !(n.Symbol is AnalyticQuotient) &&
270          !(n.Symbol is Cube) &&
271          !(n.Symbol is CubeRoot) &&
272          !(n.Symbol is Mean) &&
273          //!(n.Symbol is StandardDeviation) &&
274          !(n.Symbol is Sum)
275        select n).Any();
276      return !containsUnknownSymbol;
277    }
278  }
279}
Note: See TracBrowser for help on using the repository browser.