Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
04/09/20 14:01:09 (5 years ago)
Author:
pfleck
Message:

#3040

  • Switched whole TF-graph to float (Adam optimizer won't work with double).
  • Added progress and cancellation support for TF-const opt.
  • Added optional logging with console and/or file for later plotting.
Location:
branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression-3.4.csproj

    r17475 r17502  
    4747    <CodeAnalysisRuleSet>AllRules.ruleset</CodeAnalysisRuleSet>
    4848    <Prefer32Bit>false</Prefer32Bit>
     49    <LangVersion>7</LangVersion>
    4950  </PropertyGroup>
    5051  <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|AnyCPU' ">
     
    5758    <CodeAnalysisRuleSet>AllRules.ruleset</CodeAnalysisRuleSet>
    5859    <Prefer32Bit>false</Prefer32Bit>
     60    <LangVersion>7</LangVersion>
    5961  </PropertyGroup>
    6062  <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|x64' ">
     
    6769    <CodeAnalysisRuleSet>AllRules.ruleset</CodeAnalysisRuleSet>
    6870    <Prefer32Bit>false</Prefer32Bit>
     71    <LangVersion>7</LangVersion>
    6972  </PropertyGroup>
    7073  <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|x64' ">
     
    7780    <CodeAnalysisRuleSet>AllRules.ruleset</CodeAnalysisRuleSet>
    7881    <Prefer32Bit>false</Prefer32Bit>
     82    <LangVersion>7</LangVersion>
    7983  </PropertyGroup>
    8084  <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|x86' ">
     
    8791    <CodeAnalysisRuleSet>AllRules.ruleset</CodeAnalysisRuleSet>
    8892    <Prefer32Bit>false</Prefer32Bit>
     93    <LangVersion>7</LangVersion>
    8994  </PropertyGroup>
    9095  <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|x86' ">
     
    97102    <CodeAnalysisRuleSet>AllRules.ruleset</CodeAnalysisRuleSet>
    98103    <Prefer32Bit>false</Prefer32Bit>
     104    <LangVersion>7</LangVersion>
    99105  </PropertyGroup>
    100106  <ItemGroup>
     
    113119    </Reference>
    114120    <Reference Include="System.Drawing" />
     121    <Reference Include="System.Numerics" />
    115122    <Reference Include="System.Xml.Linq">
    116123      <RequiredTargetFramework>3.5</RequiredTargetFramework>
  • branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/TensorFlowConstantOptimizationEvaluator.cs

    r17493 r17502  
    2020#endregion
    2121
    22 #define EXPLICIT_SHAPE
     22//#define EXPORT_GRAPH
     23//#define LOG_CONSOLE
     24//#define LOG_FILE
    2325
    2426using System;
    2527using System.Collections;
    2628using System.Collections.Generic;
     29#if LOG_CONSOLE
    2730using System.Diagnostics;
     31#endif
     32#if LOG_FILE
     33using System.Globalization;
     34using System.IO;
     35#endif
    2836using System.Linq;
    2937using System.Threading;
     
    4654    private const string LearningRateName = "LearningRate";
    4755
    48     #region Parameter Properties
     56    private static readonly TF_DataType DataType = tf.float32;
     57
     58#region Parameter Properties
    4959    public IFixedValueParameter<IntValue> ConstantOptimizationIterationsParameter {
    5060      get { return (IFixedValueParameter<IntValue>)Parameters[MaximumIterationsName]; }
     
    5363      get { return (IFixedValueParameter<DoubleValue>)Parameters[LearningRateName]; }
    5464    }
    55     #endregion
    56 
    57     #region Properties
     65#endregion
     66
     67#region Properties
    5868    public int ConstantOptimizationIterations {
    5969      get { return ConstantOptimizationIterationsParameter.Value.Value; }
     
    6272      get { return LearningRateParameter.Value.Value; }
    6373    }
    64     #endregion
     74#endregion
    6575
    6676    public TensorFlowConstantOptimizationEvaluator()
    6777      : base() {
    6878      Parameters.Add(new FixedValueParameter<IntValue>(MaximumIterationsName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree(0 indicates other or default stopping criterion).", new IntValue(10)));
    69       Parameters.Add(new FixedValueParameter<DoubleValue>(LearningRateName, "", new DoubleValue(0.01)));
     79      Parameters.Add(new FixedValueParameter<DoubleValue>(LearningRateName, "", new DoubleValue(0.001)));
    7080    }
    7181
     
    8797        ApplyLinearScalingParameter.ActualValue.Value, UpdateVariableWeights,
    8898        ConstantOptimizationIterations, LearningRate,
    89         cancellationToken, counter);
    90     }
    91 
    92     public static ISymbolicExpressionTree OptimizeTree(
    93       ISymbolicExpressionTree tree,
     99        cancellationToken);
     100    }
     101
     102    public static ISymbolicExpressionTree OptimizeTree(ISymbolicExpressionTree tree,
    94103      IRegressionProblemData problemData, IEnumerable<int> rows,
    95104      bool applyLinearScaling, bool updateVariableWeights, int maxIterations, double learningRate,
    96       CancellationToken cancellationToken = default(CancellationToken), EvaluationsCounter counter = null) {
     105      CancellationToken cancellationToken = default(CancellationToken), IProgress<double> progress = null) {
    97106
    98107      int numRows = rows.Count();
     
    114123        return (ISymbolicExpressionTree)tree.Clone();
    115124
    116 #if EXPLICIT_SHAPE
    117       var target = tf.placeholder(tf.float64, new TensorShape(numRows, 1), name: problemData.TargetVariable);
    118 #endif
    119       // mse
    120       var costs = tf.reduce_sum(tf.square(target - prediction)) / (2.0 * numRows);
    121       var optimizer = tf.train.GradientDescentOptimizer((float)learningRate).minimize(costs);
     125      var target = tf.placeholder(DataType, new TensorShape(numRows), name: problemData.TargetVariable);
     126      // MSE
     127      var cost = tf.reduce_mean(tf.square(prediction - target));
     128
     129      var optimizer = tf.train.AdamOptimizer((float)learningRate);
     130      //var optimizer = tf.train.GradientDescentOptimizer((float)learningRate);
     131      var optimizationOperation = optimizer.minimize(cost);
     132
     133#if EXPORT_GRAPH
     134      //https://github.com/SciSharp/TensorFlow.NET/wiki/Debugging
     135      tf.train.export_meta_graph(@"C:\temp\TFboard\graph.meta", as_text: false,
     136        clear_devices: true, clear_extraneous_savers: false, strip_default_attrs: true);
     137#endif
    122138
    123139      // features as feed items
     
    127143        var variableName = kvp.Value;
    128144        if (problemData.Dataset.VariableHasType<double>(variableName)) {
    129           var data = problemData.Dataset.GetDoubleValues(variableName, rows).ToArray();
    130           //if (vectorLength.HasValue) {
    131           //  var vectorData = new double[numRows][];
    132           //  for (int i = 0; i < numRows; i++)
    133           //    vectorData[i] = Enumerable.Repeat(data[i], vectorLength.Value).ToArray();
    134           //  variablesFeed.Add(variable, np.array(vectorData));
    135           //} else
     145          var data = problemData.Dataset.GetDoubleValues(variableName, rows).Select(x => (float)x).ToArray();
    136146          variablesFeed.Add(variable, np.array(data, copy: false).reshape(numRows, 1));
    137           //} else if (problemData.Dataset.VariableHasType<string>(variableName)) {
    138           //  variablesFeed.Add(variable, problemData.Dataset.GetStringValues(variableName, rows));
    139147        } else if (problemData.Dataset.VariableHasType<DoubleVector>(variableName)) {
    140           var data = problemData.Dataset.GetDoubleVectorValues(variableName, rows).Select(x => x.ToArray()).ToArray();
     148          var data = problemData.Dataset.GetDoubleVectorValues(variableName, rows).Select(x => x.Select(y => (float)y).ToArray()).ToArray();
    141149          variablesFeed.Add(variable, np.array(data));
    142150        } else
    143151          throw new NotSupportedException($"Type of the variable is not supported: {variableName}");
    144152      }
    145       var targetData = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).ToArray();
    146       variablesFeed.Add(target, np.array(targetData, copy: false).reshape(numRows, 1));
     153      var targetData = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(x => (float)x).ToArray();
     154      variablesFeed.Add(target, np.array(targetData, copy: false));
     155
    147156
    148157      List<NDArray> constants;
    149158      using (var session = tf.Session()) {
     159
     160#if LOG_FILE
     161        var directoryName = $"C:\\temp\\TFboard\\logdir\\manual_{DateTime.Now.ToString("yyyyMMddHHmmss")}_{maxIterations}_{learningRate.ToString(CultureInfo.InvariantCulture)}";
     162        Directory.CreateDirectory(directoryName);
     163        var costsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Costs.csv")));
     164        var weightsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Weights.csv")));
     165        var gradientsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Gradients.csv")));
     166#endif
     167
     168#if LOG_CONSOLE || LOG_FILE
     169        var gradients = optimizer.compute_gradients(cost);
     170#endif
     171
    150172        session.run(tf.global_variables_initializer());
    151173
    152         // https://github.com/SciSharp/TensorFlow.NET/wiki/Debugging
    153         tf.train.export_meta_graph(@"C:\temp\TFboard\graph.meta", as_text: false);
     174        progress?.Report(session.run(cost, variablesFeed)[0].GetValue<float>(0));
     175
     176
     177#if LOG_CONSOLE
     178        Trace.WriteLine("Costs:");
     179        Trace.WriteLine($"MSE: {session.run(cost, variablesFeed)[0].ToString(true)}");
    154180
    155181        Trace.WriteLine("Weights:");
    156         foreach (var v in variables)
     182        foreach (var v in variables) {
    157183          Trace.WriteLine($"{v.name}: {session.run(v).ToString(true)}");
     184        }
     185
     186        Trace.WriteLine("Gradients:");
     187        foreach (var t in gradients) {
     188          Trace.WriteLine($"{t.Item2.name}: {session.run(t.Item1, variablesFeed)[0].ToString(true)}");
     189        }
     190#endif
     191
     192#if LOG_FILE
     193        costsWriter.WriteLine("MSE");
     194        costsWriter.WriteLine(session.run(cost, variablesFeed)[0].GetValue<float>(0).ToString(CultureInfo.InvariantCulture));
     195
     196        weightsWriter.WriteLine(string.Join(";", variables.Select(v => v.name)));
     197        weightsWriter.WriteLine(string.Join(";", variables.Select(v => session.run(v).GetValue<float>(0, 0).ToString(CultureInfo.InvariantCulture))));
     198
     199        gradientsWriter.WriteLine(string.Join(";", gradients.Select(t => t.Item2.name)));
     200        gradientsWriter.WriteLine(string.Join(";", gradients.Select(t => session.run(t.Item1, variablesFeed)[0].GetValue<float>(0, 0).ToString(CultureInfo.InvariantCulture))));
     201#endif
    158202
    159203        for (int i = 0; i < maxIterations; i++) {
    160 
    161           //optimizer.minimize(costs);
    162           session.run(optimizer, variablesFeed);
     204          if (cancellationToken.IsCancellationRequested)
     205            break;
     206
     207          session.run(optimizationOperation, variablesFeed);
     208
     209          progress?.Report(session.run(cost, variablesFeed)[0].GetValue<float>(0));
     210
     211#if LOG_CONSOLE
     212          Trace.WriteLine("Costs:");
     213          Trace.WriteLine($"MSE: {session.run(cost, variablesFeed)[0].ToString(true)}");
    163214
    164215          Trace.WriteLine("Weights:");
    165           foreach (var v in variables)
     216          foreach (var v in variables) {
    166217            Trace.WriteLine($"{v.name}: {session.run(v).ToString(true)}");
    167         }
    168 
     218          }
     219
     220          Trace.WriteLine("Gradients:");
     221          foreach (var t in gradients) {
     222            Trace.WriteLine($"{t.Item2.name}: {session.run(t.Item1, variablesFeed)[0].ToString(true)}");
     223          }
     224#endif
     225
     226#if LOG_FILE
     227          costsWriter.WriteLine(session.run(cost, variablesFeed)[0].GetValue<float>(0).ToString(CultureInfo.InvariantCulture));
     228          weightsWriter.WriteLine(string.Join(";", variables.Select(v => session.run(v).GetValue<float>(0, 0).ToString(CultureInfo.InvariantCulture))));
     229          gradientsWriter.WriteLine(string.Join(";", gradients.Select(t => session.run(t.Item1, variablesFeed)[0].GetValue<float>(0, 0).ToString(CultureInfo.InvariantCulture))));
     230#endif
     231        }
     232
     233#if LOG_FILE
     234        costsWriter.Close();
     235        weightsWriter.Close();
     236        gradientsWriter.Close();
     237#endif
    169238        constants = variables.Select(v => session.run(v)).ToList();
    170239      }
     
    182251      foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
    183252        if (node is ConstantTreeNode constantTreeNode)
    184           constantTreeNode.Value = constants[i++].GetDouble(0, 0);
     253          constantTreeNode.Value = constants[i++].GetValue<float>(0, 0);
    185254        else if (node is VariableTreeNodeBase variableTreeNodeBase && updateVariableWeights)
    186           variableTreeNodeBase.Weight = constants[i++].GetDouble(0, 0);
     255          variableTreeNodeBase.Weight = constants[i++].GetValue<float>(0, 0);
    187256        else if (node is FactorVariableTreeNode factorVarTreeNode && updateVariableWeights) {
    188257          for (int j = 0; j < factorVarTreeNode.Weights.Length; j++)
    189             factorVarTreeNode.Weights[j] = constants[i++].GetDouble(0, 0);
     258            factorVarTreeNode.Weights[j] = constants[i++].GetValue<float>(0, 0);
    190259        }
    191260      }
Note: See TracChangeset for help on using the changeset viewer.