Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/TreeToTensorConverter.cs @ 18238

Last change on this file since 18238 was 18238, checked in by pfleck, 19 months ago

#3040 Print MSE progress for constant opt in simplifier.

File size: 11.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22#define EXPLICIT_SHAPE
23
24using System;
25using System.Collections.Generic;
26using System.Linq;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using NumSharp;
29using Tensorflow;
30using static Tensorflow.Binding;
31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
33  public class TreeToTensorConverter {
34
35    private static readonly TF_DataType DataType = tf.float32;
36
37    public static bool TryConvert(ISymbolicExpressionTree tree, int numRows, Dictionary<string, int> variableLengths,
38      bool makeVariableWeightsVariable, bool addLinearScalingTerms,
39      out Tensor graph, out Dictionary<Tensor, string> parameters, out List<Tensor> variables) {
40
41      try {
42        var converter = new TreeToTensorConverter(numRows, variableLengths, makeVariableWeightsVariable, addLinearScalingTerms);
43        graph = converter.ConvertNode(tree.Root.GetSubtree(0));
44
45        parameters = converter.parameters;
46        variables = converter.variables;
47        return true;
48      } catch (NotSupportedException) {
49        graph = null;
50        parameters = null;
51        variables = null;
52        return false;
53      }
54    }
55
56    private readonly int numRows;
57    private readonly Dictionary<string, int> variableLengths;
58    private readonly bool makeVariableWeightsVariable;
59    private readonly bool addLinearScalingTerms;
60
61    private readonly Dictionary<Tensor, string> parameters = new Dictionary<Tensor, string>();
62    private readonly List<Tensor> variables = new List<Tensor>();
63
64    private TreeToTensorConverter(int numRows, Dictionary<string, int> variableLengths, bool makeVariableWeightsVariable, bool addLinearScalingTerms) {
65      this.numRows = numRows;
66      this.variableLengths = variableLengths;
67      this.makeVariableWeightsVariable = makeVariableWeightsVariable;
68      this.addLinearScalingTerms = addLinearScalingTerms;
69    }
70
71
72    private Tensor ConvertNode(ISymbolicExpressionTreeNode node) {
73      if (node.Symbol is Constant) {
74        var value = (float)((ConstantTreeNode)node).Value;
75        var value_arr = np.array(value).reshape(1, 1);
76        var var = tf.Variable(value_arr, name: $"c_{variables.Count}", dtype: DataType);
77        variables.Add(var);
78        return var;
79      }
80
81      if (node.Symbol is Variable/* || node.Symbol is BinaryFactorVariable*/) {
82        var varNode = node as VariableTreeNodeBase;
83        //var factorVarNode = node as BinaryFactorVariableTreeNode;
84        // factor variable values are only 0 or 1 and set in x accordingly
85        //var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;
86        //var par = FindOrCreateParameter(parameters, varNode.VariableName, varValue);
87        var par = tf.placeholder(DataType, new TensorShape(numRows, variableLengths[varNode.VariableName]), name: varNode.VariableName);
88        parameters.Add(par, varNode.VariableName);
89
90        if (makeVariableWeightsVariable) {
91          var w_arr = np.array((float)varNode.Weight).reshape(1, 1);
92          var w = tf.Variable(w_arr, name: $"w_{varNode.VariableName}", dtype: DataType);
93          variables.Add(w);
94          return w * par;
95        } else {
96          return varNode.Weight * par;
97        }
98      }
99
100      //if (node.Symbol is FactorVariable) {
101      //  var factorVarNode = node as FactorVariableTreeNode;
102      //  var products = new List<Tensor>();
103      //  foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
104      //    //var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);
105      //    var par = tf.placeholder(DataType, new TensorShape(numRows, 1), name: factorVarNode.VariableName);
106      //    parameters.Add(par, factorVarNode.VariableName);
107
108      //    var value = factorVarNode.GetValue(variableValue);
109      //    //initialConstants.Add(value);
110      //    var wVar = (RefVariable)tf.VariableV1(value, name: $"f_{factorVarNode.VariableName}_{variables.Count}", dtype: DataType, shape: new[] { 1, 1 });
111      //    //var wVar = tf.Variable(value, name: $"f_{factorVarNode.VariableName}_{variables.Count}"/*, shape: new[] { 1, 1 }*/);
112      //    variables.Add(wVar);
113
114      //    products.add(wVar * par);
115      //  }
116
117      //  return products.Aggregate((a, b) => a + b);
118      //}
119
120      if (node.Symbol is Addition) {
121        var terms = node.Subtrees.Select(ConvertNode).ToList();
122        if (terms.Count == 1) return terms[0];
123        return terms.Aggregate((a, b) => a + b);
124      }
125
126      if (node.Symbol is Subtraction) {
127        var terms = node.Subtrees.Select(ConvertNode).ToList();
128        if (terms.Count == 1) return -terms[0];
129        return terms.Aggregate((a, b) => a - b);
130      }
131
132      if (node.Symbol is Multiplication) {
133        var terms = node.Subtrees.Select(ConvertNode).ToList();
134        if (terms.Count == 1) return terms[0];
135        return terms.Aggregate((a, b) => a * b);
136      }
137
138      if (node.Symbol is Division) {
139        var terms = node.Subtrees.Select(ConvertNode).ToList();
140        if (terms.Count == 1) return 1.0f / terms[0];
141        return terms.Aggregate((a, b) => a / b);
142      }
143
144      if (node.Symbol is Absolute) {
145        var x1 = ConvertNode(node.GetSubtree(0));
146        return tf.abs(x1);
147      }
148
149      if (node.Symbol is AnalyticQuotient) {
150        var x1 = ConvertNode(node.GetSubtree(0));
151        var x2 = ConvertNode(node.GetSubtree(1));
152        return x1 / tf.pow(1.0f + x2 * x2, 0.5f);
153      }
154
155      if (node.Symbol is Logarithm) {
156        return tf.log(
157          ConvertNode(node.GetSubtree(0)));
158      }
159
160      if (node.Symbol is Exponential) {
161        return tf.pow(
162          (float)Math.E,
163          ConvertNode(node.GetSubtree(0)));
164      }
165
166      if (node.Symbol is Square) {
167        return tf.square(
168          ConvertNode(node.GetSubtree(0)));
169      }
170
171      if (node.Symbol is SquareRoot) {
172        return tf.sqrt(
173          ConvertNode(node.GetSubtree(0)));
174      }
175
176      if (node.Symbol is Cube) {
177        return tf.pow(
178          ConvertNode(node.GetSubtree(0)), 3.0f);
179      }
180
181      if (node.Symbol is CubeRoot) {
182        return tf.pow(
183          ConvertNode(node.GetSubtree(0)), 1.0f / 3.0f);
184        // TODO
185        // f: x < 0 ? -Math.Pow(-x, 1.0 / 3) : Math.Pow(x, 1.0 / 3),
186        // g:  { var cbrt_x = x < 0 ? -Math.Pow(-x, 1.0 / 3) : Math.Pow(x, 1.0 / 3); return 1.0 / (3 * cbrt_x * cbrt_x); }
187      }
188
189      if (node.Symbol is Sine) {
190        return tf.sin(
191          ConvertNode(node.GetSubtree(0)));
192      }
193
194      if (node.Symbol is Cosine) {
195        return tf.cos(
196          ConvertNode(node.GetSubtree(0)));
197      }
198
199      if (node.Symbol is Tangent) {
200        return tf.tan(
201          ConvertNode(node.GetSubtree(0)));
202      }
203
204      if (node.Symbol is Mean) {
205        return tf.reduce_mean(
206          ConvertNode(node.GetSubtree(0)),
207          axis: new[] { 1 },
208          keepdims: true);
209      }
210
211      if (node.Symbol is StandardDeviation) {
212        return reduce_std(
213          ConvertNode(node.GetSubtree(0)),
214          axis: new[] { 1 },
215          keepdims: true
216       );
217      }
218
219      if (node.Symbol is Variance) {
220        return reduce_var(
221          ConvertNode(node.GetSubtree(0)),
222          axis: new[] { 1 } ,
223          keepdims: true
224        );
225      }
226
227      if (node.Symbol is Sum) {
228        return tf.reduce_sum(
229          ConvertNode(node.GetSubtree(0)),
230          axis: new[] { 1 },
231          keepdims: true);
232      }
233
234      if (node.Symbol is SubVector) {
235        var tensor = ConvertNode(node.GetSubtree(0));
236        int rows = tensor.shape[0], vectorLength = tensor.shape[1];
237        var windowedNode = (IWindowedSymbolTreeNode)node;
238        int startIdx = SymbolicDataAnalysisExpressionTreeVectorInterpreter.ToVectorIdx(windowedNode.Offset, vectorLength);
239        int endIdx = SymbolicDataAnalysisExpressionTreeVectorInterpreter.ToVectorIdx(windowedNode.Length, vectorLength);
240        var slices = SymbolicDataAnalysisExpressionTreeVectorInterpreter.GetVectorSlices(startIdx, endIdx, vectorLength);
241
242        var segments = new List<Tensor>();
243        foreach (var (start, count) in slices) {
244          segments.Add(tensor[new Slice(), new Slice(start, start + count)]);
245        }
246        return tf.concat(segments.ToArray(), axis: 1);
247      }
248
249
250      if (node.Symbol is StartSymbol) {
251        Tensor prediction;
252        if (addLinearScalingTerms) {
253          // scaling variables α, β are given at the beginning of the parameter vector
254          var alpha_arr = np.array(1.0f).reshape(1, 1);
255          var alpha = tf.Variable(alpha_arr, name: "alpha", dtype: DataType);
256          var beta_arr = np.array(0.0f).reshape(1, 1);
257          var beta = tf.Variable(beta_arr, name: "beta", dtype: DataType);
258          variables.Add(beta);
259          variables.Add(alpha);
260          var t = ConvertNode(node.GetSubtree(0));
261          prediction = t * alpha + beta;
262        } else {
263          prediction = ConvertNode(node.GetSubtree(0));
264        }
265
266        return tf.reshape(prediction, shape: new[] { -1 });
267      }
268
269      throw new NotSupportedException($"Node symbol {node.Symbol} is not supported.");
270    }
271
272    private static Tensor reduce_var(Tensor input_tensor,  int[] axis = null, bool keepdims = false) {
273      var means = tf.reduce_mean(input_tensor, axis, true);
274      var squared_deviation = tf.square(input_tensor - means);
275      return tf.reduce_mean(squared_deviation, axis, keepdims);
276    }
277    private static Tensor reduce_std(Tensor input_tensor, int[] axis = null, bool keepdims = false) {
278      return tf.sqrt(reduce_var(input_tensor, axis, keepdims));
279    }
280
281    public static bool IsCompatible(ISymbolicExpressionTree tree) {
282      var containsUnknownSymbol = (
283        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
284        where
285          !(n.Symbol is Variable) &&
286          //!(n.Symbol is BinaryFactorVariable) &&
287          //!(n.Symbol is FactorVariable) &&
288          !(n.Symbol is Constant) &&
289          !(n.Symbol is Addition) &&
290          !(n.Symbol is Subtraction) &&
291          !(n.Symbol is Multiplication) &&
292          !(n.Symbol is Division) &&
293          !(n.Symbol is Logarithm) &&
294          !(n.Symbol is Exponential) &&
295          !(n.Symbol is SquareRoot) &&
296          !(n.Symbol is Square) &&
297          !(n.Symbol is Sine) &&
298          !(n.Symbol is Cosine) &&
299          !(n.Symbol is Tangent) &&
300          !(n.Symbol is HyperbolicTangent) &&
301          !(n.Symbol is Erf) &&
302          !(n.Symbol is Norm) &&
303          !(n.Symbol is StartSymbol) &&
304          !(n.Symbol is Absolute) &&
305          !(n.Symbol is AnalyticQuotient) &&
306          !(n.Symbol is Cube) &&
307          !(n.Symbol is CubeRoot) &&
308          !(n.Symbol is Mean) &&
309          !(n.Symbol is StandardDeviation) &&
310          !(n.Symbol is Variance) &&
311          !(n.Symbol is Sum) &&
312          !(n.Symbol is SubVector)
313        select n).Any();
314      return !containsUnknownSymbol;
315    }
316  }
317}
Note: See TracBrowser for help on using the repository browser.