Changeset 17493 for branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression
- Timestamp:
- 04/02/20 13:13:38 (5 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/TensorFlowConstantOptimizationEvaluator.cs
r17489 r17493 111 111 out Dictionary<Tensor, string> parameters, out List<Tensor> variables/*, out double[] initialConstants*/); 112 112 113 if (!success) 114 return (ISymbolicExpressionTree)tree.Clone(); 115 113 116 #if EXPLICIT_SHAPE 114 117 var target = tf.placeholder(tf.float64, new TensorShape(numRows, 1), name: problemData.TargetVariable); … … 143 146 variablesFeed.Add(target, np.array(targetData, copy: false).reshape(numRows, 1)); 144 147 148 List<NDArray> constants; 145 149 using (var session = tf.Session()) { 146 150 session.run(tf.global_variables_initializer()); … … 162 166 Trace.WriteLine($"{v.name}: {session.run(v).ToString(true)}"); 163 167 } 168 169 constants = variables.Select(v => session.run(v)).ToList(); 164 170 } 165 171 166 if (!success) 167 return (ISymbolicExpressionTree)tree.Clone(); 172 if (applyLinearScaling) 173 constants = constants.Skip(2).ToList(); 174 var newTree = (ISymbolicExpressionTree)tree.Clone(); 175 UpdateConstants(newTree, constants, updateVariableWeights); 168 176 177 return newTree; 178 } 169 179 170 return null; 180 private static void UpdateConstants(ISymbolicExpressionTree tree, IList<NDArray> constants, bool updateVariableWeights) { 181 int i = 0; 182 foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) { 183 if (node is ConstantTreeNode constantTreeNode) 184 constantTreeNode.Value = constants[i++].GetDouble(0, 0); 185 else if (node is VariableTreeNodeBase variableTreeNodeBase && updateVariableWeights) 186 variableTreeNodeBase.Weight = constants[i++].GetDouble(0, 0); 187 else if (node is FactorVariableTreeNode factorVarTreeNode && updateVariableWeights) { 188 for (int j = 0; j < factorVarTreeNode.Weights.Length; j++) 189 factorVarTreeNode.Weights[j] = constants[i++].GetDouble(0, 0); 190 } 191 } 171 192 } 172 193
Note: See TracChangeset
for help on using the changeset viewer.