Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/TreeToTensorConverter.cs @ 17474

Last change on this file since 17474 was 17474, checked in by pfleck, 4 years ago

#3040 Started working on the TF constant opt evaluator.

File size: 10.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Runtime.Serialization;
26using AutoDiff;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using Tensorflow;
29using static Tensorflow.Binding;
30
31namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
32  public class TreeToTensorConverter {
33
34    #region helper class
35    public class DataForVariable {
36      public readonly string variableName;
37      public readonly string variableValue; // for factor vars
38
39      public DataForVariable(string varName, string varValue) {
40        this.variableName = varName;
41        this.variableValue = varValue;
42      }
43
44      public override bool Equals(object obj) {
45        var other = obj as DataForVariable;
46        if (other == null) return false;
47        return other.variableName.Equals(this.variableName) &&
48               other.variableValue.Equals(this.variableValue);
49      }
50
51      public override int GetHashCode() {
52        return variableName.GetHashCode() ^ variableValue.GetHashCode();
53      }
54    }
55    #endregion
56
57    public static bool TryConvert(ISymbolicExpressionTree tree, bool makeVariableWeightsVariable, bool addLinearScalingTerms,
58      out Tensor graph, out Dictionary<DataForVariable, Tensor> variables/*, out double[] initialConstants*/) {
59
60      try {
61        var converter = new TreeToTensorConverter(makeVariableWeightsVariable, addLinearScalingTerms);
62        graph = converter.ConvertNode(tree.Root.GetSubtree(0));
63
64        //var parametersEntries = converter.parameters.ToList(); // guarantee same order for keys and values
65        variables = converter.parameters; // parametersEntries.Select(kvp => kvp.Value).ToList();
66        //initialConstants = converter.initialConstants.ToArray();
67        return true;
68      } catch (NotSupportedException) {
69        graph = null;
70        variables = null;
71        //initialConstants = null;
72        return false;
73      }
74    }
75
76    private readonly bool makeVariableWeightsVariable;
77    private readonly bool addLinearScalingTerms;
78
79    //private readonly List<double> initialConstants = new List<double>();
80    private readonly Dictionary<DataForVariable, Tensor> parameters = new Dictionary<DataForVariable, Tensor>();
81    private readonly List<Tensor> variables = new List<Tensor>();
82
83    private TreeToTensorConverter(bool makeVariableWeightsVariable, bool addLinearScalingTerms) {
84      this.makeVariableWeightsVariable = makeVariableWeightsVariable;
85      this.addLinearScalingTerms = addLinearScalingTerms;
86    }
87
88
89    private Tensor ConvertNode(ISymbolicExpressionTreeNode node) {
90      if (node.Symbol is Constant) {
91        var value = ((ConstantTreeNode)node).Value;
92        //initialConstants.Add(value);
93        var var = tf.Variable(value);
94        variables.Add(var);
95        return var;
96      }
97
98      if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) {
99        var varNode = node as VariableTreeNodeBase;
100        var factorVarNode = node as BinaryFactorVariableTreeNode;
101        // factor variable values are only 0 or 1 and set in x accordingly
102        var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;
103        var par = FindOrCreateParameter(parameters, varNode.VariableName, varValue);
104
105        if (makeVariableWeightsVariable) {
106          //initialConstants.Add(varNode.Weight);
107          var w = tf.Variable(varNode.Weight);
108          variables.Add(w);
109          return w * par;
110        } else {
111          return varNode.Weight * par;
112        }
113      }
114
115      if (node.Symbol is FactorVariable) {
116        var factorVarNode = node as FactorVariableTreeNode;
117        var products = new List<Tensor>();
118        foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
119          var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);
120
121          var value = factorVarNode.GetValue(variableValue);
122          //initialConstants.Add(value);
123          var wVar = tf.Variable(value);
124          variables.Add(wVar);
125
126          products.add(wVar * par);
127        }
128
129        return tf.add_n(products.ToArray());
130      }
131
132      if (node.Symbol is Addition) {
133        var terms = new List<Tensor>();
134        foreach (var subTree in node.Subtrees) {
135          terms.Add(ConvertNode(subTree));
136        }
137
138        return tf.add_n(terms.ToArray());
139      }
140
141      if (node.Symbol is Subtraction) {
142        var terms = new List<Tensor>();
143        for (int i = 0; i < node.SubtreeCount; i++) {
144          var t = ConvertNode(node.GetSubtree(i));
145          if (i > 0) t = -t;
146          terms.Add(t);
147        }
148
149        if (terms.Count == 1) return -terms[0];
150        else return tf.add_n(terms.ToArray());
151      }
152
153      if (node.Symbol is Multiplication) {
154        var terms = new List<Tensor>();
155        foreach (var subTree in node.Subtrees) {
156          terms.Add(ConvertNode(subTree));
157        }
158
159        if (terms.Count == 1) return terms[0];
160        else return terms.Aggregate((a, b) => a * b);
161      }
162
163      if (node.Symbol is Division) {
164        var terms = new List<Tensor>();
165        foreach (var subTree in node.Subtrees) {
166          terms.Add(ConvertNode(subTree));
167        }
168
169        if (terms.Count == 1) return 1.0 / terms[0];
170        else return terms.Aggregate((a, b) => a * (1.0 / b));
171      }
172
173      if (node.Symbol is Absolute) {
174        var x1 = ConvertNode(node.GetSubtree(0));
175        return tf.abs(x1);
176      }
177
178      if (node.Symbol is AnalyticQuotient) {
179        var x1 = ConvertNode(node.GetSubtree(0));
180        var x2 = ConvertNode(node.GetSubtree(1));
181        return x1 / tf.pow(1 + x2 * x2, 0.5);
182      }
183
184      if (node.Symbol is Logarithm) {
185        return math_ops.log(
186          ConvertNode(node.GetSubtree(0)));
187      }
188
189      if (node.Symbol is Exponential) {
190        return math_ops.pow(
191          Math.E,
192          ConvertNode(node.GetSubtree(0)));
193      }
194
195      if (node.Symbol is Square) {
196        return tf.square(
197          ConvertNode(node.GetSubtree(0)));
198      }
199
200      if (node.Symbol is SquareRoot) {
201        return math_ops.sqrt(
202          ConvertNode(node.GetSubtree(0)));
203      }
204
205      if (node.Symbol is Cube) {
206        return math_ops.pow(
207          ConvertNode(node.GetSubtree(0)), 3.0);
208      }
209
210      if (node.Symbol is CubeRoot) {
211        return math_ops.pow(
212          ConvertNode(node.GetSubtree(0)), 1.0 / 3.0);
213        // TODO
214        // f: x < 0 ? -Math.Pow(-x, 1.0 / 3) : Math.Pow(x, 1.0 / 3),
215        // g:  { var cbrt_x = x < 0 ? -Math.Pow(-x, 1.0 / 3) : Math.Pow(x, 1.0 / 3); return 1.0 / (3 * cbrt_x * cbrt_x); }
216      }
217
218      if (node.Symbol is Sine) {
219        return tf.sin(
220          ConvertNode(node.GetSubtree(0)));
221      }
222
223      if (node.Symbol is Cosine) {
224        return tf.cos(
225          ConvertNode(node.GetSubtree(0)));
226      }
227
228      if (node.Symbol is Tangent) {
229        return tf.tan(
230          ConvertNode(node.GetSubtree(0)));
231      }
232
233      if (node.Symbol is Mean) {
234        return tf.reduce_mean(
235          ConvertNode(node.GetSubtree(0)));
236      }
237
238      //if (node.Symbol is StandardDeviation) {
239      //  return tf.reduce_std(
240      //    ConvertNode(node.GetSubtree(0)));
241      //}
242
243      if (node.Symbol is Sum) {
244        return tf.reduce_sum(
245          ConvertNode(node.GetSubtree(0)));
246      }
247
248      if (node.Symbol is StartSymbol) {
249        if (addLinearScalingTerms) {
250          // scaling variables α, β are given at the beginning of the parameter vector
251          var alpha = tf.Variable(1.0);
252          var beta = tf.Variable(0.0);
253          variables.Add(beta);
254          variables.Add(alpha);
255          var t = ConvertNode(node.GetSubtree(0));
256          return t * alpha + beta;
257        } else return ConvertNode(node.GetSubtree(0));
258      }
259
260      throw new NotSupportedException($"Node symbol {node.Symbol} is not supported.");
261    }
262
263    // for each factor variable value we need a parameter which represents a binary indicator for that variable & value combination
264    // each binary indicator is only necessary once. So we only create a parameter if this combination is not yet available
265    private static Tensor FindOrCreateParameter(Dictionary<DataForVariable, Tensor> parameters, string varName, string varValue = "") {
266      var data = new DataForVariable(varName, varValue);
267
268      if (!parameters.TryGetValue(data, out var par)) {
269        // not found -> create new parameter and entries in names and values lists
270        par = tf.placeholder(tf.float64, name: varName);
271        parameters.Add(data, par);
272      }
273      return par;
274    }
275
276    public static bool IsCompatible(ISymbolicExpressionTree tree) {
277      var containsUnknownSymbol = (
278        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
279        where
280          !(n.Symbol is Variable) &&
281          !(n.Symbol is BinaryFactorVariable) &&
282          !(n.Symbol is FactorVariable) &&
283          !(n.Symbol is Constant) &&
284          !(n.Symbol is Addition) &&
285          !(n.Symbol is Subtraction) &&
286          !(n.Symbol is Multiplication) &&
287          !(n.Symbol is Division) &&
288          !(n.Symbol is Logarithm) &&
289          !(n.Symbol is Exponential) &&
290          !(n.Symbol is SquareRoot) &&
291          !(n.Symbol is Square) &&
292          !(n.Symbol is Sine) &&
293          !(n.Symbol is Cosine) &&
294          !(n.Symbol is Tangent) &&
295          !(n.Symbol is HyperbolicTangent) &&
296          !(n.Symbol is Erf) &&
297          !(n.Symbol is Norm) &&
298          !(n.Symbol is StartSymbol) &&
299          !(n.Symbol is Absolute) &&
300          !(n.Symbol is AnalyticQuotient) &&
301          !(n.Symbol is Cube) &&
302          !(n.Symbol is CubeRoot) &&
303          !(n.Symbol is Mean) &&
304          //!(n.Symbol is StandardDeviation) &&
305          !(n.Symbol is Sum)
306        select n).Any();
307      return !containsUnknownSymbol;
308    }
309  }
310}
Note: See TracBrowser for help on using the repository browser.