- Timestamp:
- 04/01/20 15:49:03 (5 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/TreeToTensorConverter.cs
r17476 r17489 20 20 #endregion 21 21 22 #define EXPLICIT_SHAPE 23 22 24 using System; 23 25 using System.Collections.Generic; 24 26 using System.Linq; 25 27 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; 28 using NumSharp; 26 29 using Tensorflow; 27 30 using static Tensorflow.Binding; … … 54 57 #endregion 55 58 56 public static bool TryConvert(ISymbolicExpressionTree tree, int numRows, int? vectorLength,59 public static bool TryConvert(ISymbolicExpressionTree tree, int numRows, Dictionary<string, int> variableLengths, 57 60 bool makeVariableWeightsVariable, bool addLinearScalingTerms, 58 61 out Tensor graph, out Dictionary<Tensor, string> parameters, out List<Tensor> variables … … 60 63 61 64 try { 62 var converter = new TreeToTensorConverter(numRows, v ectorLength, makeVariableWeightsVariable, addLinearScalingTerms);65 var converter = new TreeToTensorConverter(numRows, variableLengths, makeVariableWeightsVariable, addLinearScalingTerms); 63 66 graph = converter.ConvertNode(tree.Root.GetSubtree(0)); 64 67 … … 78 81 79 82 private readonly int numRows; 80 private readonly int? vectorLength;83 private readonly Dictionary<string, int> variableLengths; 81 84 private readonly bool makeVariableWeightsVariable; 82 85 private readonly bool addLinearScalingTerms; … … 86 89 private readonly List<Tensor> variables = new List<Tensor>(); 87 90 88 private TreeToTensorConverter(int numRows, int? vectorLength, bool makeVariableWeightsVariable, bool addLinearScalingTerms) {91 private TreeToTensorConverter(int numRows, Dictionary<string, int> variableLengths, bool makeVariableWeightsVariable, bool addLinearScalingTerms) { 89 92 this.numRows = numRows; 90 this.v ectorLength = vectorLength;93 this.variableLengths = variableLengths; 91 94 this.makeVariableWeightsVariable = makeVariableWeightsVariable; 92 95 this.addLinearScalingTerms = addLinearScalingTerms; 93 96 } 97 94 98 95 99 … … 98 102 var value = ((ConstantTreeNode)node).Value; 99 103 //initialConstants.Add(value); 100 var var = tf.Variable(value, name: $"c_{variables.Count}", dtype: tf.float64); 104 #if EXPLICIT_SHAPE 105 //var var = (RefVariable)tf.VariableV1(value, name: $"c_{variables.Count}", dtype: tf.float64, shape: new[] { 1, 1 }); 106 var value_arr = np.array(value).reshape(1, 1); 107 var var = tf.Variable(value_arr, name: $"c_{variables.Count}", dtype: tf.float64); 108 #endif 109 //var var = tf.Variable(value, name: $"c_{variables.Count}", dtype: tf.float64/*, shape: new[] { 1, 1 }*/); 101 110 variables.Add(var); 102 111 return var; … … 109 118 //var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty; 110 119 //var par = FindOrCreateParameter(parameters, varNode.VariableName, varValue); 111 var shape = vectorLength.HasValue 112 ? new TensorShape(numRows, vectorLength.Value) 113 : new TensorShape(numRows); 114 var par = tf.placeholder(tf.float64, shape: shape, name: varNode.VariableName); 120 #if EXPLICIT_SHAPE 121 var par = tf.placeholder(tf.float64, new TensorShape(numRows, variableLengths[varNode.VariableName]), name: varNode.VariableName); 122 #endif 115 123 parameters.Add(par, varNode.VariableName); 116 124 117 125 if (makeVariableWeightsVariable) { 118 126 //initialConstants.Add(varNode.Weight); 119 var w = tf.Variable(varNode.Weight, name: $"w_{varNode.VariableName}_{variables.Count}", dtype: tf.float64); 127 #if EXPLICIT_SHAPE 128 //var w = (RefVariable)tf.VariableV1(varNode.Weight, name: $"w_{varNode.VariableName}_{variables.Count}", dtype: tf.float64, shape: new[] { 1, 1 }); 129 var w_arr = np.array(varNode.Weight).reshape(1, 1); 130 var w = tf.Variable(w_arr, name: $"w_{varNode.VariableName}", dtype: tf.float64); 131 #endif 132 //var w = tf.Variable(varNode.Weight, name: $"w_{varNode.VariableName}_{variables.Count}", dtype: tf.float64/*, shape: new[] { 1, 1 }*/); 120 133 variables.Add(w); 121 134 return w * par; … … 125 138 } 126 139 127 if (node.Symbol is FactorVariable) { 128 var factorVarNode = node as FactorVariableTreeNode; 129 var products = new List<Tensor>(); 130 foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) { 131 //var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue); 132 var par = tf.placeholder(tf.float64, shape: new TensorShape(numRows), name: factorVarNode.VariableName); 133 parameters.Add(par, factorVarNode.VariableName); 134 135 var value = factorVarNode.GetValue(variableValue); 136 //initialConstants.Add(value); 137 var wVar = tf.Variable(value, name: $"f_{factorVarNode.VariableName}_{variables.Count}"); 138 variables.Add(wVar); 139 140 products.add(wVar * par); 141 } 142 143 return products.Aggregate((a, b) => a + b); 144 } 140 //if (node.Symbol is FactorVariable) { 141 // var factorVarNode = node as FactorVariableTreeNode; 142 // var products = new List<Tensor>(); 143 // foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) { 144 // //var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue); 145 // var par = tf.placeholder(tf.float64, new TensorShape(numRows, 1), name: factorVarNode.VariableName); 146 // parameters.Add(par, factorVarNode.VariableName); 147 148 // var value = factorVarNode.GetValue(variableValue); 149 // //initialConstants.Add(value); 150 // var wVar = (RefVariable)tf.VariableV1(value, name: $"f_{factorVarNode.VariableName}_{variables.Count}", dtype: tf.float64, shape: new[] { 1, 1 }); 151 // //var wVar = tf.Variable(value, name: $"f_{factorVarNode.VariableName}_{variables.Count}"/*, shape: new[] { 1, 1 }*/); 152 // variables.Add(wVar); 153 154 // products.add(wVar * par); 155 // } 156 157 // return products.Aggregate((a, b) => a + b); 158 //} 145 159 146 160 if (node.Symbol is Addition) { … … 248 262 return tf.reduce_mean( 249 263 ConvertNode(node.GetSubtree(0)), 250 axis: new[] { 1 }); 264 axis: new[] { 1 }, 265 keepdims: true); 251 266 } 252 267 … … 261 276 return tf.reduce_sum( 262 277 ConvertNode(node.GetSubtree(0)), 263 axis: new[] { 1 }); 278 axis: new[] { 1 }, 279 keepdims: true); 264 280 } 265 281 … … 267 283 if (addLinearScalingTerms) { 268 284 // scaling variables α, β are given at the beginning of the parameter vector 269 var alpha = tf.Variable(1.0, name: $"alpha_{1.0}", dtype: tf.float64); 270 var beta = tf.Variable(0.0, name: $"beta_{0.0}", dtype: tf.float64); 285 #if EXPLICIT_SHAPE 286 //var alpha = (RefVariable)tf.VariableV1(1.0, name: $"alpha_{1.0}", dtype: tf.float64, shape: new[] { 1, 1 }); 287 //var beta = (RefVariable)tf.VariableV1(0.0, name: $"beta_{0.0}", dtype: tf.float64, shape: new[] { 1, 1 }); 288 289 var alpha_arr = np.array(1.0).reshape(1, 1); 290 var alpha = tf.Variable(alpha_arr, name: $"alpha", dtype: tf.float64); 291 var beta_arr = np.array(1.0).reshape(1, 1); 292 var beta = tf.Variable(beta_arr, name: $"beta", dtype: tf.float64); 293 #endif 294 //var alpha = tf.Variable(1.0, name: $"alpha_{1.0}", dtype: tf.float64/*, shape: new[] { 1, 1 }*/); 295 //var beta = tf.Variable(0.0, name: $"beta_{0.0}", dtype: tf.float64/*, shape: new[] { 1, 1 }*/); 271 296 variables.Add(alpha); 272 297 variables.Add(beta); … … 277 302 278 303 throw new NotSupportedException($"Node symbol {node.Symbol} is not supported."); 279 }280 281 // for each factor variable value we need a parameter which represents a binary indicator for that variable & value combination282 // each binary indicator is only necessary once. So we only create a parameter if this combination is not yet available283 private static Tensor FindOrCreateParameter(Dictionary<DataForVariable, Tensor> parameters, string varName, string varValue = "") {284 var data = new DataForVariable(varName, varValue);285 286 if (!parameters.TryGetValue(data, out var par)) {287 // not found -> create new parameter and entries in names and values lists288 par = tf.placeholder(tf.float64, shape: new TensorShape(-1), name: varName);289 parameters.Add(data, par);290 }291 return par;292 304 } 293 305
Note: See TracChangeset
for help on using the changeset viewer.