#region License Information /* HeuristicLab * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion //#define EXPORT_GRAPH //#define LOG_CONSOLE //#define LOG_FILE using System; using System.Collections; using System.Collections.Generic; #if LOG_CONSOLE using System.Diagnostics; #endif #if LOG_FILE using System.Globalization; using System.IO; #endif using System.Linq; using System.Threading; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; using HeuristicLab.Parameters; using HEAL.Attic; using Tensorflow; using Tensorflow.NumPy; using static Tensorflow.Binding; using static Tensorflow.KerasApi; using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector; namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression { [StorableType("63944BF6-62E5-4BE4-974C-D30AD8770F99")] [Item("TensorFlowConstantOptimizationEvaluator", "")] public class TensorFlowConstantOptimizationEvaluator : SymbolicRegressionConstantOptimizationEvaluator { private const string MaximumIterationsName = "MaximumIterations"; private const string LearningRateName = "LearningRate"; //private static readonly TF_DataType DataType = tf.float64; //private static readonly TF_DataType DataType = tf.float32; #region Parameter Properties public IFixedValueParameter ConstantOptimizationIterationsParameter { get { return (IFixedValueParameter)Parameters[MaximumIterationsName]; } } public IFixedValueParameter LearningRateParameter { get { return (IFixedValueParameter)Parameters[LearningRateName]; } } #endregion #region Properties public int ConstantOptimizationIterations { get { return ConstantOptimizationIterationsParameter.Value.Value; } } public double LearningRate { get { return LearningRateParameter.Value.Value; } } #endregion public TensorFlowConstantOptimizationEvaluator() : base() { Parameters.Add(new FixedValueParameter(MaximumIterationsName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree(0 indicates other or default stopping criterion).", new IntValue(10))); Parameters.Add(new FixedValueParameter(LearningRateName, "", new DoubleValue(0.001))); } protected TensorFlowConstantOptimizationEvaluator(TensorFlowConstantOptimizationEvaluator original, Cloner cloner) : base(original, cloner) { } public override IDeepCloneable Clone(Cloner cloner) { return new TensorFlowConstantOptimizationEvaluator(this, cloner); } [StorableConstructor] protected TensorFlowConstantOptimizationEvaluator(StorableConstructorFlag _) : base(_) { } protected override ISymbolicExpressionTree OptimizeConstants( ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable rows, CancellationToken cancellationToken = default(CancellationToken), EvaluationsCounter counter = null) { return OptimizeTree(tree, problemData, rows, ApplyLinearScalingParameter.ActualValue.Value, UpdateVariableWeights, ConstantOptimizationIterations, LearningRate, cancellationToken); } public static ISymbolicExpressionTree OptimizeTree(ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable rows, bool applyLinearScaling, bool updateVariableWeights, int maxIterations, double learningRate, CancellationToken cancellationToken = default(CancellationToken), IProgress progress = null) { const bool eager = true; #if LOG_FILE var directoryName = $"C:\\temp\\TFboard\\logdir\\TF_{DateTime.Now.ToString("yyyyMMddHHmmss")}_{maxIterations}_{learningRate.ToString(CultureInfo.InvariantCulture)}"; Directory.CreateDirectory(directoryName); using var predictionTargetLossWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "PredictionTargetLos.csv"))); using var weightsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Weights.csv"))); using var treeGradsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "TreeGrads.csv"))); using var lossGradsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "LossGrads.csv"))); predictionTargetLossWriter.WriteLine(string.Join(";", "Prediction", "Target", "Loss")); weightsWriter.WriteLine(string.Join(";", Enumerable.Range(0, 4).Select(i => $"w_{i}"))); treeGradsWriter.WriteLine(string.Join(";", Enumerable.Range(0, 4).Select(i => $"Tg_{i}"))); lossGradsWriter.WriteLine(string.Join(";", Enumerable.Range(0, 4).Select(i => $"Lg_{i}"))); #endif //foreach (var row in rows) { bool prepared = TreeToTensorConverter.TryPrepareTree( tree, problemData, rows.ToList(), //problemData, new List(){ row }, updateVariableWeights, applyLinearScaling, eager, out Dictionary inputFeatures, out Tensor target, out Dictionary variables); if (!prepared) return (ISymbolicExpressionTree)tree.Clone(); var optimizer = keras.optimizers.Adam((float)learningRate); for (int i = 0; i < maxIterations; i++) { if (cancellationToken.IsCancellationRequested) break; #if LOG_FILE || LOG_CONSOLE using var tape = tf.GradientTape(persistent: true); #else using var tape = tf.GradientTape(persistent: false); #endif bool success = TreeToTensorConverter.TryEvaluate( tree, inputFeatures, variables, updateVariableWeights, applyLinearScaling, eager, out Tensor prediction); if (!success) return (ISymbolicExpressionTree)tree.Clone(); var loss = tf.reduce_mean(tf.square(target - prediction)); progress?.Report(loss.ToArray()[0]); var variablesList = variables.Values.SelectMany(x => x).ToList(); var gradients = tape.gradient(loss, variablesList); #if LOG_FILE predictionTargetLossWriter.WriteLine(string.Join(";", new[] { prediction.ToArray()[0], target.ToArray()[0], loss.ToArray()[0] })); weightsWriter.WriteLine(string.Join(";", variablesList.Select(v => v.numpy().ToArray()[0]))); treeGradsWriter.WriteLine(string.Join(";", tape.gradient(prediction, variablesList).Select(t => t.ToArray()[0]))); lossGradsWriter.WriteLine(string.Join(";", tape.gradient(loss, variablesList).Select(t => t.ToArray()[0]))); #endif //break; optimizer.apply_gradients(zip(gradients, variablesList)); } //} var cloner = new Cloner(); var newTree = cloner.Clone(tree); var newConstants = variables.ToDictionary( kvp => (ISymbolicExpressionTreeNode)cloner.GetClone(kvp.Key), kvp => kvp.Value.Select(x => (double)(x.numpy().ToArray()[0])).ToArray() ); UpdateConstants(newTree, newConstants); //var numRows = rows.Count(); //var variablesFeed = new Hashtable(); //foreach (var kvp in inputFeatures) { // var variableName = kvp.Key; // var variablePlaceholder = kvp.Value; // if (problemData.Dataset.VariableHasType(variableName)) { // var data = problemData.Dataset.GetDoubleValues(variableName, rows).Select(x => (float)x).ToArray(); // variablesFeed.Add(variablePlaceholder, np.array(data).reshape(new Shape(numRows, 1))); // } else if (problemData.Dataset.VariableHasType(variableName)) { // var data = problemData.Dataset.GetDoubleVectorValues(variableName, rows).SelectMany(x => x.Select(y => (float)y)).ToArray(); // variablesFeed.Add(variablePlaceholder, np.array(data).reshape(new Shape(numRows, -1))); // } else // throw new NotSupportedException($"Type of the variable is not supported: {variableName}"); //} //var targetData = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(x => (float)x).ToArray(); //variablesFeed.Add(target, np.array(targetData)); //using var session = tf.Session(); //var loss2 = tf.constant(1.23f, TF_DataType.TF_FLOAT); //var graphOptimizer = tf.train.AdamOptimizer((float)learningRate); //var minimizationOperations = graphOptimizer.minimize(loss2); //var init = tf.global_variables_initializer(); //session.run(init); //session.run((minimizationOperations, loss2), variablesFeed); return newTree; //#if EXPORT_GRAPH // //https://github.com/SciSharp/TensorFlow.NET/wiki/Debugging // tf.train.export_meta_graph(@"C:\temp\TFboard\graph.meta", as_text: false, // clear_devices: true, clear_extraneous_savers: false, strip_default_attrs: true); //#endif // //// features as feed items // //var variablesFeed = new Hashtable(); // //foreach (var kvp in parameters) { // // var variable = kvp.Key; // // var variableName = kvp.Value; // // if (problemData.Dataset.VariableHasType(variableName)) { // // var data = problemData.Dataset.GetDoubleValues(variableName, rows).Select(x => (float)x).ToArray(); // // variablesFeed.Add(variable, np.array(data).reshape(new Shape(numRows, 1))); // // } else if (problemData.Dataset.VariableHasType(variableName)) { // // var data = problemData.Dataset.GetDoubleVectorValues(variableName, rows).SelectMany(x => x.Select(y => (float)y)).ToArray(); // // variablesFeed.Add(variable, np.array(data).reshape(new Shape(numRows, -1))); // // } else // // throw new NotSupportedException($"Type of the variable is not supported: {variableName}"); // //} // //var targetData = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(x => (float)x).ToArray(); // //variablesFeed.Add(target, np.array(targetData)); } private static void UpdateConstants(ISymbolicExpressionTree tree, Dictionary constants) { foreach (var kvp in constants) { var node = kvp.Key; var value = kvp.Value; switch (node) { case ConstantTreeNode constantTreeNode: constantTreeNode.Value = value[0]; break; case VariableTreeNodeBase variableTreeNodeBase: variableTreeNodeBase.Weight = value[0]; break; case FactorVariableTreeNode factorVarTreeNode: { for (int i = 0; i < factorVarTreeNode.Weights.Length; i++) { factorVarTreeNode.Weights[i] = value[i]; } break; } } } } public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) { return TreeToTensorConverter.IsCompatible(tree); } } }