source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/TreeToAutoDiffTermConverter.cs @ 14390

Last change on this file since 14390 was 14390, checked in by gkronber, 4 years ago

#2697:

  • renaming of folder "Transformation" to "Converters" to distinguish between transformations for variables (from data preprocessing) and classes for transformation of trees.
  • renamed SymbolicDataAnalysisExpressionTreeSimplifier -> TreeSimplifier
  • Implemented a converter to create a linar model as a symbolic expression tree
File size: 11.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using AutoDiff;
26using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
27
28namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
29  public class TreeToAutoDiffTermConverter {
30    public delegate double ParametricFunction(double[] vars, double[] @params);
31    public delegate Tuple<double[], double> ParametricFunctionGradient(double[] vars, double[] @params);
32
33    #region derivations of functions
34    // create function factory for arctangent
35    private static readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
36      eval: Math.Atan,
37      diff: x => 1 / (1 + x * x));
38    private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
39      eval: Math.Sin,
40      diff: Math.Cos);
41    private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
42       eval: Math.Cos,
43       diff: x => -Math.Sin(x));
44    private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
45      eval: Math.Tan,
46      diff: x => 1 + Math.Tan(x) * Math.Tan(x));
47    private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
48      eval: alglib.errorfunction,
49      diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
50    private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
51      eval: alglib.normaldistribution,
52      diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
53
54    #endregion
55
56    public static bool TryTransformToAutoDiff(ISymbolicExpressionTree tree, bool makeVariableWeightsVariable,
57      out string[] variableNames, out int[] lags, out double[] initialConstants,
58      out ParametricFunction func,
59      out ParametricFunctionGradient func_grad) {
60
61      // use a transformator object which holds the state (variable list, parameter list, ...) for recursive transformation of the tree
62      var transformator = new TreeToAutoDiffTermConverter(makeVariableWeightsVariable);
63      AutoDiff.Term term;
64      var success = transformator.TryTransformToAutoDiff(tree.Root.GetSubtree(0), out term);
65      if (success) {
66        var compiledTerm = term.Compile(transformator.variables.ToArray(), transformator.parameters.ToArray());
67        variableNames = transformator.variableNames.ToArray();
68        lags = transformator.lags.ToArray();
69        initialConstants = transformator.initialConstants.ToArray();
70        func = (vars, @params) => compiledTerm.Evaluate(vars, @params);
71        func_grad = (vars, @params) => compiledTerm.Differentiate(vars, @params);
72      } else {
73        func = null;
74        func_grad = null;
75        variableNames = null;
76        lags = null;
77        initialConstants = null;
78      }
79      return success;
80    }
81
82    // state for recursive transformation of trees
83    private readonly List<string> variableNames;
84    private readonly List<int> lags;
85    private readonly List<double> initialConstants;
86    private readonly List<AutoDiff.Variable> parameters;
87    private readonly List<AutoDiff.Variable> variables;
88    private readonly bool makeVariableWeightsVariable;
89
90    private TreeToAutoDiffTermConverter(bool makeVariableWeightsVariable) {
91      this.makeVariableWeightsVariable = makeVariableWeightsVariable;
92      this.variableNames = new List<string>();
93      this.lags = new List<int>();
94      this.initialConstants = new List<double>();
95      this.parameters = new List<AutoDiff.Variable>();
96      this.variables = new List<AutoDiff.Variable>();
97    }
98
99    private bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, out AutoDiff.Term term) {
100      if (node.Symbol is Constant) {
101        initialConstants.Add(((ConstantTreeNode)node).Value);
102        var var = new AutoDiff.Variable();
103        variables.Add(var);
104        term = var;
105        return true;
106      }
107      if (node.Symbol is Variable) {
108        var varNode = node as VariableTreeNode;
109        var par = new AutoDiff.Variable();
110        parameters.Add(par);
111        variableNames.Add(varNode.VariableName);
112        lags.Add(0);
113
114        if (makeVariableWeightsVariable) {
115          initialConstants.Add(varNode.Weight);
116          var w = new AutoDiff.Variable();
117          variables.Add(w);
118          term = AutoDiff.TermBuilder.Product(w, par);
119        } else {
120          term = varNode.Weight * par;
121        }
122        return true;
123      }
124      if (node.Symbol is LaggedVariable) {
125        var varNode = node as LaggedVariableTreeNode;
126        var par = new AutoDiff.Variable();
127        parameters.Add(par);
128        variableNames.Add(varNode.VariableName);
129        lags.Add(varNode.Lag);
130
131        if (makeVariableWeightsVariable) {
132          initialConstants.Add(varNode.Weight);
133          var w = new AutoDiff.Variable();
134          variables.Add(w);
135          term = AutoDiff.TermBuilder.Product(w, par);
136        } else {
137          term = varNode.Weight * par;
138        }
139        return true;
140      }
141      if (node.Symbol is Addition) {
142        List<AutoDiff.Term> terms = new List<Term>();
143        foreach (var subTree in node.Subtrees) {
144          AutoDiff.Term t;
145          if (!TryTransformToAutoDiff(subTree, out t)) {
146            term = null;
147            return false;
148          }
149          terms.Add(t);
150        }
151        term = AutoDiff.TermBuilder.Sum(terms);
152        return true;
153      }
154      if (node.Symbol is Subtraction) {
155        List<AutoDiff.Term> terms = new List<Term>();
156        for (int i = 0; i < node.SubtreeCount; i++) {
157          AutoDiff.Term t;
158          if (!TryTransformToAutoDiff(node.GetSubtree(i), out t)) {
159            term = null;
160            return false;
161          }
162          if (i > 0) t = -t;
163          terms.Add(t);
164        }
165        if (terms.Count == 1) term = -terms[0];
166        else term = AutoDiff.TermBuilder.Sum(terms);
167        return true;
168      }
169      if (node.Symbol is Multiplication) {
170        List<AutoDiff.Term> terms = new List<Term>();
171        foreach (var subTree in node.Subtrees) {
172          AutoDiff.Term t;
173          if (!TryTransformToAutoDiff(subTree, out t)) {
174            term = null;
175            return false;
176          }
177          terms.Add(t);
178        }
179        if (terms.Count == 1) term = terms[0];
180        else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, b));
181        return true;
182
183      }
184      if (node.Symbol is Division) {
185        List<AutoDiff.Term> terms = new List<Term>();
186        foreach (var subTree in node.Subtrees) {
187          AutoDiff.Term t;
188          if (!TryTransformToAutoDiff(subTree, out t)) {
189            term = null;
190            return false;
191          }
192          terms.Add(t);
193        }
194        if (terms.Count == 1) term = 1.0 / terms[0];
195        else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b));
196        return true;
197      }
198      if (node.Symbol is Logarithm) {
199        AutoDiff.Term t;
200        if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {
201          term = null;
202          return false;
203        } else {
204          term = AutoDiff.TermBuilder.Log(t);
205          return true;
206        }
207      }
208      if (node.Symbol is Exponential) {
209        AutoDiff.Term t;
210        if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {
211          term = null;
212          return false;
213        } else {
214          term = AutoDiff.TermBuilder.Exp(t);
215          return true;
216        }
217      }
218      if (node.Symbol is Square) {
219        AutoDiff.Term t;
220        if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {
221          term = null;
222          return false;
223        } else {
224          term = AutoDiff.TermBuilder.Power(t, 2.0);
225          return true;
226        }
227      }
228      if (node.Symbol is SquareRoot) {
229        AutoDiff.Term t;
230        if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {
231          term = null;
232          return false;
233        } else {
234          term = AutoDiff.TermBuilder.Power(t, 0.5);
235          return true;
236        }
237      }
238      if (node.Symbol is Sine) {
239        AutoDiff.Term t;
240        if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {
241          term = null;
242          return false;
243        } else {
244          term = sin(t);
245          return true;
246        }
247      }
248      if (node.Symbol is Cosine) {
249        AutoDiff.Term t;
250        if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {
251          term = null;
252          return false;
253        } else {
254          term = cos(t);
255          return true;
256        }
257      }
258      if (node.Symbol is Tangent) {
259        AutoDiff.Term t;
260        if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {
261          term = null;
262          return false;
263        } else {
264          term = tan(t);
265          return true;
266        }
267      }
268      if (node.Symbol is Erf) {
269        AutoDiff.Term t;
270        if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {
271          term = null;
272          return false;
273        } else {
274          term = erf(t);
275          return true;
276        }
277      }
278      if (node.Symbol is Norm) {
279        AutoDiff.Term t;
280        if (!TryTransformToAutoDiff(node.GetSubtree(0), out t)) {
281          term = null;
282          return false;
283        } else {
284          term = norm(t);
285          return true;
286        }
287      }
288      if (node.Symbol is StartSymbol) {
289        var alpha = new AutoDiff.Variable(); // TODO
290        var beta = new AutoDiff.Variable();
291        variables.Add(beta);
292        variables.Add(alpha);
293        AutoDiff.Term branchTerm;
294        if (TryTransformToAutoDiff(node.GetSubtree(0), out branchTerm)) {
295          term = branchTerm * alpha + beta;
296          return true;
297        } else {
298          term = null;
299          return false;
300        }
301      }
302      term = null;
303      return false;
304    }
305
306
307    public static bool IsCompatible(ISymbolicExpressionTree tree) {
308      var containsUnknownSymbol = (
309        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
310        where
311        !(n.Symbol is Variable) &&
312        !(n.Symbol is LaggedVariable) &&
313        !(n.Symbol is Constant) &&
314        !(n.Symbol is Addition) &&
315        !(n.Symbol is Subtraction) &&
316        !(n.Symbol is Multiplication) &&
317        !(n.Symbol is Division) &&
318        !(n.Symbol is Logarithm) &&
319        !(n.Symbol is Exponential) &&
320        !(n.Symbol is SquareRoot) &&
321        !(n.Symbol is Square) &&
322        !(n.Symbol is Sine) &&
323        !(n.Symbol is Cosine) &&
324        !(n.Symbol is Tangent) &&
325        !(n.Symbol is Erf) &&
326        !(n.Symbol is Norm) &&
327        !(n.Symbol is StartSymbol)
328        select n).Any();
329      return !containsUnknownSymbol;
330    }
331  }
332}
Note: See TracBrowser for help on using the repository browser.