Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/TensorFlowConstantOptimizationEvaluator.cs @ 17476

Last change on this file since 17476 was 17476, checked in by pfleck, 4 years ago

#3040 Worked on TF-based constant optimization.

File size: 7.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections;
24using System.Collections.Generic;
25using System.Diagnostics;
26using System.Linq;
27using System.Threading;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Data;
31using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
32using HeuristicLab.Parameters;
33using HEAL.Attic;
34using NumSharp;
35using Tensorflow;
36using static Tensorflow.Binding;
37using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector<double>;
38
39namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
40  [StorableType("63944BF6-62E5-4BE4-974C-D30AD8770F99")]
41  [Item("TensorFlowConstantOptimizationEvaluator", "")]
42  public class TensorFlowConstantOptimizationEvaluator : SymbolicRegressionConstantOptimizationEvaluator {
43    private const string MaximumIterationsName = "MaximumIterations";
44    private const string LearningRateName = "LearningRate";
45
46    #region Parameter Properties
47    public IFixedValueParameter<IntValue> ConstantOptimizationIterationsParameter {
48      get { return (IFixedValueParameter<IntValue>)Parameters[MaximumIterationsName]; }
49    }
50    public IFixedValueParameter<DoubleValue> LearningRateParameter {
51      get { return (IFixedValueParameter<DoubleValue>)Parameters[LearningRateName]; }
52    }
53    #endregion
54
55    #region Properties
56    public int ConstantOptimizationIterations {
57      get { return ConstantOptimizationIterationsParameter.Value.Value; }
58    }
59    public double LearningRate {
60      get { return LearningRateParameter.Value.Value; }
61    }
62    #endregion
63
64    public TensorFlowConstantOptimizationEvaluator()
65      : base() {
66      Parameters.Add(new FixedValueParameter<IntValue>(MaximumIterationsName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree(0 indicates other or default stopping criterion).", new IntValue(10)));
67      Parameters.Add(new FixedValueParameter<DoubleValue>(LearningRateName, "", new DoubleValue(0.01)));
68    }
69
70    protected TensorFlowConstantOptimizationEvaluator(TensorFlowConstantOptimizationEvaluator original, Cloner cloner)
71      : base(original, cloner) { }
72
73    public override IDeepCloneable Clone(Cloner cloner) {
74      return new TensorFlowConstantOptimizationEvaluator(this, cloner);
75    }
76
77    [StorableConstructor]
78    protected TensorFlowConstantOptimizationEvaluator(StorableConstructorFlag _) : base(_) { }
79
80    protected override ISymbolicExpressionTree OptimizeConstants(
81      ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows,
82      CancellationToken cancellationToken = default(CancellationToken), EvaluationsCounter counter = null) {
83      return OptimizeTree(tree,
84        problemData, rows,
85        ApplyLinearScalingParameter.ActualValue.Value, UpdateVariableWeights,
86        ConstantOptimizationIterations, LearningRate,
87        cancellationToken, counter);
88    }
89
90    public static ISymbolicExpressionTree OptimizeTree(
91      ISymbolicExpressionTree tree,
92      IRegressionProblemData problemData, IEnumerable<int> rows,
93      bool applyLinearScaling, bool updateVariableWeights, int maxIterations, double learningRate,
94      CancellationToken cancellationToken = default(CancellationToken), EvaluationsCounter counter = null) {
95
96      var vectorVariables = tree.IterateNodesBreadth()
97        .OfType<VariableTreeNodeBase>()
98        .Where(node => problemData.Dataset.VariableHasType<DoubleVector>(node.VariableName))
99        .Select(node => node.VariableName);
100
101      int? vectorLength = null;
102      if (vectorVariables.Any()) {
103        vectorLength = vectorVariables.Select(var => problemData.Dataset.GetDoubleVectorValues(var, rows)).First().First().Count;
104      }
105      int numRows = rows.Count();
106
107      bool success = TreeToTensorConverter.TryConvert(tree,
108        numRows, vectorLength,
109        updateVariableWeights, applyLinearScaling,
110        out Tensor prediction,
111        out Dictionary<Tensor, string> parameters, out List<Tensor> variables/*, out double[] initialConstants*/);
112
113      var target = tf.placeholder(tf.float64, name: problemData.TargetVariable);
114      int samples = rows.Count();
115      // mse
116      var costs = tf.reduce_sum(tf.square(prediction - target)) / (2.0 * samples);
117      var optimizer = tf.train.GradientDescentOptimizer((float)learningRate).minimize(costs);
118
119      // features as feed items
120      var variablesFeed = new Hashtable();
121      foreach (var kvp in parameters) {
122        var variable = kvp.Key;
123        var variableName = kvp.Value;
124        if (problemData.Dataset.VariableHasType<double>(variableName)) {
125          var data = problemData.Dataset.GetDoubleValues(variableName, rows).ToArray();
126          if (vectorLength.HasValue) {
127            var vectorData = new double[numRows][];
128            for (int i = 0; i < numRows; i++)
129              vectorData[i] = Enumerable.Repeat(data[i], vectorLength.Value).ToArray();
130            variablesFeed.Add(variable, np.array(vectorData));
131          } else
132            variablesFeed.Add(variable, np.array(data, copy: false));
133          //} else if (problemData.Dataset.VariableHasType<string>(variableName)) {
134          //  variablesFeed.Add(variable, problemData.Dataset.GetStringValues(variableName, rows));
135        } else if (problemData.Dataset.VariableHasType<DoubleVector>(variableName)) {
136          var data = problemData.Dataset.GetDoubleVectorValues(variableName, rows).Select(x => x.ToArray()).ToArray();
137          variablesFeed.Add(variable, np.array(data));
138        } else
139          throw new NotSupportedException($"Type of the variable is not supported: {variableName}");
140      }
141      var targetData = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows).ToArray();
142      variablesFeed.Add(target, np.array(targetData, copy: false));
143
144
145      using (var session = tf.Session()) {
146        session.run(tf.global_variables_initializer());
147
148        Trace.WriteLine("Weights:");
149        foreach (var v in variables)
150          Trace.WriteLine($"{v.name}: {session.run(v).ToString(true)}");
151
152        for (int i = 0; i < maxIterations; i++) {
153
154          //optimizer.minimize(costs);
155          session.run(optimizer, variablesFeed);
156
157          Trace.WriteLine("Weights:");
158          foreach (var v in variables)
159            Trace.WriteLine($"{v.name}: {session.run(v).ToString(true)}");
160        }
161      }
162
163      if (!success)
164        return (ISymbolicExpressionTree)tree.Clone();
165
166
167      return null;
168    }
169
170    public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) {
171      return TreeToTensorConverter.IsCompatible(tree);
172    }
173  }
174}
Note: See TracBrowser for help on using the repository browser.