Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
04/01/20 15:49:03 (4 years ago)
Author:
pfleck
Message:

#3040 Added version with explicit array shapes for explicit broadcasting.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/TensorFlowConstantOptimizationEvaluator.cs

    r17476 r17489  
    1919 */
    2020#endregion
     21
     22#define EXPLICIT_SHAPE
    2123
    2224using System;
     
    9496      CancellationToken cancellationToken = default(CancellationToken), EvaluationsCounter counter = null) {
    9597
    96       var vectorVariables = tree.IterateNodesBreadth()
    97         .OfType<VariableTreeNodeBase>()
    98         .Where(node => problemData.Dataset.VariableHasType<DoubleVector>(node.VariableName))
    99         .Select(node => node.VariableName);
    100 
    101       int? vectorLength = null;
    102       if (vectorVariables.Any()) {
    103         vectorLength = vectorVariables.Select(var => problemData.Dataset.GetDoubleVectorValues(var, rows)).First().First().Count;
    104       }
    10598      int numRows = rows.Count();
     99      var variableLengths = problemData.AllowedInputVariables.ToDictionary(
     100        var => var,
     101        var => {
     102          if (problemData.Dataset.VariableHasType<double>(var)) return 1;
     103          if (problemData.Dataset.VariableHasType<DoubleVector>(var)) return problemData.Dataset.GetDoubleVectorValue(var, 0).Count;
     104          throw new NotSupportedException($"Type of variable {var} is not supported.");
     105        });
    106106
    107107      bool success = TreeToTensorConverter.TryConvert(tree,
    108         numRows, vectorLength,
     108        numRows, variableLengths,
    109109        updateVariableWeights, applyLinearScaling,
    110110        out Tensor prediction,
    111111        out Dictionary<Tensor, string> parameters, out List<Tensor> variables/*, out double[] initialConstants*/);
    112112
    113       var target = tf.placeholder(tf.float64, name: problemData.TargetVariable);
    114       int samples = rows.Count();
     113#if EXPLICIT_SHAPE
     114      var target = tf.placeholder(tf.float64, new TensorShape(numRows, 1), name: problemData.TargetVariable);
     115#endif
    115116      // mse
    116       var costs = tf.reduce_sum(tf.square(prediction - target)) / (2.0 * samples);
     117      var costs = tf.reduce_sum(tf.square(target - prediction)) / (2.0 * numRows);
    117118      var optimizer = tf.train.GradientDescentOptimizer((float)learningRate).minimize(costs);
    118119
     
    124125        if (problemData.Dataset.VariableHasType<double>(variableName)) {
    125126          var data = problemData.Dataset.GetDoubleValues(variableName, rows).ToArray();
    126           if (vectorLength.HasValue) {
    127             var vectorData = new double[numRows][];
    128             for (int i = 0; i < numRows; i++)
    129               vectorData[i] = Enumerable.Repeat(data[i], vectorLength.Value).ToArray();
    130             variablesFeed.Add(variable, np.array(vectorData));
    131           } else
    132             variablesFeed.Add(variable, np.array(data, copy: false));
     127          //if (vectorLength.HasValue) {
     128          //  var vectorData = new double[numRows][];
     129          //  for (int i = 0; i < numRows; i++)
     130          //    vectorData[i] = Enumerable.Repeat(data[i], vectorLength.Value).ToArray();
     131          //  variablesFeed.Add(variable, np.array(vectorData));
     132          //} else
     133          variablesFeed.Add(variable, np.array(data, copy: false).reshape(numRows, 1));
    133134          //} else if (problemData.Dataset.VariableHasType<string>(variableName)) {
    134135          //  variablesFeed.Add(variable, problemData.Dataset.GetStringValues(variableName, rows));
     
    140141      }
    141142      var targetData = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).ToArray();
    142       variablesFeed.Add(target, np.array(targetData, copy: false));
    143 
     143      variablesFeed.Add(target, np.array(targetData, copy: false).reshape(numRows, 1));
    144144
    145145      using (var session = tf.Session()) {
    146146        session.run(tf.global_variables_initializer());
     147
     148        // https://github.com/SciSharp/TensorFlow.NET/wiki/Debugging
     149        tf.train.export_meta_graph(@"C:\temp\TFboard\graph.meta", as_text: false);
    147150
    148151        Trace.WriteLine("Weights:");
Note: See TracChangeset for help on using the changeset viewer.