Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/TensorFlowConstantOptimizationEvaluator.cs @ 17475

Last change on this file since 17475 was 17475, checked in by pfleck, 4 years ago

#3040 Updated HeuristicLab.Algorithms.DataAnalysis plugin and its dependencies to Framework 4.7.2 to avoid conflicting System.ValueTuple locations (mscorelib or nuget).

File size: 6.0 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections;
24using System.Collections.Generic;
25using System.Linq;
26using System.Threading;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
31using HeuristicLab.Parameters;
32using HEAL.Attic;
33using Tensorflow;
34using static Tensorflow.Binding;
35using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector<double>;
36
37namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
38  [StorableType("63944BF6-62E5-4BE4-974C-D30AD8770F99")]
39  [Item("TensorFlowConstantOptimizationEvaluator", "")]
40  public class TensorFlowConstantOptimizationEvaluator : SymbolicRegressionConstantOptimizationEvaluator {
41    private const string MaximumIterationsName = "MaximumIterations";
42    private const string LearningRateName = "LearningRate";
43
44    #region Parameter Properties
45    public IFixedValueParameter<IntValue> ConstantOptimizationIterationsParameter {
46      get { return (IFixedValueParameter<IntValue>)Parameters[MaximumIterationsName]; }
47    }
48    public IFixedValueParameter<DoubleValue> LearningRateParameter {
49      get { return (IFixedValueParameter<DoubleValue>)Parameters[LearningRateName]; }
50    }
51    #endregion
52
53    #region Properties
54    public int ConstantOptimizationIterations {
55      get { return ConstantOptimizationIterationsParameter.Value.Value; }
56    }
57    public double LearningRate {
58      get { return LearningRateParameter.Value.Value; }
59    }
60    #endregion
61
62    public TensorFlowConstantOptimizationEvaluator()
63      : base() {
64      Parameters.Add(new FixedValueParameter<IntValue>(MaximumIterationsName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree(0 indicates other or default stopping criterion).", new IntValue(10)));
65      Parameters.Add(new FixedValueParameter<DoubleValue>(LearningRateName, "", new DoubleValue(0.01)));
66    }
67
68    protected TensorFlowConstantOptimizationEvaluator(TensorFlowConstantOptimizationEvaluator original, Cloner cloner)
69      : base(original, cloner) { }
70
71    public override IDeepCloneable Clone(Cloner cloner) {
72      return new TensorFlowConstantOptimizationEvaluator(this, cloner);
73    }
74
75    [StorableConstructor]
76    protected TensorFlowConstantOptimizationEvaluator(StorableConstructorFlag _) : base(_) { }
77
78    protected override ISymbolicExpressionTree OptimizeConstants(
79      ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows,
80      CancellationToken cancellationToken = default(CancellationToken), EvaluationsCounter counter = null) {
81      return OptimizeTree(tree,
82        problemData, rows,
83        ApplyLinearScalingParameter.ActualValue.Value, UpdateVariableWeights,
84        ConstantOptimizationIterations, LearningRate,
85        cancellationToken, counter);
86    }
87
88    public static ISymbolicExpressionTree OptimizeTree(
89      ISymbolicExpressionTree tree,
90      IRegressionProblemData problemData, IEnumerable<int> rows,
91      bool applyLinearScaling, bool updateVariableWeights, int maxIterations, double learningRate,
92      CancellationToken cancellationToken = default(CancellationToken), EvaluationsCounter counter = null) {
93
94      bool success = TreeToTensorConverter.TryConvert(tree, updateVariableWeights, applyLinearScaling,
95        out Tensor prediction, out Dictionary<TreeToTensorConverter.DataForVariable, Tensor> variables/*, out double[] initialConstants*/);
96
97      var target = tf.placeholder(tf.float64, name: problemData.TargetVariable);
98      int samples = rows.Count();
99      // mse
100      var costs = tf.reduce_sum(tf.square(prediction - target)) / (2.0 * samples);
101      var optimizer = tf.train.GradientDescentOptimizer((float)learningRate);
102
103      // features as feed items
104      var variablesFeed = new Hashtable();
105      foreach (var kvp in variables) {
106        var variableName = kvp.Key.variableName;
107        var variable = kvp.Value;
108        if (problemData.Dataset.VariableHasType<double>(variableName))
109          variablesFeed.Add(variable, problemData.Dataset.GetDoubleValues(variableName, rows));
110        if (problemData.Dataset.VariableHasType<string>(variableName))
111          variablesFeed.Add(variable, problemData.Dataset.GetStringValues(variableName, rows));
112        if (problemData.Dataset.VariableHasType<DoubleVector>(variableName))
113          variablesFeed.Add(variable, problemData.Dataset.GetDoubleVectorValues(variableName, rows));
114        throw new NotSupportedException($"Type of the variable is not supported: {variableName}");
115      }
116      variablesFeed.Add(target, problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows));
117
118
119      using (var session = tf.Session()) {
120        for (int i = 0; i < maxIterations; i++) {
121          optimizer.minimize(costs);
122          var result = session.run(optimizer, variablesFeed);
123        }
124      }
125      optimizer.minimize(costs);
126
127      if (!success)
128        return (ISymbolicExpressionTree)tree.Clone();
129
130
131      return null;
132    }
133
134    public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) {
135      return TreeToTensorConverter.IsCompatible(tree);
136    }
137  }
138}
Note: See TracBrowser for help on using the repository browser.