Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
03/22/22 13:28:56 (2 years ago)
Author:
pfleck
Message:

#3040 Updated to newer TensorFlow.NET version.

  • Removed IL Merge from TensorFlow.NET.
  • Temporarily removed DiffSharp.
  • Changed to a locally built Attic with a specific Protobuf version that is compatible with TensorFlow.NET. (Also adapted other versions of nuget dependencies.)
File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/TensorFlowConstantOptimizationEvaluator.cs

    r17721 r18239  
    2525
    2626using System;
    27 using System.Collections;
    2827using System.Collections.Generic;
    2928#if LOG_CONSOLE
     
    4241using HeuristicLab.Parameters;
    4342using HEAL.Attic;
    44 using NumSharp;
    4543using Tensorflow;
     44using Tensorflow.NumPy;
    4645using static Tensorflow.Binding;
     46using static Tensorflow.KerasApi;
    4747using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector<double>;
    4848
     
    5454    private const string LearningRateName = "LearningRate";
    5555
     56    //private static readonly TF_DataType DataType = tf.float64;
    5657    private static readonly TF_DataType DataType = tf.float32;
    5758
     
    105106      CancellationToken cancellationToken = default(CancellationToken), IProgress<double> progress = null) {
    106107
    107       int numRows = rows.Count();
    108       var variableLengths = problemData.AllowedInputVariables.ToDictionary(
    109         var => var,
    110         var => {
    111           if (problemData.Dataset.VariableHasType<double>(var)) return 1;
    112           if (problemData.Dataset.VariableHasType<DoubleVector>(var)) return problemData.Dataset.GetDoubleVectorValue(var, 0).Count;
    113           throw new NotSupportedException($"Type of variable {var} is not supported.");
    114         });
    115 
    116       bool success = TreeToTensorConverter.TryConvert(tree,
    117         numRows, variableLengths,
     108      const bool eager = true;
     109
     110     bool prepared = TreeToTensorConverter.TryPrepareTree(
     111        tree,
     112        problemData, rows.ToList(),
    118113        updateVariableWeights, applyLinearScaling,
    119         out Tensor prediction,
    120         out Dictionary<Tensor, string> parameters, out List<Tensor> variables/*, out double[] initialConstants*/);
    121 
    122       if (!success)
     114        eager,
     115        out Dictionary<string, Tensor> inputFeatures, out Tensor target,
     116        out Dictionary<ISymbolicExpressionTreeNode, ResourceVariable[]> variables);
     117      if (!prepared)
    123118        return (ISymbolicExpressionTree)tree.Clone();
    124119
    125       var target = tf.placeholder(DataType, new TensorShape(numRows), name: problemData.TargetVariable);
    126       // MSE
    127       var cost = tf.reduce_mean(tf.square(target - prediction));
    128 
    129       var optimizer = tf.train.AdamOptimizer((float)learningRate);
    130       //var optimizer = tf.train.GradientDescentOptimizer((float)learningRate);
    131       var optimizationOperation = optimizer.minimize(cost);
    132 
    133 #if EXPORT_GRAPH
    134       //https://github.com/SciSharp/TensorFlow.NET/wiki/Debugging
    135       tf.train.export_meta_graph(@"C:\temp\TFboard\graph.meta", as_text: false,
    136         clear_devices: true, clear_extraneous_savers: false, strip_default_attrs: true);
    137 #endif
    138 
    139       // features as feed items
    140       var variablesFeed = new Hashtable();
    141       foreach (var kvp in parameters) {
    142         var variable = kvp.Key;
    143         var variableName = kvp.Value;
    144         if (problemData.Dataset.VariableHasType<double>(variableName)) {
    145           var data = problemData.Dataset.GetDoubleValues(variableName, rows).Select(x => (float)x).ToArray();
    146           variablesFeed.Add(variable, np.array(data).reshape(numRows, 1));
    147         } else if (problemData.Dataset.VariableHasType<DoubleVector>(variableName)) {
    148           var data = problemData.Dataset.GetDoubleVectorValues(variableName, rows).Select(x => x.Select(y => (float)y).ToArray()).ToArray();
    149           variablesFeed.Add(variable, np.array(data));
    150         } else
    151           throw new NotSupportedException($"Type of the variable is not supported: {variableName}");
     120      var optimizer = keras.optimizers.Adam((float)learningRate);
     121     
     122      for (int i = 0; i < maxIterations; i++) {
     123        if (cancellationToken.IsCancellationRequested) break;
     124
     125        using var tape = tf.GradientTape();
     126
     127        bool success = TreeToTensorConverter.TryEvaluate(
     128          tree,
     129          inputFeatures, variables,
     130          updateVariableWeights, applyLinearScaling,
     131          eager,
     132          out Tensor prediction);
     133        if (!success)
     134          return (ISymbolicExpressionTree)tree.Clone();
     135
     136        var loss = tf.reduce_mean(tf.square(target - prediction));
     137
     138        progress?.Report(loss.ToArray<float>()[0]);
     139
     140        var variablesList = variables.Values.SelectMany(x => x).ToList();
     141        var gradients = tape.gradient(loss, variablesList);
     142       
     143        optimizer.apply_gradients(zip(gradients, variablesList));
    152144      }
    153       var targetData = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(x => (float)x).ToArray();
    154       variablesFeed.Add(target, np.array(targetData));
    155 
    156 
    157       List<NDArray> constants;
    158       using (var session = tf.Session()) {
    159 
    160 #if LOG_FILE
    161         var directoryName = $"C:\\temp\\TFboard\\logdir\\manual_{DateTime.Now.ToString("yyyyMMddHHmmss")}_{maxIterations}_{learningRate.ToString(CultureInfo.InvariantCulture)}";
    162         Directory.CreateDirectory(directoryName);
    163         var costsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Costs.csv")));
    164         var weightsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Weights.csv")));
    165         var gradientsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Gradients.csv")));
    166 #endif
    167 
    168 #if LOG_CONSOLE || LOG_FILE
    169         var gradients = optimizer.compute_gradients(cost);
    170 #endif
    171 
    172         session.run(tf.global_variables_initializer());
    173 
    174         progress?.Report(session.run(cost, variablesFeed)[0].GetValue<float>(0));
    175 
    176 
    177 #if LOG_CONSOLE
    178         Trace.WriteLine("Costs:");
    179         Trace.WriteLine($"MSE: {session.run(cost, variablesFeed)[0].ToString(true)}");
    180 
    181         Trace.WriteLine("Weights:");
    182         foreach (var v in variables) {
    183           Trace.WriteLine($"{v.name}: {session.run(v).ToString(true)}");
    184         }
    185 
    186         Trace.WriteLine("Gradients:");
    187         foreach (var t in gradients) {
    188           Trace.WriteLine($"{t.Item2.name}: {session.run(t.Item1, variablesFeed)[0].ToString(true)}");
    189         }
    190 #endif
    191 
    192 #if LOG_FILE
    193         costsWriter.WriteLine("MSE");
    194         costsWriter.WriteLine(session.run(cost, variablesFeed)[0].GetValue<float>(0).ToString(CultureInfo.InvariantCulture));
    195 
    196         weightsWriter.WriteLine(string.Join(";", variables.Select(v => v.name)));
    197         weightsWriter.WriteLine(string.Join(";", variables.Select(v => session.run(v).GetValue<float>(0, 0).ToString(CultureInfo.InvariantCulture))));
    198 
    199         gradientsWriter.WriteLine(string.Join(";", gradients.Select(t => t.Item2.name)));
    200         gradientsWriter.WriteLine(string.Join(";", gradients.Select(t => session.run(t.Item1, variablesFeed)[0].GetValue<float>(0, 0).ToString(CultureInfo.InvariantCulture))));
    201 #endif
    202 
    203         for (int i = 0; i < maxIterations; i++) {
    204           if (cancellationToken.IsCancellationRequested)
     145     
     146      var cloner = new Cloner();
     147      var newTree = cloner.Clone(tree);
     148      var newConstants = variables.ToDictionary(
     149        kvp => (ISymbolicExpressionTreeNode)cloner.GetClone(kvp.Key),
     150        kvp => kvp.Value.Select(x => (double)(x.numpy().ToArray<float>()[0])).ToArray()
     151      );
     152      UpdateConstants(newTree, newConstants);
     153
     154
     155      return newTree;
     156
     157
     158     
     159
     160
     161//      //int numRows = rows.Count();
     162
     163
     164
     165
     166
     167
     168//      var variableLengths = problemData.AllowedInputVariables.ToDictionary(
     169//        var => var,
     170//        var => {
     171//          if (problemData.Dataset.VariableHasType<double>(var)) return 1;
     172//          if (problemData.Dataset.VariableHasType<DoubleVector>(var)) return problemData.Dataset.GetDoubleVectorValue(var, 0).Count;
     173//          throw new NotSupportedException($"Type of variable {var} is not supported.");
     174//        });
     175
     176//      var variablesDict = problemData.AllowedInputVariables.ToDictionary(
     177//        var => var,
     178//        var => {
     179//          if (problemData.Dataset.VariableHasType<double>(var)) {
     180//            var data = problemData.Dataset.GetDoubleValues(var, rows).Select(x => (float)x).ToArray();
     181//            return tf.convert_to_tensor(np.array(data).reshape(new Shape(numRows, 1)), DataType);
     182//          } else if (problemData.Dataset.VariableHasType<DoubleVector>(var)) {
     183//            var data = problemData.Dataset.GetDoubleVectorValues(var, rows).SelectMany(x => x.Select(y => (float)y)).ToArray();
     184//            return tf.convert_to_tensor(np.array(data).reshape(new Shape(numRows, -1)), DataType);
     185//          } else  throw new NotSupportedException($"Type of the variable is not supported: {var}");
     186//        }
     187//      );
     188
     189//      using var tape = tf.GradientTape(persistent: true);
     190
     191//      bool success = TreeToTensorConverter.TryEvaluateEager(tree,
     192//        numRows, variablesDict,
     193//        updateVariableWeights, applyLinearScaling,
     194//        out Tensor prediction,
     195//        out Dictionary<Tensor, string> parameters, out List<ResourceVariable> variables);
     196
     197//      //bool success = TreeToTensorConverter.TryConvert(tree,
     198//      //  numRows, variableLengths,
     199//      //  updateVariableWeights, applyLinearScaling,
     200//      //  out Tensor prediction,
     201//      //  out Dictionary<Tensor, string> parameters, out List<Tensor> variables);
     202
     203//      if (!success)
     204//        return (ISymbolicExpressionTree)tree.Clone();
     205
     206//      //var target = tf.placeholder(DataType, new Shape(numRows), name: problemData.TargetVariable);
     207//      var targetData = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(x => (float)x).ToArray();
     208//      var target = tf.convert_to_tensor(np.array(targetData).reshape(new Shape(numRows)), DataType);
     209//      // MSE
     210//      var cost = tf.reduce_sum(tf.square(prediction - target));
     211
     212//      tape.watch(cost);
     213
     214//      //var optimizer = tf.train.AdamOptimizer((float)learningRate);
     215//      //var optimizer = tf.train.AdamOptimizer(tf.constant(learningRate, DataType));
     216//      //var optimizer = tf.train.GradientDescentOptimizer((float)learningRate);
     217//      //var optimizer = tf.train.GradientDescentOptimizer(tf.constant(learningRate, DataType));
     218//      //var optimizer = tf.train.GradientDescentOptimizer((float)learningRate);
     219//      //var optimizer = tf.train.AdamOptimizer((float)learningRate);
     220//      //var optimizationOperation = optimizer.minimize(cost);
     221//      var optimizer = keras.optimizers.Adam((float)learningRate);
     222
     223//      #if EXPORT_GRAPH
     224//      //https://github.com/SciSharp/TensorFlow.NET/wiki/Debugging
     225//      tf.train.export_meta_graph(@"C:\temp\TFboard\graph.meta", as_text: false,
     226//        clear_devices: true, clear_extraneous_savers: false, strip_default_attrs: true);
     227//#endif
     228
     229//      //// features as feed items
     230//      //var variablesFeed = new Hashtable();
     231//      //foreach (var kvp in parameters) {
     232//      //  var variable = kvp.Key;
     233//      //  var variableName = kvp.Value;
     234//      //  if (problemData.Dataset.VariableHasType<double>(variableName)) {
     235//      //    var data = problemData.Dataset.GetDoubleValues(variableName, rows).Select(x => (float)x).ToArray();
     236//      //    variablesFeed.Add(variable, np.array(data).reshape(new Shape(numRows, 1)));
     237//      //  } else if (problemData.Dataset.VariableHasType<DoubleVector>(variableName)) {
     238//      //    var data = problemData.Dataset.GetDoubleVectorValues(variableName, rows).SelectMany(x => x.Select(y => (float)y)).ToArray();
     239//      //    variablesFeed.Add(variable, np.array(data).reshape(new Shape(numRows, -1)));
     240//      //  } else
     241//      //    throw new NotSupportedException($"Type of the variable is not supported: {variableName}");
     242//      //}
     243//      //var targetData = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(x => (float)x).ToArray();
     244//      //variablesFeed.Add(target, np.array(targetData));
     245
     246
     247//      List<NDArray> constants;
     248//      //using (var session = tf.Session()) {
     249
     250//#if LOG_FILE
     251//        var directoryName = $"C:\\temp\\TFboard\\logdir\\manual_{DateTime.Now.ToString("yyyyMMddHHmmss")}_{maxIterations}_{learningRate.ToString(CultureInfo.InvariantCulture)}";
     252//        Directory.CreateDirectory(directoryName);
     253//        var costsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Costs.csv")));
     254//        var weightsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Weights.csv")));
     255//        var gradientsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Gradients.csv")));
     256//#endif
     257
     258//      //session.run(tf.global_variables_initializer());
     259
     260//#if LOG_CONSOLE || LOG_FILE
     261//        var gradients = optimizer.compute_gradients(cost);
     262//#endif
     263
     264//      //var vars = variables.Select(v => session.run(v, variablesFeed)[0].ToArray<float>()[0]).ToList();
     265//      //var gradient = optimizer.compute_gradients(cost)
     266//      //  .Where(g => g.Item1 != null)
     267//      //  //.Select(g => session.run(g.Item1, variablesFeed)[0].GetValue<float>(0)).
     268//      //  .Select(g => session.run(g.Item1, variablesFeed)[0].ToArray<float>()[0])
     269//      //  .ToList();
     270
     271//      //var gradientPrediction = optimizer.compute_gradients(prediction)
     272//      //  .Where(g => g.Item1 != null)
     273//      //  .Select(g => session.run(g.Item1, variablesFeed)[0].ToArray<float>()[0])
     274//      //  .ToList();
     275
     276
     277//      //progress?.Report(session.run(cost, variablesFeed)[0].ToArray<float>()[0]);
     278//      progress?.Report(cost.ToArray<float>()[0]);
     279
     280
     281     
     282
     283
     284//#if LOG_CONSOLE
     285//        Trace.WriteLine("Costs:");
     286//        Trace.WriteLine($"MSE: {session.run(cost, variablesFeed)[0].ToString(true)}");
     287
     288//        Trace.WriteLine("Weights:");
     289//        foreach (var v in variables) {
     290//          Trace.WriteLine($"{v.name}: {session.run(v).ToString(true)}");
     291//        }
     292
     293//        Trace.WriteLine("Gradients:");
     294//        foreach (var t in gradients) {
     295//          Trace.WriteLine($"{t.Item2.name}: {session.run(t.Item1, variablesFeed)[0].ToString(true)}");
     296//        }
     297//#endif
     298
     299//#if LOG_FILE
     300//        costsWriter.WriteLine("MSE");
     301//        costsWriter.WriteLine(session.run(cost, variablesFeed)[0].ToArray<float>()[0].ToString(CultureInfo.InvariantCulture));
     302
     303//        weightsWriter.WriteLine(string.Join(";", variables.Select(v => v.name)));
     304//        weightsWriter.WriteLine(string.Join(";", variables.Select(v => session.run(v).ToArray<float>()[0].ToString(CultureInfo.InvariantCulture))));
     305
     306//        gradientsWriter.WriteLine(string.Join(";", gradients.Select(t => t.Item2.Name)));
     307//        gradientsWriter.WriteLine(string.Join(";", gradients.Select(t => session.run(t.Item1, variablesFeed)[0].ToArray<float>()[0].ToString(CultureInfo.InvariantCulture))));
     308//#endif
     309
     310//        for (int i = 0; i < maxIterations; i++) {
     311//          if (cancellationToken.IsCancellationRequested)
     312//            break;
     313
     314         
     315//        var gradients = tape.gradient(cost, variables);
     316//        //optimizer.apply_gradients(gradients.Zip(variables, Tuple.Create<Tensor, IVariableV1>).ToArray());
     317//        optimizer.apply_gradients(zip(gradients, variables));
     318       
     319
     320//        //session.run(optimizationOperation, variablesFeed);
     321
     322//        progress?.Report(cost.ToArray<float>()[0]);
     323//        //progress?.Report(session.run(cost, variablesFeed)[0].ToArray<float>()[0]);
     324
     325//#if LOG_CONSOLE
     326//          Trace.WriteLine("Costs:");
     327//          Trace.WriteLine($"MSE: {session.run(cost, variablesFeed)[0].ToString(true)}");
     328
     329//          Trace.WriteLine("Weights:");
     330//          foreach (var v in variables) {
     331//            Trace.WriteLine($"{v.name}: {session.run(v).ToString(true)}");
     332//          }
     333
     334//          Trace.WriteLine("Gradients:");
     335//          foreach (var t in gradients) {
     336//            Trace.WriteLine($"{t.Item2.name}: {session.run(t.Item1, variablesFeed)[0].ToString(true)}");
     337//          }
     338//#endif
     339
     340//#if LOG_FILE
     341//          costsWriter.WriteLine(session.run(cost, variablesFeed)[0].ToArray<float>()[0].ToString(CultureInfo.InvariantCulture));
     342//          weightsWriter.WriteLine(string.Join(";", variables.Select(v => session.run(v).ToArray<float>()[0].ToString(CultureInfo.InvariantCulture))));
     343//          gradientsWriter.WriteLine(string.Join(";", gradients.Select(t => session.run(t.Item1, variablesFeed)[0].ToArray<float>()[0].ToString(CultureInfo.InvariantCulture))));
     344//#endif
     345//      }
     346
     347//#if LOG_FILE
     348//        costsWriter.Close();
     349//        weightsWriter.Close();
     350//        gradientsWriter.Close();
     351//#endif
     352//      //constants = variables.Select(v => session.run(v)).ToList();
     353//      constants = variables.Select(v => v.numpy()).ToList();
     354//      //}
     355
     356//      if (applyLinearScaling)
     357//        constants = constants.Skip(2).ToList();
     358//      var newTree = (ISymbolicExpressionTree)tree.Clone();
     359//      UpdateConstants(newTree, constants, updateVariableWeights);
     360
     361//      return newTree;
     362    }
     363
     364    private static void UpdateConstants(ISymbolicExpressionTree tree, Dictionary<ISymbolicExpressionTreeNode, double[]> constants) {
     365      foreach (var kvp in constants) {
     366        var node = kvp.Key;
     367        var value = kvp.Value;
     368
     369        switch (node) {
     370          case ConstantTreeNode constantTreeNode:
     371            constantTreeNode.Value = value[0];
    205372            break;
    206 
    207           session.run(optimizationOperation, variablesFeed);
    208 
    209           progress?.Report(session.run(cost, variablesFeed)[0].GetValue<float>(0));
    210 
    211 #if LOG_CONSOLE
    212           Trace.WriteLine("Costs:");
    213           Trace.WriteLine($"MSE: {session.run(cost, variablesFeed)[0].ToString(true)}");
    214 
    215           Trace.WriteLine("Weights:");
    216           foreach (var v in variables) {
    217             Trace.WriteLine($"{v.name}: {session.run(v).ToString(true)}");
     373          case VariableTreeNodeBase variableTreeNodeBase:
     374            variableTreeNodeBase.Weight = value[0];
     375            break;
     376          case FactorVariableTreeNode factorVarTreeNode: {
     377            for (int i = 0; i < factorVarTreeNode.Weights.Length; i++) {
     378              factorVarTreeNode.Weights[i] = value[i];
     379            }
     380            break;
    218381          }
    219 
    220           Trace.WriteLine("Gradients:");
    221           foreach (var t in gradients) {
    222             Trace.WriteLine($"{t.Item2.name}: {session.run(t.Item1, variablesFeed)[0].ToString(true)}");
    223           }
    224 #endif
    225 
    226 #if LOG_FILE
    227           costsWriter.WriteLine(session.run(cost, variablesFeed)[0].GetValue<float>(0).ToString(CultureInfo.InvariantCulture));
    228           weightsWriter.WriteLine(string.Join(";", variables.Select(v => session.run(v).GetValue<float>(0, 0).ToString(CultureInfo.InvariantCulture))));
    229           gradientsWriter.WriteLine(string.Join(";", gradients.Select(t => session.run(t.Item1, variablesFeed)[0].GetValue<float>(0, 0).ToString(CultureInfo.InvariantCulture))));
    230 #endif
    231         }
    232 
    233 #if LOG_FILE
    234         costsWriter.Close();
    235         weightsWriter.Close();
    236         gradientsWriter.Close();
    237 #endif
    238         constants = variables.Select(v => session.run(v)).ToList();
    239       }
    240 
    241       if (applyLinearScaling)
    242         constants = constants.Skip(2).ToList();
    243       var newTree = (ISymbolicExpressionTree)tree.Clone();
    244       UpdateConstants(newTree, constants, updateVariableWeights);
    245 
    246       return newTree;
    247     }
    248 
    249     private static void UpdateConstants(ISymbolicExpressionTree tree, IList<NDArray> constants, bool updateVariableWeights) {
    250       int i = 0;
    251       foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
    252         if (node is ConstantTreeNode constantTreeNode)
    253           constantTreeNode.Value = constants[i++].GetValue<float>(0, 0);
    254         else if (node is VariableTreeNodeBase variableTreeNodeBase && updateVariableWeights)
    255           variableTreeNodeBase.Weight = constants[i++].GetValue<float>(0, 0);
    256         else if (node is FactorVariableTreeNode factorVarTreeNode && updateVariableWeights) {
    257           for (int j = 0; j < factorVarTreeNode.Weights.Length; j++)
    258             factorVarTreeNode.Weights[j] = constants[i++].GetValue<float>(0, 0);
    259382        }
    260383      }
    261384    }
    262385
     386    //private static void UpdateConstants(ISymbolicExpressionTree tree, IList<NDArray> constants, bool updateVariableWeights) {
     387    //  int i = 0;
     388    //  foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
     389    //    if (node is ConstantTreeNode constantTreeNode) {
     390    //      constantTreeNode.Value = constants[i++].ToArray<float>()[0];
     391    //    } else if (node is VariableTreeNodeBase variableTreeNodeBase && updateVariableWeights) {
     392    //      variableTreeNodeBase.Weight = constants[i++].ToArray<float>()[0];
     393    //    } else if (node is FactorVariableTreeNode factorVarTreeNode && updateVariableWeights) {
     394    //      for (int j = 0; j < factorVarTreeNode.Weights.Length; j++)
     395    //        factorVarTreeNode.Weights[j] = constants[i++].ToArray<float>()[0];
     396    //    }
     397    //  }
     398    //}
     399
    263400    public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) {
    264401      return TreeToTensorConverter.IsCompatible(tree);
Note: See TracChangeset for help on using the changeset viewer.