Free cookie consent management tool by TermsFeed Policy Generator

source: stable/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Transformation/TreeToAutoDiffTermTransformator.cs @ 15141

Last change on this file since 15141 was 14840, checked in by gkronber, 8 years ago

#2697 applied changes from r14378 again

File size: 13.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using AutoDiff;
26using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
27
28namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
29  public class TreeToAutoDiffTermTransformator {
30    public delegate double ParametricFunction(double[] vars, double[] @params);
31    public delegate Tuple<double[], double> ParametricFunctionGradient(double[] vars, double[] @params);
32
33    #region helper class
34    public class DataForVariable {
35      public readonly string variableName;
36      public readonly string variableValue; // for factor vars
37      public readonly int lag;
38
39      public DataForVariable(string varName, string varValue, int lag) {
40        this.variableName = varName;
41        this.variableValue = varValue;
42        this.lag = lag;
43      }
44
45      public override bool Equals(object obj) {
46        var other = obj as DataForVariable;
47        if (other == null) return false;
48        return other.variableName.Equals(this.variableName) &&
49               other.variableValue.Equals(this.variableValue) &&
50               other.lag == this.lag;
51      }
52
53      public override int GetHashCode() {
54        return variableName.GetHashCode() ^ variableValue.GetHashCode() ^ lag;
55      }
56    }
57    #endregion
58
59    #region derivations of functions
60    // create function factory for arctangent
61    private static readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
62      eval: Math.Atan,
63      diff: x => 1 / (1 + x * x));
64    private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
65      eval: Math.Sin,
66      diff: Math.Cos);
67    private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
68       eval: Math.Cos,
69       diff: x => -Math.Sin(x));
70    private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
71      eval: Math.Tan,
72      diff: x => 1 + Math.Tan(x) * Math.Tan(x));
73    private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
74      eval: alglib.errorfunction,
75      diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
76    private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
77      eval: alglib.normaldistribution,
78      diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
79
80    #endregion
81
82    public static bool TryTransformToAutoDiff(ISymbolicExpressionTree tree, bool makeVariableWeightsVariable,
83      out List<DataForVariable> parameters, out double[] initialConstants,
84      out ParametricFunction func,
85      out ParametricFunctionGradient func_grad) {
86
87      // use a transformator object which holds the state (variable list, parameter list, ...) for recursive transformation of the tree
88      var transformator = new TreeToAutoDiffTermTransformator(makeVariableWeightsVariable);
89      AutoDiff.Term term;
90      var success = transformator.TryTransformToAutoDiff(tree.Root.GetSubtree(0), out term);
91      if (success) {
92        var parameterEntries = transformator.parameters.ToArray(); // guarantee same order for keys and values
93        var compiledTerm = term.Compile(transformator.variables.ToArray(), parameterEntries.Select(kvp => kvp.Value).ToArray());
94        parameters = new List<DataForVariable>(parameterEntries.Select(kvp => kvp.Key));
95        initialConstants = transformator.initialConstants.ToArray();
96        func = (vars, @params) => compiledTerm.Evaluate(vars, @params);
97        func_grad = (vars, @params) => compiledTerm.Differentiate(vars, @params);
98      } else {
99        func = null;
100        func_grad = null;
101        parameters = null;
102        initialConstants = null;
103      }
104      return success;
105    }
106
107    // state for recursive transformation of trees
108    private readonly List<string> variableNames;
109    private readonly List<int> lags;
110    private readonly List<double> initialConstants;
111    private readonly Dictionary<DataForVariable, AutoDiff.Variable> parameters;
112    private readonly List<AutoDiff.Variable> variables;
113    private readonly bool makeVariableWeightsVariable;
114
115    private TreeToAutoDiffTermTransformator(bool makeVariableWeightsVariable) {
116      this.makeVariableWeightsVariable = makeVariableWeightsVariable;
117      this.variableNames = new List<string>();
118      this.lags = new List<int>();
119      this.initialConstants = new List<double>();
120      this.parameters = new Dictionary<DataForVariable, AutoDiff.Variable>();
121      this.variables = new List<AutoDiff.Variable>();
122    }
123
124    private bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, out AutoDiff.Term term) {
125      if (node.Symbol is Constant) {
126        initialConstants.Add(((ConstantTreeNode)node).Value);
127        var var = new AutoDiff.Variable();
128        variables.Add(var);
129        term = var;
130        return true;
131      }
132      if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) {
133        var varNode = node as VariableTreeNode;
134        var factorVarNode = node as BinaryFactorVariableTreeNode;
135        // factor variable values are only 0 or 1 and set in x accordingly
136        var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;
137        var par = FindOrCreateParameter(parameters, varNode.VariableName, varValue);
138
139        if (makeVariableWeightsVariable) {
140          initialConstants.Add(varNode.Weight);
141          var w = new AutoDiff.Variable();
142          variables.Add(w);
143          term = AutoDiff.TermBuilder.Product(w, par);
144        } else {
145          term = varNode.Weight * par;
146        }
147        return true;
148      }
149      if (node.Symbol is FactorVariable) {
150        var factorVarNode = node as FactorVariableTreeNode;
151        var products = new List<Term>();
152        foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
153          var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);
154
155          initialConstants.Add(factorVarNode.GetValue(variableValue));
156          var wVar = new AutoDiff.Variable();
157          variables.Add(wVar);
158
159          products.Add(AutoDiff.TermBuilder.Product(wVar, par));
160        }
161        term = AutoDiff.TermBuilder.Sum(products);
162        return true;
163      }
164      if (node.Symbol is LaggedVariable) {
165        var varNode = node as LaggedVariableTreeNode;
166        var par = FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag);
167
168        if (makeVariableWeightsVariable) {
169          initialConstants.Add(varNode.Weight);
170          var w = new AutoDiff.Variable();
171          variables.Add(w);
172          term = AutoDiff.TermBuilder.Product(w, par);
173        } else {
174          term = varNode.Weight * par;
175        }
176        return true;
177      }
178      if (node.Symbol is Addition) {
179        List<AutoDiff.Term> terms = new List<Term>();
180        foreach (var subTree in node.Subtrees) {
181          AutoDiff.Term t;
182          if (!TryTransformToAutoDiff(subTree, out t)) {
183            term = null;
184            return false;
185          }
186          terms.Add(t);
187        }
188        term = AutoDiff.TermBuilder.Sum(terms);
189        return true;
190      }
191      if (node.Symbol is Subtraction) {
192        List<AutoDiff.Term> terms = new List<Term>();
193        for (int i = 0; i < node.SubtreeCount; i++) {
194          AutoDiff.Term t;
195          if (!TryTransformToAutoDiff(node.GetSubtree(i), out t)) {
196            term = null;
197            return false;
198          }
199          if (i > 0) t = -t;
200          terms.Add(t);
201        }
202        if (terms.Count == 1) term = -terms[0];
203        else term = AutoDiff.TermBuilder.Sum(terms);
204        return true;
205      }
206      if (node.Symbol is Multiplication) {
207        List<AutoDiff.Term> terms = new List<Term>();
208        foreach (var subTree in node.Subtrees) {
209          AutoDiff.Term t;
210          if (!TryTransformToAutoDiff(subTree, out t)) {
211            term = null;
212            return false;
213          }
214          terms.Add(t);
215        }
216        if (terms.Count == 1) term = terms[0];
217        else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, b));
218        return true;
219
220      }
221      if (node.Symbol is Division) {
222        List<AutoDiff.Term> terms = new List<Term>();
223        foreach (var subTree in node.Subtrees) {
224          AutoDiff.Term t;
225          if (!TryTransformToAutoDiff(subTree, out t)) {
226            term = null;
227            return false;
228          }
229          terms.Add(t);
230        }
231        if (terms.Count == 1) term = 1.0 / terms[0];
232        else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b));
233        return true;
234      }
235      if (node.Symbol is Logarithm) {
236        AutoDiff.Term t;
237        if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {
238          term = null;
239          return false;
240        } else {
241          term = AutoDiff.TermBuilder.Log(t);
242          return true;
243        }
244      }
245      if (node.Symbol is Exponential) {
246        AutoDiff.Term t;
247        if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {
248          term = null;
249          return false;
250        } else {
251          term = AutoDiff.TermBuilder.Exp(t);
252          return true;
253        }
254      }
255      if (node.Symbol is Square) {
256        AutoDiff.Term t;
257        if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {
258          term = null;
259          return false;
260        } else {
261          term = AutoDiff.TermBuilder.Power(t, 2.0);
262          return true;
263        }
264      }
265      if (node.Symbol is SquareRoot) {
266        AutoDiff.Term t;
267        if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {
268          term = null;
269          return false;
270        } else {
271          term = AutoDiff.TermBuilder.Power(t, 0.5);
272          return true;
273        }
274      }
275      if (node.Symbol is Sine) {
276        AutoDiff.Term t;
277        if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {
278          term = null;
279          return false;
280        } else {
281          term = sin(t);
282          return true;
283        }
284      }
285      if (node.Symbol is Cosine) {
286        AutoDiff.Term t;
287        if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {
288          term = null;
289          return false;
290        } else {
291          term = cos(t);
292          return true;
293        }
294      }
295      if (node.Symbol is Tangent) {
296        AutoDiff.Term t;
297        if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {
298          term = null;
299          return false;
300        } else {
301          term = tan(t);
302          return true;
303        }
304      }
305      if (node.Symbol is Erf) {
306        AutoDiff.Term t;
307        if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {
308          term = null;
309          return false;
310        } else {
311          term = erf(t);
312          return true;
313        }
314      }
315      if (node.Symbol is Norm) {
316        AutoDiff.Term t;
317        if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {
318          term = null;
319          return false;
320        } else {
321          term = norm(t);
322          return true;
323        }
324      }
325      if (node.Symbol is StartSymbol) {
326        var alpha = new AutoDiff.Variable();
327        var beta = new AutoDiff.Variable();
328        variables.Add(beta);
329        variables.Add(alpha);
330        AutoDiff.Term branchTerm;
331        if (TryTransformToAutoDiff(node.GetSubtree(0), out branchTerm)) {
332          term = branchTerm * alpha + beta;
333          return true;
334        } else {
335          term = null;
336          return false;
337        }
338      }
339      term = null;
340      return false;
341    }
342
343
344    // for each factor variable value we need a parameter which represents a binary indicator for that variable & value combination
345    // each binary indicator is only necessary once. So we only create a parameter if this combination is not yet available
346    private static Term FindOrCreateParameter(Dictionary<DataForVariable, AutoDiff.Variable> parameters,
347      string varName, string varValue = "", int lag = 0) {
348      var data = new DataForVariable(varName, varValue, lag);
349
350      AutoDiff.Variable par = null;
351      if (!parameters.TryGetValue(data, out par)) {
352        // not found -> create new parameter and entries in names and values lists
353        par = new AutoDiff.Variable();
354        parameters.Add(data, par);
355      }
356      return par;
357    }
358
359    public static bool IsCompatible(ISymbolicExpressionTree tree) {
360      var containsUnknownSymbol = (
361        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
362        where
363        !(n.Symbol is Variable) &&
364        !(n.Symbol is LaggedVariable) &&
365        !(n.Symbol is Constant) &&
366        !(n.Symbol is Addition) &&
367        !(n.Symbol is Subtraction) &&
368        !(n.Symbol is Multiplication) &&
369        !(n.Symbol is Division) &&
370        !(n.Symbol is Logarithm) &&
371        !(n.Symbol is Exponential) &&
372        !(n.Symbol is SquareRoot) &&
373        !(n.Symbol is Square) &&
374        !(n.Symbol is Sine) &&
375        !(n.Symbol is Cosine) &&
376        !(n.Symbol is Tangent) &&
377        !(n.Symbol is Erf) &&
378        !(n.Symbol is Norm) &&
379        !(n.Symbol is StartSymbol)
380        select n).Any();
381      return !containsUnknownSymbol;
382    }
383  }
384}
Note: See TracBrowser for help on using the repository browser.