Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/TensorFlowConstantOptimizationEvaluator.cs @ 17721

Last change on this file since 17721 was 17721, checked in by pfleck, 4 years ago

#3040 First draft of different-vector-length strategies (cut, fill, resample, cycle, ...)

File size: 11.5 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22//#define EXPORT_GRAPH
23//#define LOG_CONSOLE
24//#define LOG_FILE
25
26using System;
27using System.Collections;
28using System.Collections.Generic;
29#if LOG_CONSOLE
30using System.Diagnostics;
31#endif
32#if LOG_FILE
33using System.Globalization;
34using System.IO;
35#endif
36using System.Linq;
37using System.Threading;
38using HeuristicLab.Common;
39using HeuristicLab.Core;
40using HeuristicLab.Data;
41using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
42using HeuristicLab.Parameters;
43using HEAL.Attic;
44using NumSharp;
45using Tensorflow;
46using static Tensorflow.Binding;
47using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector<double>;
48
49namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
50  [StorableType("63944BF6-62E5-4BE4-974C-D30AD8770F99")]
51  [Item("TensorFlowConstantOptimizationEvaluator", "")]
52  public class TensorFlowConstantOptimizationEvaluator : SymbolicRegressionConstantOptimizationEvaluator {
53    private const string MaximumIterationsName = "MaximumIterations";
54    private const string LearningRateName = "LearningRate";
55
56    private static readonly TF_DataType DataType = tf.float32;
57
58    #region Parameter Properties
59    public IFixedValueParameter<IntValue> ConstantOptimizationIterationsParameter {
60      get { return (IFixedValueParameter<IntValue>)Parameters[MaximumIterationsName]; }
61    }
62    public IFixedValueParameter<DoubleValue> LearningRateParameter {
63      get { return (IFixedValueParameter<DoubleValue>)Parameters[LearningRateName]; }
64    }
65    #endregion
66
67    #region Properties
68    public int ConstantOptimizationIterations {
69      get { return ConstantOptimizationIterationsParameter.Value.Value; }
70    }
71    public double LearningRate {
72      get { return LearningRateParameter.Value.Value; }
73    }
74    #endregion
75
76    public TensorFlowConstantOptimizationEvaluator()
77      : base() {
78      Parameters.Add(new FixedValueParameter<IntValue>(MaximumIterationsName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree(0 indicates other or default stopping criterion).", new IntValue(10)));
79      Parameters.Add(new FixedValueParameter<DoubleValue>(LearningRateName, "", new DoubleValue(0.001)));
80    }
81
82    protected TensorFlowConstantOptimizationEvaluator(TensorFlowConstantOptimizationEvaluator original, Cloner cloner)
83      : base(original, cloner) { }
84
85    public override IDeepCloneable Clone(Cloner cloner) {
86      return new TensorFlowConstantOptimizationEvaluator(this, cloner);
87    }
88
89    [StorableConstructor]
90    protected TensorFlowConstantOptimizationEvaluator(StorableConstructorFlag _) : base(_) { }
91
92    protected override ISymbolicExpressionTree OptimizeConstants(
93      ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows,
94      CancellationToken cancellationToken = default(CancellationToken), EvaluationsCounter counter = null) {
95      return OptimizeTree(tree,
96        problemData, rows,
97        ApplyLinearScalingParameter.ActualValue.Value, UpdateVariableWeights,
98        ConstantOptimizationIterations, LearningRate,
99        cancellationToken);
100    }
101
102    public static ISymbolicExpressionTree OptimizeTree(ISymbolicExpressionTree tree,
103      IRegressionProblemData problemData, IEnumerable<int> rows,
104      bool applyLinearScaling, bool updateVariableWeights, int maxIterations, double learningRate,
105      CancellationToken cancellationToken = default(CancellationToken), IProgress<double> progress = null) {
106
107      int numRows = rows.Count();
108      var variableLengths = problemData.AllowedInputVariables.ToDictionary(
109        var => var,
110        var => {
111          if (problemData.Dataset.VariableHasType<double>(var)) return 1;
112          if (problemData.Dataset.VariableHasType<DoubleVector>(var)) return problemData.Dataset.GetDoubleVectorValue(var, 0).Count;
113          throw new NotSupportedException($"Type of variable {var} is not supported.");
114        });
115
116      bool success = TreeToTensorConverter.TryConvert(tree,
117        numRows, variableLengths,
118        updateVariableWeights, applyLinearScaling,
119        out Tensor prediction,
120        out Dictionary<Tensor, string> parameters, out List<Tensor> variables/*, out double[] initialConstants*/);
121
122      if (!success)
123        return (ISymbolicExpressionTree)tree.Clone();
124
125      var target = tf.placeholder(DataType, new TensorShape(numRows), name: problemData.TargetVariable);
126      // MSE
127      var cost = tf.reduce_mean(tf.square(target - prediction));
128
129      var optimizer = tf.train.AdamOptimizer((float)learningRate);
130      //var optimizer = tf.train.GradientDescentOptimizer((float)learningRate);
131      var optimizationOperation = optimizer.minimize(cost);
132
133#if EXPORT_GRAPH
134      //https://github.com/SciSharp/TensorFlow.NET/wiki/Debugging
135      tf.train.export_meta_graph(@"C:\temp\TFboard\graph.meta", as_text: false,
136        clear_devices: true, clear_extraneous_savers: false, strip_default_attrs: true);
137#endif
138
139      // features as feed items
140      var variablesFeed = new Hashtable();
141      foreach (var kvp in parameters) {
142        var variable = kvp.Key;
143        var variableName = kvp.Value;
144        if (problemData.Dataset.VariableHasType<double>(variableName)) {
145          var data = problemData.Dataset.GetDoubleValues(variableName, rows).Select(x => (float)x).ToArray();
146          variablesFeed.Add(variable, np.array(data).reshape(numRows, 1));
147        } else if (problemData.Dataset.VariableHasType<DoubleVector>(variableName)) {
148          var data = problemData.Dataset.GetDoubleVectorValues(variableName, rows).Select(x => x.Select(y => (float)y).ToArray()).ToArray();
149          variablesFeed.Add(variable, np.array(data));
150        } else
151          throw new NotSupportedException($"Type of the variable is not supported: {variableName}");
152      }
153      var targetData = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).Select(x => (float)x).ToArray();
154      variablesFeed.Add(target, np.array(targetData));
155
156
157      List<NDArray> constants;
158      using (var session = tf.Session()) {
159
160#if LOG_FILE
161        var directoryName = $"C:\\temp\\TFboard\\logdir\\manual_{DateTime.Now.ToString("yyyyMMddHHmmss")}_{maxIterations}_{learningRate.ToString(CultureInfo.InvariantCulture)}";
162        Directory.CreateDirectory(directoryName);
163        var costsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Costs.csv")));
164        var weightsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Weights.csv")));
165        var gradientsWriter = new StreamWriter(File.Create(Path.Combine(directoryName, "Gradients.csv")));
166#endif
167
168#if LOG_CONSOLE || LOG_FILE
169        var gradients = optimizer.compute_gradients(cost);
170#endif
171
172        session.run(tf.global_variables_initializer());
173
174        progress?.Report(session.run(cost, variablesFeed)[0].GetValue<float>(0));
175
176
177#if LOG_CONSOLE
178        Trace.WriteLine("Costs:");
179        Trace.WriteLine($"MSE: {session.run(cost, variablesFeed)[0].ToString(true)}");
180
181        Trace.WriteLine("Weights:");
182        foreach (var v in variables) {
183          Trace.WriteLine($"{v.name}: {session.run(v).ToString(true)}");
184        }
185
186        Trace.WriteLine("Gradients:");
187        foreach (var t in gradients) {
188          Trace.WriteLine($"{t.Item2.name}: {session.run(t.Item1, variablesFeed)[0].ToString(true)}");
189        }
190#endif
191
192#if LOG_FILE
193        costsWriter.WriteLine("MSE");
194        costsWriter.WriteLine(session.run(cost, variablesFeed)[0].GetValue<float>(0).ToString(CultureInfo.InvariantCulture));
195
196        weightsWriter.WriteLine(string.Join(";", variables.Select(v => v.name)));
197        weightsWriter.WriteLine(string.Join(";", variables.Select(v => session.run(v).GetValue<float>(0, 0).ToString(CultureInfo.InvariantCulture))));
198
199        gradientsWriter.WriteLine(string.Join(";", gradients.Select(t => t.Item2.name)));
200        gradientsWriter.WriteLine(string.Join(";", gradients.Select(t => session.run(t.Item1, variablesFeed)[0].GetValue<float>(0, 0).ToString(CultureInfo.InvariantCulture))));
201#endif
202
203        for (int i = 0; i < maxIterations; i++) {
204          if (cancellationToken.IsCancellationRequested)
205            break;
206
207          session.run(optimizationOperation, variablesFeed);
208
209          progress?.Report(session.run(cost, variablesFeed)[0].GetValue<float>(0));
210
211#if LOG_CONSOLE
212          Trace.WriteLine("Costs:");
213          Trace.WriteLine($"MSE: {session.run(cost, variablesFeed)[0].ToString(true)}");
214
215          Trace.WriteLine("Weights:");
216          foreach (var v in variables) {
217            Trace.WriteLine($"{v.name}: {session.run(v).ToString(true)}");
218          }
219
220          Trace.WriteLine("Gradients:");
221          foreach (var t in gradients) {
222            Trace.WriteLine($"{t.Item2.name}: {session.run(t.Item1, variablesFeed)[0].ToString(true)}");
223          }
224#endif
225
226#if LOG_FILE
227          costsWriter.WriteLine(session.run(cost, variablesFeed)[0].GetValue<float>(0).ToString(CultureInfo.InvariantCulture));
228          weightsWriter.WriteLine(string.Join(";", variables.Select(v => session.run(v).GetValue<float>(0, 0).ToString(CultureInfo.InvariantCulture))));
229          gradientsWriter.WriteLine(string.Join(";", gradients.Select(t => session.run(t.Item1, variablesFeed)[0].GetValue<float>(0, 0).ToString(CultureInfo.InvariantCulture))));
230#endif
231        }
232
233#if LOG_FILE
234        costsWriter.Close();
235        weightsWriter.Close();
236        gradientsWriter.Close();
237#endif
238        constants = variables.Select(v => session.run(v)).ToList();
239      }
240
241      if (applyLinearScaling)
242        constants = constants.Skip(2).ToList();
243      var newTree = (ISymbolicExpressionTree)tree.Clone();
244      UpdateConstants(newTree, constants, updateVariableWeights);
245
246      return newTree;
247    }
248
249    private static void UpdateConstants(ISymbolicExpressionTree tree, IList<NDArray> constants, bool updateVariableWeights) {
250      int i = 0;
251      foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
252        if (node is ConstantTreeNode constantTreeNode)
253          constantTreeNode.Value = constants[i++].GetValue<float>(0, 0);
254        else if (node is VariableTreeNodeBase variableTreeNodeBase && updateVariableWeights)
255          variableTreeNodeBase.Weight = constants[i++].GetValue<float>(0, 0);
256        else if (node is FactorVariableTreeNode factorVarTreeNode && updateVariableWeights) {
257          for (int j = 0; j < factorVarTreeNode.Weights.Length; j++)
258            factorVarTreeNode.Weights[j] = constants[i++].GetValue<float>(0, 0);
259        }
260      }
261    }
262
263    public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) {
264      return TreeToTensorConverter.IsCompatible(tree);
265    }
266  }
267}
Note: See TracBrowser for help on using the repository browser.