Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2974_Constants_Optimization/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/ConstantsOptimization/AutoDiffConverter.cs @ 16796

Last change on this file since 16796 was 16507, checked in by mkommend, 6 years ago

#2974: First stable version of new CoOp.

File size: 12.3 KB
RevLine 
[14843]1#region License Information
2/* HeuristicLab
[15583]3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[14843]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
[14950]25using System.Runtime.Serialization;
[14843]26using AutoDiff;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28
[16507]29namespace HeuristicLab.Problems.DataAnalysis.Symbolic.ConstantsOptimization{
30  public class AutoDiffConverter {
[14950]31
[16507]32    /// <summary>
33    /// Converts a symbolic expression tree into a parametetric AutoDiff term.
34    /// </summary>
35    /// <param name="tree">The tree the should be converted.</param>
36    /// <param name="addLinearScalingTerms">A flag that determines whether linear scaling terms should be added to the parametric term.</param>
37    /// <param name="numericNodes">The nodes that contain numeric coefficents that should be added as variables in the term.</param>
38    /// <param name="variableData">The variable information that is used to create parameters in the term.</param>
39    /// <param name="autoDiffTerm">The resulting parametric AutoDiff term.</param>
40    /// <returns>A flag to see if the conversion has succeeded.</returns>
41    public static bool TryConvertToAutoDiff(ISymbolicExpressionTree tree, bool addLinearScalingTerms,
42      IEnumerable<ISymbolicExpressionTreeNode> numericNodes, IEnumerable<VariableData> variableData,
43      out IParametricCompiledTerm autoDiffTerm) {
[14843]44      // use a transformator object which holds the state (variable list, parameter list, ...) for recursive transformation of the tree
[16507]45      var transformator = new AutoDiffConverter(numericNodes, variableData);
[14843]46      AutoDiff.Term term;
[16457]47
[16461]48      try {
[16500]49        term = transformator.ConvertToAutoDiff(tree.Root.GetSubtree(0));
[16461]50        if (addLinearScalingTerms) {
[16500]51          // scaling variables α, β are given at the end of the parameter vector
[16461]52          var alpha = new AutoDiff.Variable();
53          var beta = new AutoDiff.Variable();
[16500]54
55          term = term * alpha + beta;
56
57          transformator.variables.Add(alpha);
[16463]58          transformator.variables.Add(beta);
[16461]59        }
[16507]60        var compiledTerm = term.Compile(transformator.variables.ToArray(), transformator.parameters.Values.ToArray());
[16500]61        autoDiffTerm = compiledTerm;
[16461]62        return true;
63      } catch (ConversionException) {
[16500]64        autoDiffTerm = null;
[16461]65      }
66      return false;
67    }
68
[14843]69    // state for recursive transformation of trees
[16507]70    private readonly HashSet<ISymbolicExpressionTreeNode> nodesForOptimization;
71    private readonly Dictionary<VariableData, AutoDiff.Variable> parameters;
[14843]72    private readonly List<AutoDiff.Variable> variables;
73
[16507]74    private AutoDiffConverter(IEnumerable<ISymbolicExpressionTreeNode> nodesForOptimization, IEnumerable<VariableData> variableData) {
75      this.nodesForOptimization = new HashSet<ISymbolicExpressionTreeNode>(nodesForOptimization);
76      this.parameters = variableData.ToDictionary(k => k, v => new AutoDiff.Variable());
[14843]77      this.variables = new List<AutoDiff.Variable>();
78    }
79
[14950]80    private AutoDiff.Term ConvertToAutoDiff(ISymbolicExpressionTreeNode node) {
[14843]81      if (node.Symbol is Constant) {
[16507]82        var constantNode = node as ConstantTreeNode;
83        var value = constantNode.Value;
84        if (nodesForOptimization.Contains(node)) {
85          AutoDiff.Variable var = new AutoDiff.Variable();
86          variables.Add(var);
87          return var;
88        } else {
89          return value;
90        }
[14843]91      }
92      if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) {
93        var varNode = node as VariableTreeNodeBase;
94        var factorVarNode = node as BinaryFactorVariableTreeNode;
95        // factor variable values are only 0 or 1 and set in x accordingly
96        var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;
[16507]97        var data = new VariableData(varNode.VariableName, varValue, 0);
98        var par = parameters[data];
99        var value = varNode.Weight;
[14843]100
[16507]101        if (nodesForOptimization.Contains(node)) {
102          AutoDiff.Variable var = new AutoDiff.Variable();
103          variables.Add(var);
104          return AutoDiff.TermBuilder.Product(var, par);
[14843]105        } else {
[16507]106          return AutoDiff.TermBuilder.Product(value, par);
[14843]107        }
108      }
109      if (node.Symbol is FactorVariable) {
110        var factorVarNode = node as FactorVariableTreeNode;
111        var products = new List<Term>();
112        foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
[16507]113          var data = new VariableData(factorVarNode.VariableName, variableValue, 0);
114          var par = parameters[data];
115          var value = factorVarNode.GetValue(variableValue);
[14843]116
[16507]117          if (nodesForOptimization.Contains(node)) {
118            var wVar = new AutoDiff.Variable();
119            variables.Add(wVar);
[14843]120
[16507]121            products.Add(AutoDiff.TermBuilder.Product(wVar, par));
122          } else {
123            products.Add(AutoDiff.TermBuilder.Product(value, par));
124          }
[14843]125        }
[14950]126        return AutoDiff.TermBuilder.Sum(products);
[14843]127      }
128      if (node.Symbol is LaggedVariable) {
129        var varNode = node as LaggedVariableTreeNode;
[16507]130        var data = new VariableData(varNode.VariableName, string.Empty, varNode.Lag);
131        var par = parameters[data];
132        var value = varNode.Weight;
[14843]133
[16507]134        if (nodesForOptimization.Contains(node)) {
135          AutoDiff.Variable var = new AutoDiff.Variable();
136          variables.Add(var);
137          return AutoDiff.TermBuilder.Product(var, par);
[14843]138        } else {
[16507]139          return AutoDiff.TermBuilder.Product(value, par);
[14843]140        }
[16507]141
[14843]142      }
143      if (node.Symbol is Addition) {
144        List<AutoDiff.Term> terms = new List<Term>();
145        foreach (var subTree in node.Subtrees) {
[14950]146          terms.Add(ConvertToAutoDiff(subTree));
[14843]147        }
[14950]148        return AutoDiff.TermBuilder.Sum(terms);
[14843]149      }
150      if (node.Symbol is Subtraction) {
151        List<AutoDiff.Term> terms = new List<Term>();
152        for (int i = 0; i < node.SubtreeCount; i++) {
[14950]153          AutoDiff.Term t = ConvertToAutoDiff(node.GetSubtree(i));
[14843]154          if (i > 0) t = -t;
155          terms.Add(t);
156        }
[14950]157        if (terms.Count == 1) return -terms[0];
158        else return AutoDiff.TermBuilder.Sum(terms);
[14843]159      }
160      if (node.Symbol is Multiplication) {
161        List<AutoDiff.Term> terms = new List<Term>();
162        foreach (var subTree in node.Subtrees) {
[14950]163          terms.Add(ConvertToAutoDiff(subTree));
[14843]164        }
[14950]165        if (terms.Count == 1) return terms[0];
166        else return terms.Aggregate((a, b) => new AutoDiff.Product(a, b));
[14843]167      }
168      if (node.Symbol is Division) {
169        List<AutoDiff.Term> terms = new List<Term>();
170        foreach (var subTree in node.Subtrees) {
[14950]171          terms.Add(ConvertToAutoDiff(subTree));
[14843]172        }
[14950]173        if (terms.Count == 1) return 1.0 / terms[0];
174        else return terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b));
[14843]175      }
[16356]176      if (node.Symbol is Absolute) {
177        var x1 = ConvertToAutoDiff(node.GetSubtree(0));
178        return abs(x1);
179      }
[16360]180      if (node.Symbol is AnalyticQuotient) {
[16356]181        var x1 = ConvertToAutoDiff(node.GetSubtree(0));
182        var x2 = ConvertToAutoDiff(node.GetSubtree(1));
183        return x1 / (TermBuilder.Power(1 + x2 * x2, 0.5));
184      }
[14843]185      if (node.Symbol is Logarithm) {
[14950]186        return AutoDiff.TermBuilder.Log(
187          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]188      }
189      if (node.Symbol is Exponential) {
[14950]190        return AutoDiff.TermBuilder.Exp(
191          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]192      }
193      if (node.Symbol is Square) {
[14950]194        return AutoDiff.TermBuilder.Power(
195          ConvertToAutoDiff(node.GetSubtree(0)), 2.0);
[14843]196      }
197      if (node.Symbol is SquareRoot) {
[14950]198        return AutoDiff.TermBuilder.Power(
199          ConvertToAutoDiff(node.GetSubtree(0)), 0.5);
[14843]200      }
[16356]201      if (node.Symbol is Cube) {
202        return AutoDiff.TermBuilder.Power(
203          ConvertToAutoDiff(node.GetSubtree(0)), 3.0);
204      }
205      if (node.Symbol is CubeRoot) {
206        return AutoDiff.TermBuilder.Power(
[16457]207          ConvertToAutoDiff(node.GetSubtree(0)), 1.0 / 3.0);
[16356]208      }
[14843]209      if (node.Symbol is Sine) {
[14950]210        return sin(
211          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]212      }
213      if (node.Symbol is Cosine) {
[14950]214        return cos(
215          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]216      }
217      if (node.Symbol is Tangent) {
[14950]218        return tan(
219          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]220      }
221      if (node.Symbol is Erf) {
[14950]222        return erf(
223          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]224      }
225      if (node.Symbol is Norm) {
[14950]226        return norm(
227          ConvertToAutoDiff(node.GetSubtree(0)));
[14843]228      }
229      if (node.Symbol is StartSymbol) {
[16457]230        return ConvertToAutoDiff(node.GetSubtree(0));
[14843]231      }
[14950]232      throw new ConversionException();
[14843]233    }
234
[16507]235    #region derivations of functions
236    // create function factory for arctangent
237    private static readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
238      eval: Math.Atan,
239      diff: x => 1 / (1 + x * x));
[14843]240
[16507]241    private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
242      eval: Math.Sin,
243      diff: Math.Cos);
[14843]244
[16507]245    private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
246      eval: Math.Cos,
247      diff: x => -Math.Sin(x));
[14843]248
[16507]249    private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
250      eval: Math.Tan,
251      diff: x => 1 + Math.Tan(x) * Math.Tan(x));
252
253    private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
254      eval: alglib.errorfunction,
255      diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
256
257    private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
258      eval: alglib.normaldistribution,
259      diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
260
261    private static readonly Func<Term, UnaryFunc> abs = UnaryFunc.Factory(
262      eval: Math.Abs,
263      diff: x => Math.Sign(x)
264      );
265
266    #endregion
267
268
[14843]269    public static bool IsCompatible(ISymbolicExpressionTree tree) {
270      var containsUnknownSymbol = (
271        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
272        where
[14950]273          !(n.Symbol is Variable) &&
274          !(n.Symbol is BinaryFactorVariable) &&
275          !(n.Symbol is FactorVariable) &&
276          !(n.Symbol is LaggedVariable) &&
277          !(n.Symbol is Constant) &&
278          !(n.Symbol is Addition) &&
279          !(n.Symbol is Subtraction) &&
280          !(n.Symbol is Multiplication) &&
281          !(n.Symbol is Division) &&
282          !(n.Symbol is Logarithm) &&
283          !(n.Symbol is Exponential) &&
284          !(n.Symbol is SquareRoot) &&
285          !(n.Symbol is Square) &&
286          !(n.Symbol is Sine) &&
287          !(n.Symbol is Cosine) &&
288          !(n.Symbol is Tangent) &&
289          !(n.Symbol is Erf) &&
290          !(n.Symbol is Norm) &&
[16356]291          !(n.Symbol is StartSymbol) &&
292          !(n.Symbol is Absolute) &&
[16360]293          !(n.Symbol is AnalyticQuotient) &&
[16356]294          !(n.Symbol is Cube) &&
295          !(n.Symbol is CubeRoot)
[14843]296        select n).Any();
297      return !containsUnknownSymbol;
298    }
[14950]299    #region exception class
300    [Serializable]
301    public class ConversionException : Exception {
[16507]302      public ConversionException() { }
303      public ConversionException(string message) : base(message) { }
304      public ConversionException(string message, Exception inner) : base(message, inner) { }
[14950]305      protected ConversionException(
306        SerializationInfo info,
307        StreamingContext context) : base(info, context) {
308      }
309    }
310    #endregion
[14843]311  }
312}
Note: See TracBrowser for help on using the repository browser.