Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
03/12/20 17:51:39 (5 years ago)
Author:
pfleck
Message:

#3040 Worked on TF-based constant optimization.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/TensorFlowConstantOptimizationEvaluator.cs

    r17475 r17476  
    2323using System.Collections;
    2424using System.Collections.Generic;
     25using System.Diagnostics;
    2526using System.Linq;
    2627using System.Threading;
     
    3132using HeuristicLab.Parameters;
    3233using HEAL.Attic;
     34using NumSharp;
    3335using Tensorflow;
    3436using static Tensorflow.Binding;
     
    9294      CancellationToken cancellationToken = default(CancellationToken), EvaluationsCounter counter = null) {
    9395
    94       bool success = TreeToTensorConverter.TryConvert(tree, updateVariableWeights, applyLinearScaling,
    95         out Tensor prediction, out Dictionary<TreeToTensorConverter.DataForVariable, Tensor> variables/*, out double[] initialConstants*/);
     96      var vectorVariables = tree.IterateNodesBreadth()
     97        .OfType<VariableTreeNodeBase>()
     98        .Where(node => problemData.Dataset.VariableHasType<DoubleVector>(node.VariableName))
     99        .Select(node => node.VariableName);
     100
     101      int? vectorLength = null;
     102      if (vectorVariables.Any()) {
     103        vectorLength = vectorVariables.Select(var => problemData.Dataset.GetDoubleVectorValues(var, rows)).First().First().Count;
     104      }
     105      int numRows = rows.Count();
     106
     107      bool success = TreeToTensorConverter.TryConvert(tree,
     108        numRows, vectorLength,
     109        updateVariableWeights, applyLinearScaling,
     110        out Tensor prediction,
     111        out Dictionary<Tensor, string> parameters, out List<Tensor> variables/*, out double[] initialConstants*/);
    96112
    97113      var target = tf.placeholder(tf.float64, name: problemData.TargetVariable);
     
    99115      // mse
    100116      var costs = tf.reduce_sum(tf.square(prediction - target)) / (2.0 * samples);
    101       var optimizer = tf.train.GradientDescentOptimizer((float)learningRate);
     117      var optimizer = tf.train.GradientDescentOptimizer((float)learningRate).minimize(costs);
    102118
    103119      // features as feed items
    104120      var variablesFeed = new Hashtable();
    105       foreach (var kvp in variables) {
    106         var variableName = kvp.Key.variableName;
    107         var variable = kvp.Value;
    108         if (problemData.Dataset.VariableHasType<double>(variableName))
    109           variablesFeed.Add(variable, problemData.Dataset.GetDoubleValues(variableName, rows));
    110         if (problemData.Dataset.VariableHasType<string>(variableName))
    111           variablesFeed.Add(variable, problemData.Dataset.GetStringValues(variableName, rows));
    112         if (problemData.Dataset.VariableHasType<DoubleVector>(variableName))
    113           variablesFeed.Add(variable, problemData.Dataset.GetDoubleVectorValues(variableName, rows));
    114         throw new NotSupportedException($"Type of the variable is not supported: {variableName}");
     121      foreach (var kvp in parameters) {
     122        var variable = kvp.Key;
     123        var variableName = kvp.Value;
     124        if (problemData.Dataset.VariableHasType<double>(variableName)) {
     125          var data = problemData.Dataset.GetDoubleValues(variableName, rows).ToArray();
     126          if (vectorLength.HasValue) {
     127            var vectorData = new double[numRows][];
     128            for (int i = 0; i < numRows; i++)
     129              vectorData[i] = Enumerable.Repeat(data[i], vectorLength.Value).ToArray();
     130            variablesFeed.Add(variable, np.array(vectorData));
     131          } else
     132            variablesFeed.Add(variable, np.array(data, copy: false));
     133          //} else if (problemData.Dataset.VariableHasType<string>(variableName)) {
     134          //  variablesFeed.Add(variable, problemData.Dataset.GetStringValues(variableName, rows));
     135        } else if (problemData.Dataset.VariableHasType<DoubleVector>(variableName)) {
     136          var data = problemData.Dataset.GetDoubleVectorValues(variableName, rows).Select(x => x.ToArray()).ToArray();
     137          variablesFeed.Add(variable, np.array(data));
     138        } else
     139          throw new NotSupportedException($"Type of the variable is not supported: {variableName}");
    115140      }
    116       variablesFeed.Add(target, problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows));
     141      var targetData = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).ToArray();
     142      variablesFeed.Add(target, np.array(targetData, copy: false));
    117143
    118144
    119145      using (var session = tf.Session()) {
     146        session.run(tf.global_variables_initializer());
     147
     148        Trace.WriteLine("Weights:");
     149        foreach (var v in variables)
     150          Trace.WriteLine($"{v.name}: {session.run(v).ToString(true)}");
     151
    120152        for (int i = 0; i < maxIterations; i++) {
    121           optimizer.minimize(costs);
    122           var result = session.run(optimizer, variablesFeed);
     153
     154          //optimizer.minimize(costs);
     155          session.run(optimizer, variablesFeed);
     156
     157          Trace.WriteLine("Weights:");
     158          foreach (var v in variables)
     159            Trace.WriteLine($"{v.name}: {session.run(v).ToString(true)}");
    123160        }
    124161      }
    125       optimizer.minimize(costs);
    126162
    127163      if (!success)
Note: See TracChangeset for help on using the changeset viewer.