Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
04/09/20 14:01:09 (4 years ago)
Author:
pfleck
Message:

#3040

  • Switched whole TF-graph to float (Adam optimizer won't work with double).
  • Added progress and cancellation support for TF-const opt.
  • Added optional logging with console and/or file for later plotting.
Location:
branches/3040_VectorBasedGP
Files:
4 edited

Legend:

Unmodified
Added
Removed
  • branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression.Views/3.4/InteractiveSymbolicRegressionSolutionSimplifierView.cs

    r17489 r17502  
    6464      int reps = 0;
    6565
    66       do {
    67         prevResult = result;
    68         //tree = NonlinearLeastSquaresConstantOptimizationEvaluator.OptimizeTree(tree, regressionProblemData, regressionProblemData.TrainingIndices,
    69         //  applyLinearScaling: true, maxIterations: constOptIterations, updateVariableWeights: true,
    70         //  cancellationToken: cancellationToken, iterationCallback: (args, func, obj) => {
    71         //    double newProgressValue = progress.ProgressValue + (1.0 / (constOptIterations + 2) / maxRepetitions); // (constOptIterations + 2) iterations are reported
    72         //    progress.ProgressValue = Math.Min(newProgressValue, 1.0);
    73         //  });
    74         tree = TensorFlowConstantOptimizationEvaluator.OptimizeTree(tree, regressionProblemData, regressionProblemData.TrainingIndices,
    75           applyLinearScaling: true, updateVariableWeights: true, maxIterations: 10, learningRate: 0.001);
    76         result = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(model.Interpreter, tree,
    77           model.LowerEstimationLimit, model.UpperEstimationLimit, regressionProblemData, regressionProblemData.TrainingIndices, applyLinearScaling: true);
    78         reps++;
    79         improvement = result - prevResult;
    80       } while (improvement > minimumImprovement && reps < maxRepetitions &&
    81                progress.ProgressState != ProgressState.StopRequested &&
    82                progress.ProgressState != ProgressState.CancelRequested);
    83       return tree;
     66      //do {
     67      //  prevResult = result;
     68      //  tree = NonlinearLeastSquaresConstantOptimizationEvaluator.OptimizeTree(tree, regressionProblemData, regressionProblemData.TrainingIndices,
     69      //    applyLinearScaling: true, maxIterations: constOptIterations, updateVariableWeights: true,
     70      //    cancellationToken: cancellationToken, iterationCallback: (args, func, obj) => {
     71      //      double newProgressValue = progress.ProgressValue + (1.0 / (constOptIterations + 2) / maxRepetitions); // (constOptIterations + 2) iterations are reported
     72      //      progress.ProgressValue = Math.Min(newProgressValue, 1.0);
     73      //    });
     74      //  result = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(model.Interpreter, tree,
     75      //    model.LowerEstimationLimit, model.UpperEstimationLimit, regressionProblemData, regressionProblemData.TrainingIndices, applyLinearScaling: true);
     76      //  reps++;
     77      //  improvement = result - prevResult;
     78      //} while (improvement > minimumImprovement && reps < maxRepetitions &&
     79      //         progress.ProgressState != ProgressState.StopRequested &&
     80      //         progress.ProgressState != ProgressState.CancelRequested);
     81      //return tree;
     82      const int maxIterations = 1000;
     83      return TensorFlowConstantOptimizationEvaluator.OptimizeTree(tree, regressionProblemData,
     84        regressionProblemData.TrainingIndices,
     85        applyLinearScaling: false, updateVariableWeights: true, maxIterations: maxIterations, learningRate: 0.0001,
     86        cancellationToken: cancellationToken,
     87        progress: new SynchronousProgress<double>(cost => {
     88          var newProgress = progress.ProgressValue + (1.0 / (maxIterations + 1));
     89          progress.ProgressValue = Math.Min(newProgress, 1.0);
     90          progress.Message = $"MSE: {cost}";
     91        })
     92      );
     93    }
     94
     95    internal class SynchronousProgress<T> : IProgress<T> {
     96      private readonly Action<T> callback;
     97      public SynchronousProgress(Action<T> callback) {
     98        this.callback = callback;
     99      }
     100      public void Report(T value) {
     101        callback(value);
     102      }
    84103    }
    85104  }
  • branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression-3.4.csproj

    r17475 r17502  
    4747    <CodeAnalysisRuleSet>AllRules.ruleset</CodeAnalysisRuleSet>
    4848    <Prefer32Bit>false</Prefer32Bit>
     49    <LangVersion>7</LangVersion>
    4950  </PropertyGroup>
    5051  <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|AnyCPU' ">
     
    5758    <CodeAnalysisRuleSet>AllRules.ruleset</CodeAnalysisRuleSet>
    5859    <Prefer32Bit>false</Prefer32Bit>
     60    <LangVersion>7</LangVersion>
    5961  </PropertyGroup>
    6062  <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|x64' ">
     
    6769    <CodeAnalysisRuleSet>AllRules.ruleset</CodeAnalysisRuleSet>
    6870    <Prefer32Bit>false</Prefer32Bit>
     71    <LangVersion>7</LangVersion>
    6972  </PropertyGroup>
    7073  <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|x64' ">
     
    7780    <CodeAnalysisRuleSet>AllRules.ruleset</CodeAnalysisRuleSet>
    7881    <Prefer32Bit>false</Prefer32Bit>
     82    <LangVersion>7</LangVersion>
    7983  </PropertyGroup>
    8084  <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|x86' ">
     
    8791    <CodeAnalysisRuleSet>AllRules.ruleset</CodeAnalysisRuleSet>
    8892    <Prefer32Bit>false</Prefer32Bit>
     93    <LangVersion>7</LangVersion>
    8994  </PropertyGroup>
    9095  <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Release|x86' ">
     
    97102    <CodeAnalysisRuleSet>AllRules.ruleset</CodeAnalysisRuleSet>
    98103    <Prefer32Bit>false</Prefer32Bit>
     104    <LangVersion>7</LangVersion>
    99105  </PropertyGroup>
    100106  <ItemGroup>
     
    113119    </Reference>
    114120    <Reference Include="System.Drawing" />
     121    <Reference Include="System.Numerics" />
    115122    <Reference Include="System.Xml.Linq">
    116123      <RequiredTargetFramework>3.5</RequiredTargetFramework>
  • branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/TensorFlowConstantOptimizationEvaluator.cs

    r17493 r17502  
    2020#endregion
    2121
    22 #define EXPLICIT_SHAPE
     22//#define EXPORT_GRAPH
     23//#define LOG_CONSOLE
     24//#define LOG_FILE
    2325
    2426using System;
    2527using System.Collections;
    2628using System.Collections.Generic;
     29#if LOG_CONSOLE
    2730using System.Diagnostics;
     31#endif
     32#if LOG_FILE
     33using System.Globalization;
     34using System.IO;
     35#endif
    2836using System.Linq;
    2937using System.Threading;
     
    4654    private const string LearningRateName = "LearningRate";
    4755
    48     #region Parameter Properties
     56    private static readonly TF_DataType DataType = tf.float32;
     57
     58#region Parameter Properties
    4959    public IFixedValueParameter<IntValue> ConstantOptimizationIterationsParameter {
    5060      get { return (IFixedValueParameter<IntValue>)Parameters[MaximumIterationsName]; }
     
    5363      get { return (IFixedValueParameter<DoubleValue>)Parameters[LearningRateName]; }
    5464    }
    55     #endregion
    56 
    57     #region Properties
     65#endregion
     66
     67#region Properties
    5868    public int ConstantOptimizationIterations {
    5969      get { return ConstantOptimizationIterationsParameter.Value.Value; }
     
    6272      get { return LearningRateParameter.Value.Value; }
    6373    }
    64     #endregion
     74#endregion
    6575
    6676    public TensorFlowConstantOptimizationEvaluator()
    6777      : base() {
    6878      Parameters.Add(new FixedValueParameter<IntValue>(MaximumIterationsName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree(0 indicates other or default stopping criterion).", new IntValue(10)));
    69       Parameters.Add(new FixedValueParameter<DoubleValue>(LearningRateName, "", new DoubleValue(0.01)));
     79      Parameters.Add(new FixedValueParameter<DoubleValue>(LearningRateName, "", new DoubleValue(0.001)));
    7080    }
    7181
     
    8797        ApplyLinearScalingParameter.ActualValue.Value, UpdateVariableWeights,
    8898        ConstantOptimizationIterations, LearningRate,
    89         cancellationToken, counter);
    90     }
    91 
    92     public static ISymbolicExpressionTree OptimizeTree(
    93       ISymbolicExpressionTree tree,
     99        cancellationToken);
     100    }
     101
     102    public static ISymbolicExpressionTree OptimizeTree(ISymbolicExpressionTree tree,
    94103      IRegressionProblemData problemData, IEnumerable<int> rows,
    95104      bool applyLinearScaling, bool updateVariableWeights, int maxIterations, double learningRate,
    96       CancellationToken cancellationToken = default(CancellationToken), EvaluationsCounter counter = null) {
     105      CancellationToken cancellationToken = default(CancellationToken), IProgress<double> progress = null) {
    97106
    98107      int numRows = rows.Count();
     
    114123        return (ISymbolicExpressionTree)tree.Clone();
    115124
    116 #if EXPLICIT_SHAPE
    117       var target = tf.placeholder(tf.float64, new TensorShape(numRows, 1), name: problemData.TargetVariable);
    118 #endif
    119       // mse
    120       var costs = tf.reduce_sum(tf.square(target - prediction)) / (2.0 * numRows);
    121       var optimizer = tf.train.GradientDescentOptimizer((float)learningRate).minimize(costs);
     125      var target = tf.placeholder(DataType, new TensorShape(numRows), name: problemData.TargetVariable);
     126      // MSE
     127      var cost = tf.reduce_mean(tf.square(prediction - target));
     128
     129      var optimizer = tf.train.AdamOptimizer((float)learningRate);
     130      //var optimizer = tf.train.GradientDescentOptimizer((float)learningRate);
     131      var optimizationOperation = optimizer.minimize(cost);
     132
     133#if EXPORT_GRAPH
     134      //https://github.com/SciSharp/TensorFlow.NET/wiki/Debugging
     135      tf.train.export_meta_graph(@"C:\temp\TFboard\graph.meta", as_text: false,
     136        clear_devices: true, clear_extraneous_savers: false, strip_default_attrs: true);
     137#endif
    122138
    123139      // features as feed items
     
    127143        var variableName = kvp.Value;
    128144        if (problemData.Dataset.VariableHasType<double>(variableName)) {
    129           var data = problemData.Dataset.GetDoubleValues(variableName, rows).ToArray();
    130           //if (vectorLength.HasValue) {
    131           //  var vectorData = new double[numRows][];
    132           //  for (int i = 0; i < numRows; i++)
    133           //    vectorData[i] = Enumerable.Repeat(data[i], vectorLength.Value).ToArray();
    134           //  variablesFeed.Add(variable, np.array(vectorData));
    135           //} else
     145          var data = problemData.Dataset.GetDoubleValues(variableName, rows).Select(x => (float)x).ToArray();
    136146          variablesFeed.Add(variable, np.array(data, copy: false).reshape(numRows, 1));
    137           //} else if (problemData.Dataset.VariableHasType<string>(variableName)) {
    138           //  variablesFeed.Add(variable, problemData.Dataset.GetStringValues(variableName, rows));
    139147        } else if (problemData.Dataset.VariableHasType<DoubleVector>(variableName)) {
    140           var data = problemData.Dataset.GetDoubleVectorValues(variableName, rows).Select(x => x.ToArray()).ToArray();
     148          var data = problemData.Dataset.GetDoubleVectorValues(variableName, rows).Select(x => x.Select(y => (float)y).ToArray()).ToArray();
    141149          variablesFeed.Add(variable, np.array(data));
    142150        } else
    143151          throw new NotSupportedException($"Type of the variable is not supported: {variableName}");
    144152      }
    145       var targetData = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).ToArray();
    146       variablesFeed.Add(target, np.array(targetData, copy: false).reshape(numRows, 1));
     153      var targetData = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(x => (float)x).ToArray();
     154      variablesFeed.Add(target, np.array(targetData, copy: false));
     155
    147156
    148157      List<NDArray> constants;
    149158      using (var session = tf.Session()) {
     159
     160#if LOG_FILE
     161        var directoryName = $"C:\\temp\\TFboard\\logdir\\manual_{DateTime.Now.ToString("yyyyMMddHHmmss")}_{maxIterations}_{learningRate.ToString(CultureInfo.InvariantCulture)}";
     162        Directory.CreateDirectory(directoryName);
     163        var costsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Costs.csv")));
     164        var weightsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Weights.csv")));
     165        var gradientsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Gradients.csv")));
     166#endif
     167
     168#if LOG_CONSOLE || LOG_FILE
     169        var gradients = optimizer.compute_gradients(cost);
     170#endif
     171
    150172        session.run(tf.global_variables_initializer());
    151173
    152         // https://github.com/SciSharp/TensorFlow.NET/wiki/Debugging
    153         tf.train.export_meta_graph(@"C:\temp\TFboard\graph.meta", as_text: false);
     174        progress?.Report(session.run(cost, variablesFeed)[0].GetValue<float>(0));
     175
     176
     177#if LOG_CONSOLE
     178        Trace.WriteLine("Costs:");
     179        Trace.WriteLine($"MSE: {session.run(cost, variablesFeed)[0].ToString(true)}");
    154180
    155181        Trace.WriteLine("Weights:");
    156         foreach (var v in variables)
     182        foreach (var v in variables) {
    157183          Trace.WriteLine($"{v.name}: {session.run(v).ToString(true)}");
     184        }
     185
     186        Trace.WriteLine("Gradients:");
     187        foreach (var t in gradients) {
     188          Trace.WriteLine($"{t.Item2.name}: {session.run(t.Item1, variablesFeed)[0].ToString(true)}");
     189        }
     190#endif
     191
     192#if LOG_FILE
     193        costsWriter.WriteLine("MSE");
     194        costsWriter.WriteLine(session.run(cost, variablesFeed)[0].GetValue<float>(0).ToString(CultureInfo.InvariantCulture));
     195
     196        weightsWriter.WriteLine(string.Join(";", variables.Select(v => v.name)));
     197        weightsWriter.WriteLine(string.Join(";", variables.Select(v => session.run(v).GetValue<float>(0, 0).ToString(CultureInfo.InvariantCulture))));
     198
     199        gradientsWriter.WriteLine(string.Join(";", gradients.Select(t => t.Item2.name)));
     200        gradientsWriter.WriteLine(string.Join(";", gradients.Select(t => session.run(t.Item1, variablesFeed)[0].GetValue<float>(0, 0).ToString(CultureInfo.InvariantCulture))));
     201#endif
    158202
    159203        for (int i = 0; i < maxIterations; i++) {
    160 
    161           //optimizer.minimize(costs);
    162           session.run(optimizer, variablesFeed);
     204          if (cancellationToken.IsCancellationRequested)
     205            break;
     206
     207          session.run(optimizationOperation, variablesFeed);
     208
     209          progress?.Report(session.run(cost, variablesFeed)[0].GetValue<float>(0));
     210
     211#if LOG_CONSOLE
     212          Trace.WriteLine("Costs:");
     213          Trace.WriteLine($"MSE: {session.run(cost, variablesFeed)[0].ToString(true)}");
    163214
    164215          Trace.WriteLine("Weights:");
    165           foreach (var v in variables)
     216          foreach (var v in variables) {
    166217            Trace.WriteLine($"{v.name}: {session.run(v).ToString(true)}");
    167         }
    168 
     218          }
     219
     220          Trace.WriteLine("Gradients:");
     221          foreach (var t in gradients) {
     222            Trace.WriteLine($"{t.Item2.name}: {session.run(t.Item1, variablesFeed)[0].ToString(true)}");
     223          }
     224#endif
     225
     226#if LOG_FILE
     227          costsWriter.WriteLine(session.run(cost, variablesFeed)[0].GetValue<float>(0).ToString(CultureInfo.InvariantCulture));
     228          weightsWriter.WriteLine(string.Join(";", variables.Select(v => session.run(v).GetValue<float>(0, 0).ToString(CultureInfo.InvariantCulture))));
     229          gradientsWriter.WriteLine(string.Join(";", gradients.Select(t => session.run(t.Item1, variablesFeed)[0].GetValue<float>(0, 0).ToString(CultureInfo.InvariantCulture))));
     230#endif
     231        }
     232
     233#if LOG_FILE
     234        costsWriter.Close();
     235        weightsWriter.Close();
     236        gradientsWriter.Close();
     237#endif
    169238        constants = variables.Select(v => session.run(v)).ToList();
    170239      }
     
    182251      foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
    183252        if (node is ConstantTreeNode constantTreeNode)
    184           constantTreeNode.Value = constants[i++].GetDouble(0, 0);
     253          constantTreeNode.Value = constants[i++].GetValue<float>(0, 0);
    185254        else if (node is VariableTreeNodeBase variableTreeNodeBase && updateVariableWeights)
    186           variableTreeNodeBase.Weight = constants[i++].GetDouble(0, 0);
     255          variableTreeNodeBase.Weight = constants[i++].GetValue<float>(0, 0);
    187256        else if (node is FactorVariableTreeNode factorVarTreeNode && updateVariableWeights) {
    188257          for (int j = 0; j < factorVarTreeNode.Weights.Length; j++)
    189             factorVarTreeNode.Weights[j] = constants[i++].GetDouble(0, 0);
     258            factorVarTreeNode.Weights[j] = constants[i++].GetValue<float>(0, 0);
    190259        }
    191260      }
  • branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/TreeToTensorConverter.cs

    r17493 r17502  
    2929using Tensorflow;
    3030using static Tensorflow.Binding;
    31 using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector<double>;
    3231
    3332namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
    3433  public class TreeToTensorConverter {
    3534
    36     #region helper class
    37     public class DataForVariable {
    38       public readonly string variableName;
    39       public readonly string variableValue; // for factor vars
    40 
    41       public DataForVariable(string varName, string varValue) {
    42         this.variableName = varName;
    43         this.variableValue = varValue;
    44       }
    45 
    46       public override bool Equals(object obj) {
    47         var other = obj as DataForVariable;
    48         if (other == null) return false;
    49         return other.variableName.Equals(this.variableName) &&
    50                other.variableValue.Equals(this.variableValue);
    51       }
    52 
    53       public override int GetHashCode() {
    54         return variableName.GetHashCode() ^ variableValue.GetHashCode();
    55       }
    56     }
    57     #endregion
     35    private static readonly TF_DataType DataType = tf.float32;
    5836
    5937    public static bool TryConvert(ISymbolicExpressionTree tree, int numRows, Dictionary<string, int> variableLengths,
    6038      bool makeVariableWeightsVariable, bool addLinearScalingTerms,
    61       out Tensor graph, out Dictionary<Tensor, string> parameters, out List<Tensor> variables
    62 /*, out double[] initialConstants*/) {
     39      out Tensor graph, out Dictionary<Tensor, string> parameters, out List<Tensor> variables) {
    6340
    6441      try {
     
    6643        graph = converter.ConvertNode(tree.Root.GetSubtree(0));
    6744
    68         //var parametersEntries = converter.parameters.ToList(); // guarantee same order for keys and values
    69         parameters = converter.parameters; // parametersEntries.Select(kvp => kvp.Value).ToList();
     45        parameters = converter.parameters;
    7046        variables = converter.variables;
    71         //initialConstants = converter.initialConstants.ToArray();
    7247        return true;
    7348      } catch (NotSupportedException) {
     
    7550        parameters = null;
    7651        variables = null;
    77         //initialConstants = null;
    7852        return false;
    7953      }
     
    8559    private readonly bool addLinearScalingTerms;
    8660
    87     //private readonly List<double> initialConstants = new List<double>();
    8861    private readonly Dictionary<Tensor, string> parameters = new Dictionary<Tensor, string>();
    8962    private readonly List<Tensor> variables = new List<Tensor>();
     
    9770
    9871
    99 
    10072    private Tensor ConvertNode(ISymbolicExpressionTreeNode node) {
    10173      if (node.Symbol is Constant) {
    102         var value = ((ConstantTreeNode)node).Value;
    103         //initialConstants.Add(value);
    104 #if EXPLICIT_SHAPE
    105         //var var = (RefVariable)tf.VariableV1(value, name: $"c_{variables.Count}", dtype: tf.float64, shape: new[] { 1, 1 });
     74        var value = (float)((ConstantTreeNode)node).Value;
    10675        var value_arr = np.array(value).reshape(1, 1);
    107         var var = tf.Variable(value_arr, name: $"c_{variables.Count}", dtype: tf.float64);
    108 #endif
    109         //var var = tf.Variable(value, name: $"c_{variables.Count}", dtype: tf.float64/*, shape: new[] { 1, 1 }*/);
     76        var var = tf.Variable(value_arr, name: $"c_{variables.Count}", dtype: DataType);
    11077        variables.Add(var);
    11178        return var;
     
    11885        //var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;
    11986        //var par = FindOrCreateParameter(parameters, varNode.VariableName, varValue);
    120 #if EXPLICIT_SHAPE
    121         var par = tf.placeholder(tf.float64, new TensorShape(numRows, variableLengths[varNode.VariableName]), name: varNode.VariableName);
    122 #endif
     87        var par = tf.placeholder(DataType, new TensorShape(numRows, variableLengths[varNode.VariableName]), name: varNode.VariableName);
    12388        parameters.Add(par, varNode.VariableName);
    12489
    12590        if (makeVariableWeightsVariable) {
    126           //initialConstants.Add(varNode.Weight);
    127 #if EXPLICIT_SHAPE
    128           //var w = (RefVariable)tf.VariableV1(varNode.Weight, name: $"w_{varNode.VariableName}_{variables.Count}", dtype: tf.float64, shape: new[] { 1, 1 });
    129           var w_arr = np.array(varNode.Weight).reshape(1, 1);
    130           var w = tf.Variable(w_arr, name: $"w_{varNode.VariableName}", dtype: tf.float64);
    131 #endif
    132           //var w = tf.Variable(varNode.Weight, name: $"w_{varNode.VariableName}_{variables.Count}", dtype: tf.float64/*, shape: new[] { 1, 1 }*/);
     91          var w_arr = np.array((float)varNode.Weight).reshape(1, 1);
     92          var w = tf.Variable(w_arr, name: $"w_{varNode.VariableName}", dtype: DataType);
    13393          variables.Add(w);
    13494          return w * par;
     
    143103      //  foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
    144104      //    //var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);
    145       //    var par = tf.placeholder(tf.float64, new TensorShape(numRows, 1), name: factorVarNode.VariableName);
     105      //    var par = tf.placeholder(DataType, new TensorShape(numRows, 1), name: factorVarNode.VariableName);
    146106      //    parameters.Add(par, factorVarNode.VariableName);
    147107
    148108      //    var value = factorVarNode.GetValue(variableValue);
    149109      //    //initialConstants.Add(value);
    150       //    var wVar = (RefVariable)tf.VariableV1(value, name: $"f_{factorVarNode.VariableName}_{variables.Count}", dtype: tf.float64, shape: new[] { 1, 1 });
     110      //    var wVar = (RefVariable)tf.VariableV1(value, name: $"f_{factorVarNode.VariableName}_{variables.Count}", dtype: DataType, shape: new[] { 1, 1 });
    151111      //    //var wVar = tf.Variable(value, name: $"f_{factorVarNode.VariableName}_{variables.Count}"/*, shape: new[] { 1, 1 }*/);
    152112      //    variables.Add(wVar);
     
    159119
    160120      if (node.Symbol is Addition) {
    161         var terms = new List<Tensor>();
    162         foreach (var subTree in node.Subtrees) {
    163           terms.Add(ConvertNode(subTree));
    164         }
    165 
     121        var terms = node.Subtrees.Select(ConvertNode).ToList();
    166122        return terms.Aggregate((a, b) => a + b);
    167123      }
    168124
    169125      if (node.Symbol is Subtraction) {
    170         var terms = new List<Tensor>();
    171         for (int i = 0; i < node.SubtreeCount; i++) {
    172           var t = ConvertNode(node.GetSubtree(i));
    173           if (i > 0) t = -t;
    174           terms.Add(t);
    175         }
    176 
     126        var terms = node.Subtrees.Select(ConvertNode).ToList();
    177127        if (terms.Count == 1) return -terms[0];
    178         else return terms.Aggregate((a, b) => a + b);
     128        return terms.Aggregate((a, b) => a - b);
    179129      }
    180130
    181131      if (node.Symbol is Multiplication) {
    182         var terms = new List<Tensor>();
    183         foreach (var subTree in node.Subtrees) {
    184           terms.Add(ConvertNode(subTree));
    185         }
    186 
    187         if (terms.Count == 1) return terms[0];
    188         else return terms.Aggregate((a, b) => a * b);
     132        var terms = node.Subtrees.Select(ConvertNode).ToList();
     133        return terms.Aggregate((a, b) => a * b);
    189134      }
    190135
    191136      if (node.Symbol is Division) {
    192         var terms = new List<Tensor>();
    193         foreach (var subTree in node.Subtrees) {
    194           terms.Add(ConvertNode(subTree));
    195         }
    196 
    197         if (terms.Count == 1) return 1.0 / terms[0];
    198         else return terms.Aggregate((a, b) => a * (1.0 / b));
     137        var terms = node.Subtrees.Select(ConvertNode).ToList();
     138        if (terms.Count == 1) return 1.0f / terms[0];
     139        return terms.Aggregate((a, b) => a / b);
    199140      }
    200141
     
    207148        var x1 = ConvertNode(node.GetSubtree(0));
    208149        var x2 = ConvertNode(node.GetSubtree(1));
    209         return x1 / tf.pow(1 + x2 * x2, 0.5);
     150        return x1 / tf.pow(1.0f + x2 * x2, 0.5f);
    210151      }
    211152
     
    233174      if (node.Symbol is Cube) {
    234175        return math_ops.pow(
    235           ConvertNode(node.GetSubtree(0)), 3.0);
     176          ConvertNode(node.GetSubtree(0)), 3.0f);
    236177      }
    237178
    238179      if (node.Symbol is CubeRoot) {
    239180        return math_ops.pow(
    240           ConvertNode(node.GetSubtree(0)), 1.0 / 3.0);
     181          ConvertNode(node.GetSubtree(0)), 1.0f / 3.0f);
    241182        // TODO
    242183        // f: x < 0 ? -Math.Pow(-x, 1.0 / 3) : Math.Pow(x, 1.0 / 3),
     
    281222
    282223      if (node.Symbol is StartSymbol) {
     224        Tensor prediction;
    283225        if (addLinearScalingTerms) {
    284226          // scaling variables α, β are given at the beginning of the parameter vector
    285 #if EXPLICIT_SHAPE
    286           //var alpha = (RefVariable)tf.VariableV1(1.0, name: $"alpha_{1.0}", dtype: tf.float64, shape: new[] { 1, 1 });
    287           //var beta = (RefVariable)tf.VariableV1(0.0, name: $"beta_{0.0}", dtype: tf.float64, shape: new[] { 1, 1 });
    288 
    289           var alpha_arr = np.array(1.0).reshape(1, 1);
    290           var alpha = tf.Variable(alpha_arr, name: "alpha", dtype: tf.float64);
    291           var beta_arr = np.array(0.0).reshape(1, 1);
    292           var beta = tf.Variable(beta_arr, name: "beta", dtype: tf.float64);
    293 #endif
    294           //var alpha = tf.Variable(1.0, name: $"alpha_{1.0}", dtype: tf.float64/*, shape: new[] { 1, 1 }*/);
    295           //var beta = tf.Variable(0.0, name: $"beta_{0.0}", dtype: tf.float64/*, shape: new[] { 1, 1 }*/);
     227          var alpha_arr = np.array(1.0f).reshape(1, 1);
     228          var alpha = tf.Variable(alpha_arr, name: "alpha", dtype: DataType);
     229          var beta_arr = np.array(0.0f).reshape(1, 1);
     230          var beta = tf.Variable(beta_arr, name: "beta", dtype: DataType);
    296231          variables.Add(alpha);
    297232          variables.Add(beta);
    298233          var t = ConvertNode(node.GetSubtree(0));
    299           return t * alpha + beta;
    300         } else return ConvertNode(node.GetSubtree(0));
     234          prediction = t * alpha + beta;
     235        } else {
     236          prediction = ConvertNode(node.GetSubtree(0));
     237        }
     238
     239        return tf.reduce_sum(prediction, axis: new[] { 1 });
    301240      }
    302241
Note: See TracChangeset for help on using the changeset viewer.