Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/TreeToTensorConverter.cs @ 17476

Last change on this file since 17476 was 17476, checked in by pfleck, 4 years ago

#3040 Worked on TF-based constant optimization.

File size: 11.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
26using Tensorflow;
27using static Tensorflow.Binding;
28using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector<double>;
29
30namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
31  public class TreeToTensorConverter {
32
33    #region helper class
34    public class DataForVariable {
35      public readonly string variableName;
36      public readonly string variableValue; // for factor vars
37
38      public DataForVariable(string varName, string varValue) {
39        this.variableName = varName;
40        this.variableValue = varValue;
41      }
42
43      public override bool Equals(object obj) {
44        var other = obj as DataForVariable;
45        if (other == null) return false;
46        return other.variableName.Equals(this.variableName) &&
47               other.variableValue.Equals(this.variableValue);
48      }
49
50      public override int GetHashCode() {
51        return variableName.GetHashCode() ^ variableValue.GetHashCode();
52      }
53    }
54    #endregion
55
56    public static bool TryConvert(ISymbolicExpressionTree tree, int numRows, int? vectorLength,
57      bool makeVariableWeightsVariable, bool addLinearScalingTerms,
58      out Tensor graph, out Dictionary<Tensor, string> parameters, out List<Tensor> variables
59/*, out double[] initialConstants*/) {
60
61      try {
62        var converter = new TreeToTensorConverter(numRows, vectorLength, makeVariableWeightsVariable, addLinearScalingTerms);
63        graph = converter.ConvertNode(tree.Root.GetSubtree(0));
64
65        //var parametersEntries = converter.parameters.ToList(); // guarantee same order for keys and values
66        parameters = converter.parameters; // parametersEntries.Select(kvp => kvp.Value).ToList();
67        variables = converter.variables;
68        //initialConstants = converter.initialConstants.ToArray();
69        return true;
70      } catch (NotSupportedException) {
71        graph = null;
72        parameters = null;
73        variables = null;
74        //initialConstants = null;
75        return false;
76      }
77    }
78
79    private readonly int numRows;
80    private readonly int? vectorLength;
81    private readonly bool makeVariableWeightsVariable;
82    private readonly bool addLinearScalingTerms;
83
84    //private readonly List<double> initialConstants = new List<double>();
85    private readonly Dictionary<Tensor, string> parameters = new Dictionary<Tensor, string>();
86    private readonly List<Tensor> variables = new List<Tensor>();
87
88    private TreeToTensorConverter(int numRows, int? vectorLength, bool makeVariableWeightsVariable, bool addLinearScalingTerms) {
89      this.numRows = numRows;
90      this.vectorLength = vectorLength;
91      this.makeVariableWeightsVariable = makeVariableWeightsVariable;
92      this.addLinearScalingTerms = addLinearScalingTerms;
93    }
94
95
96    private Tensor ConvertNode(ISymbolicExpressionTreeNode node) {
97      if (node.Symbol is Constant) {
98        var value = ((ConstantTreeNode)node).Value;
99        //initialConstants.Add(value);
100        var var = tf.Variable(value, name: $"c_{variables.Count}", dtype: tf.float64);
101        variables.Add(var);
102        return var;
103      }
104
105      if (node.Symbol is Variable/* || node.Symbol is BinaryFactorVariable*/) {
106        var varNode = node as VariableTreeNodeBase;
107        //var factorVarNode = node as BinaryFactorVariableTreeNode;
108        // factor variable values are only 0 or 1 and set in x accordingly
109        //var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;
110        //var par = FindOrCreateParameter(parameters, varNode.VariableName, varValue);
111        var shape = vectorLength.HasValue
112          ? new TensorShape(numRows, vectorLength.Value)
113          : new TensorShape(numRows);
114        var par = tf.placeholder(tf.float64, shape: shape, name: varNode.VariableName);
115        parameters.Add(par, varNode.VariableName);
116
117        if (makeVariableWeightsVariable) {
118          //initialConstants.Add(varNode.Weight);
119          var w = tf.Variable(varNode.Weight, name: $"w_{varNode.VariableName}_{variables.Count}", dtype: tf.float64);
120          variables.Add(w);
121          return w * par;
122        } else {
123          return varNode.Weight * par;
124        }
125      }
126
127      if (node.Symbol is FactorVariable) {
128        var factorVarNode = node as FactorVariableTreeNode;
129        var products = new List<Tensor>();
130        foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
131          //var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);
132          var par = tf.placeholder(tf.float64, shape: new TensorShape(numRows), name: factorVarNode.VariableName);
133          parameters.Add(par, factorVarNode.VariableName);
134
135          var value = factorVarNode.GetValue(variableValue);
136          //initialConstants.Add(value);
137          var wVar = tf.Variable(value, name: $"f_{factorVarNode.VariableName}_{variables.Count}");
138          variables.Add(wVar);
139
140          products.add(wVar * par);
141        }
142
143        return products.Aggregate((a, b) => a + b);
144      }
145
146      if (node.Symbol is Addition) {
147        var terms = new List<Tensor>();
148        foreach (var subTree in node.Subtrees) {
149          terms.Add(ConvertNode(subTree));
150        }
151
152        return terms.Aggregate((a, b) => a + b);
153      }
154
155      if (node.Symbol is Subtraction) {
156        var terms = new List<Tensor>();
157        for (int i = 0; i < node.SubtreeCount; i++) {
158          var t = ConvertNode(node.GetSubtree(i));
159          if (i > 0) t = -t;
160          terms.Add(t);
161        }
162
163        if (terms.Count == 1) return -terms[0];
164        else return terms.Aggregate((a, b) => a + b);
165      }
166
167      if (node.Symbol is Multiplication) {
168        var terms = new List<Tensor>();
169        foreach (var subTree in node.Subtrees) {
170          terms.Add(ConvertNode(subTree));
171        }
172
173        if (terms.Count == 1) return terms[0];
174        else return terms.Aggregate((a, b) => a * b);
175      }
176
177      if (node.Symbol is Division) {
178        var terms = new List<Tensor>();
179        foreach (var subTree in node.Subtrees) {
180          terms.Add(ConvertNode(subTree));
181        }
182
183        if (terms.Count == 1) return 1.0 / terms[0];
184        else return terms.Aggregate((a, b) => a * (1.0 / b));
185      }
186
187      if (node.Symbol is Absolute) {
188        var x1 = ConvertNode(node.GetSubtree(0));
189        return tf.abs(x1);
190      }
191
192      if (node.Symbol is AnalyticQuotient) {
193        var x1 = ConvertNode(node.GetSubtree(0));
194        var x2 = ConvertNode(node.GetSubtree(1));
195        return x1 / tf.pow(1 + x2 * x2, 0.5);
196      }
197
198      if (node.Symbol is Logarithm) {
199        return math_ops.log(
200          ConvertNode(node.GetSubtree(0)));
201      }
202
203      if (node.Symbol is Exponential) {
204        return math_ops.pow(
205          Math.E,
206          ConvertNode(node.GetSubtree(0)));
207      }
208
209      if (node.Symbol is Square) {
210        return tf.square(
211          ConvertNode(node.GetSubtree(0)));
212      }
213
214      if (node.Symbol is SquareRoot) {
215        return math_ops.sqrt(
216          ConvertNode(node.GetSubtree(0)));
217      }
218
219      if (node.Symbol is Cube) {
220        return math_ops.pow(
221          ConvertNode(node.GetSubtree(0)), 3.0);
222      }
223
224      if (node.Symbol is CubeRoot) {
225        return math_ops.pow(
226          ConvertNode(node.GetSubtree(0)), 1.0 / 3.0);
227        // TODO
228        // f: x < 0 ? -Math.Pow(-x, 1.0 / 3) : Math.Pow(x, 1.0 / 3),
229        // g:  { var cbrt_x = x < 0 ? -Math.Pow(-x, 1.0 / 3) : Math.Pow(x, 1.0 / 3); return 1.0 / (3 * cbrt_x * cbrt_x); }
230      }
231
232      if (node.Symbol is Sine) {
233        return tf.sin(
234          ConvertNode(node.GetSubtree(0)));
235      }
236
237      if (node.Symbol is Cosine) {
238        return tf.cos(
239          ConvertNode(node.GetSubtree(0)));
240      }
241
242      if (node.Symbol is Tangent) {
243        return tf.tan(
244          ConvertNode(node.GetSubtree(0)));
245      }
246
247      if (node.Symbol is Mean) {
248        return tf.reduce_mean(
249          ConvertNode(node.GetSubtree(0)),
250          axis: new[] { 1 });
251      }
252
253      //if (node.Symbol is StandardDeviation) {
254      //  return tf.reduce_std(
255      //    ConvertNode(node.GetSubtree(0)),
256      //    axis: new [] { 1 }
257      // );
258      //}
259
260      if (node.Symbol is Sum) {
261        return tf.reduce_sum(
262          ConvertNode(node.GetSubtree(0)),
263          axis: new[] { 1 });
264      }
265
266      if (node.Symbol is StartSymbol) {
267        if (addLinearScalingTerms) {
268          // scaling variables α, β are given at the beginning of the parameter vector
269          var alpha = tf.Variable(1.0, name: $"alpha_{1.0}", dtype: tf.float64);
270          var beta = tf.Variable(0.0, name: $"beta_{0.0}", dtype: tf.float64);
271          variables.Add(alpha);
272          variables.Add(beta);
273          var t = ConvertNode(node.GetSubtree(0));
274          return t * alpha + beta;
275        } else return ConvertNode(node.GetSubtree(0));
276      }
277
278      throw new NotSupportedException($"Node symbol {node.Symbol} is not supported.");
279    }
280
281    // for each factor variable value we need a parameter which represents a binary indicator for that variable & value combination
282    // each binary indicator is only necessary once. So we only create a parameter if this combination is not yet available
283    private static Tensor FindOrCreateParameter(Dictionary<DataForVariable, Tensor> parameters, string varName, string varValue = "") {
284      var data = new DataForVariable(varName, varValue);
285
286      if (!parameters.TryGetValue(data, out var par)) {
287        // not found -> create new parameter and entries in names and values lists
288        par = tf.placeholder(tf.float64, shape: new TensorShape(-1), name: varName);
289        parameters.Add(data, par);
290      }
291      return par;
292    }
293
294    public static bool IsCompatible(ISymbolicExpressionTree tree) {
295      var containsUnknownSymbol = (
296        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
297        where
298          !(n.Symbol is Variable) &&
299          //!(n.Symbol is BinaryFactorVariable) &&
300          //!(n.Symbol is FactorVariable) &&
301          !(n.Symbol is Constant) &&
302          !(n.Symbol is Addition) &&
303          !(n.Symbol is Subtraction) &&
304          !(n.Symbol is Multiplication) &&
305          !(n.Symbol is Division) &&
306          !(n.Symbol is Logarithm) &&
307          !(n.Symbol is Exponential) &&
308          !(n.Symbol is SquareRoot) &&
309          !(n.Symbol is Square) &&
310          !(n.Symbol is Sine) &&
311          !(n.Symbol is Cosine) &&
312          !(n.Symbol is Tangent) &&
313          !(n.Symbol is HyperbolicTangent) &&
314          !(n.Symbol is Erf) &&
315          !(n.Symbol is Norm) &&
316          !(n.Symbol is StartSymbol) &&
317          !(n.Symbol is Absolute) &&
318          !(n.Symbol is AnalyticQuotient) &&
319          !(n.Symbol is Cube) &&
320          !(n.Symbol is CubeRoot) &&
321          !(n.Symbol is Mean) &&
322          //!(n.Symbol is StandardDeviation) &&
323          !(n.Symbol is Sum)
324        select n).Any();
325      return !containsUnknownSymbol;
326    }
327  }
328}
Note: See TracBrowser for help on using the repository browser.