Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/TreeToTensorConverter.cs @ 17489

Last change on this file since 17489 was 17489, checked in by pfleck, 4 years ago

#3040 Added version with explicit array shapes for explicit broadcasting.

File size: 12.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22#define EXPLICIT_SHAPE
23
24using System;
25using System.Collections.Generic;
26using System.Linq;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using NumSharp;
29using Tensorflow;
30using static Tensorflow.Binding;
31using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector<double>;
32
33namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
34  public class TreeToTensorConverter {
35
36    #region helper class
37    public class DataForVariable {
38      public readonly string variableName;
39      public readonly string variableValue; // for factor vars
40
41      public DataForVariable(string varName, string varValue) {
42        this.variableName = varName;
43        this.variableValue = varValue;
44      }
45
46      public override bool Equals(object obj) {
47        var other = obj as DataForVariable;
48        if (other == null) return false;
49        return other.variableName.Equals(this.variableName) &&
50               other.variableValue.Equals(this.variableValue);
51      }
52
53      public override int GetHashCode() {
54        return variableName.GetHashCode() ^ variableValue.GetHashCode();
55      }
56    }
57    #endregion
58
59    public static bool TryConvert(ISymbolicExpressionTree tree, int numRows, Dictionary<string, int> variableLengths,
60      bool makeVariableWeightsVariable, bool addLinearScalingTerms,
61      out Tensor graph, out Dictionary<Tensor, string> parameters, out List<Tensor> variables
62/*, out double[] initialConstants*/) {
63
64      try {
65        var converter = new TreeToTensorConverter(numRows, variableLengths, makeVariableWeightsVariable, addLinearScalingTerms);
66        graph = converter.ConvertNode(tree.Root.GetSubtree(0));
67
68        //var parametersEntries = converter.parameters.ToList(); // guarantee same order for keys and values
69        parameters = converter.parameters; // parametersEntries.Select(kvp => kvp.Value).ToList();
70        variables = converter.variables;
71        //initialConstants = converter.initialConstants.ToArray();
72        return true;
73      } catch (NotSupportedException) {
74        graph = null;
75        parameters = null;
76        variables = null;
77        //initialConstants = null;
78        return false;
79      }
80    }
81
82    private readonly int numRows;
83    private readonly Dictionary<string, int> variableLengths;
84    private readonly bool makeVariableWeightsVariable;
85    private readonly bool addLinearScalingTerms;
86
87    //private readonly List<double> initialConstants = new List<double>();
88    private readonly Dictionary<Tensor, string> parameters = new Dictionary<Tensor, string>();
89    private readonly List<Tensor> variables = new List<Tensor>();
90
91    private TreeToTensorConverter(int numRows, Dictionary<string, int> variableLengths, bool makeVariableWeightsVariable, bool addLinearScalingTerms) {
92      this.numRows = numRows;
93      this.variableLengths = variableLengths;
94      this.makeVariableWeightsVariable = makeVariableWeightsVariable;
95      this.addLinearScalingTerms = addLinearScalingTerms;
96    }
97
98
99
100    private Tensor ConvertNode(ISymbolicExpressionTreeNode node) {
101      if (node.Symbol is Constant) {
102        var value = ((ConstantTreeNode)node).Value;
103        //initialConstants.Add(value);
104#if EXPLICIT_SHAPE
105        //var var = (RefVariable)tf.VariableV1(value, name: $"c_{variables.Count}", dtype: tf.float64, shape: new[] { 1, 1 });
106        var value_arr = np.array(value).reshape(1, 1);
107        var var = tf.Variable(value_arr, name: $"c_{variables.Count}", dtype: tf.float64);
108#endif
109        //var var = tf.Variable(value, name: $"c_{variables.Count}", dtype: tf.float64/*, shape: new[] { 1, 1 }*/);
110        variables.Add(var);
111        return var;
112      }
113
114      if (node.Symbol is Variable/* || node.Symbol is BinaryFactorVariable*/) {
115        var varNode = node as VariableTreeNodeBase;
116        //var factorVarNode = node as BinaryFactorVariableTreeNode;
117        // factor variable values are only 0 or 1 and set in x accordingly
118        //var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;
119        //var par = FindOrCreateParameter(parameters, varNode.VariableName, varValue);
120#if EXPLICIT_SHAPE
121        var par = tf.placeholder(tf.float64, new TensorShape(numRows, variableLengths[varNode.VariableName]), name: varNode.VariableName);
122#endif
123        parameters.Add(par, varNode.VariableName);
124
125        if (makeVariableWeightsVariable) {
126          //initialConstants.Add(varNode.Weight);
127#if EXPLICIT_SHAPE
128          //var w = (RefVariable)tf.VariableV1(varNode.Weight, name: $"w_{varNode.VariableName}_{variables.Count}", dtype: tf.float64, shape: new[] { 1, 1 });
129          var w_arr = np.array(varNode.Weight).reshape(1, 1);
130          var w = tf.Variable(w_arr, name: $"w_{varNode.VariableName}", dtype: tf.float64);
131#endif
132          //var w = tf.Variable(varNode.Weight, name: $"w_{varNode.VariableName}_{variables.Count}", dtype: tf.float64/*, shape: new[] { 1, 1 }*/);
133          variables.Add(w);
134          return w * par;
135        } else {
136          return varNode.Weight * par;
137        }
138      }
139
140      //if (node.Symbol is FactorVariable) {
141      //  var factorVarNode = node as FactorVariableTreeNode;
142      //  var products = new List<Tensor>();
143      //  foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
144      //    //var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);
145      //    var par = tf.placeholder(tf.float64, new TensorShape(numRows, 1), name: factorVarNode.VariableName);
146      //    parameters.Add(par, factorVarNode.VariableName);
147
148      //    var value = factorVarNode.GetValue(variableValue);
149      //    //initialConstants.Add(value);
150      //    var wVar = (RefVariable)tf.VariableV1(value, name: $"f_{factorVarNode.VariableName}_{variables.Count}", dtype: tf.float64, shape: new[] { 1, 1 });
151      //    //var wVar = tf.Variable(value, name: $"f_{factorVarNode.VariableName}_{variables.Count}"/*, shape: new[] { 1, 1 }*/);
152      //    variables.Add(wVar);
153
154      //    products.add(wVar * par);
155      //  }
156
157      //  return products.Aggregate((a, b) => a + b);
158      //}
159
160      if (node.Symbol is Addition) {
161        var terms = new List<Tensor>();
162        foreach (var subTree in node.Subtrees) {
163          terms.Add(ConvertNode(subTree));
164        }
165
166        return terms.Aggregate((a, b) => a + b);
167      }
168
169      if (node.Symbol is Subtraction) {
170        var terms = new List<Tensor>();
171        for (int i = 0; i < node.SubtreeCount; i++) {
172          var t = ConvertNode(node.GetSubtree(i));
173          if (i > 0) t = -t;
174          terms.Add(t);
175        }
176
177        if (terms.Count == 1) return -terms[0];
178        else return terms.Aggregate((a, b) => a + b);
179      }
180
181      if (node.Symbol is Multiplication) {
182        var terms = new List<Tensor>();
183        foreach (var subTree in node.Subtrees) {
184          terms.Add(ConvertNode(subTree));
185        }
186
187        if (terms.Count == 1) return terms[0];
188        else return terms.Aggregate((a, b) => a * b);
189      }
190
191      if (node.Symbol is Division) {
192        var terms = new List<Tensor>();
193        foreach (var subTree in node.Subtrees) {
194          terms.Add(ConvertNode(subTree));
195        }
196
197        if (terms.Count == 1) return 1.0 / terms[0];
198        else return terms.Aggregate((a, b) => a * (1.0 / b));
199      }
200
201      if (node.Symbol is Absolute) {
202        var x1 = ConvertNode(node.GetSubtree(0));
203        return tf.abs(x1);
204      }
205
206      if (node.Symbol is AnalyticQuotient) {
207        var x1 = ConvertNode(node.GetSubtree(0));
208        var x2 = ConvertNode(node.GetSubtree(1));
209        return x1 / tf.pow(1 + x2 * x2, 0.5);
210      }
211
212      if (node.Symbol is Logarithm) {
213        return math_ops.log(
214          ConvertNode(node.GetSubtree(0)));
215      }
216
217      if (node.Symbol is Exponential) {
218        return math_ops.pow(
219          Math.E,
220          ConvertNode(node.GetSubtree(0)));
221      }
222
223      if (node.Symbol is Square) {
224        return tf.square(
225          ConvertNode(node.GetSubtree(0)));
226      }
227
228      if (node.Symbol is SquareRoot) {
229        return math_ops.sqrt(
230          ConvertNode(node.GetSubtree(0)));
231      }
232
233      if (node.Symbol is Cube) {
234        return math_ops.pow(
235          ConvertNode(node.GetSubtree(0)), 3.0);
236      }
237
238      if (node.Symbol is CubeRoot) {
239        return math_ops.pow(
240          ConvertNode(node.GetSubtree(0)), 1.0 / 3.0);
241        // TODO
242        // f: x < 0 ? -Math.Pow(-x, 1.0 / 3) : Math.Pow(x, 1.0 / 3),
243        // g:  { var cbrt_x = x < 0 ? -Math.Pow(-x, 1.0 / 3) : Math.Pow(x, 1.0 / 3); return 1.0 / (3 * cbrt_x * cbrt_x); }
244      }
245
246      if (node.Symbol is Sine) {
247        return tf.sin(
248          ConvertNode(node.GetSubtree(0)));
249      }
250
251      if (node.Symbol is Cosine) {
252        return tf.cos(
253          ConvertNode(node.GetSubtree(0)));
254      }
255
256      if (node.Symbol is Tangent) {
257        return tf.tan(
258          ConvertNode(node.GetSubtree(0)));
259      }
260
261      if (node.Symbol is Mean) {
262        return tf.reduce_mean(
263          ConvertNode(node.GetSubtree(0)),
264          axis: new[] { 1 },
265          keepdims: true);
266      }
267
268      //if (node.Symbol is StandardDeviation) {
269      //  return tf.reduce_std(
270      //    ConvertNode(node.GetSubtree(0)),
271      //    axis: new [] { 1 }
272      // );
273      //}
274
275      if (node.Symbol is Sum) {
276        return tf.reduce_sum(
277          ConvertNode(node.GetSubtree(0)),
278          axis: new[] { 1 },
279          keepdims: true);
280      }
281
282      if (node.Symbol is StartSymbol) {
283        if (addLinearScalingTerms) {
284          // scaling variables α, β are given at the beginning of the parameter vector
285#if EXPLICIT_SHAPE
286          //var alpha = (RefVariable)tf.VariableV1(1.0, name: $"alpha_{1.0}", dtype: tf.float64, shape: new[] { 1, 1 });
287          //var beta = (RefVariable)tf.VariableV1(0.0, name: $"beta_{0.0}", dtype: tf.float64, shape: new[] { 1, 1 });
288
289          var alpha_arr = np.array(1.0).reshape(1, 1);
290          var alpha = tf.Variable(alpha_arr, name: $"alpha", dtype: tf.float64);
291          var beta_arr = np.array(1.0).reshape(1, 1);
292          var beta = tf.Variable(beta_arr, name: $"beta", dtype: tf.float64);
293#endif
294          //var alpha = tf.Variable(1.0, name: $"alpha_{1.0}", dtype: tf.float64/*, shape: new[] { 1, 1 }*/);
295          //var beta = tf.Variable(0.0, name: $"beta_{0.0}", dtype: tf.float64/*, shape: new[] { 1, 1 }*/);
296          variables.Add(alpha);
297          variables.Add(beta);
298          var t = ConvertNode(node.GetSubtree(0));
299          return t * alpha + beta;
300        } else return ConvertNode(node.GetSubtree(0));
301      }
302
303      throw new NotSupportedException($"Node symbol {node.Symbol} is not supported.");
304    }
305
306    public static bool IsCompatible(ISymbolicExpressionTree tree) {
307      var containsUnknownSymbol = (
308        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
309        where
310          !(n.Symbol is Variable) &&
311          //!(n.Symbol is BinaryFactorVariable) &&
312          //!(n.Symbol is FactorVariable) &&
313          !(n.Symbol is Constant) &&
314          !(n.Symbol is Addition) &&
315          !(n.Symbol is Subtraction) &&
316          !(n.Symbol is Multiplication) &&
317          !(n.Symbol is Division) &&
318          !(n.Symbol is Logarithm) &&
319          !(n.Symbol is Exponential) &&
320          !(n.Symbol is SquareRoot) &&
321          !(n.Symbol is Square) &&
322          !(n.Symbol is Sine) &&
323          !(n.Symbol is Cosine) &&
324          !(n.Symbol is Tangent) &&
325          !(n.Symbol is HyperbolicTangent) &&
326          !(n.Symbol is Erf) &&
327          !(n.Symbol is Norm) &&
328          !(n.Symbol is StartSymbol) &&
329          !(n.Symbol is Absolute) &&
330          !(n.Symbol is AnalyticQuotient) &&
331          !(n.Symbol is Cube) &&
332          !(n.Symbol is CubeRoot) &&
333          !(n.Symbol is Mean) &&
334          //!(n.Symbol is StandardDeviation) &&
335          !(n.Symbol is Sum)
336        select n).Any();
337      return !containsUnknownSymbol;
338    }
339  }
340}
Note: See TracBrowser for help on using the repository browser.