Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2866_SymRegHyperbolicFunctions/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/TreeToAutoDiffTermConverter.cs @ 17226

Last change on this file since 17226 was 16654, checked in by gkronber, 6 years ago

#2866: merged r16364:16653 from trunk to branch to prepare for trunk reintegration (resolving conflicts in the project file)

File size: 14.0 KB
RevLine 
[14843]1#region License Information
2/* HeuristicLab
[16654]3 * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[14843]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
[14950]25using System.Runtime.Serialization;
[14843]26using AutoDiff;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28
29namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
30  public class TreeToAutoDiffTermConverter {
31    public delegate double ParametricFunction(double[] vars, double[] @params);
[14950]32
[14843]33    public delegate Tuple<double[], double> ParametricFunctionGradient(double[] vars, double[] @params);
34
35    #region helper class
36    public class DataForVariable {
37      public readonly string variableName;
38      public readonly string variableValue; // for factor vars
39      public readonly int lag;
40
41      public DataForVariable(string varName, string varValue, int lag) {
42        this.variableName = varName;
43        this.variableValue = varValue;
44        this.lag = lag;
45      }
46
47      public override bool Equals(object obj) {
48        var other = obj as DataForVariable;
49        if (other == null) return false;
50        return other.variableName.Equals(this.variableName) &&
51               other.variableValue.Equals(this.variableValue) &&
52               other.lag == this.lag;
53      }
54
55      public override int GetHashCode() {
56        return variableName.GetHashCode() ^ variableValue.GetHashCode() ^ lag;
57      }
58    }
59    #endregion
60
61    #region derivations of functions
62    // create function factory for arctangent
63    private static readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
64      eval: Math.Atan,
65      diff: x => 1 / (1 + x * x));
[14950]66
[14843]67    private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
68      eval: Math.Sin,
69      diff: Math.Cos);
[14950]70
[14843]71    private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
[14950]72      eval: Math.Cos,
73      diff: x => -Math.Sin(x));
74
[14843]75    private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
76      eval: Math.Tan,
77      diff: x => 1 + Math.Tan(x) * Math.Tan(x));
[16531]78    private static readonly Func<Term, UnaryFunc> tanh = UnaryFunc.Factory(
79      eval: Math.Tanh,
80      diff: x => 1 - Math.Tanh(x) * Math.Tanh(x));
[14843]81    private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
82      eval: alglib.errorfunction,
83      diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
[14950]84
[14843]85    private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
86      eval: alglib.normaldistribution,
87      diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
88
[16356]89    private static readonly Func<Term, UnaryFunc> abs = UnaryFunc.Factory(
90      eval: Math.Abs,
91      diff: x => Math.Sign(x)
92      );
93
[14843]94    #endregion
95
[15447]96    public static bool TryConvertToAutoDiff(ISymbolicExpressionTree tree, bool makeVariableWeightsVariable, bool addLinearScalingTerms,
[14843]97      out List<DataForVariable> parameters, out double[] initialConstants,
98      out ParametricFunction func,
99      out ParametricFunctionGradient func_grad) {
100
101      // use a transformator object which holds the state (variable list, parameter list, ...) for recursive transformation of the tree
[15447]102      var transformator = new TreeToAutoDiffTermConverter(makeVariableWeightsVariable, addLinearScalingTerms);
[14843]103      AutoDiff.Term term;
[14950]104      try {
105        term = transformator.ConvertToAutoDiff(tree.Root.GetSubtree(0));
[14843]106        var parameterEntries = transformator.parameters.ToArray(); // guarantee same order for keys and values
[14950]107        var compiledTerm = term.Compile(transformator.variables.ToArray(),
108          parameterEntries.Select(kvp => kvp.Value).ToArray());
[14843]109        parameters = new List<DataForVariable>(parameterEntries.Select(kvp => kvp.Key));
110        initialConstants = transformator.initialConstants.ToArray();
111        func = (vars, @params) => compiledTerm.Evaluate(vars, @params);
112        func_grad = (vars, @params) => compiledTerm.Differentiate(vars, @params);
[14950]113        return true;
114      } catch (ConversionException) {
[14843]115        func = null;
116        func_grad = null;
117        parameters = null;
118        initialConstants = null;
119      }
[14950]120      return false;
[14843]121    }
122
123    // state for recursive transformation of trees
[14950]124    private readonly
125    List<double> initialConstants;
[14843]126    private readonly Dictionary<DataForVariable, AutoDiff.Variable> parameters;
127    private readonly List<AutoDiff.Variable> variables;
128    private readonly bool makeVariableWeightsVariable;
[15447]129    private readonly bool addLinearScalingTerms;
[14843]130
[15447]131    private TreeToAutoDiffTermConverter(bool makeVariableWeightsVariable, bool addLinearScalingTerms) {
[14843]132      this.makeVariableWeightsVariable = makeVariableWeightsVariable;
[15447]133      this.addLinearScalingTerms = addLinearScalingTerms;
[14843]134      this.initialConstants = new List<double>();
135      this.parameters = new Dictionary<DataForVariable, AutoDiff.Variable>();
136      this.variables = new List<AutoDiff.Variable>();
137    }
138
[14950]139    private AutoDiff.Term ConvertToAutoDiff(ISymbolicExpressionTreeNode node) {
[14843]140      if (node.Symbol is Constant) {
141        initialConstants.Add(((ConstantTreeNode)node).Value);
142        var var = new AutoDiff.Variable();
143        variables.Add(var);
[14950]144        return var;
[14843]145      }
146      if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) {
147        var varNode = node as VariableTreeNodeBase;
148        var factorVarNode = node as BinaryFactorVariableTreeNode;
149        // factor variable values are only 0 or 1 and set in x accordingly
150        var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;
151        var par = FindOrCreateParameter(parameters, varNode.VariableName, varValue);
152
153        if (makeVariableWeightsVariable) {
154          initialConstants.Add(varNode.Weight);
155          var w = new AutoDiff.Variable();
156          variables.Add(w);
[14950]157          return AutoDiff.TermBuilder.Product(w, par);
[14843]158        } else {
[14950]159          return varNode.Weight * par;
[14843]160        }
161      }
162      if (node.Symbol is FactorVariable) {
163        var factorVarNode = node as FactorVariableTreeNode;
164        var products = new List<Term>();
165        foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
166          var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);
167
168          initialConstants.Add(factorVarNode.GetValue(variableValue));
169          var wVar = new AutoDiff.Variable();
170          variables.Add(wVar);
171
172          products.Add(AutoDiff.TermBuilder.Product(wVar, par));
173        }
[14950]174        return AutoDiff.TermBuilder.Sum(products);
[14843]175      }
176      if (node.Symbol is LaggedVariable) {
177        var varNode = node as LaggedVariableTreeNode;
178        var par = FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag);
179
180        if (makeVariableWeightsVariable) {
181          initialConstants.Add(varNode.Weight);
182          var w = new AutoDiff.Variable();
183          variables.Add(w);
[14950]184          return AutoDiff.TermBuilder.Product(w, par);
[14843]185        } else {
[14950]186          return varNode.Weight * par;
[14843]187        }
188      }
189      if (node.Symbol is Addition) {
190        List<AutoDiff.Term> terms = new List<Term>();
191        foreach (var subTree in node.Subtrees) {
[14950]192          terms.Add(ConvertToAutoDiff(subTree));
[14843]193        }
[14950]194        return AutoDiff.TermBuilder.Sum(terms);
[14843]195      }
196      if (node.Symbol is Subtraction) {
197        List<AutoDiff.Term> terms = new List<Term>();
198        for (int i = 0; i < node.SubtreeCount; i++) {
[14950]199          AutoDiff.Term t = ConvertToAutoDiff(node.GetSubtree(i));
[14843]200          if (i > 0) t = -t;
201          terms.Add(t);
202        }
[14950]203        if (terms.Count == 1) return -terms[0];
204        else return AutoDiff.TermBuilder.Sum(terms);
[14843]205      }
206      if (node.Symbol is Multiplication) {
207        List<AutoDiff.Term> terms = new List<Term>();
208        foreach (var subTree in node.Subtrees) {
[14950]209          terms.Add(ConvertToAutoDiff(subTree));
[14843]210        }
[14950]211        if (terms.Count == 1) return terms[0];
212        else return terms.Aggregate((a, b) => new AutoDiff.Product(a, b));
[14843]213      }
214      if (node.Symbol is Division) {
215        List<AutoDiff.Term> terms = new List<Term>();
216        foreach (var subTree in node.Subtrees) {
[14950]217          terms.Add(ConvertToAutoDiff(subTree));
[14843]218        }
[14950]219        if (terms.Count == 1) return 1.0 / terms[0];
220        else return terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b));
[14843]221      }
[16356]222      if (node.Symbol is Absolute) {
223        var x1 = ConvertToAutoDiff(node.GetSubtree(0));
224        return abs(x1);
225      }
[16360]226      if (node.Symbol is AnalyticQuotient) {
[16356]227        var x1 = ConvertToAutoDiff(node.GetSubtree(0));
228        var x2 = ConvertToAutoDiff(node.GetSubtree(1));
229        return x1 / (TermBuilder.Power(1 + x2 * x2, 0.5));
230      }
[14843]231      if (node.Symbol is Logarithm) {
[14950]232        return AutoDiff.TermBuilder.Log(
233          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]234      }
235      if (node.Symbol is Exponential) {
[14950]236        return AutoDiff.TermBuilder.Exp(
237          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]238      }
239      if (node.Symbol is Square) {
[14950]240        return AutoDiff.TermBuilder.Power(
241          ConvertToAutoDiff(node.GetSubtree(0)), 2.0);
[14843]242      }
243      if (node.Symbol is SquareRoot) {
[14950]244        return AutoDiff.TermBuilder.Power(
245          ConvertToAutoDiff(node.GetSubtree(0)), 0.5);
[14843]246      }
[16356]247      if (node.Symbol is Cube) {
248        return AutoDiff.TermBuilder.Power(
249          ConvertToAutoDiff(node.GetSubtree(0)), 3.0);
250      }
251      if (node.Symbol is CubeRoot) {
252        return AutoDiff.TermBuilder.Power(
253          ConvertToAutoDiff(node.GetSubtree(0)), 1.0/3.0);
254      }
[14843]255      if (node.Symbol is Sine) {
[14950]256        return sin(
257          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]258      }
259      if (node.Symbol is Cosine) {
[14950]260        return cos(
261          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]262      }
263      if (node.Symbol is Tangent) {
[14950]264        return tan(
265          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]266      }
[16531]267      if (node.Symbol is HyperbolicTangent) {
268        return tanh(
269          ConvertToAutoDiff(node.GetSubtree(0)));
270      }
[14843]271      if (node.Symbol is Erf) {
[14950]272        return erf(
273          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]274      }
275      if (node.Symbol is Norm) {
[14950]276        return norm(
277          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]278      }
279      if (node.Symbol is StartSymbol) {
[15447]280        if (addLinearScalingTerms) {
[15481]281          // scaling variables α, β are given at the beginning of the parameter vector
[15447]282          var alpha = new AutoDiff.Variable();
283          var beta = new AutoDiff.Variable();
284          variables.Add(beta);
285          variables.Add(alpha);
[15481]286          var t = ConvertToAutoDiff(node.GetSubtree(0));
[15480]287          return t * alpha + beta;
[15447]288        } else return ConvertToAutoDiff(node.GetSubtree(0));
[14843]289      }
[14950]290      throw new ConversionException();
[14843]291    }
292
293
294    // for each factor variable value we need a parameter which represents a binary indicator for that variable & value combination
295    // each binary indicator is only necessary once. So we only create a parameter if this combination is not yet available
296    private static Term FindOrCreateParameter(Dictionary<DataForVariable, AutoDiff.Variable> parameters,
297      string varName, string varValue = "", int lag = 0) {
298      var data = new DataForVariable(varName, varValue, lag);
299
300      AutoDiff.Variable par = null;
301      if (!parameters.TryGetValue(data, out par)) {
302        // not found -> create new parameter and entries in names and values lists
303        par = new AutoDiff.Variable();
304        parameters.Add(data, par);
305      }
306      return par;
307    }
308
309    public static bool IsCompatible(ISymbolicExpressionTree tree) {
310      var containsUnknownSymbol = (
311        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
312        where
[14950]313          !(n.Symbol is Variable) &&
314          !(n.Symbol is BinaryFactorVariable) &&
315          !(n.Symbol is FactorVariable) &&
316          !(n.Symbol is LaggedVariable) &&
317          !(n.Symbol is Constant) &&
318          !(n.Symbol is Addition) &&
319          !(n.Symbol is Subtraction) &&
320          !(n.Symbol is Multiplication) &&
321          !(n.Symbol is Division) &&
322          !(n.Symbol is Logarithm) &&
323          !(n.Symbol is Exponential) &&
324          !(n.Symbol is SquareRoot) &&
325          !(n.Symbol is Square) &&
326          !(n.Symbol is Sine) &&
327          !(n.Symbol is Cosine) &&
328          !(n.Symbol is Tangent) &&
[16531]329          !(n.Symbol is HyperbolicTangent) &&
[14950]330          !(n.Symbol is Erf) &&
331          !(n.Symbol is Norm) &&
[16356]332          !(n.Symbol is StartSymbol) &&
333          !(n.Symbol is Absolute) &&
[16360]334          !(n.Symbol is AnalyticQuotient) &&
[16356]335          !(n.Symbol is Cube) &&
336          !(n.Symbol is CubeRoot)
[14843]337        select n).Any();
338      return !containsUnknownSymbol;
339    }
[14950]340    #region exception class
341    [Serializable]
342    public class ConversionException : Exception {
343
344      public ConversionException() {
345      }
346
347      public ConversionException(string message) : base(message) {
348      }
349
350      public ConversionException(string message, Exception inner) : base(message, inner) {
351      }
352
353      protected ConversionException(
354        SerializationInfo info,
355        StreamingContext context) : base(info, context) {
356      }
357    }
358    #endregion
[14843]359  }
360}
Note: See TracBrowser for help on using the repository browser.