Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
04/02/20 13:13:38 (5 years ago)
Author:
pfleck
Message:

#3040 Write optimized constants back to tree.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/TensorFlowConstantOptimizationEvaluator.cs

    r17489 r17493  
    111111        out Dictionary<Tensor, string> parameters, out List<Tensor> variables/*, out double[] initialConstants*/);
    112112
     113      if (!success)
     114        return (ISymbolicExpressionTree)tree.Clone();
     115
    113116#if EXPLICIT_SHAPE
    114117      var target = tf.placeholder(tf.float64, new TensorShape(numRows, 1), name: problemData.TargetVariable);
     
    143146      variablesFeed.Add(target, np.array(targetData, copy: false).reshape(numRows, 1));
    144147
     148      List<NDArray> constants;
    145149      using (var session = tf.Session()) {
    146150        session.run(tf.global_variables_initializer());
     
    162166            Trace.WriteLine($"{v.name}: {session.run(v).ToString(true)}");
    163167        }
     168
     169        constants = variables.Select(v => session.run(v)).ToList();
    164170      }
    165171
    166       if (!success)
    167         return (ISymbolicExpressionTree)tree.Clone();
     172      if (applyLinearScaling)
     173        constants = constants.Skip(2).ToList();
     174      var newTree = (ISymbolicExpressionTree)tree.Clone();
     175      UpdateConstants(newTree, constants, updateVariableWeights);
    168176
     177      return newTree;
     178    }
    169179
    170       return null;
     180    private static void UpdateConstants(ISymbolicExpressionTree tree, IList<NDArray> constants, bool updateVariableWeights) {
     181      int i = 0;
     182      foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
     183        if (node is ConstantTreeNode constantTreeNode)
     184          constantTreeNode.Value = constants[i++].GetDouble(0, 0);
     185        else if (node is VariableTreeNodeBase variableTreeNodeBase && updateVariableWeights)
     186          variableTreeNodeBase.Weight = constants[i++].GetDouble(0, 0);
     187        else if (node is FactorVariableTreeNode factorVarTreeNode && updateVariableWeights) {
     188          for (int j = 0; j < factorVarTreeNode.Weights.Length; j++)
     189            factorVarTreeNode.Weights[j] = constants[i++].GetDouble(0, 0);
     190        }
     191      }
    171192    }
    172193
Note: See TracChangeset for help on using the changeset viewer.