Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/TreeToAutoDiffTermConverter.cs

Last change on this file was 18220, checked in by gkronber, 3 years ago

#3136: reintegrated structure-template GP branch into trunk

File size: 16.3 KB
RevLine 
[14843]1#region License Information
2/* HeuristicLab
[17180]3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[14843]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
[14950]25using System.Runtime.Serialization;
[14843]26using AutoDiff;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28
29namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
30  public class TreeToAutoDiffTermConverter {
31    public delegate double ParametricFunction(double[] vars, double[] @params);
[14950]32
[14843]33    public delegate Tuple<double[], double> ParametricFunctionGradient(double[] vars, double[] @params);
34
35    #region helper class
36    public class DataForVariable {
37      public readonly string variableName;
38      public readonly string variableValue; // for factor vars
39      public readonly int lag;
40
41      public DataForVariable(string varName, string varValue, int lag) {
42        this.variableName = varName;
43        this.variableValue = varValue;
44        this.lag = lag;
45      }
46
47      public override bool Equals(object obj) {
48        var other = obj as DataForVariable;
49        if (other == null) return false;
50        return other.variableName.Equals(this.variableName) &&
51               other.variableValue.Equals(this.variableValue) &&
52               other.lag == this.lag;
53      }
54
55      public override int GetHashCode() {
56        return variableName.GetHashCode() ^ variableValue.GetHashCode() ^ lag;
57      }
58    }
59    #endregion
60
61    #region derivations of functions
62    // create function factory for arctangent
63    private static readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
64      eval: Math.Atan,
65      diff: x => 1 / (1 + x * x));
[14950]66
[14843]67    private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
68      eval: Math.Sin,
69      diff: Math.Cos);
[14950]70
[14843]71    private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
[14950]72      eval: Math.Cos,
73      diff: x => -Math.Sin(x));
74
[14843]75    private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
76      eval: Math.Tan,
77      diff: x => 1 + Math.Tan(x) * Math.Tan(x));
[16656]78    private static readonly Func<Term, UnaryFunc> tanh = UnaryFunc.Factory(
79      eval: Math.Tanh,
80      diff: x => 1 - Math.Tanh(x) * Math.Tanh(x));
[14843]81    private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
82      eval: alglib.errorfunction,
83      diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
[14950]84
[14843]85    private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
86      eval: alglib.normaldistribution,
87      diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
88
[16356]89    private static readonly Func<Term, UnaryFunc> abs = UnaryFunc.Factory(
90      eval: Math.Abs,
91      diff: x => Math.Sign(x)
92      );
93
[16905]94    private static readonly Func<Term, UnaryFunc> cbrt = UnaryFunc.Factory(
95      eval: x => x < 0 ? -Math.Pow(-x, 1.0 / 3) : Math.Pow(x, 1.0 / 3),
96      diff: x => { var cbrt_x = x < 0 ? -Math.Pow(-x, 1.0 / 3) : Math.Pow(x, 1.0 / 3); return 1.0 / (3 * cbrt_x * cbrt_x); }
97      );
98
99
100
[14843]101    #endregion
102
[15447]103    public static bool TryConvertToAutoDiff(ISymbolicExpressionTree tree, bool makeVariableWeightsVariable, bool addLinearScalingTerms,
[18132]104      out List<DataForVariable> parameters, out double[] initialParamValues,
[14843]105      out ParametricFunction func,
106      out ParametricFunctionGradient func_grad) {
107
[18220]108      return TryConvertToAutoDiff(tree, makeVariableWeightsVariable, addLinearScalingTerms, Enumerable.Empty<ISymbolicExpressionTreeNode>(),
109        out parameters, out initialParamValues, out func, out func_grad);
110    }
111
112    public static bool TryConvertToAutoDiff(ISymbolicExpressionTree tree, bool makeVariableWeightsVariable, bool addLinearScalingTerms, IEnumerable<ISymbolicExpressionTreeNode> excludedNodes,
113      out List<DataForVariable> parameters, out double[] initialParamValues,
114      out ParametricFunction func,
115      out ParametricFunctionGradient func_grad) {
116
[14843]117      // use a transformator object which holds the state (variable list, parameter list, ...) for recursive transformation of the tree
[18220]118      var transformator = new TreeToAutoDiffTermConverter(makeVariableWeightsVariable, addLinearScalingTerms, excludedNodes);
[14843]119      AutoDiff.Term term;
[14950]120      try {
121        term = transformator.ConvertToAutoDiff(tree.Root.GetSubtree(0));
[14843]122        var parameterEntries = transformator.parameters.ToArray(); // guarantee same order for keys and values
[14950]123        var compiledTerm = term.Compile(transformator.variables.ToArray(),
124          parameterEntries.Select(kvp => kvp.Value).ToArray());
[14843]125        parameters = new List<DataForVariable>(parameterEntries.Select(kvp => kvp.Key));
[18132]126        initialParamValues = transformator.initialParamValues.ToArray();
[14843]127        func = (vars, @params) => compiledTerm.Evaluate(vars, @params);
128        func_grad = (vars, @params) => compiledTerm.Differentiate(vars, @params);
[14950]129        return true;
130      } catch (ConversionException) {
[14843]131        func = null;
132        func_grad = null;
133        parameters = null;
[18132]134        initialParamValues = null;
[14843]135      }
[14950]136      return false;
[14843]137    }
138
139    // state for recursive transformation of trees
[18132]140    private readonly List<double> initialParamValues;
[14843]141    private readonly Dictionary<DataForVariable, AutoDiff.Variable> parameters;
142    private readonly List<AutoDiff.Variable> variables;
143    private readonly bool makeVariableWeightsVariable;
[15447]144    private readonly bool addLinearScalingTerms;
[18220]145    private readonly HashSet<ISymbolicExpressionTreeNode> excludedNodes;
[14843]146
[18220]147    private TreeToAutoDiffTermConverter(bool makeVariableWeightsVariable, bool addLinearScalingTerms, IEnumerable<ISymbolicExpressionTreeNode> excludedNodes) {
[14843]148      this.makeVariableWeightsVariable = makeVariableWeightsVariable;
[15447]149      this.addLinearScalingTerms = addLinearScalingTerms;
[18220]150      this.excludedNodes = new HashSet<ISymbolicExpressionTreeNode>(excludedNodes);
151
[18132]152      this.initialParamValues = new List<double>();
[14843]153      this.parameters = new Dictionary<DataForVariable, AutoDiff.Variable>();
154      this.variables = new List<AutoDiff.Variable>();
155    }
156
[14950]157    private AutoDiff.Term ConvertToAutoDiff(ISymbolicExpressionTreeNode node) {
[18132]158      if (node.Symbol is Number) {
159        initialParamValues.Add(((NumberTreeNode)node).Value);
[14843]160        var var = new AutoDiff.Variable();
161        variables.Add(var);
[14950]162        return var;
[14843]163      }
[18132]164      if (node.Symbol is Constant) {
165        // constants are fixed in autodiff
166        return (node as ConstantTreeNode).Value;
167      }
[14843]168      if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) {
169        var varNode = node as VariableTreeNodeBase;
170        var factorVarNode = node as BinaryFactorVariableTreeNode;
171        // factor variable values are only 0 or 1 and set in x accordingly
172        var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;
173        var par = FindOrCreateParameter(parameters, varNode.VariableName, varValue);
174
[18220]175        if (makeVariableWeightsVariable && !excludedNodes.Contains(node)) {
[18132]176          initialParamValues.Add(varNode.Weight);
[14843]177          var w = new AutoDiff.Variable();
178          variables.Add(w);
[14950]179          return AutoDiff.TermBuilder.Product(w, par);
[14843]180        } else {
[14950]181          return varNode.Weight * par;
[14843]182        }
183      }
184      if (node.Symbol is FactorVariable) {
185        var factorVarNode = node as FactorVariableTreeNode;
186        var products = new List<Term>();
187        foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
188          var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);
189
[18220]190          if (makeVariableWeightsVariable && !excludedNodes.Contains(node)) {
191            initialParamValues.Add(factorVarNode.GetValue(variableValue));
192            var wVar = new AutoDiff.Variable();
193            variables.Add(wVar);
[14843]194
[18220]195            products.Add(AutoDiff.TermBuilder.Product(wVar, par));
196          } else {
197            var weight = factorVarNode.GetValue(variableValue);
198            products.Add(weight * par);
199          }
200
[14843]201        }
[14950]202        return AutoDiff.TermBuilder.Sum(products);
[14843]203      }
204      if (node.Symbol is LaggedVariable) {
205        var varNode = node as LaggedVariableTreeNode;
206        var par = FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag);
207
[18220]208        if (makeVariableWeightsVariable && !excludedNodes.Contains(node)) {
[18132]209          initialParamValues.Add(varNode.Weight);
[14843]210          var w = new AutoDiff.Variable();
211          variables.Add(w);
[14950]212          return AutoDiff.TermBuilder.Product(w, par);
[14843]213        } else {
[14950]214          return varNode.Weight * par;
[14843]215        }
216      }
217      if (node.Symbol is Addition) {
218        List<AutoDiff.Term> terms = new List<Term>();
219        foreach (var subTree in node.Subtrees) {
[14950]220          terms.Add(ConvertToAutoDiff(subTree));
[14843]221        }
[14950]222        return AutoDiff.TermBuilder.Sum(terms);
[14843]223      }
224      if (node.Symbol is Subtraction) {
225        List<AutoDiff.Term> terms = new List<Term>();
226        for (int i = 0; i < node.SubtreeCount; i++) {
[14950]227          AutoDiff.Term t = ConvertToAutoDiff(node.GetSubtree(i));
[14843]228          if (i > 0) t = -t;
229          terms.Add(t);
230        }
[14950]231        if (terms.Count == 1) return -terms[0];
232        else return AutoDiff.TermBuilder.Sum(terms);
[14843]233      }
234      if (node.Symbol is Multiplication) {
235        List<AutoDiff.Term> terms = new List<Term>();
236        foreach (var subTree in node.Subtrees) {
[14950]237          terms.Add(ConvertToAutoDiff(subTree));
[14843]238        }
[14950]239        if (terms.Count == 1) return terms[0];
240        else return terms.Aggregate((a, b) => new AutoDiff.Product(a, b));
[14843]241      }
242      if (node.Symbol is Division) {
243        List<AutoDiff.Term> terms = new List<Term>();
244        foreach (var subTree in node.Subtrees) {
[14950]245          terms.Add(ConvertToAutoDiff(subTree));
[14843]246        }
[14950]247        if (terms.Count == 1) return 1.0 / terms[0];
248        else return terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b));
[14843]249      }
[16356]250      if (node.Symbol is Absolute) {
251        var x1 = ConvertToAutoDiff(node.GetSubtree(0));
252        return abs(x1);
253      }
[16360]254      if (node.Symbol is AnalyticQuotient) {
[16356]255        var x1 = ConvertToAutoDiff(node.GetSubtree(0));
256        var x2 = ConvertToAutoDiff(node.GetSubtree(1));
257        return x1 / (TermBuilder.Power(1 + x2 * x2, 0.5));
258      }
[14843]259      if (node.Symbol is Logarithm) {
[14950]260        return AutoDiff.TermBuilder.Log(
261          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]262      }
263      if (node.Symbol is Exponential) {
[14950]264        return AutoDiff.TermBuilder.Exp(
265          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]266      }
267      if (node.Symbol is Square) {
[14950]268        return AutoDiff.TermBuilder.Power(
269          ConvertToAutoDiff(node.GetSubtree(0)), 2.0);
[14843]270      }
271      if (node.Symbol is SquareRoot) {
[14950]272        return AutoDiff.TermBuilder.Power(
273          ConvertToAutoDiff(node.GetSubtree(0)), 0.5);
[14843]274      }
[16356]275      if (node.Symbol is Cube) {
276        return AutoDiff.TermBuilder.Power(
277          ConvertToAutoDiff(node.GetSubtree(0)), 3.0);
278      }
279      if (node.Symbol is CubeRoot) {
[16905]280        return cbrt(ConvertToAutoDiff(node.GetSubtree(0)));
[16356]281      }
[17817]282      if (node.Symbol is Power) {
[18132]283        var powerNode = node.GetSubtree(1) as INumericTreeNode;
[17817]284        if (powerNode == null)
[18132]285          throw new NotSupportedException("Only numeric powers are allowed in parameter optimization. Try to use exp() and log() instead of the power symbol.");
[17817]286        var intPower = Math.Truncate(powerNode.Value);
287        if (intPower != powerNode.Value)
288          throw new NotSupportedException("Only integer powers are allowed in parameter optimization. Try to use exp() and log() instead of the power symbol.");
289        return AutoDiff.TermBuilder.Power(ConvertToAutoDiff(node.GetSubtree(0)), intPower);
290      }
[14843]291      if (node.Symbol is Sine) {
[14950]292        return sin(
293          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]294      }
295      if (node.Symbol is Cosine) {
[14950]296        return cos(
297          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]298      }
299      if (node.Symbol is Tangent) {
[14950]300        return tan(
301          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]302      }
[16656]303      if (node.Symbol is HyperbolicTangent) {
304        return tanh(
305          ConvertToAutoDiff(node.GetSubtree(0)));
306      }
[14843]307      if (node.Symbol is Erf) {
[14950]308        return erf(
309          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]310      }
311      if (node.Symbol is Norm) {
[14950]312        return norm(
313          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]314      }
315      if (node.Symbol is StartSymbol) {
[15447]316        if (addLinearScalingTerms) {
[15481]317          // scaling variables α, β are given at the beginning of the parameter vector
[15447]318          var alpha = new AutoDiff.Variable();
319          var beta = new AutoDiff.Variable();
320          variables.Add(beta);
321          variables.Add(alpha);
[15481]322          var t = ConvertToAutoDiff(node.GetSubtree(0));
[15480]323          return t * alpha + beta;
[15447]324        } else return ConvertToAutoDiff(node.GetSubtree(0));
[14843]325      }
[18220]326      if (node.Symbol is SubFunctionSymbol) {
327        return ConvertToAutoDiff(node.GetSubtree(0));
328      }
[14950]329      throw new ConversionException();
[14843]330    }
331
332
333    // for each factor variable value we need a parameter which represents a binary indicator for that variable & value combination
334    // each binary indicator is only necessary once. So we only create a parameter if this combination is not yet available
335    private static Term FindOrCreateParameter(Dictionary<DataForVariable, AutoDiff.Variable> parameters,
336      string varName, string varValue = "", int lag = 0) {
337      var data = new DataForVariable(varName, varValue, lag);
338
[18132]339      if (!parameters.TryGetValue(data, out AutoDiff.Variable par)) {
[14843]340        // not found -> create new parameter and entries in names and values lists
341        par = new AutoDiff.Variable();
342        parameters.Add(data, par);
343      }
344      return par;
345    }
346
347    public static bool IsCompatible(ISymbolicExpressionTree tree) {
348      var containsUnknownSymbol = (
349        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
350        where
[14950]351          !(n.Symbol is Variable) &&
352          !(n.Symbol is BinaryFactorVariable) &&
353          !(n.Symbol is FactorVariable) &&
354          !(n.Symbol is LaggedVariable) &&
[18132]355          !(n.Symbol is Number) &&
[14950]356          !(n.Symbol is Constant) &&
357          !(n.Symbol is Addition) &&
358          !(n.Symbol is Subtraction) &&
359          !(n.Symbol is Multiplication) &&
360          !(n.Symbol is Division) &&
361          !(n.Symbol is Logarithm) &&
362          !(n.Symbol is Exponential) &&
363          !(n.Symbol is SquareRoot) &&
364          !(n.Symbol is Square) &&
365          !(n.Symbol is Sine) &&
366          !(n.Symbol is Cosine) &&
367          !(n.Symbol is Tangent) &&
[16656]368          !(n.Symbol is HyperbolicTangent) &&
[14950]369          !(n.Symbol is Erf) &&
370          !(n.Symbol is Norm) &&
[16356]371          !(n.Symbol is StartSymbol) &&
372          !(n.Symbol is Absolute) &&
[16360]373          !(n.Symbol is AnalyticQuotient) &&
[16356]374          !(n.Symbol is Cube) &&
[17817]375          !(n.Symbol is CubeRoot) &&
[18220]376          !(n.Symbol is Power) &&
377          !(n.Symbol is SubFunctionSymbol)
[14843]378        select n).Any();
379      return !containsUnknownSymbol;
380    }
[14950]381    #region exception class
382    [Serializable]
383    public class ConversionException : Exception {
384
385      public ConversionException() {
386      }
387
388      public ConversionException(string message) : base(message) {
389      }
390
391      public ConversionException(string message, Exception inner) : base(message, inner) {
392      }
393
394      protected ConversionException(
395        SerializationInfo info,
396        StreamingContext context) : base(info, context) {
397      }
398    }
399    #endregion
[14843]400  }
401}
Note: See TracBrowser for help on using the repository browser.