Free cookie consent management tool by TermsFeed Policy Generator

Changeset 17493


Ignore:
Timestamp:
04/02/20 13:13:38 (4 years ago)
Author:
pfleck
Message:

#3040 Write optimized constants back to tree.

Location:
branches/3040_VectorBasedGP
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/TensorFlowConstantOptimizationEvaluator.cs

    r17489 r17493  
    111111        out Dictionary<Tensor, string> parameters, out List<Tensor> variables/*, out double[] initialConstants*/);
    112112
     113      if (!success)
     114        return (ISymbolicExpressionTree)tree.Clone();
     115
    113116#if EXPLICIT_SHAPE
    114117      var target = tf.placeholder(tf.float64, new TensorShape(numRows, 1), name: problemData.TargetVariable);
     
    143146      variablesFeed.Add(target, np.array(targetData, copy: false).reshape(numRows, 1));
    144147
     148      List<NDArray> constants;
    145149      using (var session = tf.Session()) {
    146150        session.run(tf.global_variables_initializer());
     
    162166            Trace.WriteLine($"{v.name}: {session.run(v).ToString(true)}");
    163167        }
     168
     169        constants = variables.Select(v => session.run(v)).ToList();
    164170      }
    165171
    166       if (!success)
    167         return (ISymbolicExpressionTree)tree.Clone();
     172      if (applyLinearScaling)
     173        constants = constants.Skip(2).ToList();
     174      var newTree = (ISymbolicExpressionTree)tree.Clone();
     175      UpdateConstants(newTree, constants, updateVariableWeights);
    168176
     177      return newTree;
     178    }
    169179
    170       return null;
     180    private static void UpdateConstants(ISymbolicExpressionTree tree, IList<NDArray> constants, bool updateVariableWeights) {
     181      int i = 0;
     182      foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
     183        if (node is ConstantTreeNode constantTreeNode)
     184          constantTreeNode.Value = constants[i++].GetDouble(0, 0);
     185        else if (node is VariableTreeNodeBase variableTreeNodeBase && updateVariableWeights)
     186          variableTreeNodeBase.Weight = constants[i++].GetDouble(0, 0);
     187        else if (node is FactorVariableTreeNode factorVarTreeNode && updateVariableWeights) {
     188          for (int j = 0; j < factorVarTreeNode.Weights.Length; j++)
     189            factorVarTreeNode.Weights[j] = constants[i++].GetDouble(0, 0);
     190        }
     191      }
    171192    }
    172193
  • branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/TreeToTensorConverter.cs

    r17489 r17493  
    288288
    289289          var alpha_arr = np.array(1.0).reshape(1, 1);
    290           var alpha = tf.Variable(alpha_arr, name: $"alpha", dtype: tf.float64);
    291           var beta_arr = np.array(1.0).reshape(1, 1);
    292           var beta = tf.Variable(beta_arr, name: $"beta", dtype: tf.float64);
     290          var alpha = tf.Variable(alpha_arr, name: "alpha", dtype: tf.float64);
     291          var beta_arr = np.array(0.0).reshape(1, 1);
     292          var beta = tf.Variable(beta_arr, name: "beta", dtype: tf.float64);
    293293#endif
    294294          //var alpha = tf.Variable(1.0, name: $"alpha_{1.0}", dtype: tf.float64/*, shape: new[] { 1, 1 }*/);
Note: See TracChangeset for help on using the changeset viewer.