source: branches/2974_Constants_Optimization/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/TreeToAutoDiffTermConverter.cs @ 16500

Last change on this file since 16500 was 16500, checked in by mkommend, 14 months ago

#2974: Added intermediate version of new constants optimization for profiling.

File size: 15.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Runtime.Serialization;
26using AutoDiff;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28
29namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
30  public class TreeToAutoDiffTermConverter {
31    public delegate double ParametricFunction(double[] vars, double[] @params);
32
33    public delegate Tuple<double[], double> ParametricFunctionGradient(double[] vars, double[] @params);
34
35    #region helper class
36    public class DataForVariable {
37      public readonly string variableName;
38      public readonly string variableValue; // for factor vars
39      public readonly int lag;
40
41      public DataForVariable(string varName, string varValue, int lag) {
42        this.variableName = varName;
43        this.variableValue = varValue;
44        this.lag = lag;
45      }
46
47      public override bool Equals(object obj) {
48        var other = obj as DataForVariable;
49        if (other == null) return false;
50        return other.variableName.Equals(this.variableName) &&
51               other.variableValue.Equals(this.variableValue) &&
52               other.lag == this.lag;
53      }
54
55      public override int GetHashCode() {
56        return variableName.GetHashCode() ^ variableValue.GetHashCode() ^ lag;
57      }
58    }
59    #endregion
60
61    #region derivations of functions
62    // create function factory for arctangent
63    private static readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
64      eval: Math.Atan,
65      diff: x => 1 / (1 + x * x));
66
67    private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
68      eval: Math.Sin,
69      diff: Math.Cos);
70
71    private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
72      eval: Math.Cos,
73      diff: x => -Math.Sin(x));
74
75    private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
76      eval: Math.Tan,
77      diff: x => 1 + Math.Tan(x) * Math.Tan(x));
78
79    private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
80      eval: alglib.errorfunction,
81      diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
82
83    private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
84      eval: alglib.normaldistribution,
85      diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
86
87    private static readonly Func<Term, UnaryFunc> abs = UnaryFunc.Factory(
88      eval: Math.Abs,
89      diff: x => Math.Sign(x)
90      );
91
92    #endregion
93
94    public static bool TryConvertToAutoDiff(ISymbolicExpressionTree tree, bool makeVariableWeightsVariable, bool addLinearScalingTerms,
95      out List<DataForVariable> parameters, out double[] initialConstants,
96      out ParametricFunction func,
97      out ParametricFunctionGradient func_grad) {
98
99      // use a transformator object which holds the state (variable list, parameter list, ...) for recursive transformation of the tree
100      var transformator = new TreeToAutoDiffTermConverter(makeVariableWeightsVariable);
101      AutoDiff.Term term;
102      try {
103        term = transformator.ConvertToAutoDiff(tree.Root.GetSubtree(0));
104
105        if (addLinearScalingTerms) {
106          // scaling variables α, β are given at the beginning of the parameter vector
107          var alpha = new AutoDiff.Variable();
108          var beta = new AutoDiff.Variable();
109          transformator.variables.Insert(0, alpha);
110          transformator.variables.Insert(0, beta);
111
112          term = term * alpha + beta;
113        }
114
115        var parameterEntries = transformator.parameters.ToArray(); // guarantee same order for keys and values
116        var compiledTerm = term.Compile(transformator.variables.ToArray(),
117          parameterEntries.Select(kvp => kvp.Value).ToArray());
118
119        parameters = new List<DataForVariable>(parameterEntries.Select(kvp => kvp.Key));
120        initialConstants = transformator.initialConstants.ToArray();
121        func = (vars, @params) => compiledTerm.Evaluate(vars, @params);
122        func_grad = (vars, @params) => compiledTerm.Differentiate(vars, @params);
123        return true;
124      } catch (ConversionException) {
125        parameters = null;
126        initialConstants = null;
127        func = null;
128        func_grad = null;
129      }
130      return false;
131    }
132
133    public static bool TryConvertToAutoDiff(ISymbolicExpressionTree tree, bool addLinearScalingTerms, IEnumerable<DataForVariable> variables,
134      out IParametricCompiledTerm autoDiffTerm, out double[] initialConstants) {
135      // use a transformator object which holds the state (variable list, parameter list, ...) for recursive transformation of the tree
136      //TODO change ctor
137      var transformator = new TreeToAutoDiffTermConverter(true);
138      var parameters = new AutoDiff.Variable[variables.Count()];
139
140      int i = 0;
141      foreach(var variable in variables) { 
142        var autoDiffVar = new AutoDiff.Variable();
143        transformator.parameters.Add(variable, autoDiffVar);
144        parameters[i] = autoDiffVar;
145        i++;
146      }
147
148      AutoDiff.Term term;
149      try {
150        term = transformator.ConvertToAutoDiff(tree.Root.GetSubtree(0));
151        if (addLinearScalingTerms) {
152          // scaling variables α, β are given at the end of the parameter vector
153          var alpha = new AutoDiff.Variable();
154          var beta = new AutoDiff.Variable();
155
156          term = term * alpha + beta;
157
158          transformator.variables.Add(alpha);
159          transformator.variables.Add(beta);
160
161          transformator.initialConstants.Add(1.0);
162          transformator.initialConstants.Add(0.0);
163        }
164
165        var compiledTerm = term.Compile(transformator.variables.ToArray(), parameters);
166        autoDiffTerm = compiledTerm;
167        initialConstants = transformator.initialConstants.ToArray();
168
169        return true;
170      } catch (ConversionException) {
171        autoDiffTerm = null;
172        initialConstants = null;
173      }
174      return false;
175    }
176
177    // state for recursive transformation of trees
178    private readonly List<double> initialConstants;
179    private readonly Dictionary<DataForVariable, AutoDiff.Variable> parameters;
180    private readonly List<AutoDiff.Variable> variables;
181    private readonly bool makeVariableWeightsVariable;
182
183    private TreeToAutoDiffTermConverter(bool makeVariableWeightsVariable) {
184      this.makeVariableWeightsVariable = makeVariableWeightsVariable;
185      this.initialConstants = new List<double>();
186      this.parameters = new Dictionary<DataForVariable, AutoDiff.Variable>();
187      this.variables = new List<AutoDiff.Variable>();
188    }
189
190    private AutoDiff.Term ConvertToAutoDiff(ISymbolicExpressionTreeNode node) {
191      if (node.Symbol is Constant) {
192        initialConstants.Add(((ConstantTreeNode)node).Value);
193        var var = new AutoDiff.Variable();
194        variables.Add(var);
195        return var;
196      }
197      if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) {
198        var varNode = node as VariableTreeNodeBase;
199        var factorVarNode = node as BinaryFactorVariableTreeNode;
200        // factor variable values are only 0 or 1 and set in x accordingly
201        var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;
202        var par = FindOrCreateParameter(parameters, varNode.VariableName, varValue);
203
204        if (makeVariableWeightsVariable) {
205          initialConstants.Add(varNode.Weight);
206          var w = new AutoDiff.Variable();
207          variables.Add(w);
208          return AutoDiff.TermBuilder.Product(w, par);
209        } else {
210          return varNode.Weight * par;
211        }
212      }
213      if (node.Symbol is FactorVariable) {
214        var factorVarNode = node as FactorVariableTreeNode;
215        var products = new List<Term>();
216        foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
217          var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);
218
219          initialConstants.Add(factorVarNode.GetValue(variableValue));
220          var wVar = new AutoDiff.Variable();
221          variables.Add(wVar);
222
223          products.Add(AutoDiff.TermBuilder.Product(wVar, par));
224        }
225        return AutoDiff.TermBuilder.Sum(products);
226      }
227      if (node.Symbol is LaggedVariable) {
228        var varNode = node as LaggedVariableTreeNode;
229        var par = FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag);
230
231        if (makeVariableWeightsVariable) {
232          initialConstants.Add(varNode.Weight);
233          var w = new AutoDiff.Variable();
234          variables.Add(w);
235          return AutoDiff.TermBuilder.Product(w, par);
236        } else {
237          return varNode.Weight * par;
238        }
239      }
240      if (node.Symbol is Addition) {
241        List<AutoDiff.Term> terms = new List<Term>();
242        foreach (var subTree in node.Subtrees) {
243          terms.Add(ConvertToAutoDiff(subTree));
244        }
245        return AutoDiff.TermBuilder.Sum(terms);
246      }
247      if (node.Symbol is Subtraction) {
248        List<AutoDiff.Term> terms = new List<Term>();
249        for (int i = 0; i < node.SubtreeCount; i++) {
250          AutoDiff.Term t = ConvertToAutoDiff(node.GetSubtree(i));
251          if (i > 0) t = -t;
252          terms.Add(t);
253        }
254        if (terms.Count == 1) return -terms[0];
255        else return AutoDiff.TermBuilder.Sum(terms);
256      }
257      if (node.Symbol is Multiplication) {
258        List<AutoDiff.Term> terms = new List<Term>();
259        foreach (var subTree in node.Subtrees) {
260          terms.Add(ConvertToAutoDiff(subTree));
261        }
262        if (terms.Count == 1) return terms[0];
263        else return terms.Aggregate((a, b) => new AutoDiff.Product(a, b));
264      }
265      if (node.Symbol is Division) {
266        List<AutoDiff.Term> terms = new List<Term>();
267        foreach (var subTree in node.Subtrees) {
268          terms.Add(ConvertToAutoDiff(subTree));
269        }
270        if (terms.Count == 1) return 1.0 / terms[0];
271        else return terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b));
272      }
273      if (node.Symbol is Absolute) {
274        var x1 = ConvertToAutoDiff(node.GetSubtree(0));
275        return abs(x1);
276      }
277      if (node.Symbol is AnalyticQuotient) {
278        var x1 = ConvertToAutoDiff(node.GetSubtree(0));
279        var x2 = ConvertToAutoDiff(node.GetSubtree(1));
280        return x1 / (TermBuilder.Power(1 + x2 * x2, 0.5));
281      }
282      if (node.Symbol is Logarithm) {
283        return AutoDiff.TermBuilder.Log(
284          ConvertToAutoDiff(node.GetSubtree(0)));
285      }
286      if (node.Symbol is Exponential) {
287        return AutoDiff.TermBuilder.Exp(
288          ConvertToAutoDiff(node.GetSubtree(0)));
289      }
290      if (node.Symbol is Square) {
291        return AutoDiff.TermBuilder.Power(
292          ConvertToAutoDiff(node.GetSubtree(0)), 2.0);
293      }
294      if (node.Symbol is SquareRoot) {
295        return AutoDiff.TermBuilder.Power(
296          ConvertToAutoDiff(node.GetSubtree(0)), 0.5);
297      }
298      if (node.Symbol is Cube) {
299        return AutoDiff.TermBuilder.Power(
300          ConvertToAutoDiff(node.GetSubtree(0)), 3.0);
301      }
302      if (node.Symbol is CubeRoot) {
303        return AutoDiff.TermBuilder.Power(
304          ConvertToAutoDiff(node.GetSubtree(0)), 1.0 / 3.0);
305      }
306      if (node.Symbol is Sine) {
307        return sin(
308          ConvertToAutoDiff(node.GetSubtree(0)));
309      }
310      if (node.Symbol is Cosine) {
311        return cos(
312          ConvertToAutoDiff(node.GetSubtree(0)));
313      }
314      if (node.Symbol is Tangent) {
315        return tan(
316          ConvertToAutoDiff(node.GetSubtree(0)));
317      }
318      if (node.Symbol is Erf) {
319        return erf(
320          ConvertToAutoDiff(node.GetSubtree(0)));
321      }
322      if (node.Symbol is Norm) {
323        return norm(
324          ConvertToAutoDiff(node.GetSubtree(0)));
325      }
326      if (node.Symbol is StartSymbol) {
327        return ConvertToAutoDiff(node.GetSubtree(0));
328      }
329      throw new ConversionException();
330    }
331
332
333    // for each factor variable value we need a parameter which represents a binary indicator for that variable & value combination
334    // each binary indicator is only necessary once. So we only create a parameter if this combination is not yet available
335    private static Term FindOrCreateParameter(Dictionary<DataForVariable, AutoDiff.Variable> parameters,
336      string varName, string varValue = "", int lag = 0) {
337      var data = new DataForVariable(varName, varValue, lag);
338
339      AutoDiff.Variable par = null;
340      if (!parameters.TryGetValue(data, out par)) {
341        // not found -> create new parameter and entries in names and values lists
342        par = new AutoDiff.Variable();
343        parameters.Add(data, par);
344      }
345      return par;
346    }
347
348    public static bool IsCompatible(ISymbolicExpressionTree tree) {
349      var containsUnknownSymbol = (
350        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
351        where
352          !(n.Symbol is Variable) &&
353          !(n.Symbol is BinaryFactorVariable) &&
354          !(n.Symbol is FactorVariable) &&
355          !(n.Symbol is LaggedVariable) &&
356          !(n.Symbol is Constant) &&
357          !(n.Symbol is Addition) &&
358          !(n.Symbol is Subtraction) &&
359          !(n.Symbol is Multiplication) &&
360          !(n.Symbol is Division) &&
361          !(n.Symbol is Logarithm) &&
362          !(n.Symbol is Exponential) &&
363          !(n.Symbol is SquareRoot) &&
364          !(n.Symbol is Square) &&
365          !(n.Symbol is Sine) &&
366          !(n.Symbol is Cosine) &&
367          !(n.Symbol is Tangent) &&
368          !(n.Symbol is Erf) &&
369          !(n.Symbol is Norm) &&
370          !(n.Symbol is StartSymbol) &&
371          !(n.Symbol is Absolute) &&
372          !(n.Symbol is AnalyticQuotient) &&
373          !(n.Symbol is Cube) &&
374          !(n.Symbol is CubeRoot)
375        select n).Any();
376      return !containsUnknownSymbol;
377    }
378    #region exception class
379    [Serializable]
380    public class ConversionException : Exception {
381
382      public ConversionException() {
383      }
384
385      public ConversionException(string message) : base(message) {
386      }
387
388      public ConversionException(string message, Exception inner) : base(message, inner) {
389      }
390
391      protected ConversionException(
392        SerializationInfo info,
393        StreamingContext context) : base(info, context) {
394      }
395    }
396    #endregion
397  }
398}
Note: See TracBrowser for help on using the repository browser.