Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/VectorUnrollingTreeToAutoDiffTermConverter.cs

Last change on this file was 18234, checked in by pfleck, 3 years ago

#3040 Fixed vector-unrolling AutoDiff conversion.

File size: 19.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Runtime.Serialization;
26using AutoDiff;
27using HeuristicLab.Common;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29
30namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
31  public class VectorUnrollingTreeToAutoDiffTermConverter {
32    public delegate double ParametricFunction(double[] vars, double[] @params);
33
34    public delegate Tuple<double[], double> ParametricFunctionGradient(double[] vars, double[] @params);
35
36    #region helper class
37    public class DataForVariable {
38      public readonly string variableName;
39      public readonly string variableValue; // for factor vars
40      public readonly int lag;
41      public readonly int index; // for vectors
42
43      public DataForVariable(string varName, string varValue, int lag, int index) {
44        this.variableName = varName;
45        this.variableValue = varValue;
46        this.lag = lag;
47        this.index = index;
48      }
49
50      public override bool Equals(object obj) {
51        var other = obj as DataForVariable;
52        if (other == null) return false;
53        return other.variableName.Equals(this.variableName) &&
54               other.variableValue.Equals(this.variableValue) &&
55               other.lag == this.lag &&
56               other.index == this.index;
57      }
58
59      public override int GetHashCode() {
60        return variableName.GetHashCode() ^ variableValue.GetHashCode() ^ lag ^ index;
61      }
62    }
63    #endregion
64
65    #region derivations of functions
66    // create function factory for arctangent
67    private static readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
68      eval: Math.Atan,
69      diff: x => 1 / (1 + x * x));
70
71    private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
72      eval: Math.Sin,
73      diff: Math.Cos);
74
75    private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
76      eval: Math.Cos,
77      diff: x => -Math.Sin(x));
78
79    private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
80      eval: Math.Tan,
81      diff: x => 1 + Math.Tan(x) * Math.Tan(x));
82    private static readonly Func<Term, UnaryFunc> tanh = UnaryFunc.Factory(
83      eval: Math.Tanh,
84      diff: x => 1 - Math.Tanh(x) * Math.Tanh(x));
85    private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
86      eval: alglib.errorfunction,
87      diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
88
89    private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
90      eval: alglib.normaldistribution,
91      diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
92
93    private static readonly Func<Term, UnaryFunc> abs = UnaryFunc.Factory(
94      eval: Math.Abs,
95      diff: x => Math.Sign(x)
96      );
97
98    private static readonly Func<Term, UnaryFunc> cbrt = UnaryFunc.Factory(
99      eval: x => x < 0 ? -Math.Pow(-x, 1.0 / 3) : Math.Pow(x, 1.0 / 3),
100      diff: x => { var cbrt_x = x < 0 ? -Math.Pow(-x, 1.0 / 3) : Math.Pow(x, 1.0 / 3); return 1.0 / (3 * cbrt_x * cbrt_x); }
101      );
102
103
104
105    #endregion
106
107    public static bool TryConvertToAutoDiff(ISymbolicExpressionTree tree,
108      IDictionary<ISymbolicExpressionTreeNode, SymbolicDataAnalysisExpressionTreeVectorInterpreter.EvaluationResult> evaluationTrace,
109      bool makeVariableWeightsVariable, bool addLinearScalingTerms,
110      out List<DataForVariable> parameters, out double[] initialConstants,
111      out ParametricFunction func,
112      out ParametricFunctionGradient func_grad) {
113
114      // use a transformator object which holds the state (variable list, parameter list, ...) for recursive transformation of the tree
115      var transformator = new VectorUnrollingTreeToAutoDiffTermConverter(evaluationTrace,
116        makeVariableWeightsVariable, addLinearScalingTerms);
117      Term term;
118      try {
119        term = transformator.ConvertToAutoDiff(tree.Root.GetSubtree(0)).Single();
120        var parameterEntries = transformator.parameters.ToArray(); // guarantee same order for keys and values
121        var compiledTerm = term.Compile(transformator.variables.ToArray(),
122          parameterEntries.Select(kvp => kvp.Value).ToArray());
123        parameters = new List<DataForVariable>(parameterEntries.Select(kvp => kvp.Key));
124        initialConstants = transformator.initialConstants.ToArray();
125        func = (vars, @params) => compiledTerm.Evaluate(vars, @params);
126        func_grad = (vars, @params) => compiledTerm.Differentiate(vars, @params);
127        return true;
128      } catch (ConversionException) {
129        func = null;
130        func_grad = null;
131        parameters = null;
132        initialConstants = null;
133      }
134      return false;
135    }
136
137    private readonly IDictionary<ISymbolicExpressionTreeNode, SymbolicDataAnalysisExpressionTreeVectorInterpreter.EvaluationResult> evaluationTrace;
138    // state for recursive transformation of trees
139    private readonly List<double> initialConstants;
140    private readonly Dictionary<DataForVariable, AutoDiff.Variable> parameters;
141    private readonly List<AutoDiff.Variable> variables;
142    private readonly bool makeVariableWeightsVariable;
143    private readonly bool addLinearScalingTerms;
144
145    private VectorUnrollingTreeToAutoDiffTermConverter(IDictionary<ISymbolicExpressionTreeNode, SymbolicDataAnalysisExpressionTreeVectorInterpreter.EvaluationResult> evaluationTrace,
146      bool makeVariableWeightsVariable, bool addLinearScalingTerms) {
147      this.evaluationTrace = evaluationTrace;
148      this.makeVariableWeightsVariable = makeVariableWeightsVariable;
149      this.addLinearScalingTerms = addLinearScalingTerms;
150      this.initialConstants = new List<double>();
151      this.parameters = new Dictionary<DataForVariable, AutoDiff.Variable>();
152      this.variables = new List<AutoDiff.Variable>();
153    }
154
155    private static IEnumerable<IEnumerable<T>> Broadcast<T>(IList<T>[] source) {
156      var maxLength = source.Max(x => x.Count);
157      if (source.Any(x => x.Count != maxLength && x.Count != 1))
158        throw new InvalidOperationException("Length must match to maxLength or one");
159      return source.Select(x => x.Count == maxLength ? x : Enumerable.Repeat(x[0], maxLength));
160    }
161    public static IEnumerable<IEnumerable<T>> Transpose<T>(IEnumerable<IEnumerable<T>> source) {
162      var enumerators = source.Select(x => x.GetEnumerator()).ToArray();
163      try {
164        while (enumerators.All(x => x.MoveNext())) {
165          yield return enumerators.Select(x => x.Current).ToArray();
166        }
167      } finally {
168        foreach (var enumerator in enumerators)
169          enumerator.Dispose();
170      }
171    }
172
173    private IList<AutoDiff.Term> ConvertToAutoDiff(ISymbolicExpressionTreeNode node) {
174      IList<Term> BinaryOp(Func<Term, Term, Term> binaryOp, Func<Term, Term> singleElementOp, params IList<Term>[] terms) {
175        if (terms.Length == 1) return terms[0].Select(singleElementOp).ToList();
176        var broadcastedTerms = Broadcast(terms);
177        var transposedTerms = Transpose(broadcastedTerms);
178        return transposedTerms.Select(term => term.Aggregate(binaryOp)).ToList();
179      }
180      IList<Term> UnaryOp(Func<Term, Term> unaryOp, IList<Term> term) {
181        return term.Select(unaryOp).ToList();
182      }
183
184      var evaluationResult = evaluationTrace[node];
185
186      if (node.Symbol is Constant) { // assume scalar constant
187        initialConstants.Add(((ConstantTreeNode)node).Value);
188        var var = new AutoDiff.Variable();
189        variables.Add(var);
190        return new Term[] { var };
191      }
192      if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) {
193        var varNode = node as VariableTreeNodeBase;
194        var factorVarNode = node as BinaryFactorVariableTreeNode;
195        // factor variable values are only 0 or 1 and set in x accordingly
196        var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;
197        var pars = evaluationResult.IsVector
198          ? Enumerable.Range(0, evaluationResult.Vector.Count).Select(i => FindOrCreateParameter(parameters, varNode.VariableName, varValue, index: i))
199          : FindOrCreateParameter(parameters, varNode.VariableName, varValue).ToEnumerable();
200
201        if (makeVariableWeightsVariable) {
202          initialConstants.Add(varNode.Weight);
203          var w = new AutoDiff.Variable();
204          variables.Add(w);
205          return pars.Select(par => AutoDiff.TermBuilder.Product(w, par)).ToList();
206        } else {
207          return pars.Select(par => varNode.Weight * par).ToList();
208        }
209      }
210      if (node.Symbol is FactorVariable) {
211        var factorVarNode = node as FactorVariableTreeNode;
212        var products = new List<Term>();
213        foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
214          var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);
215
216          initialConstants.Add(factorVarNode.GetValue(variableValue));
217          var wVar = new AutoDiff.Variable();
218          variables.Add(wVar);
219
220          products.Add(AutoDiff.TermBuilder.Product(wVar, par));
221        }
222        return new[] { AutoDiff.TermBuilder.Sum(products) };
223      }
224      //if (node.Symbol is LaggedVariable) {
225      //  var varNode = node as LaggedVariableTreeNode;
226      //  var par = FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag);
227
228      //  if (makeVariableWeightsVariable) {
229      //    initialConstants.Add(varNode.Weight);
230      //    var w = new AutoDiff.Variable();
231      //    variables.Add(w);
232      //    return AutoDiff.TermBuilder.Product(w, par);
233      //  } else {
234      //    return varNode.Weight * par;
235      //  }
236      //}
237      if (node.Symbol is Addition) {
238        var terms = node.Subtrees.Select(ConvertToAutoDiff).ToArray();
239        return BinaryOp((a, b) => a + b, a => a, terms);
240      }
241      if (node.Symbol is Subtraction) {
242        var terms = node.Subtrees.Select(ConvertToAutoDiff).ToArray();
243        return BinaryOp((a, b) => a - b, a => -a, terms);
244      }
245      if (node.Symbol is Multiplication) {
246        var terms = node.Subtrees.Select(ConvertToAutoDiff).ToArray();
247        return BinaryOp((a, b) => a * b, a => a, terms);
248      }
249      if (node.Symbol is Division) {
250        var terms = node.Subtrees.Select(ConvertToAutoDiff).ToArray();
251        return BinaryOp((a, b) => a / b, a => 1.0 / a, terms);
252      }
253      if (node.Symbol is Absolute) {
254        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
255        return UnaryOp(abs, term);
256      }
257      //if (node.Symbol is AnalyticQuotient) {
258      //  var x1 = ConvertToAutoDiff(node.GetSubtree(0));
259      //  var x2 = ConvertToAutoDiff(node.GetSubtree(1));
260      //  return x1 / (TermBuilder.Power(1 + x2 * x2, 0.5));
261      //}
262      if (node.Symbol is Logarithm) {
263        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
264        return UnaryOp(TermBuilder.Log, term);
265      }
266      if (node.Symbol is Exponential) {
267        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
268        return UnaryOp(TermBuilder.Exp, term);
269      }
270      if (node.Symbol is Square) {
271        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
272        return UnaryOp(t => TermBuilder.Power(t, 2.0), term);
273      }
274      if (node.Symbol is SquareRoot) {
275        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
276        return UnaryOp(t => TermBuilder.Power(t, 0.5), term);
277      }
278      if (node.Symbol is Cube) {
279        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
280        return UnaryOp(t => TermBuilder.Power(t, 3.0), term);
281      }
282      if (node.Symbol is CubeRoot) {
283        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
284        return UnaryOp(cbrt, term);
285      }
286      if (node.Symbol is Sine) {
287        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
288        return UnaryOp(sin, term);
289      }
290      if (node.Symbol is Cosine) {
291        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
292        return UnaryOp(cos, term);
293      }
294      if (node.Symbol is Tangent) {
295        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
296        return UnaryOp(tan, term);
297      }
298      if (node.Symbol is HyperbolicTangent) {
299        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
300        return UnaryOp(tanh, term);
301      }
302      if (node.Symbol is Erf) {
303        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
304        return UnaryOp(erf, term);
305      }
306      if (node.Symbol is Norm) {
307        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
308        return UnaryOp(norm, term);
309      }
310      if (node.Symbol is StartSymbol) {
311        if (addLinearScalingTerms) {
312          // scaling variables α, β are given at the beginning of the parameter vector
313          var alpha = new AutoDiff.Variable();
314          var beta = new AutoDiff.Variable();
315          variables.Add(beta);
316          variables.Add(alpha);
317          var t = ConvertToAutoDiff(node.GetSubtree(0));
318          if (t.Count > 1) throw new InvalidOperationException("Tree Result must be scalar value");
319          return new[] { t[0] * alpha + beta };
320        } else return ConvertToAutoDiff(node.GetSubtree(0));
321      }
322      if (node.Symbol is Sum) {
323        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
324        return new[] { TermBuilder.Sum(term) };
325      }
326      if (node.Symbol is Mean) {
327        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
328        return new[] { TermBuilder.Sum(term) / term.Count };
329      }
330      if (node.Symbol is StandardDeviation) {
331        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
332        var mean = TermBuilder.Sum(term) / term.Count;
333        var ssd = TermBuilder.Sum(term.Select(t => TermBuilder.Power(t - mean, 2.0)));
334        return new[] { TermBuilder.Power(ssd / term.Count, 0.5) };
335      }
336      if (node.Symbol is Length) {
337        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
338        return new[] { TermBuilder.Constant(term.Count) };
339      }
340      //if (node.Symbol is Min) {
341      //}
342      //if (node.Symbol is Max) {
343      //}
344      if (node.Symbol is Variance) {
345        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
346        var mean = TermBuilder.Sum(term) / term.Count;
347        var ssd = TermBuilder.Sum(term.Select(t => TermBuilder.Power(t - mean, 2.0)));
348        return new[] { ssd / term.Count };
349      }
350      //if (node.Symbol is Skewness) {
351      //}
352      //if (node.Symbol is Kurtosis) {
353      //}
354      //if (node.Symbol is EuclideanDistance) {
355      //}
356      //if (node.Symbol is Covariance) {
357      //}
358
359      if (node.Symbol is SubVector) {
360        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
361        var windowedNode = (IWindowedSymbolTreeNode)node;
362        int startIdx = SymbolicDataAnalysisExpressionTreeVectorInterpreter.ToVectorIdx(windowedNode.Offset, term.Count);
363        int endIdx = SymbolicDataAnalysisExpressionTreeVectorInterpreter.ToVectorIdx(windowedNode.Length, term.Count);
364        var slices = SymbolicDataAnalysisExpressionTreeVectorInterpreter.GetVectorSlices(startIdx, endIdx, term.Count);
365
366        var selectedTerms = new List<Term>(capacity: slices.Sum(s => s.Item2));
367        foreach (var (start, count) in slices) {
368          for (int i = start; i < start + count; i++){
369             selectedTerms.Add(term[i]);
370          }
371        }
372        return selectedTerms;
373      }
374
375      throw new ConversionException();
376    }
377
378
379    // for each factor variable value we need a parameter which represents a binary indicator for that variable & value combination
380    // each binary indicator is only necessary once. So we only create a parameter if this combination is not yet available
381    private static Term FindOrCreateParameter(Dictionary<DataForVariable, AutoDiff.Variable> parameters,
382      string varName, string varValue = "", int lag = 0, int index = -1) {
383      var data = new DataForVariable(varName, varValue, lag, index);
384
385      AutoDiff.Variable par = null;
386      if (!parameters.TryGetValue(data, out par)) {
387        // not found -> create new parameter and entries in names and values lists
388        par = new AutoDiff.Variable();
389        parameters.Add(data, par);
390      }
391      return par;
392    }
393
394    public static bool IsCompatible(ISymbolicExpressionTree tree) {
395      var containsUnknownSymbol = (
396        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
397        where
398          !(n.Symbol is Variable) &&
399          !(n.Symbol is BinaryFactorVariable) &&
400          //!(n.Symbol is FactorVariable) &&
401          //!(n.Symbol is LaggedVariable) &&
402          !(n.Symbol is Constant) &&
403          !(n.Symbol is Addition) &&
404          !(n.Symbol is Subtraction) &&
405          !(n.Symbol is Multiplication) &&
406          !(n.Symbol is Division) &&
407          !(n.Symbol is Logarithm) &&
408          !(n.Symbol is Exponential) &&
409          !(n.Symbol is SquareRoot) &&
410          !(n.Symbol is Square) &&
411          !(n.Symbol is Sine) &&
412          !(n.Symbol is Cosine) &&
413          !(n.Symbol is Tangent) &&
414          !(n.Symbol is HyperbolicTangent) &&
415          !(n.Symbol is Erf) &&
416          !(n.Symbol is Norm) &&
417          !(n.Symbol is StartSymbol) &&
418          !(n.Symbol is Absolute) &&
419          //!(n.Symbol is AnalyticQuotient) &&
420          !(n.Symbol is Cube) &&
421          !(n.Symbol is CubeRoot) &&
422          !(n.Symbol is Sum) &&
423          !(n.Symbol is Mean) &&
424          !(n.Symbol is StandardDeviation) &&
425          !(n.Symbol is Length) &&
426          //!(n.Symbol is Min) &&
427          //!(n.Symbol is Max) &&
428          !(n.Symbol is Variance) &&
429        //!(n.Symbol is Skewness) &&
430        //!(n.Symbol is Kurtosis) &&
431        //!(n.Symbol is EuclideanDistance) &&
432        //!(n.Symbol is Covariance) &&
433          !(n.Symbol is SubVector)
434        select n).Any();
435      return !containsUnknownSymbol;
436    }
437    #region exception class
438    [Serializable]
439    public class ConversionException : Exception {
440
441      public ConversionException() {
442      }
443
444      public ConversionException(string message) : base(message) {
445      }
446
447      public ConversionException(string message, Exception inner) : base(message, inner) {
448      }
449
450      protected ConversionException(
451        SerializationInfo info,
452        StreamingContext context) : base(info, context) {
453      }
454    }
455    #endregion
456  }
457}
Note: See TracBrowser for help on using the repository browser.