#region License Information
/* HeuristicLab
* Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
*
* This file is part of HeuristicLab.
*
* HeuristicLab is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* HeuristicLab is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with HeuristicLab. If not, see .
*/
#endregion
//#define EXPORT_GRAPH
//#define LOG_CONSOLE
//#define LOG_FILE
using System;
using System.Collections;
using System.Collections.Generic;
#if LOG_CONSOLE
using System.Diagnostics;
#endif
#if LOG_FILE
using System.Globalization;
using System.IO;
#endif
using System.Linq;
using System.Threading;
using HeuristicLab.Common;
using HeuristicLab.Core;
using HeuristicLab.Data;
using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
using HeuristicLab.Parameters;
using HEAL.Attic;
using Tensorflow;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector;
namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
[StorableType("63944BF6-62E5-4BE4-974C-D30AD8770F99")]
[Item("TensorFlowConstantOptimizationEvaluator", "")]
public class TensorFlowConstantOptimizationEvaluator : SymbolicRegressionConstantOptimizationEvaluator {
private const string MaximumIterationsName = "MaximumIterations";
private const string LearningRateName = "LearningRate";
//private static readonly TF_DataType DataType = tf.float64;
//private static readonly TF_DataType DataType = tf.float32;
#region Parameter Properties
public IFixedValueParameter ConstantOptimizationIterationsParameter {
get { return (IFixedValueParameter)Parameters[MaximumIterationsName]; }
}
public IFixedValueParameter LearningRateParameter {
get { return (IFixedValueParameter)Parameters[LearningRateName]; }
}
#endregion
#region Properties
public int ConstantOptimizationIterations {
get { return ConstantOptimizationIterationsParameter.Value.Value; }
}
public double LearningRate {
get { return LearningRateParameter.Value.Value; }
}
#endregion
public TensorFlowConstantOptimizationEvaluator()
: base() {
Parameters.Add(new FixedValueParameter(MaximumIterationsName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree(0 indicates other or default stopping criterion).", new IntValue(10)));
Parameters.Add(new FixedValueParameter(LearningRateName, "", new DoubleValue(0.001)));
}
protected TensorFlowConstantOptimizationEvaluator(TensorFlowConstantOptimizationEvaluator original, Cloner cloner)
: base(original, cloner) { }
public override IDeepCloneable Clone(Cloner cloner) {
return new TensorFlowConstantOptimizationEvaluator(this, cloner);
}
[StorableConstructor]
protected TensorFlowConstantOptimizationEvaluator(StorableConstructorFlag _) : base(_) { }
protected override ISymbolicExpressionTree OptimizeConstants(
ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable rows,
CancellationToken cancellationToken = default(CancellationToken), EvaluationsCounter counter = null) {
return OptimizeTree(tree,
problemData, rows,
ApplyLinearScalingParameter.ActualValue.Value, UpdateVariableWeights,
ConstantOptimizationIterations, LearningRate,
cancellationToken);
}
public static ISymbolicExpressionTree OptimizeTree(ISymbolicExpressionTree tree,
IRegressionProblemData problemData, IEnumerable rows,
bool applyLinearScaling, bool updateVariableWeights, int maxIterations, double learningRate,
CancellationToken cancellationToken = default(CancellationToken), IProgress progress = null) {
const bool eager = true;
#if LOG_FILE
var directoryName = $"C:\\temp\\TFboard\\logdir\\TF_{DateTime.Now.ToString("yyyyMMddHHmmss")}_{maxIterations}_{learningRate.ToString(CultureInfo.InvariantCulture)}";
Directory.CreateDirectory(directoryName);
using var predictionTargetLossWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "PredictionTargetLos.csv")));
using var weightsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Weights.csv")));
using var treeGradsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "TreeGrads.csv")));
using var lossGradsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "LossGrads.csv")));
predictionTargetLossWriter.WriteLine(string.Join(";", "Prediction", "Target", "Loss"));
weightsWriter.WriteLine(string.Join(";", Enumerable.Range(0, 4).Select(i => $"w_{i}")));
treeGradsWriter.WriteLine(string.Join(";", Enumerable.Range(0, 4).Select(i => $"Tg_{i}")));
lossGradsWriter.WriteLine(string.Join(";", Enumerable.Range(0, 4).Select(i => $"Lg_{i}")));
#endif
//foreach (var row in rows) {
bool prepared = TreeToTensorConverter.TryPrepareTree(
tree,
problemData, rows.ToList(),
//problemData, new List(){ row },
updateVariableWeights, applyLinearScaling,
eager,
out Dictionary inputFeatures, out Tensor target,
out Dictionary variables);
if (!prepared)
return (ISymbolicExpressionTree)tree.Clone();
var optimizer = keras.optimizers.Adam((float)learningRate);
for (int i = 0; i < maxIterations; i++) {
if (cancellationToken.IsCancellationRequested) break;
#if LOG_FILE || LOG_CONSOLE
using var tape = tf.GradientTape(persistent: true);
#else
using var tape = tf.GradientTape(persistent: false);
#endif
bool success = TreeToTensorConverter.TryEvaluate(
tree,
inputFeatures, variables,
updateVariableWeights, applyLinearScaling,
eager,
out Tensor prediction);
if (!success)
return (ISymbolicExpressionTree)tree.Clone();
var loss = tf.reduce_mean(tf.square(target - prediction));
progress?.Report(loss.ToArray()[0]);
var variablesList = variables.Values.SelectMany(x => x).ToList();
var gradients = tape.gradient(loss, variablesList);
#if LOG_FILE
predictionTargetLossWriter.WriteLine(string.Join(";", new[] { prediction.ToArray()[0], target.ToArray()[0], loss.ToArray()[0] }));
weightsWriter.WriteLine(string.Join(";", variablesList.Select(v => v.numpy().ToArray()[0])));
treeGradsWriter.WriteLine(string.Join(";", tape.gradient(prediction, variablesList).Select(t => t.ToArray()[0])));
lossGradsWriter.WriteLine(string.Join(";", tape.gradient(loss, variablesList).Select(t => t.ToArray()[0])));
#endif
//break;
optimizer.apply_gradients(zip(gradients, variablesList));
}
//}
var cloner = new Cloner();
var newTree = cloner.Clone(tree);
var newConstants = variables.ToDictionary(
kvp => (ISymbolicExpressionTreeNode)cloner.GetClone(kvp.Key),
kvp => kvp.Value.Select(x => (double)(x.numpy().ToArray()[0])).ToArray()
);
UpdateConstants(newTree, newConstants);
//var numRows = rows.Count();
//var variablesFeed = new Hashtable();
//foreach (var kvp in inputFeatures) {
// var variableName = kvp.Key;
// var variablePlaceholder = kvp.Value;
// if (problemData.Dataset.VariableHasType(variableName)) {
// var data = problemData.Dataset.GetDoubleValues(variableName, rows).Select(x => (float)x).ToArray();
// variablesFeed.Add(variablePlaceholder, np.array(data).reshape(new Shape(numRows, 1)));
// } else if (problemData.Dataset.VariableHasType(variableName)) {
// var data = problemData.Dataset.GetDoubleVectorValues(variableName, rows).SelectMany(x => x.Select(y => (float)y)).ToArray();
// variablesFeed.Add(variablePlaceholder, np.array(data).reshape(new Shape(numRows, -1)));
// } else
// throw new NotSupportedException($"Type of the variable is not supported: {variableName}");
//}
//var targetData = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(x => (float)x).ToArray();
//variablesFeed.Add(target, np.array(targetData));
//using var session = tf.Session();
//var loss2 = tf.constant(1.23f, TF_DataType.TF_FLOAT);
//var graphOptimizer = tf.train.AdamOptimizer((float)learningRate);
//var minimizationOperations = graphOptimizer.minimize(loss2);
//var init = tf.global_variables_initializer();
//session.run(init);
//session.run((minimizationOperations, loss2), variablesFeed);
return newTree;
//#if EXPORT_GRAPH
// //https://github.com/SciSharp/TensorFlow.NET/wiki/Debugging
// tf.train.export_meta_graph(@"C:\temp\TFboard\graph.meta", as_text: false,
// clear_devices: true, clear_extraneous_savers: false, strip_default_attrs: true);
//#endif
// //// features as feed items
// //var variablesFeed = new Hashtable();
// //foreach (var kvp in parameters) {
// // var variable = kvp.Key;
// // var variableName = kvp.Value;
// // if (problemData.Dataset.VariableHasType(variableName)) {
// // var data = problemData.Dataset.GetDoubleValues(variableName, rows).Select(x => (float)x).ToArray();
// // variablesFeed.Add(variable, np.array(data).reshape(new Shape(numRows, 1)));
// // } else if (problemData.Dataset.VariableHasType(variableName)) {
// // var data = problemData.Dataset.GetDoubleVectorValues(variableName, rows).SelectMany(x => x.Select(y => (float)y)).ToArray();
// // variablesFeed.Add(variable, np.array(data).reshape(new Shape(numRows, -1)));
// // } else
// // throw new NotSupportedException($"Type of the variable is not supported: {variableName}");
// //}
// //var targetData = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(x => (float)x).ToArray();
// //variablesFeed.Add(target, np.array(targetData));
}
private static void UpdateConstants(ISymbolicExpressionTree tree, Dictionary constants) {
foreach (var kvp in constants) {
var node = kvp.Key;
var value = kvp.Value;
switch (node) {
case ConstantTreeNode constantTreeNode:
constantTreeNode.Value = value[0];
break;
case VariableTreeNodeBase variableTreeNodeBase:
variableTreeNodeBase.Weight = value[0];
break;
case FactorVariableTreeNode factorVarTreeNode: {
for (int i = 0; i < factorVarTreeNode.Weights.Length; i++) {
factorVarTreeNode.Weights[i] = value[i];
}
break;
}
}
}
}
public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) {
return TreeToTensorConverter.IsCompatible(tree);
}
}
}