Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
03/22/22 13:28:56 (3 years ago)
Author:
pfleck
Message:

#3040 Updated to newer TensorFlow.NET version.

  • Removed IL Merge from TensorFlow.NET.
  • Temporarily removed DiffSharp.
  • Changed to a locally built Attic with a specific Protobuf version that is compatible with TensorFlow.NET. (Also adapted other versions of nuget dependencies.)
Location:
branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4
Files:
4 edited

Legend:

Unmodified
Added
Removed
  • branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression-3.4.csproj

    r17930 r18239  
    4747    <CodeAnalysisRuleSet>AllRules.ruleset</CodeAnalysisRuleSet>
    4848    <Prefer32Bit>false</Prefer32Bit>
    49     <LangVersion>7</LangVersion>
     49    <LangVersion>latest</LangVersion>
    5050  </PropertyGroup>
    5151  <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|AnyCPU' ">
     
    5858    <CodeAnalysisRuleSet>AllRules.ruleset</CodeAnalysisRuleSet>
    5959    <Prefer32Bit>false</Prefer32Bit>
    60     <LangVersion>7</LangVersion>
     60    <LangVersion>latest</LangVersion>
    6161  </PropertyGroup>
    6262  <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|x64' ">
     
    6969    <CodeAnalysisRuleSet>AllRules.ruleset</CodeAnalysisRuleSet>
    7070    <Prefer32Bit>false</Prefer32Bit>
    71     <LangVersion>7</LangVersion>
     71    <LangVersion>latest</LangVersion>
    7272  </PropertyGroup>
    7373  <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|x64' ">
     
    8080    <CodeAnalysisRuleSet>AllRules.ruleset</CodeAnalysisRuleSet>
    8181    <Prefer32Bit>false</Prefer32Bit>
    82     <LangVersion>7</LangVersion>
     82    <LangVersion>latest</LangVersion>
    8383  </PropertyGroup>
    8484  <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|x86' ">
     
    9191    <CodeAnalysisRuleSet>AllRules.ruleset</CodeAnalysisRuleSet>
    9292    <Prefer32Bit>false</Prefer32Bit>
    93     <LangVersion>7</LangVersion>
     93    <LangVersion>latest</LangVersion>
    9494  </PropertyGroup>
    9595  <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|x86' ">
     
    102102    <CodeAnalysisRuleSet>AllRules.ruleset</CodeAnalysisRuleSet>
    103103    <Prefer32Bit>false</Prefer32Bit>
    104     <LangVersion>7</LangVersion>
     104    <LangVersion>latest</LangVersion>
    105105  </PropertyGroup>
    106106  <ItemGroup>
     
    109109      <HintPath>..\..\bin\ALGLIB-3.7.0.dll</HintPath>
    110110      <Private>False</Private>
    111     </Reference>
    112     <Reference Include="DiffSharp.Merged, Version=0.8.4.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=AMD64">
    113       <SpecificVersion>False</SpecificVersion>
    114       <HintPath>..\..\bin\DiffSharp.Merged.dll</HintPath>
    115111    </Reference>
    116112    <Reference Include="MathNet.Numerics">
     
    133129    <Reference Include="System.Data" />
    134130    <Reference Include="System.Xml" />
    135     <Reference Include="TensorFlow.NET.Merged, Version=0.15.0.0, Culture=neutral, PublicKeyToken=ba48961d6f65dcec, processorArchitecture=MSIL">
     131    <Reference Include="Tensorflow.Binding, Version=0.70.1.0, Culture=neutral, PublicKeyToken=cc7b13ffcd2ddd51, processorArchitecture=AMD64">
    136132      <SpecificVersion>False</SpecificVersion>
    137       <HintPath>..\..\bin\TensorFlow.NET.Merged.dll</HintPath>
     133      <HintPath>..\..\bin\Tensorflow.Binding.dll</HintPath>
     134    </Reference>
     135    <Reference Include="Tensorflow.Keras, Version=0.7.0.0, Culture=neutral, PublicKeyToken=cc7b13ffcd2ddd51, processorArchitecture=AMD64">
     136      <SpecificVersion>False</SpecificVersion>
     137      <HintPath>..\..\bin\Tensorflow.Keras.dll</HintPath>
    138138    </Reference>
    139139  </ItemGroup>
  • branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/Plugin.cs.frame

    r17786 r18239  
    4343  [PluginDependency("HeuristicLab.MathNet.Numerics", "4.9.0")]
    4444  [PluginDependency("HeuristicLab.TensorFlowNet", "0.15.0")]
    45   [PluginDependency("HeuristicLab.DiffSharp", "0.7.7")]
     45  //[PluginDependency("HeuristicLab.DiffSharp", "0.7.7")]
    4646  public class HeuristicLabProblemsDataAnalysisSymbolicRegressionPlugin : PluginBase {
    4747  }
  • branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/NonlinearLeastSquaresVectorConstantOptimizationEvaluator.cs

    r17930 r18239  
    1919 */
    2020#endregion
     21
     22#if INCLUDE_DIFFSHARP
    2123
    2224using System;
     
    194196  }
    195197}
     198
     199#endif
  • branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/TensorFlowConstantOptimizationEvaluator.cs

    r17721 r18239  
    2525
    2626using System;
    27 using System.Collections;
    2827using System.Collections.Generic;
    2928#if LOG_CONSOLE
     
    4241using HeuristicLab.Parameters;
    4342using HEAL.Attic;
    44 using NumSharp;
    4543using Tensorflow;
     44using Tensorflow.NumPy;
    4645using static Tensorflow.Binding;
     46using static Tensorflow.KerasApi;
    4747using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector<double>;
    4848
     
    5454    private const string LearningRateName = "LearningRate";
    5555
     56    //private static readonly TF_DataType DataType = tf.float64;
    5657    private static readonly TF_DataType DataType = tf.float32;
    5758
     
    105106      CancellationToken cancellationToken = default(CancellationToken), IProgress<double> progress = null) {
    106107
    107       int numRows = rows.Count();
    108       var variableLengths = problemData.AllowedInputVariables.ToDictionary(
    109         var => var,
    110         var => {
    111           if (problemData.Dataset.VariableHasType<double>(var)) return 1;
    112           if (problemData.Dataset.VariableHasType<DoubleVector>(var)) return problemData.Dataset.GetDoubleVectorValue(var, 0).Count;
    113           throw new NotSupportedException($"Type of variable {var} is not supported.");
    114         });
    115 
    116       bool success = TreeToTensorConverter.TryConvert(tree,
    117         numRows, variableLengths,
     108      const bool eager = true;
     109
     110     bool prepared = TreeToTensorConverter.TryPrepareTree(
     111        tree,
     112        problemData, rows.ToList(),
    118113        updateVariableWeights, applyLinearScaling,
    119         out Tensor prediction,
    120         out Dictionary<Tensor, string> parameters, out List<Tensor> variables/*, out double[] initialConstants*/);
    121 
    122       if (!success)
     114        eager,
     115        out Dictionary<string, Tensor> inputFeatures, out Tensor target,
     116        out Dictionary<ISymbolicExpressionTreeNode, ResourceVariable[]> variables);
     117      if (!prepared)
    123118        return (ISymbolicExpressionTree)tree.Clone();
    124119
    125       var target = tf.placeholder(DataType, new TensorShape(numRows), name: problemData.TargetVariable);
    126       // MSE
    127       var cost = tf.reduce_mean(tf.square(target - prediction));
    128 
    129       var optimizer = tf.train.AdamOptimizer((float)learningRate);
    130       //var optimizer = tf.train.GradientDescentOptimizer((float)learningRate);
    131       var optimizationOperation = optimizer.minimize(cost);
    132 
    133 #if EXPORT_GRAPH
    134       //https://github.com/SciSharp/TensorFlow.NET/wiki/Debugging
    135       tf.train.export_meta_graph(@"C:\temp\TFboard\graph.meta", as_text: false,
    136         clear_devices: true, clear_extraneous_savers: false, strip_default_attrs: true);
    137 #endif
    138 
    139       // features as feed items
    140       var variablesFeed = new Hashtable();
    141       foreach (var kvp in parameters) {
    142         var variable = kvp.Key;
    143         var variableName = kvp.Value;
    144         if (problemData.Dataset.VariableHasType<double>(variableName)) {
    145           var data = problemData.Dataset.GetDoubleValues(variableName, rows).Select(x => (float)x).ToArray();
    146           variablesFeed.Add(variable, np.array(data).reshape(numRows, 1));
    147         } else if (problemData.Dataset.VariableHasType<DoubleVector>(variableName)) {
    148           var data = problemData.Dataset.GetDoubleVectorValues(variableName, rows).Select(x => x.Select(y => (float)y).ToArray()).ToArray();
    149           variablesFeed.Add(variable, np.array(data));
    150         } else
    151           throw new NotSupportedException($"Type of the variable is not supported: {variableName}");
     120      var optimizer = keras.optimizers.Adam((float)learningRate);
     121     
     122      for (int i = 0; i < maxIterations; i++) {
     123        if (cancellationToken.IsCancellationRequested) break;
     124
     125        using var tape = tf.GradientTape();
     126
     127        bool success = TreeToTensorConverter.TryEvaluate(
     128          tree,
     129          inputFeatures, variables,
     130          updateVariableWeights, applyLinearScaling,
     131          eager,
     132          out Tensor prediction);
     133        if (!success)
     134          return (ISymbolicExpressionTree)tree.Clone();
     135
     136        var loss = tf.reduce_mean(tf.square(target - prediction));
     137
     138        progress?.Report(loss.ToArray<float>()[0]);
     139
     140        var variablesList = variables.Values.SelectMany(x => x).ToList();
     141        var gradients = tape.gradient(loss, variablesList);
     142       
     143        optimizer.apply_gradients(zip(gradients, variablesList));
    152144      }
    153       var targetData = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(x => (float)x).ToArray();
    154       variablesFeed.Add(target, np.array(targetData));
    155 
    156 
    157       List<NDArray> constants;
    158       using (var session = tf.Session()) {
    159 
    160 #if LOG_FILE
    161         var directoryName = $"C:\\temp\\TFboard\\logdir\\manual_{DateTime.Now.ToString("yyyyMMddHHmmss")}_{maxIterations}_{learningRate.ToString(CultureInfo.InvariantCulture)}";
    162         Directory.CreateDirectory(directoryName);
    163         var costsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Costs.csv")));
    164         var weightsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Weights.csv")));
    165         var gradientsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Gradients.csv")));
    166 #endif
    167 
    168 #if LOG_CONSOLE || LOG_FILE
    169         var gradients = optimizer.compute_gradients(cost);
    170 #endif
    171 
    172         session.run(tf.global_variables_initializer());
    173 
    174         progress?.Report(session.run(cost, variablesFeed)[0].GetValue<float>(0));
    175 
    176 
    177 #if LOG_CONSOLE
    178         Trace.WriteLine("Costs:");
    179         Trace.WriteLine($"MSE: {session.run(cost, variablesFeed)[0].ToString(true)}");
    180 
    181         Trace.WriteLine("Weights:");
    182         foreach (var v in variables) {
    183           Trace.WriteLine($"{v.name}: {session.run(v).ToString(true)}");
    184         }
    185 
    186         Trace.WriteLine("Gradients:");
    187         foreach (var t in gradients) {
    188           Trace.WriteLine($"{t.Item2.name}: {session.run(t.Item1, variablesFeed)[0].ToString(true)}");
    189         }
    190 #endif
    191 
    192 #if LOG_FILE
    193         costsWriter.WriteLine("MSE");
    194         costsWriter.WriteLine(session.run(cost, variablesFeed)[0].GetValue<float>(0).ToString(CultureInfo.InvariantCulture));
    195 
    196         weightsWriter.WriteLine(string.Join(";", variables.Select(v => v.name)));
    197         weightsWriter.WriteLine(string.Join(";", variables.Select(v => session.run(v).GetValue<float>(0, 0).ToString(CultureInfo.InvariantCulture))));
    198 
    199         gradientsWriter.WriteLine(string.Join(";", gradients.Select(t => t.Item2.name)));
    200         gradientsWriter.WriteLine(string.Join(";", gradients.Select(t => session.run(t.Item1, variablesFeed)[0].GetValue<float>(0, 0).ToString(CultureInfo.InvariantCulture))));
    201 #endif
    202 
    203         for (int i = 0; i < maxIterations; i++) {
    204           if (cancellationToken.IsCancellationRequested)
     145     
     146      var cloner = new Cloner();
     147      var newTree = cloner.Clone(tree);
     148      var newConstants = variables.ToDictionary(
     149        kvp => (ISymbolicExpressionTreeNode)cloner.GetClone(kvp.Key),
     150        kvp => kvp.Value.Select(x => (double)(x.numpy().ToArray<float>()[0])).ToArray()
     151      );
     152      UpdateConstants(newTree, newConstants);
     153
     154
     155      return newTree;
     156
     157
     158     
     159
     160
     161//      //int numRows = rows.Count();
     162
     163
     164
     165
     166
     167
     168//      var variableLengths = problemData.AllowedInputVariables.ToDictionary(
     169//        var => var,
     170//        var => {
     171//          if (problemData.Dataset.VariableHasType<double>(var)) return 1;
     172//          if (problemData.Dataset.VariableHasType<DoubleVector>(var)) return problemData.Dataset.GetDoubleVectorValue(var, 0).Count;
     173//          throw new NotSupportedException($"Type of variable {var} is not supported.");
     174//        });
     175
     176//      var variablesDict = problemData.AllowedInputVariables.ToDictionary(
     177//        var => var,
     178//        var => {
     179//          if (problemData.Dataset.VariableHasType<double>(var)) {
     180//            var data = problemData.Dataset.GetDoubleValues(var, rows).Select(x => (float)x).ToArray();
     181//            return tf.convert_to_tensor(np.array(data).reshape(new Shape(numRows, 1)), DataType);
     182//          } else if (problemData.Dataset.VariableHasType<DoubleVector>(var)) {
     183//            var data = problemData.Dataset.GetDoubleVectorValues(var, rows).SelectMany(x => x.Select(y => (float)y)).ToArray();
     184//            return tf.convert_to_tensor(np.array(data).reshape(new Shape(numRows, -1)), DataType);
     185//          } else  throw new NotSupportedException($"Type of the variable is not supported: {var}");
     186//        }
     187//      );
     188
     189//      using var tape = tf.GradientTape(persistent: true);
     190
     191//      bool success = TreeToTensorConverter.TryEvaluateEager(tree,
     192//        numRows, variablesDict,
     193//        updateVariableWeights, applyLinearScaling,
     194//        out Tensor prediction,
     195//        out Dictionary<Tensor, string> parameters, out List<ResourceVariable> variables);
     196
     197//      //bool success = TreeToTensorConverter.TryConvert(tree,
     198//      //  numRows, variableLengths,
     199//      //  updateVariableWeights, applyLinearScaling,
     200//      //  out Tensor prediction,
     201//      //  out Dictionary<Tensor, string> parameters, out List<Tensor> variables);
     202
     203//      if (!success)
     204//        return (ISymbolicExpressionTree)tree.Clone();
     205
     206//      //var target = tf.placeholder(DataType, new Shape(numRows), name: problemData.TargetVariable);
     207//      var targetData = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(x => (float)x).ToArray();
     208//      var target = tf.convert_to_tensor(np.array(targetData).reshape(new Shape(numRows)), DataType);
     209//      // MSE
     210//      var cost = tf.reduce_sum(tf.square(prediction - target));
     211
     212//      tape.watch(cost);
     213
     214//      //var optimizer = tf.train.AdamOptimizer((float)learningRate);
     215//      //var optimizer = tf.train.AdamOptimizer(tf.constant(learningRate, DataType));
     216//      //var optimizer = tf.train.GradientDescentOptimizer((float)learningRate);
     217//      //var optimizer = tf.train.GradientDescentOptimizer(tf.constant(learningRate, DataType));
     218//      //var optimizer = tf.train.GradientDescentOptimizer((float)learningRate);
     219//      //var optimizer = tf.train.AdamOptimizer((float)learningRate);
     220//      //var optimizationOperation = optimizer.minimize(cost);
     221//      var optimizer = keras.optimizers.Adam((float)learningRate);
     222
     223//      #if EXPORT_GRAPH
     224//      //https://github.com/SciSharp/TensorFlow.NET/wiki/Debugging
     225//      tf.train.export_meta_graph(@"C:\temp\TFboard\graph.meta", as_text: false,
     226//        clear_devices: true, clear_extraneous_savers: false, strip_default_attrs: true);
     227//#endif
     228
     229//      //// features as feed items
     230//      //var variablesFeed = new Hashtable();
     231//      //foreach (var kvp in parameters) {
     232//      //  var variable = kvp.Key;
     233//      //  var variableName = kvp.Value;
     234//      //  if (problemData.Dataset.VariableHasType<double>(variableName)) {
     235//      //    var data = problemData.Dataset.GetDoubleValues(variableName, rows).Select(x => (float)x).ToArray();
     236//      //    variablesFeed.Add(variable, np.array(data).reshape(new Shape(numRows, 1)));
     237//      //  } else if (problemData.Dataset.VariableHasType<DoubleVector>(variableName)) {
     238//      //    var data = problemData.Dataset.GetDoubleVectorValues(variableName, rows).SelectMany(x => x.Select(y => (float)y)).ToArray();
     239//      //    variablesFeed.Add(variable, np.array(data).reshape(new Shape(numRows, -1)));
     240//      //  } else
     241//      //    throw new NotSupportedException($"Type of the variable is not supported: {variableName}");
     242//      //}
     243//      //var targetData = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(x => (float)x).ToArray();
     244//      //variablesFeed.Add(target, np.array(targetData));
     245
     246
     247//      List<NDArray> constants;
     248//      //using (var session = tf.Session()) {
     249
     250//#if LOG_FILE
     251//        var directoryName = $"C:\\temp\\TFboard\\logdir\\manual_{DateTime.Now.ToString("yyyyMMddHHmmss")}_{maxIterations}_{learningRate.ToString(CultureInfo.InvariantCulture)}";
     252//        Directory.CreateDirectory(directoryName);
     253//        var costsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Costs.csv")));
     254//        var weightsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Weights.csv")));
     255//        var gradientsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Gradients.csv")));
     256//#endif
     257
     258//      //session.run(tf.global_variables_initializer());
     259
     260//#if LOG_CONSOLE || LOG_FILE
     261//        var gradients = optimizer.compute_gradients(cost);
     262//#endif
     263
     264//      //var vars = variables.Select(v => session.run(v, variablesFeed)[0].ToArray<float>()[0]).ToList();
     265//      //var gradient = optimizer.compute_gradients(cost)
     266//      //  .Where(g => g.Item1 != null)
     267//      //  //.Select(g => session.run(g.Item1, variablesFeed)[0].GetValue<float>(0)).
     268//      //  .Select(g => session.run(g.Item1, variablesFeed)[0].ToArray<float>()[0])
     269//      //  .ToList();
     270
     271//      //var gradientPrediction = optimizer.compute_gradients(prediction)
     272//      //  .Where(g => g.Item1 != null)
     273//      //  .Select(g => session.run(g.Item1, variablesFeed)[0].ToArray<float>()[0])
     274//      //  .ToList();
     275
     276
     277//      //progress?.Report(session.run(cost, variablesFeed)[0].ToArray<float>()[0]);
     278//      progress?.Report(cost.ToArray<float>()[0]);
     279
     280
     281     
     282
     283
     284//#if LOG_CONSOLE
     285//        Trace.WriteLine("Costs:");
     286//        Trace.WriteLine($"MSE: {session.run(cost, variablesFeed)[0].ToString(true)}");
     287
     288//        Trace.WriteLine("Weights:");
     289//        foreach (var v in variables) {
     290//          Trace.WriteLine($"{v.name}: {session.run(v).ToString(true)}");
     291//        }
     292
     293//        Trace.WriteLine("Gradients:");
     294//        foreach (var t in gradients) {
     295//          Trace.WriteLine($"{t.Item2.name}: {session.run(t.Item1, variablesFeed)[0].ToString(true)}");
     296//        }
     297//#endif
     298
     299//#if LOG_FILE
     300//        costsWriter.WriteLine("MSE");
     301//        costsWriter.WriteLine(session.run(cost, variablesFeed)[0].ToArray<float>()[0].ToString(CultureInfo.InvariantCulture));
     302
     303//        weightsWriter.WriteLine(string.Join(";", variables.Select(v => v.name)));
     304//        weightsWriter.WriteLine(string.Join(";", variables.Select(v => session.run(v).ToArray<float>()[0].ToString(CultureInfo.InvariantCulture))));
     305
     306//        gradientsWriter.WriteLine(string.Join(";", gradients.Select(t => t.Item2.Name)));
     307//        gradientsWriter.WriteLine(string.Join(";", gradients.Select(t => session.run(t.Item1, variablesFeed)[0].ToArray<float>()[0].ToString(CultureInfo.InvariantCulture))));
     308//#endif
     309
     310//        for (int i = 0; i < maxIterations; i++) {
     311//          if (cancellationToken.IsCancellationRequested)
     312//            break;
     313
     314         
     315//        var gradients = tape.gradient(cost, variables);
     316//        //optimizer.apply_gradients(gradients.Zip(variables, Tuple.Create<Tensor, IVariableV1>).ToArray());
     317//        optimizer.apply_gradients(zip(gradients, variables));
     318       
     319
     320//        //session.run(optimizationOperation, variablesFeed);
     321
     322//        progress?.Report(cost.ToArray<float>()[0]);
     323//        //progress?.Report(session.run(cost, variablesFeed)[0].ToArray<float>()[0]);
     324
     325//#if LOG_CONSOLE
     326//          Trace.WriteLine("Costs:");
     327//          Trace.WriteLine($"MSE: {session.run(cost, variablesFeed)[0].ToString(true)}");
     328
     329//          Trace.WriteLine("Weights:");
     330//          foreach (var v in variables) {
     331//            Trace.WriteLine($"{v.name}: {session.run(v).ToString(true)}");
     332//          }
     333
     334//          Trace.WriteLine("Gradients:");
     335//          foreach (var t in gradients) {
     336//            Trace.WriteLine($"{t.Item2.name}: {session.run(t.Item1, variablesFeed)[0].ToString(true)}");
     337//          }
     338//#endif
     339
     340//#if LOG_FILE
     341//          costsWriter.WriteLine(session.run(cost, variablesFeed)[0].ToArray<float>()[0].ToString(CultureInfo.InvariantCulture));
     342//          weightsWriter.WriteLine(string.Join(";", variables.Select(v => session.run(v).ToArray<float>()[0].ToString(CultureInfo.InvariantCulture))));
     343//          gradientsWriter.WriteLine(string.Join(";", gradients.Select(t => session.run(t.Item1, variablesFeed)[0].ToArray<float>()[0].ToString(CultureInfo.InvariantCulture))));
     344//#endif
     345//      }
     346
     347//#if LOG_FILE
     348//        costsWriter.Close();
     349//        weightsWriter.Close();
     350//        gradientsWriter.Close();
     351//#endif
     352//      //constants = variables.Select(v => session.run(v)).ToList();
     353//      constants = variables.Select(v => v.numpy()).ToList();
     354//      //}
     355
     356//      if (applyLinearScaling)
     357//        constants = constants.Skip(2).ToList();
     358//      var newTree = (ISymbolicExpressionTree)tree.Clone();
     359//      UpdateConstants(newTree, constants, updateVariableWeights);
     360
     361//      return newTree;
     362    }
     363
     364    private static void UpdateConstants(ISymbolicExpressionTree tree, Dictionary<ISymbolicExpressionTreeNode, double[]> constants) {
     365      foreach (var kvp in constants) {
     366        var node = kvp.Key;
     367        var value = kvp.Value;
     368
     369        switch (node) {
     370          case ConstantTreeNode constantTreeNode:
     371            constantTreeNode.Value = value[0];
    205372            break;
    206 
    207           session.run(optimizationOperation, variablesFeed);
    208 
    209           progress?.Report(session.run(cost, variablesFeed)[0].GetValue<float>(0));
    210 
    211 #if LOG_CONSOLE
    212           Trace.WriteLine("Costs:");
    213           Trace.WriteLine($"MSE: {session.run(cost, variablesFeed)[0].ToString(true)}");
    214 
    215           Trace.WriteLine("Weights:");
    216           foreach (var v in variables) {
    217             Trace.WriteLine($"{v.name}: {session.run(v).ToString(true)}");
     373          case VariableTreeNodeBase variableTreeNodeBase:
     374            variableTreeNodeBase.Weight = value[0];
     375            break;
     376          case FactorVariableTreeNode factorVarTreeNode: {
     377            for (int i = 0; i < factorVarTreeNode.Weights.Length; i++) {
     378              factorVarTreeNode.Weights[i] = value[i];
     379            }
     380            break;
    218381          }
    219 
    220           Trace.WriteLine("Gradients:");
    221           foreach (var t in gradients) {
    222             Trace.WriteLine($"{t.Item2.name}: {session.run(t.Item1, variablesFeed)[0].ToString(true)}");
    223           }
    224 #endif
    225 
    226 #if LOG_FILE
    227           costsWriter.WriteLine(session.run(cost, variablesFeed)[0].GetValue<float>(0).ToString(CultureInfo.InvariantCulture));
    228           weightsWriter.WriteLine(string.Join(";", variables.Select(v => session.run(v).GetValue<float>(0, 0).ToString(CultureInfo.InvariantCulture))));
    229           gradientsWriter.WriteLine(string.Join(";", gradients.Select(t => session.run(t.Item1, variablesFeed)[0].GetValue<float>(0, 0).ToString(CultureInfo.InvariantCulture))));
    230 #endif
    231         }
    232 
    233 #if LOG_FILE
    234         costsWriter.Close();
    235         weightsWriter.Close();
    236         gradientsWriter.Close();
    237 #endif
    238         constants = variables.Select(v => session.run(v)).ToList();
    239       }
    240 
    241       if (applyLinearScaling)
    242         constants = constants.Skip(2).ToList();
    243       var newTree = (ISymbolicExpressionTree)tree.Clone();
    244       UpdateConstants(newTree, constants, updateVariableWeights);
    245 
    246       return newTree;
    247     }
    248 
    249     private static void UpdateConstants(ISymbolicExpressionTree tree, IList<NDArray> constants, bool updateVariableWeights) {
    250       int i = 0;
    251       foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
    252         if (node is ConstantTreeNode constantTreeNode)
    253           constantTreeNode.Value = constants[i++].GetValue<float>(0, 0);
    254         else if (node is VariableTreeNodeBase variableTreeNodeBase && updateVariableWeights)
    255           variableTreeNodeBase.Weight = constants[i++].GetValue<float>(0, 0);
    256         else if (node is FactorVariableTreeNode factorVarTreeNode && updateVariableWeights) {
    257           for (int j = 0; j < factorVarTreeNode.Weights.Length; j++)
    258             factorVarTreeNode.Weights[j] = constants[i++].GetValue<float>(0, 0);
    259382        }
    260383      }
    261384    }
    262385
     386    //private static void UpdateConstants(ISymbolicExpressionTree tree, IList<NDArray> constants, bool updateVariableWeights) {
     387    //  int i = 0;
     388    //  foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
     389    //    if (node is ConstantTreeNode constantTreeNode) {
     390    //      constantTreeNode.Value = constants[i++].ToArray<float>()[0];
     391    //    } else if (node is VariableTreeNodeBase variableTreeNodeBase && updateVariableWeights) {
     392    //      variableTreeNodeBase.Weight = constants[i++].ToArray<float>()[0];
     393    //    } else if (node is FactorVariableTreeNode factorVarTreeNode && updateVariableWeights) {
     394    //      for (int j = 0; j < factorVarTreeNode.Weights.Length; j++)
     395    //        factorVarTreeNode.Weights[j] = constants[i++].ToArray<float>()[0];
     396    //    }
     397    //  }
     398    //}
     399
    263400    public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) {
    264401      return TreeToTensorConverter.IsCompatible(tree);
Note: See TracChangeset for help on using the changeset viewer.