Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/TreeToAutoDiffTermConverter.cs @ 14851

Last change on this file since 14851 was 14851, checked in by gkronber, 7 years ago

#2697: removed unnecessary variables

File size: 13.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using AutoDiff;
26using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
27
28namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
29  public class TreeToAutoDiffTermConverter {
30    public delegate double ParametricFunction(double[] vars, double[] @params);
31    public delegate Tuple<double[], double> ParametricFunctionGradient(double[] vars, double[] @params);
32
33    #region helper class
34    public class DataForVariable {
35      public readonly string variableName;
36      public readonly string variableValue; // for factor vars
37      public readonly int lag;
38
39      public DataForVariable(string varName, string varValue, int lag) {
40        this.variableName = varName;
41        this.variableValue = varValue;
42        this.lag = lag;
43      }
44
45      public override bool Equals(object obj) {
46        var other = obj as DataForVariable;
47        if (other == null) return false;
48        return other.variableName.Equals(this.variableName) &&
49               other.variableValue.Equals(this.variableValue) &&
50               other.lag == this.lag;
51      }
52
53      public override int GetHashCode() {
54        return variableName.GetHashCode() ^ variableValue.GetHashCode() ^ lag;
55      }
56    }
57    #endregion
58
59    #region derivations of functions
60    // create function factory for arctangent
61    private static readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
62      eval: Math.Atan,
63      diff: x => 1 / (1 + x * x));
64    private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
65      eval: Math.Sin,
66      diff: Math.Cos);
67    private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
68       eval: Math.Cos,
69       diff: x => -Math.Sin(x));
70    private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
71      eval: Math.Tan,
72      diff: x => 1 + Math.Tan(x) * Math.Tan(x));
73    private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
74      eval: alglib.errorfunction,
75      diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
76    private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
77      eval: alglib.normaldistribution,
78      diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
79
80    #endregion
81
82    public static bool TryConvertToAutoDiff(ISymbolicExpressionTree tree, bool makeVariableWeightsVariable,
83      out List<DataForVariable> parameters, out double[] initialConstants,
84      out ParametricFunction func,
85      out ParametricFunctionGradient func_grad) {
86
87      // use a transformator object which holds the state (variable list, parameter list, ...) for recursive transformation of the tree
88      var transformator = new TreeToAutoDiffTermConverter(makeVariableWeightsVariable);
89      AutoDiff.Term term;
90      var success = transformator.TryConvertToAutoDiff(tree.Root.GetSubtree(0), out term);
91      if (success) {
92        var parameterEntries = transformator.parameters.ToArray(); // guarantee same order for keys and values
93        var compiledTerm = term.Compile(transformator.variables.ToArray(), parameterEntries.Select(kvp => kvp.Value).ToArray());
94        parameters = new List<DataForVariable>(parameterEntries.Select(kvp => kvp.Key));
95        initialConstants = transformator.initialConstants.ToArray();
96        func = (vars, @params) => compiledTerm.Evaluate(vars, @params);
97        func_grad = (vars, @params) => compiledTerm.Differentiate(vars, @params);
98      } else {
99        func = null;
100        func_grad = null;
101        parameters = null;
102        initialConstants = null;
103      }
104      return success;
105    }
106
107    // state for recursive transformation of trees
108    private readonly List<double> initialConstants;
109    private readonly Dictionary<DataForVariable, AutoDiff.Variable> parameters;
110    private readonly List<AutoDiff.Variable> variables;
111    private readonly bool makeVariableWeightsVariable;
112
113    private TreeToAutoDiffTermConverter(bool makeVariableWeightsVariable) {
114      this.makeVariableWeightsVariable = makeVariableWeightsVariable;
115      this.initialConstants = new List<double>();
116      this.parameters = new Dictionary<DataForVariable, AutoDiff.Variable>();
117      this.variables = new List<AutoDiff.Variable>();
118    }
119
120    private bool TryConvertToAutoDiff(ISymbolicExpressionTreeNode node, out AutoDiff.Term term) {
121      if (node.Symbol is Constant) {
122        initialConstants.Add(((ConstantTreeNode)node).Value);
123        var var = new AutoDiff.Variable();
124        variables.Add(var);
125        term = var;
126        return true;
127      }
128      if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) {
129        var varNode = node as VariableTreeNodeBase;
130        var factorVarNode = node as BinaryFactorVariableTreeNode;
131        // factor variable values are only 0 or 1 and set in x accordingly
132        var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;
133        var par = FindOrCreateParameter(parameters, varNode.VariableName, varValue);
134
135        if (makeVariableWeightsVariable) {
136          initialConstants.Add(varNode.Weight);
137          var w = new AutoDiff.Variable();
138          variables.Add(w);
139          term = AutoDiff.TermBuilder.Product(w, par);
140        } else {
141          term = varNode.Weight * par;
142        }
143        return true;
144      }
145      if (node.Symbol is FactorVariable) {
146        var factorVarNode = node as FactorVariableTreeNode;
147        var products = new List<Term>();
148        foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
149          var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);
150
151          initialConstants.Add(factorVarNode.GetValue(variableValue));
152          var wVar = new AutoDiff.Variable();
153          variables.Add(wVar);
154
155          products.Add(AutoDiff.TermBuilder.Product(wVar, par));
156        }
157        term = AutoDiff.TermBuilder.Sum(products);
158        return true;
159      }
160      if (node.Symbol is LaggedVariable) {
161        var varNode = node as LaggedVariableTreeNode;
162        var par = FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag);
163
164        if (makeVariableWeightsVariable) {
165          initialConstants.Add(varNode.Weight);
166          var w = new AutoDiff.Variable();
167          variables.Add(w);
168          term = AutoDiff.TermBuilder.Product(w, par);
169        } else {
170          term = varNode.Weight * par;
171        }
172        return true;
173      }
174      if (node.Symbol is Addition) {
175        List<AutoDiff.Term> terms = new List<Term>();
176        foreach (var subTree in node.Subtrees) {
177          AutoDiff.Term t;
178          if (!TryConvertToAutoDiff(subTree, out t)) {
179            term = null;
180            return false;
181          }
182          terms.Add(t);
183        }
184        term = AutoDiff.TermBuilder.Sum(terms);
185        return true;
186      }
187      if (node.Symbol is Subtraction) {
188        List<AutoDiff.Term> terms = new List<Term>();
189        for (int i = 0; i < node.SubtreeCount; i++) {
190          AutoDiff.Term t;
191          if (!TryConvertToAutoDiff(node.GetSubtree(i), out t)) {
192            term = null;
193            return false;
194          }
195          if (i > 0) t = -t;
196          terms.Add(t);
197        }
198        if (terms.Count == 1) term = -terms[0];
199        else term = AutoDiff.TermBuilder.Sum(terms);
200        return true;
201      }
202      if (node.Symbol is Multiplication) {
203        List<AutoDiff.Term> terms = new List<Term>();
204        foreach (var subTree in node.Subtrees) {
205          AutoDiff.Term t;
206          if (!TryConvertToAutoDiff(subTree, out t)) {
207            term = null;
208            return false;
209          }
210          terms.Add(t);
211        }
212        if (terms.Count == 1) term = terms[0];
213        else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, b));
214        return true;
215
216      }
217      if (node.Symbol is Division) {
218        List<AutoDiff.Term> terms = new List<Term>();
219        foreach (var subTree in node.Subtrees) {
220          AutoDiff.Term t;
221          if (!TryConvertToAutoDiff(subTree, out t)) {
222            term = null;
223            return false;
224          }
225          terms.Add(t);
226        }
227        if (terms.Count == 1) term = 1.0 / terms[0];
228        else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b));
229        return true;
230      }
231      if (node.Symbol is Logarithm) {
232        AutoDiff.Term t;
233        if (!TryConvertToAutoDiff(node.GetSubtree(0), out t)) {
234          term = null;
235          return false;
236        } else {
237          term = AutoDiff.TermBuilder.Log(t);
238          return true;
239        }
240      }
241      if (node.Symbol is Exponential) {
242        AutoDiff.Term t;
243        if (!TryConvertToAutoDiff(node.GetSubtree(0), out t)) {
244          term = null;
245          return false;
246        } else {
247          term = AutoDiff.TermBuilder.Exp(t);
248          return true;
249        }
250      }
251      if (node.Symbol is Square) {
252        AutoDiff.Term t;
253        if (!TryConvertToAutoDiff(node.GetSubtree(0), out t)) {
254          term = null;
255          return false;
256        } else {
257          term = AutoDiff.TermBuilder.Power(t, 2.0);
258          return true;
259        }
260      }
261      if (node.Symbol is SquareRoot) {
262        AutoDiff.Term t;
263        if (!TryConvertToAutoDiff(node.GetSubtree(0), out t)) {
264          term = null;
265          return false;
266        } else {
267          term = AutoDiff.TermBuilder.Power(t, 0.5);
268          return true;
269        }
270      }
271      if (node.Symbol is Sine) {
272        AutoDiff.Term t;
273        if (!TryConvertToAutoDiff(node.GetSubtree(0), out t)) {
274          term = null;
275          return false;
276        } else {
277          term = sin(t);
278          return true;
279        }
280      }
281      if (node.Symbol is Cosine) {
282        AutoDiff.Term t;
283        if (!TryConvertToAutoDiff(node.GetSubtree(0), out t)) {
284          term = null;
285          return false;
286        } else {
287          term = cos(t);
288          return true;
289        }
290      }
291      if (node.Symbol is Tangent) {
292        AutoDiff.Term t;
293        if (!TryConvertToAutoDiff(node.GetSubtree(0), out t)) {
294          term = null;
295          return false;
296        } else {
297          term = tan(t);
298          return true;
299        }
300      }
301      if (node.Symbol is Erf) {
302        AutoDiff.Term t;
303        if (!TryConvertToAutoDiff(node.GetSubtree(0), out t)) {
304          term = null;
305          return false;
306        } else {
307          term = erf(t);
308          return true;
309        }
310      }
311      if (node.Symbol is Norm) {
312        AutoDiff.Term t;
313        if (!TryConvertToAutoDiff(node.GetSubtree(0), out t)) {
314          term = null;
315          return false;
316        } else {
317          term = norm(t);
318          return true;
319        }
320      }
321      if (node.Symbol is StartSymbol) {
322        var alpha = new AutoDiff.Variable();
323        var beta = new AutoDiff.Variable();
324        variables.Add(beta);
325        variables.Add(alpha);
326        AutoDiff.Term branchTerm;
327        if (TryConvertToAutoDiff(node.GetSubtree(0), out branchTerm)) {
328          term = branchTerm * alpha + beta;
329          return true;
330        } else {
331          term = null;
332          return false;
333        }
334      }
335      term = null;
336      return false;
337    }
338
339
340    // for each factor variable value we need a parameter which represents a binary indicator for that variable & value combination
341    // each binary indicator is only necessary once. So we only create a parameter if this combination is not yet available
342    private static Term FindOrCreateParameter(Dictionary<DataForVariable, AutoDiff.Variable> parameters,
343      string varName, string varValue = "", int lag = 0) {
344      var data = new DataForVariable(varName, varValue, lag);
345
346      AutoDiff.Variable par = null;
347      if (!parameters.TryGetValue(data, out par)) {
348        // not found -> create new parameter and entries in names and values lists
349        par = new AutoDiff.Variable();
350        parameters.Add(data, par);
351      }
352      return par;
353    }
354
355    public static bool IsCompatible(ISymbolicExpressionTree tree) {
356      var containsUnknownSymbol = (
357        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
358        where
359        !(n.Symbol is Variable) &&
360        !(n.Symbol is BinaryFactorVariable) &&
361        !(n.Symbol is FactorVariable) &&
362        !(n.Symbol is LaggedVariable) &&
363        !(n.Symbol is Constant) &&
364        !(n.Symbol is Addition) &&
365        !(n.Symbol is Subtraction) &&
366        !(n.Symbol is Multiplication) &&
367        !(n.Symbol is Division) &&
368        !(n.Symbol is Logarithm) &&
369        !(n.Symbol is Exponential) &&
370        !(n.Symbol is SquareRoot) &&
371        !(n.Symbol is Square) &&
372        !(n.Symbol is Sine) &&
373        !(n.Symbol is Cosine) &&
374        !(n.Symbol is Tangent) &&
375        !(n.Symbol is Erf) &&
376        !(n.Symbol is Norm) &&
377        !(n.Symbol is StartSymbol)
378        select n).Any();
379      return !containsUnknownSymbol;
380    }
381  }
382}
Note: See TracBrowser for help on using the repository browser.