Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3136_Structural_GP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/TreeToAutoDiffTermConverter.cs @ 18146

Last change on this file since 18146 was 18146, checked in by mkommend, 2 years ago

#3136: Merged trunk changes into branch.

File size: 15.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Runtime.Serialization;
26using AutoDiff;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28
29namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
30  public class TreeToAutoDiffTermConverter {
31    public delegate double ParametricFunction(double[] vars, double[] @params);
32
33    public delegate Tuple<double[], double> ParametricFunctionGradient(double[] vars, double[] @params);
34
35    #region helper class
36    public class DataForVariable {
37      public readonly string variableName;
38      public readonly string variableValue; // for factor vars
39      public readonly int lag;
40
41      public DataForVariable(string varName, string varValue, int lag) {
42        this.variableName = varName;
43        this.variableValue = varValue;
44        this.lag = lag;
45      }
46
47      public override bool Equals(object obj) {
48        var other = obj as DataForVariable;
49        if (other == null) return false;
50        return other.variableName.Equals(this.variableName) &&
51               other.variableValue.Equals(this.variableValue) &&
52               other.lag == this.lag;
53      }
54
55      public override int GetHashCode() {
56        return variableName.GetHashCode() ^ variableValue.GetHashCode() ^ lag;
57      }
58    }
59    #endregion
60
61    #region derivations of functions
62    // create function factory for arctangent
63    private static readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
64      eval: Math.Atan,
65      diff: x => 1 / (1 + x * x));
66
67    private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
68      eval: Math.Sin,
69      diff: Math.Cos);
70
71    private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
72      eval: Math.Cos,
73      diff: x => -Math.Sin(x));
74
75    private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
76      eval: Math.Tan,
77      diff: x => 1 + Math.Tan(x) * Math.Tan(x));
78    private static readonly Func<Term, UnaryFunc> tanh = UnaryFunc.Factory(
79      eval: Math.Tanh,
80      diff: x => 1 - Math.Tanh(x) * Math.Tanh(x));
81    private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
82      eval: alglib.errorfunction,
83      diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
84
85    private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
86      eval: alglib.normaldistribution,
87      diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
88
89    private static readonly Func<Term, UnaryFunc> abs = UnaryFunc.Factory(
90      eval: Math.Abs,
91      diff: x => Math.Sign(x)
92      );
93
94    private static readonly Func<Term, UnaryFunc> cbrt = UnaryFunc.Factory(
95      eval: x => x < 0 ? -Math.Pow(-x, 1.0 / 3) : Math.Pow(x, 1.0 / 3),
96      diff: x => { var cbrt_x = x < 0 ? -Math.Pow(-x, 1.0 / 3) : Math.Pow(x, 1.0 / 3); return 1.0 / (3 * cbrt_x * cbrt_x); }
97      );
98
99
100
101    #endregion
102
103    public static bool TryConvertToAutoDiff(ISymbolicExpressionTree tree, bool makeVariableWeightsVariable, bool addLinearScalingTerms,
104      out List<DataForVariable> parameters, out double[] initialParamValues,
105      out ParametricFunction func,
106      out ParametricFunctionGradient func_grad) {
107
108      // use a transformator object which holds the state (variable list, parameter list, ...) for recursive transformation of the tree
109      var transformator = new TreeToAutoDiffTermConverter(makeVariableWeightsVariable, addLinearScalingTerms);
110      AutoDiff.Term term;
111      try {
112        term = transformator.ConvertToAutoDiff(tree.Root.GetSubtree(0));
113        var parameterEntries = transformator.parameters.ToArray(); // guarantee same order for keys and values
114        var compiledTerm = term.Compile(transformator.variables.ToArray(),
115          parameterEntries.Select(kvp => kvp.Value).ToArray());
116        parameters = new List<DataForVariable>(parameterEntries.Select(kvp => kvp.Key));
117        initialParamValues = transformator.initialParamValues.ToArray();
118        func = (vars, @params) => compiledTerm.Evaluate(vars, @params);
119        func_grad = (vars, @params) => compiledTerm.Differentiate(vars, @params);
120        return true;
121      } catch (ConversionException) {
122        func = null;
123        func_grad = null;
124        parameters = null;
125        initialParamValues = null;
126      }
127      return false;
128    }
129
130    // state for recursive transformation of trees
131    private readonly List<double> initialParamValues;
132    private readonly Dictionary<DataForVariable, AutoDiff.Variable> parameters;
133    private readonly List<AutoDiff.Variable> variables;
134    private readonly bool makeVariableWeightsVariable;
135    private readonly bool addLinearScalingTerms;
136
137    private TreeToAutoDiffTermConverter(bool makeVariableWeightsVariable, bool addLinearScalingTerms) {
138      this.makeVariableWeightsVariable = makeVariableWeightsVariable;
139      this.addLinearScalingTerms = addLinearScalingTerms;
140      this.initialParamValues = new List<double>();
141      this.parameters = new Dictionary<DataForVariable, AutoDiff.Variable>();
142      this.variables = new List<AutoDiff.Variable>();
143    }
144
145    private AutoDiff.Term ConvertToAutoDiff(ISymbolicExpressionTreeNode node) {
146      if (node.Symbol is Number) {
147        initialParamValues.Add(((NumberTreeNode)node).Value);
148        var var = new AutoDiff.Variable();
149        variables.Add(var);
150        return var;
151      }
152      if (node.Symbol is Constant) {
153        // constants are fixed in autodiff
154        return (node as ConstantTreeNode).Value;
155      }
156      if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) {
157        var varNode = node as VariableTreeNodeBase;
158        var factorVarNode = node as BinaryFactorVariableTreeNode;
159        // factor variable values are only 0 or 1 and set in x accordingly
160        var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;
161        var par = FindOrCreateParameter(parameters, varNode.VariableName, varValue);
162
163        if (makeVariableWeightsVariable) {
164          initialParamValues.Add(varNode.Weight);
165          var w = new AutoDiff.Variable();
166          variables.Add(w);
167          return AutoDiff.TermBuilder.Product(w, par);
168        } else {
169          return varNode.Weight * par;
170        }
171      }
172      if (node.Symbol is FactorVariable) {
173        var factorVarNode = node as FactorVariableTreeNode;
174        var products = new List<Term>();
175        foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
176          var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);
177
178          initialParamValues.Add(factorVarNode.GetValue(variableValue));
179          var wVar = new AutoDiff.Variable();
180          variables.Add(wVar);
181
182          products.Add(AutoDiff.TermBuilder.Product(wVar, par));
183        }
184        return AutoDiff.TermBuilder.Sum(products);
185      }
186      if (node.Symbol is LaggedVariable) {
187        var varNode = node as LaggedVariableTreeNode;
188        var par = FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag);
189
190        if (makeVariableWeightsVariable) {
191          initialParamValues.Add(varNode.Weight);
192          var w = new AutoDiff.Variable();
193          variables.Add(w);
194          return AutoDiff.TermBuilder.Product(w, par);
195        } else {
196          return varNode.Weight * par;
197        }
198      }
199      if (node.Symbol is Addition) {
200        List<AutoDiff.Term> terms = new List<Term>();
201        foreach (var subTree in node.Subtrees) {
202          terms.Add(ConvertToAutoDiff(subTree));
203        }
204        return AutoDiff.TermBuilder.Sum(terms);
205      }
206      if (node.Symbol is Subtraction) {
207        List<AutoDiff.Term> terms = new List<Term>();
208        for (int i = 0; i < node.SubtreeCount; i++) {
209          AutoDiff.Term t = ConvertToAutoDiff(node.GetSubtree(i));
210          if (i > 0) t = -t;
211          terms.Add(t);
212        }
213        if (terms.Count == 1) return -terms[0];
214        else return AutoDiff.TermBuilder.Sum(terms);
215      }
216      if (node.Symbol is Multiplication) {
217        List<AutoDiff.Term> terms = new List<Term>();
218        foreach (var subTree in node.Subtrees) {
219          terms.Add(ConvertToAutoDiff(subTree));
220        }
221        if (terms.Count == 1) return terms[0];
222        else return terms.Aggregate((a, b) => new AutoDiff.Product(a, b));
223      }
224      if (node.Symbol is Division) {
225        List<AutoDiff.Term> terms = new List<Term>();
226        foreach (var subTree in node.Subtrees) {
227          terms.Add(ConvertToAutoDiff(subTree));
228        }
229        if (terms.Count == 1) return 1.0 / terms[0];
230        else return terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b));
231      }
232      if (node.Symbol is Absolute) {
233        var x1 = ConvertToAutoDiff(node.GetSubtree(0));
234        return abs(x1);
235      }
236      if (node.Symbol is AnalyticQuotient) {
237        var x1 = ConvertToAutoDiff(node.GetSubtree(0));
238        var x2 = ConvertToAutoDiff(node.GetSubtree(1));
239        return x1 / (TermBuilder.Power(1 + x2 * x2, 0.5));
240      }
241      if (node.Symbol is Logarithm) {
242        return AutoDiff.TermBuilder.Log(
243          ConvertToAutoDiff(node.GetSubtree(0)));
244      }
245      if (node.Symbol is Exponential) {
246        return AutoDiff.TermBuilder.Exp(
247          ConvertToAutoDiff(node.GetSubtree(0)));
248      }
249      if (node.Symbol is Square) {
250        return AutoDiff.TermBuilder.Power(
251          ConvertToAutoDiff(node.GetSubtree(0)), 2.0);
252      }
253      if (node.Symbol is SquareRoot) {
254        return AutoDiff.TermBuilder.Power(
255          ConvertToAutoDiff(node.GetSubtree(0)), 0.5);
256      }
257      if (node.Symbol is Cube) {
258        return AutoDiff.TermBuilder.Power(
259          ConvertToAutoDiff(node.GetSubtree(0)), 3.0);
260      }
261      if (node.Symbol is CubeRoot) {
262        return cbrt(ConvertToAutoDiff(node.GetSubtree(0)));
263      }
264      if (node.Symbol is Power) {
265        var powerNode = node.GetSubtree(1) as INumericTreeNode;
266        if (powerNode == null)
267          throw new NotSupportedException("Only numeric powers are allowed in parameter optimization. Try to use exp() and log() instead of the power symbol.");
268        var intPower = Math.Truncate(powerNode.Value);
269        if (intPower != powerNode.Value)
270          throw new NotSupportedException("Only integer powers are allowed in parameter optimization. Try to use exp() and log() instead of the power symbol.");
271        return AutoDiff.TermBuilder.Power(ConvertToAutoDiff(node.GetSubtree(0)), intPower);
272      }
273      if (node.Symbol is Sine) {
274        return sin(
275          ConvertToAutoDiff(node.GetSubtree(0)));
276      }
277      if (node.Symbol is Cosine) {
278        return cos(
279          ConvertToAutoDiff(node.GetSubtree(0)));
280      }
281      if (node.Symbol is Tangent) {
282        return tan(
283          ConvertToAutoDiff(node.GetSubtree(0)));
284      }
285      if (node.Symbol is HyperbolicTangent) {
286        return tanh(
287          ConvertToAutoDiff(node.GetSubtree(0)));
288      }
289      if (node.Symbol is Erf) {
290        return erf(
291          ConvertToAutoDiff(node.GetSubtree(0)));
292      }
293      if (node.Symbol is Norm) {
294        return norm(
295          ConvertToAutoDiff(node.GetSubtree(0)));
296      }
297      if (node.Symbol is StartSymbol) {
298        if (addLinearScalingTerms) {
299          // scaling variables α, β are given at the beginning of the parameter vector
300          var alpha = new AutoDiff.Variable();
301          var beta = new AutoDiff.Variable();
302          variables.Add(beta);
303          variables.Add(alpha);
304          var t = ConvertToAutoDiff(node.GetSubtree(0));
305          return t * alpha + beta;
306        } else return ConvertToAutoDiff(node.GetSubtree(0));
307      }
308      if (node.Symbol is SubFunctionSymbol) {
309        return ConvertToAutoDiff(node.GetSubtree(0));
310      }
311      throw new ConversionException();
312    }
313
314
315    // for each factor variable value we need a parameter which represents a binary indicator for that variable & value combination
316    // each binary indicator is only necessary once. So we only create a parameter if this combination is not yet available
317    private static Term FindOrCreateParameter(Dictionary<DataForVariable, AutoDiff.Variable> parameters,
318      string varName, string varValue = "", int lag = 0) {
319      var data = new DataForVariable(varName, varValue, lag);
320
321      if (!parameters.TryGetValue(data, out AutoDiff.Variable par)) {
322        // not found -> create new parameter and entries in names and values lists
323        par = new AutoDiff.Variable();
324        parameters.Add(data, par);
325      }
326      return par;
327    }
328
329    public static bool IsCompatible(ISymbolicExpressionTree tree) {
330      var containsUnknownSymbol = (
331        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
332        where
333          !(n.Symbol is Variable) &&
334          !(n.Symbol is BinaryFactorVariable) &&
335          !(n.Symbol is FactorVariable) &&
336          !(n.Symbol is LaggedVariable) &&
337          !(n.Symbol is Number) &&
338          !(n.Symbol is Constant) &&
339          !(n.Symbol is Addition) &&
340          !(n.Symbol is Subtraction) &&
341          !(n.Symbol is Multiplication) &&
342          !(n.Symbol is Division) &&
343          !(n.Symbol is Logarithm) &&
344          !(n.Symbol is Exponential) &&
345          !(n.Symbol is SquareRoot) &&
346          !(n.Symbol is Square) &&
347          !(n.Symbol is Sine) &&
348          !(n.Symbol is Cosine) &&
349          !(n.Symbol is Tangent) &&
350          !(n.Symbol is HyperbolicTangent) &&
351          !(n.Symbol is Erf) &&
352          !(n.Symbol is Norm) &&
353          !(n.Symbol is StartSymbol) &&
354          !(n.Symbol is Absolute) &&
355          !(n.Symbol is AnalyticQuotient) &&
356          !(n.Symbol is Cube) &&
357          !(n.Symbol is CubeRoot) &&
358          !(n.Symbol is Power) &&
359          !(n.Symbol is SubFunctionSymbol)
360        select n).Any();
361      return !containsUnknownSymbol;
362    }
363    #region exception class
364    [Serializable]
365    public class ConversionException : Exception {
366
367      public ConversionException() {
368      }
369
370      public ConversionException(string message) : base(message) {
371      }
372
373      public ConversionException(string message, Exception inner) : base(message, inner) {
374      }
375
376      protected ConversionException(
377        SerializationInfo info,
378        StreamingContext context) : base(info, context) {
379      }
380    }
381    #endregion
382  }
383}
Note: See TracBrowser for help on using the repository browser.