Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/VectorUnrollingTreeToAutoDiffTermConverter.cs @ 17837

Last change on this file since 17837 was 17726, checked in by pfleck, 4 years ago

#3040 Added a constant opt evaluator for vectors that uses the existing AutoDiff library by unrolling all vector operations.

File size: 17.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Runtime.Serialization;
26using AutoDiff;
27using HeuristicLab.Common;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29
30namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
31  public class VectorUnrollingTreeToAutoDiffTermConverter {
32    public delegate double ParametricFunction(double[] vars, double[] @params);
33
34    public delegate Tuple<double[], double> ParametricFunctionGradient(double[] vars, double[] @params);
35
36    #region helper class
37    public class DataForVariable {
38      public readonly string variableName;
39      public readonly string variableValue; // for factor vars
40      public readonly int lag;
41      public readonly int index; // for vectors
42
43      public DataForVariable(string varName, string varValue, int lag, int index) {
44        this.variableName = varName;
45        this.variableValue = varValue;
46        this.lag = lag;
47        this.index = index;
48      }
49
50      public override bool Equals(object obj) {
51        var other = obj as DataForVariable;
52        if (other == null) return false;
53        return other.variableName.Equals(this.variableName) &&
54               other.variableValue.Equals(this.variableValue) &&
55               other.lag == this.lag &&
56               other.index == this.index;
57      }
58
59      public override int GetHashCode() {
60        return variableName.GetHashCode() ^ variableValue.GetHashCode() ^ lag ^ index;
61      }
62    }
63    #endregion
64
65    #region derivations of functions
66    // create function factory for arctangent
67    private static readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
68      eval: Math.Atan,
69      diff: x => 1 / (1 + x * x));
70
71    private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
72      eval: Math.Sin,
73      diff: Math.Cos);
74
75    private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
76      eval: Math.Cos,
77      diff: x => -Math.Sin(x));
78
79    private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
80      eval: Math.Tan,
81      diff: x => 1 + Math.Tan(x) * Math.Tan(x));
82    private static readonly Func<Term, UnaryFunc> tanh = UnaryFunc.Factory(
83      eval: Math.Tanh,
84      diff: x => 1 - Math.Tanh(x) * Math.Tanh(x));
85    private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
86      eval: alglib.errorfunction,
87      diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
88
89    private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
90      eval: alglib.normaldistribution,
91      diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
92
93    private static readonly Func<Term, UnaryFunc> abs = UnaryFunc.Factory(
94      eval: Math.Abs,
95      diff: x => Math.Sign(x)
96      );
97
98    private static readonly Func<Term, UnaryFunc> cbrt = UnaryFunc.Factory(
99      eval: x => x < 0 ? -Math.Pow(-x, 1.0 / 3) : Math.Pow(x, 1.0 / 3),
100      diff: x => { var cbrt_x = x < 0 ? -Math.Pow(-x, 1.0 / 3) : Math.Pow(x, 1.0 / 3); return 1.0 / (3 * cbrt_x * cbrt_x); }
101      );
102
103
104
105    #endregion
106
107    public static bool TryConvertToAutoDiff(ISymbolicExpressionTree tree,
108      IDictionary<ISymbolicExpressionTreeNode, SymbolicDataAnalysisExpressionTreeVectorInterpreter.EvaluationResult> evaluationTrace,
109      bool makeVariableWeightsVariable, bool addLinearScalingTerms,
110      out List<DataForVariable> parameters, out double[] initialConstants,
111      out ParametricFunction func,
112      out ParametricFunctionGradient func_grad) {
113
114      // use a transformator object which holds the state (variable list, parameter list, ...) for recursive transformation of the tree
115      var transformator = new VectorUnrollingTreeToAutoDiffTermConverter(evaluationTrace,
116        makeVariableWeightsVariable, addLinearScalingTerms);
117      Term term;
118      try {
119        term = transformator.ConvertToAutoDiff(tree.Root.GetSubtree(0)).Single();
120        var parameterEntries = transformator.parameters.ToArray(); // guarantee same order for keys and values
121        var compiledTerm = term.Compile(transformator.variables.ToArray(),
122          parameterEntries.Select(kvp => kvp.Value).ToArray());
123        parameters = new List<DataForVariable>(parameterEntries.Select(kvp => kvp.Key));
124        initialConstants = transformator.initialConstants.ToArray();
125        func = (vars, @params) => compiledTerm.Evaluate(vars, @params);
126        func_grad = (vars, @params) => compiledTerm.Differentiate(vars, @params);
127        return true;
128      } catch (ConversionException) {
129        func = null;
130        func_grad = null;
131        parameters = null;
132        initialConstants = null;
133      }
134      return false;
135    }
136
137    private readonly IDictionary<ISymbolicExpressionTreeNode, SymbolicDataAnalysisExpressionTreeVectorInterpreter.EvaluationResult> evaluationTrace;
138    // state for recursive transformation of trees
139    private readonly List<double> initialConstants;
140    private readonly Dictionary<DataForVariable, AutoDiff.Variable> parameters;
141    private readonly List<AutoDiff.Variable> variables;
142    private readonly bool makeVariableWeightsVariable;
143    private readonly bool addLinearScalingTerms;
144
145    private VectorUnrollingTreeToAutoDiffTermConverter(IDictionary<ISymbolicExpressionTreeNode, SymbolicDataAnalysisExpressionTreeVectorInterpreter.EvaluationResult> evaluationTrace,
146      bool makeVariableWeightsVariable, bool addLinearScalingTerms) {
147      this.evaluationTrace = evaluationTrace;
148      this.makeVariableWeightsVariable = makeVariableWeightsVariable;
149      this.addLinearScalingTerms = addLinearScalingTerms;
150      this.initialConstants = new List<double>();
151      this.parameters = new Dictionary<DataForVariable, AutoDiff.Variable>();
152      this.variables = new List<AutoDiff.Variable>();
153    }
154
155    private IList<AutoDiff.Term> ConvertToAutoDiff(ISymbolicExpressionTreeNode node) {
156      IList<Term> BinaryOp(Func<Term, Term, Term> binaryOp, Func<Term, Term> singleElementOp, params IList<Term>[] terms) {
157        if (terms.Length == 1) return terms[0].Select(singleElementOp).ToList();
158        return terms.Aggregate((acc, vectorizedTerm) => acc.Zip(vectorizedTerm, binaryOp).ToList());
159      }
160      IList<Term> BinaryOp2(Func<Term, Term, Term> binaryOp, params IList<Term>[] terms) {
161        return terms.Aggregate((acc, vectorizedTerm) => acc.Zip(vectorizedTerm, binaryOp).ToList());
162      }
163      IList<Term> UnaryOp(Func<Term, Term> unaryOp, IList<Term> term) {
164        return term.Select(unaryOp).ToList();
165      }
166
167      var evaluationResult = evaluationTrace[node];
168
169      if (node.Symbol is Constant) { // assume scalar constant
170        initialConstants.Add(((ConstantTreeNode)node).Value);
171        var var = new AutoDiff.Variable();
172        variables.Add(var);
173        return new Term[] { var };
174      }
175      if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) {
176        var varNode = node as VariableTreeNodeBase;
177        var factorVarNode = node as BinaryFactorVariableTreeNode;
178        // factor variable values are only 0 or 1 and set in x accordingly
179        var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;
180        var pars = evaluationResult.IsVector
181          ? Enumerable.Range(0, evaluationResult.Vector.Count).Select(i => FindOrCreateParameter(parameters, varNode.VariableName, varValue, index: i))
182          : FindOrCreateParameter(parameters, varNode.VariableName, varValue).ToEnumerable();
183
184        if (makeVariableWeightsVariable) {
185          initialConstants.Add(varNode.Weight);
186          var w = new AutoDiff.Variable();
187          variables.Add(w);
188          return pars.Select(par => AutoDiff.TermBuilder.Product(w, par)).ToList();
189        } else {
190          return pars.Select(par => varNode.Weight * par).ToList();
191        }
192      }
193      if (node.Symbol is FactorVariable) {
194        var factorVarNode = node as FactorVariableTreeNode;
195        var products = new List<Term>();
196        foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
197          var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);
198
199          initialConstants.Add(factorVarNode.GetValue(variableValue));
200          var wVar = new AutoDiff.Variable();
201          variables.Add(wVar);
202
203          products.Add(AutoDiff.TermBuilder.Product(wVar, par));
204        }
205        return new[] { AutoDiff.TermBuilder.Sum(products) };
206      }
207      //if (node.Symbol is LaggedVariable) {
208      //  var varNode = node as LaggedVariableTreeNode;
209      //  var par = FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag);
210
211      //  if (makeVariableWeightsVariable) {
212      //    initialConstants.Add(varNode.Weight);
213      //    var w = new AutoDiff.Variable();
214      //    variables.Add(w);
215      //    return AutoDiff.TermBuilder.Product(w, par);
216      //  } else {
217      //    return varNode.Weight * par;
218      //  }
219      //}
220      if (node.Symbol is Addition) {
221        var terms = node.Subtrees.Select(ConvertToAutoDiff).ToArray();
222        return BinaryOp((a, b) => a + b, a => a, terms);
223      }
224      if (node.Symbol is Subtraction) {
225        var terms = node.Subtrees.Select(ConvertToAutoDiff).ToArray();
226        return BinaryOp((a, b) => a - b, a => -a, terms);
227      }
228      if (node.Symbol is Multiplication) {
229        var terms = node.Subtrees.Select(ConvertToAutoDiff).ToArray();
230        return BinaryOp((a, b) => a * b, a => a, terms);
231      }
232      if (node.Symbol is Division) {
233        var terms = node.Subtrees.Select(ConvertToAutoDiff).ToArray();
234        return BinaryOp((a, b) => a / b, a => 1.0 / a, terms);
235      }
236      if (node.Symbol is Absolute) {
237        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
238        return UnaryOp(abs, term);
239      }
240      //if (node.Symbol is AnalyticQuotient) {
241      //  var x1 = ConvertToAutoDiff(node.GetSubtree(0));
242      //  var x2 = ConvertToAutoDiff(node.GetSubtree(1));
243      //  return x1 / (TermBuilder.Power(1 + x2 * x2, 0.5));
244      //}
245      if (node.Symbol is Logarithm) {
246        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
247        return UnaryOp(TermBuilder.Log, term);
248      }
249      if (node.Symbol is Exponential) {
250        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
251        return UnaryOp(TermBuilder.Exp, term);
252      }
253      if (node.Symbol is Square) {
254        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
255        return UnaryOp(t => TermBuilder.Power(t, 2.0), term);
256      }
257      if (node.Symbol is SquareRoot) {
258        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
259        return UnaryOp(t => TermBuilder.Power(t, 0.5), term);
260      }
261      if (node.Symbol is Cube) {
262        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
263        return UnaryOp(t => TermBuilder.Power(t, 3.0), term);
264      }
265      if (node.Symbol is CubeRoot) {
266        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
267        return UnaryOp(cbrt, term);
268      }
269      if (node.Symbol is Sine) {
270        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
271        return UnaryOp(sin, term);
272      }
273      if (node.Symbol is Cosine) {
274        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
275        return UnaryOp(cos, term);
276      }
277      if (node.Symbol is Tangent) {
278        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
279        return UnaryOp(tan, term);
280      }
281      if (node.Symbol is HyperbolicTangent) {
282        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
283        return UnaryOp(tanh, term);
284      }
285      if (node.Symbol is Erf) {
286        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
287        return UnaryOp(erf, term);
288      }
289      if (node.Symbol is Norm) {
290        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
291        return UnaryOp(norm, term);
292      }
293      if (node.Symbol is StartSymbol) {
294        if (addLinearScalingTerms) {
295          // scaling variables α, β are given at the beginning of the parameter vector
296          var alpha = new AutoDiff.Variable();
297          var beta = new AutoDiff.Variable();
298          variables.Add(beta);
299          variables.Add(alpha);
300          var t = ConvertToAutoDiff(node.GetSubtree(0));
301          if (t.Count > 1) throw new InvalidOperationException("Tree Result must be scalar value");
302          return new[] { t[0] * alpha + beta };
303        } else return ConvertToAutoDiff(node.GetSubtree(0));
304      }
305      if (node.Symbol is Sum) {
306        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
307        return new[] { TermBuilder.Sum(term) };
308      }
309      if (node.Symbol is Mean) {
310        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
311        return new[] { TermBuilder.Sum(term) / term.Count };
312      }
313      if (node.Symbol is StandardDeviation) {
314        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
315        var mean = TermBuilder.Sum(term) / term.Count;
316        var ssd = TermBuilder.Sum(term.Select(t => TermBuilder.Power(t - mean, 2.0)));
317        return new[] { TermBuilder.Power(ssd / term.Count, 0.5) };
318      }
319      if (node.Symbol is Length) {
320        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
321        return new[] { TermBuilder.Constant(term.Count) };
322      }
323      //if (node.Symbol is Min) {
324      //}
325      //if (node.Symbol is Max) {
326      //}
327      if (node.Symbol is Variance) {
328        var term = node.Subtrees.Select(ConvertToAutoDiff).Single();
329        var mean = TermBuilder.Sum(term) / term.Count;
330        var ssd = TermBuilder.Sum(term.Select(t => TermBuilder.Power(t - mean, 2.0)));
331        return new[] { ssd / term.Count };
332      }
333      //if (node.Symbol is Skewness) {
334      //}
335      //if (node.Symbol is Kurtosis) {
336      //}
337      //if (node.Symbol is EuclideanDistance) {
338      //}
339      //if (node.Symbol is Covariance) {
340      //}
341
342
343      throw new ConversionException();
344    }
345
346
347    // for each factor variable value we need a parameter which represents a binary indicator for that variable & value combination
348    // each binary indicator is only necessary once. So we only create a parameter if this combination is not yet available
349    private static Term FindOrCreateParameter(Dictionary<DataForVariable, AutoDiff.Variable> parameters,
350      string varName, string varValue = "", int lag = 0, int index = -1) {
351      var data = new DataForVariable(varName, varValue, lag, index);
352
353      AutoDiff.Variable par = null;
354      if (!parameters.TryGetValue(data, out par)) {
355        // not found -> create new parameter and entries in names and values lists
356        par = new AutoDiff.Variable();
357        parameters.Add(data, par);
358      }
359      return par;
360    }
361
362    public static bool IsCompatible(ISymbolicExpressionTree tree) {
363      var containsUnknownSymbol = (
364        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
365        where
366          !(n.Symbol is Variable) &&
367          !(n.Symbol is BinaryFactorVariable) &&
368          //!(n.Symbol is FactorVariable) &&
369          //!(n.Symbol is LaggedVariable) &&
370          !(n.Symbol is Constant) &&
371          !(n.Symbol is Addition) &&
372          !(n.Symbol is Subtraction) &&
373          !(n.Symbol is Multiplication) &&
374          !(n.Symbol is Division) &&
375          !(n.Symbol is Logarithm) &&
376          !(n.Symbol is Exponential) &&
377          !(n.Symbol is SquareRoot) &&
378          !(n.Symbol is Square) &&
379          !(n.Symbol is Sine) &&
380          !(n.Symbol is Cosine) &&
381          !(n.Symbol is Tangent) &&
382          !(n.Symbol is HyperbolicTangent) &&
383          !(n.Symbol is Erf) &&
384          !(n.Symbol is Norm) &&
385          !(n.Symbol is StartSymbol) &&
386          !(n.Symbol is Absolute) &&
387          //!(n.Symbol is AnalyticQuotient) &&
388          !(n.Symbol is Cube) &&
389          !(n.Symbol is CubeRoot) &&
390          !(n.Symbol is Sum) &&
391          !(n.Symbol is Mean) &&
392          !(n.Symbol is StandardDeviation) &&
393          !(n.Symbol is Length) &&
394          //!(n.Symbol is Min) &&
395          //!(n.Symbol is Max) &&
396          !(n.Symbol is Variance)
397        //!(n.Symbol is Skewness) &&
398        //!(n.Symbol is Kurtosis) &&
399        //!(n.Symbol is EuclideanDistance) &&
400        //!(n.Symbol is Covariance)
401        select n).Any();
402      return !containsUnknownSymbol;
403    }
404    #region exception class
405    [Serializable]
406    public class ConversionException : Exception {
407
408      public ConversionException() {
409      }
410
411      public ConversionException(string message) : base(message) {
412      }
413
414      public ConversionException(string message, Exception inner) : base(message, inner) {
415      }
416
417      protected ConversionException(
418        SerializationInfo info,
419        StreamingContext context) : base(info, context) {
420      }
421    }
422    #endregion
423  }
424}
Note: See TracBrowser for help on using the repository browser.