Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/TreeToAutoDiffTermConverter.cs @ 15303

Last change on this file since 15303 was 14950, checked in by gkronber, 8 years ago

#2697: code improvement in TreeToAutoDiffTermConverter

File size: 12.3 KB
RevLine 
[14843]1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
[14950]25using System.Runtime.Serialization;
[14843]26using AutoDiff;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28
29namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
30  public class TreeToAutoDiffTermConverter {
31    public delegate double ParametricFunction(double[] vars, double[] @params);
[14950]32
[14843]33    public delegate Tuple<double[], double> ParametricFunctionGradient(double[] vars, double[] @params);
34
35    #region helper class
36    public class DataForVariable {
37      public readonly string variableName;
38      public readonly string variableValue; // for factor vars
39      public readonly int lag;
40
41      public DataForVariable(string varName, string varValue, int lag) {
42        this.variableName = varName;
43        this.variableValue = varValue;
44        this.lag = lag;
45      }
46
47      public override bool Equals(object obj) {
48        var other = obj as DataForVariable;
49        if (other == null) return false;
50        return other.variableName.Equals(this.variableName) &&
51               other.variableValue.Equals(this.variableValue) &&
52               other.lag == this.lag;
53      }
54
55      public override int GetHashCode() {
56        return variableName.GetHashCode() ^ variableValue.GetHashCode() ^ lag;
57      }
58    }
59    #endregion
60
61    #region derivations of functions
62    // create function factory for arctangent
63    private static readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
64      eval: Math.Atan,
65      diff: x => 1 / (1 + x * x));
[14950]66
[14843]67    private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
68      eval: Math.Sin,
69      diff: Math.Cos);
[14950]70
[14843]71    private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
[14950]72      eval: Math.Cos,
73      diff: x => -Math.Sin(x));
74
[14843]75    private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
76      eval: Math.Tan,
77      diff: x => 1 + Math.Tan(x) * Math.Tan(x));
[14950]78
[14843]79    private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
80      eval: alglib.errorfunction,
81      diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
[14950]82
[14843]83    private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
84      eval: alglib.normaldistribution,
85      diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
86
87    #endregion
88
89    public static bool TryConvertToAutoDiff(ISymbolicExpressionTree tree, bool makeVariableWeightsVariable,
90      out List<DataForVariable> parameters, out double[] initialConstants,
91      out ParametricFunction func,
92      out ParametricFunctionGradient func_grad) {
93
94      // use a transformator object which holds the state (variable list, parameter list, ...) for recursive transformation of the tree
95      var transformator = new TreeToAutoDiffTermConverter(makeVariableWeightsVariable);
96      AutoDiff.Term term;
[14950]97      try {
98        term = transformator.ConvertToAutoDiff(tree.Root.GetSubtree(0));
[14843]99        var parameterEntries = transformator.parameters.ToArray(); // guarantee same order for keys and values
[14950]100        var compiledTerm = term.Compile(transformator.variables.ToArray(),
101          parameterEntries.Select(kvp => kvp.Value).ToArray());
[14843]102        parameters = new List<DataForVariable>(parameterEntries.Select(kvp => kvp.Key));
103        initialConstants = transformator.initialConstants.ToArray();
104        func = (vars, @params) => compiledTerm.Evaluate(vars, @params);
105        func_grad = (vars, @params) => compiledTerm.Differentiate(vars, @params);
[14950]106        return true;
107      } catch (ConversionException) {
[14843]108        func = null;
109        func_grad = null;
110        parameters = null;
111        initialConstants = null;
112      }
[14950]113      return false;
[14843]114    }
115
116    // state for recursive transformation of trees
[14950]117    private readonly
118    List<double> initialConstants;
[14843]119    private readonly Dictionary<DataForVariable, AutoDiff.Variable> parameters;
120    private readonly List<AutoDiff.Variable> variables;
121    private readonly bool makeVariableWeightsVariable;
122
123    private TreeToAutoDiffTermConverter(bool makeVariableWeightsVariable) {
124      this.makeVariableWeightsVariable = makeVariableWeightsVariable;
125      this.initialConstants = new List<double>();
126      this.parameters = new Dictionary<DataForVariable, AutoDiff.Variable>();
127      this.variables = new List<AutoDiff.Variable>();
128    }
129
[14950]130    private AutoDiff.Term ConvertToAutoDiff(ISymbolicExpressionTreeNode node) {
[14843]131      if (node.Symbol is Constant) {
132        initialConstants.Add(((ConstantTreeNode)node).Value);
133        var var = new AutoDiff.Variable();
134        variables.Add(var);
[14950]135        return var;
[14843]136      }
137      if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) {
138        var varNode = node as VariableTreeNodeBase;
139        var factorVarNode = node as BinaryFactorVariableTreeNode;
140        // factor variable values are only 0 or 1 and set in x accordingly
141        var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;
142        var par = FindOrCreateParameter(parameters, varNode.VariableName, varValue);
143
144        if (makeVariableWeightsVariable) {
145          initialConstants.Add(varNode.Weight);
146          var w = new AutoDiff.Variable();
147          variables.Add(w);
[14950]148          return AutoDiff.TermBuilder.Product(w, par);
[14843]149        } else {
[14950]150          return varNode.Weight * par;
[14843]151        }
152      }
153      if (node.Symbol is FactorVariable) {
154        var factorVarNode = node as FactorVariableTreeNode;
155        var products = new List<Term>();
156        foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
157          var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);
158
159          initialConstants.Add(factorVarNode.GetValue(variableValue));
160          var wVar = new AutoDiff.Variable();
161          variables.Add(wVar);
162
163          products.Add(AutoDiff.TermBuilder.Product(wVar, par));
164        }
[14950]165        return AutoDiff.TermBuilder.Sum(products);
[14843]166      }
167      if (node.Symbol is LaggedVariable) {
168        var varNode = node as LaggedVariableTreeNode;
169        var par = FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag);
170
171        if (makeVariableWeightsVariable) {
172          initialConstants.Add(varNode.Weight);
173          var w = new AutoDiff.Variable();
174          variables.Add(w);
[14950]175          return AutoDiff.TermBuilder.Product(w, par);
[14843]176        } else {
[14950]177          return varNode.Weight * par;
[14843]178        }
179      }
180      if (node.Symbol is Addition) {
181        List<AutoDiff.Term> terms = new List<Term>();
182        foreach (var subTree in node.Subtrees) {
[14950]183          terms.Add(ConvertToAutoDiff(subTree));
[14843]184        }
[14950]185        return AutoDiff.TermBuilder.Sum(terms);
[14843]186      }
187      if (node.Symbol is Subtraction) {
188        List<AutoDiff.Term> terms = new List<Term>();
189        for (int i = 0; i < node.SubtreeCount; i++) {
[14950]190          AutoDiff.Term t = ConvertToAutoDiff(node.GetSubtree(i));
[14843]191          if (i > 0) t = -t;
192          terms.Add(t);
193        }
[14950]194        if (terms.Count == 1) return -terms[0];
195        else return AutoDiff.TermBuilder.Sum(terms);
[14843]196      }
197      if (node.Symbol is Multiplication) {
198        List<AutoDiff.Term> terms = new List<Term>();
199        foreach (var subTree in node.Subtrees) {
[14950]200          terms.Add(ConvertToAutoDiff(subTree));
[14843]201        }
[14950]202        if (terms.Count == 1) return terms[0];
203        else return terms.Aggregate((a, b) => new AutoDiff.Product(a, b));
[14843]204      }
205      if (node.Symbol is Division) {
206        List<AutoDiff.Term> terms = new List<Term>();
207        foreach (var subTree in node.Subtrees) {
[14950]208          terms.Add(ConvertToAutoDiff(subTree));
[14843]209        }
[14950]210        if (terms.Count == 1) return 1.0 / terms[0];
211        else return terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b));
[14843]212      }
213      if (node.Symbol is Logarithm) {
[14950]214        return AutoDiff.TermBuilder.Log(
215          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]216      }
217      if (node.Symbol is Exponential) {
[14950]218        return AutoDiff.TermBuilder.Exp(
219          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]220      }
221      if (node.Symbol is Square) {
[14950]222        return AutoDiff.TermBuilder.Power(
223          ConvertToAutoDiff(node.GetSubtree(0)), 2.0);
[14843]224      }
225      if (node.Symbol is SquareRoot) {
[14950]226        return AutoDiff.TermBuilder.Power(
227          ConvertToAutoDiff(node.GetSubtree(0)), 0.5);
[14843]228      }
229      if (node.Symbol is Sine) {
[14950]230        return sin(
231          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]232      }
233      if (node.Symbol is Cosine) {
[14950]234        return cos(
235          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]236      }
237      if (node.Symbol is Tangent) {
[14950]238        return tan(
239          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]240      }
241      if (node.Symbol is Erf) {
[14950]242        return erf(
243          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]244      }
245      if (node.Symbol is Norm) {
[14950]246        return norm(
247          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]248      }
249      if (node.Symbol is StartSymbol) {
250        var alpha = new AutoDiff.Variable();
251        var beta = new AutoDiff.Variable();
252        variables.Add(beta);
253        variables.Add(alpha);
[14950]254        return ConvertToAutoDiff(node.GetSubtree(0)) * alpha + beta;
[14843]255      }
[14950]256      throw new ConversionException();
[14843]257    }
258
259
260    // for each factor variable value we need a parameter which represents a binary indicator for that variable & value combination
261    // each binary indicator is only necessary once. So we only create a parameter if this combination is not yet available
262    private static Term FindOrCreateParameter(Dictionary<DataForVariable, AutoDiff.Variable> parameters,
263      string varName, string varValue = "", int lag = 0) {
264      var data = new DataForVariable(varName, varValue, lag);
265
266      AutoDiff.Variable par = null;
267      if (!parameters.TryGetValue(data, out par)) {
268        // not found -> create new parameter and entries in names and values lists
269        par = new AutoDiff.Variable();
270        parameters.Add(data, par);
271      }
272      return par;
273    }
274
275    public static bool IsCompatible(ISymbolicExpressionTree tree) {
276      var containsUnknownSymbol = (
277        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
278        where
[14950]279          !(n.Symbol is Variable) &&
280          !(n.Symbol is BinaryFactorVariable) &&
281          !(n.Symbol is FactorVariable) &&
282          !(n.Symbol is LaggedVariable) &&
283          !(n.Symbol is Constant) &&
284          !(n.Symbol is Addition) &&
285          !(n.Symbol is Subtraction) &&
286          !(n.Symbol is Multiplication) &&
287          !(n.Symbol is Division) &&
288          !(n.Symbol is Logarithm) &&
289          !(n.Symbol is Exponential) &&
290          !(n.Symbol is SquareRoot) &&
291          !(n.Symbol is Square) &&
292          !(n.Symbol is Sine) &&
293          !(n.Symbol is Cosine) &&
294          !(n.Symbol is Tangent) &&
295          !(n.Symbol is Erf) &&
296          !(n.Symbol is Norm) &&
297          !(n.Symbol is StartSymbol)
[14843]298        select n).Any();
299      return !containsUnknownSymbol;
300    }
[14950]301    #region exception class
302    [Serializable]
303    public class ConversionException : Exception {
304
305      public ConversionException() {
306      }
307
308      public ConversionException(string message) : base(message) {
309      }
310
311      public ConversionException(string message, Exception inner) : base(message, inner) {
312      }
313
314      protected ConversionException(
315        SerializationInfo info,
316        StreamingContext context) : base(info, context) {
317      }
318    }
319    #endregion
[14843]320  }
321}
Note: See TracBrowser for help on using the repository browser.