source: branches/2974_Constants_Optimization/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/ConstantsOptimization/AutoDiffConverter.cs @ 16507

Last change on this file since 16507 was 16507, checked in by mkommend, 15 months ago

#2974: First stable version of new CoOp.

File size: 12.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Runtime.Serialization;
26using AutoDiff;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28
29namespace HeuristicLab.Problems.DataAnalysis.Symbolic.ConstantsOptimization{
30  public class AutoDiffConverter {
31
32    /// <summary>
33    /// Converts a symbolic expression tree into a parametetric AutoDiff term.
34    /// </summary>
35    /// <param name="tree">The tree the should be converted.</param>
36    /// <param name="addLinearScalingTerms">A flag that determines whether linear scaling terms should be added to the parametric term.</param>
37    /// <param name="numericNodes">The nodes that contain numeric coefficents that should be added as variables in the term.</param>
38    /// <param name="variableData">The variable information that is used to create parameters in the term.</param>
39    /// <param name="autoDiffTerm">The resulting parametric AutoDiff term.</param>
40    /// <returns>A flag to see if the conversion has succeeded.</returns>
41    public static bool TryConvertToAutoDiff(ISymbolicExpressionTree tree, bool addLinearScalingTerms,
42      IEnumerable<ISymbolicExpressionTreeNode> numericNodes, IEnumerable<VariableData> variableData,
43      out IParametricCompiledTerm autoDiffTerm) {
44      // use a transformator object which holds the state (variable list, parameter list, ...) for recursive transformation of the tree
45      var transformator = new AutoDiffConverter(numericNodes, variableData);
46      AutoDiff.Term term;
47
48      try {
49        term = transformator.ConvertToAutoDiff(tree.Root.GetSubtree(0));
50        if (addLinearScalingTerms) {
51          // scaling variables α, β are given at the end of the parameter vector
52          var alpha = new AutoDiff.Variable();
53          var beta = new AutoDiff.Variable();
54
55          term = term * alpha + beta;
56
57          transformator.variables.Add(alpha);
58          transformator.variables.Add(beta);
59        }
60        var compiledTerm = term.Compile(transformator.variables.ToArray(), transformator.parameters.Values.ToArray());
61        autoDiffTerm = compiledTerm;
62        return true;
63      } catch (ConversionException) {
64        autoDiffTerm = null;
65      }
66      return false;
67    }
68
69    // state for recursive transformation of trees
70    private readonly HashSet<ISymbolicExpressionTreeNode> nodesForOptimization;
71    private readonly Dictionary<VariableData, AutoDiff.Variable> parameters;
72    private readonly List<AutoDiff.Variable> variables;
73
74    private AutoDiffConverter(IEnumerable<ISymbolicExpressionTreeNode> nodesForOptimization, IEnumerable<VariableData> variableData) {
75      this.nodesForOptimization = new HashSet<ISymbolicExpressionTreeNode>(nodesForOptimization);
76      this.parameters = variableData.ToDictionary(k => k, v => new AutoDiff.Variable());
77      this.variables = new List<AutoDiff.Variable>();
78    }
79
80    private AutoDiff.Term ConvertToAutoDiff(ISymbolicExpressionTreeNode node) {
81      if (node.Symbol is Constant) {
82        var constantNode = node as ConstantTreeNode;
83        var value = constantNode.Value;
84        if (nodesForOptimization.Contains(node)) {
85          AutoDiff.Variable var = new AutoDiff.Variable();
86          variables.Add(var);
87          return var;
88        } else {
89          return value;
90        }
91      }
92      if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) {
93        var varNode = node as VariableTreeNodeBase;
94        var factorVarNode = node as BinaryFactorVariableTreeNode;
95        // factor variable values are only 0 or 1 and set in x accordingly
96        var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;
97        var data = new VariableData(varNode.VariableName, varValue, 0);
98        var par = parameters[data];
99        var value = varNode.Weight;
100
101        if (nodesForOptimization.Contains(node)) {
102          AutoDiff.Variable var = new AutoDiff.Variable();
103          variables.Add(var);
104          return AutoDiff.TermBuilder.Product(var, par);
105        } else {
106          return AutoDiff.TermBuilder.Product(value, par);
107        }
108      }
109      if (node.Symbol is FactorVariable) {
110        var factorVarNode = node as FactorVariableTreeNode;
111        var products = new List<Term>();
112        foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
113          var data = new VariableData(factorVarNode.VariableName, variableValue, 0);
114          var par = parameters[data];
115          var value = factorVarNode.GetValue(variableValue);
116
117          if (nodesForOptimization.Contains(node)) {
118            var wVar = new AutoDiff.Variable();
119            variables.Add(wVar);
120
121            products.Add(AutoDiff.TermBuilder.Product(wVar, par));
122          } else {
123            products.Add(AutoDiff.TermBuilder.Product(value, par));
124          }
125        }
126        return AutoDiff.TermBuilder.Sum(products);
127      }
128      if (node.Symbol is LaggedVariable) {
129        var varNode = node as LaggedVariableTreeNode;
130        var data = new VariableData(varNode.VariableName, string.Empty, varNode.Lag);
131        var par = parameters[data];
132        var value = varNode.Weight;
133
134        if (nodesForOptimization.Contains(node)) {
135          AutoDiff.Variable var = new AutoDiff.Variable();
136          variables.Add(var);
137          return AutoDiff.TermBuilder.Product(var, par);
138        } else {
139          return AutoDiff.TermBuilder.Product(value, par);
140        }
141
142      }
143      if (node.Symbol is Addition) {
144        List<AutoDiff.Term> terms = new List<Term>();
145        foreach (var subTree in node.Subtrees) {
146          terms.Add(ConvertToAutoDiff(subTree));
147        }
148        return AutoDiff.TermBuilder.Sum(terms);
149      }
150      if (node.Symbol is Subtraction) {
151        List<AutoDiff.Term> terms = new List<Term>();
152        for (int i = 0; i < node.SubtreeCount; i++) {
153          AutoDiff.Term t = ConvertToAutoDiff(node.GetSubtree(i));
154          if (i > 0) t = -t;
155          terms.Add(t);
156        }
157        if (terms.Count == 1) return -terms[0];
158        else return AutoDiff.TermBuilder.Sum(terms);
159      }
160      if (node.Symbol is Multiplication) {
161        List<AutoDiff.Term> terms = new List<Term>();
162        foreach (var subTree in node.Subtrees) {
163          terms.Add(ConvertToAutoDiff(subTree));
164        }
165        if (terms.Count == 1) return terms[0];
166        else return terms.Aggregate((a, b) => new AutoDiff.Product(a, b));
167      }
168      if (node.Symbol is Division) {
169        List<AutoDiff.Term> terms = new List<Term>();
170        foreach (var subTree in node.Subtrees) {
171          terms.Add(ConvertToAutoDiff(subTree));
172        }
173        if (terms.Count == 1) return 1.0 / terms[0];
174        else return terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b));
175      }
176      if (node.Symbol is Absolute) {
177        var x1 = ConvertToAutoDiff(node.GetSubtree(0));
178        return abs(x1);
179      }
180      if (node.Symbol is AnalyticQuotient) {
181        var x1 = ConvertToAutoDiff(node.GetSubtree(0));
182        var x2 = ConvertToAutoDiff(node.GetSubtree(1));
183        return x1 / (TermBuilder.Power(1 + x2 * x2, 0.5));
184      }
185      if (node.Symbol is Logarithm) {
186        return AutoDiff.TermBuilder.Log(
187          ConvertToAutoDiff(node.GetSubtree(0)));
188      }
189      if (node.Symbol is Exponential) {
190        return AutoDiff.TermBuilder.Exp(
191          ConvertToAutoDiff(node.GetSubtree(0)));
192      }
193      if (node.Symbol is Square) {
194        return AutoDiff.TermBuilder.Power(
195          ConvertToAutoDiff(node.GetSubtree(0)), 2.0);
196      }
197      if (node.Symbol is SquareRoot) {
198        return AutoDiff.TermBuilder.Power(
199          ConvertToAutoDiff(node.GetSubtree(0)), 0.5);
200      }
201      if (node.Symbol is Cube) {
202        return AutoDiff.TermBuilder.Power(
203          ConvertToAutoDiff(node.GetSubtree(0)), 3.0);
204      }
205      if (node.Symbol is CubeRoot) {
206        return AutoDiff.TermBuilder.Power(
207          ConvertToAutoDiff(node.GetSubtree(0)), 1.0 / 3.0);
208      }
209      if (node.Symbol is Sine) {
210        return sin(
211          ConvertToAutoDiff(node.GetSubtree(0)));
212      }
213      if (node.Symbol is Cosine) {
214        return cos(
215          ConvertToAutoDiff(node.GetSubtree(0)));
216      }
217      if (node.Symbol is Tangent) {
218        return tan(
219          ConvertToAutoDiff(node.GetSubtree(0)));
220      }
221      if (node.Symbol is Erf) {
222        return erf(
223          ConvertToAutoDiff(node.GetSubtree(0)));
224      }
225      if (node.Symbol is Norm) {
226        return norm(
227          ConvertToAutoDiff(node.GetSubtree(0)));
228      }
229      if (node.Symbol is StartSymbol) {
230        return ConvertToAutoDiff(node.GetSubtree(0));
231      }
232      throw new ConversionException();
233    }
234
235    #region derivations of functions
236    // create function factory for arctangent
237    private static readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
238      eval: Math.Atan,
239      diff: x => 1 / (1 + x * x));
240
241    private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
242      eval: Math.Sin,
243      diff: Math.Cos);
244
245    private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
246      eval: Math.Cos,
247      diff: x => -Math.Sin(x));
248
249    private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
250      eval: Math.Tan,
251      diff: x => 1 + Math.Tan(x) * Math.Tan(x));
252
253    private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
254      eval: alglib.errorfunction,
255      diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
256
257    private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
258      eval: alglib.normaldistribution,
259      diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
260
261    private static readonly Func<Term, UnaryFunc> abs = UnaryFunc.Factory(
262      eval: Math.Abs,
263      diff: x => Math.Sign(x)
264      );
265
266    #endregion
267
268
269    public static bool IsCompatible(ISymbolicExpressionTree tree) {
270      var containsUnknownSymbol = (
271        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
272        where
273          !(n.Symbol is Variable) &&
274          !(n.Symbol is BinaryFactorVariable) &&
275          !(n.Symbol is FactorVariable) &&
276          !(n.Symbol is LaggedVariable) &&
277          !(n.Symbol is Constant) &&
278          !(n.Symbol is Addition) &&
279          !(n.Symbol is Subtraction) &&
280          !(n.Symbol is Multiplication) &&
281          !(n.Symbol is Division) &&
282          !(n.Symbol is Logarithm) &&
283          !(n.Symbol is Exponential) &&
284          !(n.Symbol is SquareRoot) &&
285          !(n.Symbol is Square) &&
286          !(n.Symbol is Sine) &&
287          !(n.Symbol is Cosine) &&
288          !(n.Symbol is Tangent) &&
289          !(n.Symbol is Erf) &&
290          !(n.Symbol is Norm) &&
291          !(n.Symbol is StartSymbol) &&
292          !(n.Symbol is Absolute) &&
293          !(n.Symbol is AnalyticQuotient) &&
294          !(n.Symbol is Cube) &&
295          !(n.Symbol is CubeRoot)
296        select n).Any();
297      return !containsUnknownSymbol;
298    }
299    #region exception class
300    [Serializable]
301    public class ConversionException : Exception {
302      public ConversionException() { }
303      public ConversionException(string message) : base(message) { }
304      public ConversionException(string message, Exception inner) : base(message, inner) { }
305      protected ConversionException(
306        SerializationInfo info,
307        StreamingContext context) : base(info, context) {
308      }
309    }
310    #endregion
311  }
312}
Note: See TracBrowser for help on using the repository browser.