Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/TreeToDiffSharpConverter.cs @ 18239

Last change on this file since 18239 was 18239, checked in by pfleck, 2 years ago

#3040 Updated to newer TensorFlow.NET version.

  • Removed IL Merge from TensorFlow.NET.
  • Temporarily removed DiffSharp.
  • Changed to a locally built Attic with a specific Protobuf version that is compatible with TensorFlow.NET. (Also adapted other versions of nuget dependencies.)
File size: 18.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22#if INCLUDE_DIFFSHARP
23
24using System;
25using System.Collections.Generic;
26using System.Linq;
27using System.Runtime.Serialization;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using DiffSharp.Interop.Float64;
30using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector<double>;
31
32
33namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
34  public class TreeToDiffSharpConverter {
35    public delegate D ParametricFunction(D[] vars, D[] @params, DV[] vectorParams);
36
37    #region helper class
38    public class DataForVariable {
39      public readonly string variableName;
40      public readonly string variableValue; // for factor vars
41      public readonly int lag;
42
43      public DataForVariable(string varName, string varValue, int lag) {
44        this.variableName = varName;
45        this.variableValue = varValue;
46        this.lag = lag;
47      }
48
49      public override bool Equals(object obj) {
50        if (obj is DataForVariable other) {
51          return other.variableName.Equals(this.variableName) &&
52                 other.variableValue.Equals(this.variableValue) &&
53                 other.lag == this.lag;
54        }
55        return false;
56      }
57
58      public override int GetHashCode() {
59        return variableName.GetHashCode() ^ variableValue.GetHashCode() ^ lag;
60      }
61    }
62
63    public class EvaluationResult {
64      public D Scalar { get; }
65      public bool IsScalar => !ReferenceEquals(Scalar, NanScalar);
66
67      public DV Vector { get; }
68      public bool IsVector => !ReferenceEquals(Vector, NaNVector);
69
70      public bool IsNaN => !IsScalar && !IsVector;
71
72      public EvaluationResult(D scalar) {
73        if (scalar == null) throw new ArgumentNullException(nameof(scalar));
74        Scalar = scalar;
75        Vector = NaNVector;
76      }
77      public EvaluationResult(DV vector) {
78        if (vector == null) throw new ArgumentNullException(nameof(vector));
79        Scalar = NanScalar;
80        Vector = vector;
81      }
82      private EvaluationResult() {
83        Scalar = NanScalar;
84        Vector = NaNVector;
85      }
86
87      private static readonly DV NaNVector = new DV(new[] { double.NaN });
88      private static readonly D NanScalar = new D(double.NaN);
89      public static readonly EvaluationResult NaN = new EvaluationResult();
90    }
91    #endregion
92
93    public static D Evaluate(ISymbolicExpressionTree tree,
94      bool makeVariableWeightsVariable, bool addLinearScalingTerms,
95      DV variables,
96      IDictionary<string, D> scalarParameters, IDictionary<string, DV> vectorsParameters) {
97
98      var transformator = new TreeToDiffSharpConverter(
99        variables,
100        scalarParameters, vectorsParameters,
101        makeVariableWeightsVariable, addLinearScalingTerms);
102
103      var result = transformator.ConvertNode(tree.Root.GetSubtree(0));
104      if (!result.IsScalar) throw new InvalidOperationException("Result of evaluation is not a scalar.");
105      return result.Scalar;
106    }
107
108    //public static bool TryConvert(ISymbolicExpressionTree tree, IDataset dataset,
109    //  bool makeVariableWeightsVariable, bool addLinearScalingTerms,
110    //  out double[] initialConstants, out List<DataForVariable> scalarParameters, out List<DataForVariable> vectorParameters,
111    //  out D func) {
112
113    //  var transformator = new TreeToDiffSharpConverter(dataset, makeVariableWeightsVariable, addLinearScalingTerms);
114    //  try {
115    //    D term = transformator.ConvertNode(tree.Root.GetSubtree(0));
116    //    initialConstants = transformator.initialConstants.ToArray();
117    //    var scalarParameterEntries = transformator.scalarParameters.ToArray(); // guarantee same order for keys and values
118    //    var vectorParameterEntries = transformator.vectorParameters.ToArray(); // guarantee same order for keys and values
119    //    scalarParameters = scalarParameterEntries.Select(kvp => kvp.Key).ToList();
120    //    vectorParameters = vectorParameterEntries.Select(kvp => kvp.Key).ToList();
121    //    func = term;
122    //    return true;
123    //  } catch (ConversionException) {
124    //    initialConstants = null;
125    //    scalarParameters = null;
126    //    vectorParameters = null;
127    //    func = null;
128    //  }
129    //  return false;
130    //}
131
132    public static List<double> GetInitialConstants(ISymbolicExpressionTree tree,
133      bool makeVariableWeightsVariable, bool addLinearScalingTerms,
134      IDictionary<string, D> scalarParameters, IDictionary<string, DV> vectorsParameters) {
135
136      var transformator = new TreeToDiffSharpConverter( /*dataset, */
137        null,
138        scalarParameters, vectorsParameters,
139        makeVariableWeightsVariable, addLinearScalingTerms);
140
141      transformator.ConvertNode(tree.Root.GetSubtree(0));
142      return transformator.initialConstants;
143    }
144
145    /*private readonly IDataset dataset;*/
146    private readonly IDictionary<string, D> scalarParameters;
147    private readonly IDictionary<string, DV> vectorsParameters;
148    private readonly bool makeVariableWeightsVariable;
149    private readonly bool addLinearScalingTerms;
150
151    private readonly List<double> initialConstants;
152    private readonly DV variables;
153    private int variableIdx;
154    //private readonly Dictionary<DataForVariable, D> scalarParameters;
155    //private readonly Dictionary<DataForVariable, DV> vectorParameters;
156
157    private TreeToDiffSharpConverter(/*IDataset dataset,*/
158      DV variables,
159      IDictionary<string, D> scalarParameters, IDictionary<string, DV> vectorsParameters,
160      bool makeVariableWeightsVariable, bool addLinearScalingTerms) {
161      /*this.dataset = dataset;*/
162      this.scalarParameters = scalarParameters;
163      this.vectorsParameters = vectorsParameters;
164      this.makeVariableWeightsVariable = makeVariableWeightsVariable;
165      this.addLinearScalingTerms = addLinearScalingTerms;
166
167      initialConstants = new List<double>();
168      this.variables = variables;
169      variableIdx = 0;
170      //scalarParameters = new Dictionary<DataForVariable, D>();
171      //vectorParameters = new Dictionary<DataForVariable, DV>();
172    }
173
174    #region Evaluation helpers
175    private static EvaluationResult ArithmeticApply(EvaluationResult lhs, EvaluationResult rhs,
176      /*Func<DV, DV, (DV, DV)> lengthStrategy,*/
177      Func<D, D, D> ssFunc = null,
178      Func<D, DV, DV> svFunc = null,
179      Func<DV, D, DV> vsFunc = null,
180      Func<DV, DV, DV> vvFunc = null) {
181
182      if (lhs.IsScalar && rhs.IsScalar && ssFunc != null) return new EvaluationResult(ssFunc(lhs.Scalar, rhs.Scalar));
183      if (lhs.IsScalar && rhs.IsVector && svFunc != null) return new EvaluationResult(svFunc(lhs.Scalar, rhs.Vector));
184      if (lhs.IsVector && rhs.IsScalar && vsFunc != null) return new EvaluationResult(vsFunc(lhs.Vector, rhs.Scalar));
185      if (lhs.IsVector && rhs.IsVector && vvFunc != null) return new EvaluationResult(vvFunc(lhs.Vector, rhs.Vector));
186      /* }
187      if (lhs.Vector.Count == rhs.Vector.Count) {
188        return new EvaluationResult(vvFunc(lhs.Vector, rhs.Vector));
189      } else {
190        var (lhsVector, rhsVector) = lengthStrategy(lhs.Vector, rhs.Vector);
191        return new EvaluationResult(vvFunc(lhsVector, rhsVector));
192      }*/
193      return EvaluationResult.NaN;
194    }
195
196    private static EvaluationResult FunctionApply(EvaluationResult val,
197      Func<D, D> sFunc = null,
198      Func<DV, DV> vFunc = null) {
199      if (val.IsScalar && sFunc != null) return new EvaluationResult(sFunc(val.Scalar));
200      if (val.IsVector && vFunc != null) return new EvaluationResult(vFunc(val.Vector));
201      return EvaluationResult.NaN;
202    }
203    private static EvaluationResult AggregateApply(EvaluationResult val,
204      Func<D, D> sFunc = null,
205      Func<DV, D> vFunc = null) {
206      if (val.IsScalar && sFunc != null) return new EvaluationResult(sFunc(val.Scalar));
207      if (val.IsVector && vFunc != null) return new EvaluationResult(vFunc(val.Vector));
208      return EvaluationResult.NaN;
209    }
210    #endregion
211
212    private EvaluationResult ConvertNode(ISymbolicExpressionTreeNode node) {
213
214      if (node.Symbol is Constant) {
215        // assume scalar constant
216        var constant = ((ConstantTreeNode)node).Value;
217        initialConstants.Add(constant);
218
219        var c = variables?[variableIdx++] ?? constant;
220
221        return new EvaluationResult(c);
222      }
223
224      if (node.Symbol is Variable) {
225        var varNode = node as VariableTreeNodeBase;
226        if (scalarParameters.ContainsKey(varNode.VariableName)) {
227          var par = scalarParameters[varNode.VariableName];
228          if (makeVariableWeightsVariable) {
229            var weight = varNode.Weight;
230            initialConstants.Add(weight);
231            var w = variables?[variableIdx++] ?? weight;
232            return new EvaluationResult(w * par);
233          } else {
234            return new EvaluationResult(varNode.Weight * par);
235          }
236        } else if (vectorsParameters.ContainsKey(varNode.VariableName)) {
237          var par = vectorsParameters[varNode.VariableName];
238          if (makeVariableWeightsVariable) {
239            var weight = varNode.Weight;
240            initialConstants.Add(weight);
241            var w = variables?[variableIdx++] ?? weight;
242            return new EvaluationResult(w * par);
243          } else {
244            return new EvaluationResult(varNode.Weight * par);
245          }
246        }
247      }
248      //if (node.Symbol is FactorVariable) {
249      //  var factorVarNode = node as FactorVariableTreeNode;
250      //  var products = new List<D>();
251      //  foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
252      //    var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);
253
254      //    var wVar = new D(factorVarNode.GetValue(variableValue));
255      //    variables.Add(wVar);
256      //    products.Add(wVar * par);
257      //  }
258
259      //  return products.Aggregate((x, y) => x + y);
260      //}
261
262      if (node.Symbol is Addition) {
263        var terms = node.Subtrees.Select(ConvertNode).ToList();
264        return terms.Aggregate((a, b) =>
265          ArithmeticApply(a, b,
266            (s1, s2) => s1 + s2,
267            (s1, v2) => s1 + v2,
268            (v1, s2) => v1 + s2,
269            (v1, v2) => v1 + v2
270          ));
271      }
272      if (node.Symbol is Subtraction) {
273        var terms = node.Subtrees.Select(ConvertNode).ToList();
274        if (terms.Count == 1) return FunctionApply(terms[0],
275            s => D.Neg(s),
276            v => DV.Neg(v));
277        return terms.Aggregate((a, b) =>
278          ArithmeticApply(a, b,
279            (s1, s2) => s1 - s2,
280            (s1, v2) => s1 - v2,
281            (v1, s2) => v1 - s2,
282            (v1, v2) => v1 - v2
283          ));
284      }
285      if (node.Symbol is Multiplication) {
286        var terms = node.Subtrees.Select(ConvertNode).ToList();
287        return terms.Aggregate((a, b) =>
288          ArithmeticApply(a, b,
289            (s1, s2) => s1 * s2,
290            (s1, v2) => s1 * v2,
291            (v1, s2) => v1 * s2,
292            (v1, v2) => DV.PointwiseMultiply(v1, v2)
293          ));
294      }
295      if (node.Symbol is Division) {
296        var terms = node.Subtrees.Select(ConvertNode).ToList();
297        if (terms.Count == 1) return FunctionApply(terms[0],
298          s => 1.0 / s,
299          v => 1.0 / v);
300        return terms.Aggregate((a, b) =>
301          ArithmeticApply(a, b,
302            (s1, s2) => s1 / s2,
303            (s1, v2) => s1 / v2,
304            (v1, s2) => v1 / s2,
305            (v1, v2) => DV.PointwiseDivision(v1, v2)
306          ));
307      }
308
309      if (node.Symbol is Absolute) {
310        return FunctionApply(ConvertNode(node.GetSubtree(0)),
311          s => D.Abs(s),
312          v => DV.Abs(v)
313        );
314      }
315
316      if (node.Symbol is Logarithm) {
317        return FunctionApply(ConvertNode(node.GetSubtree(0)),
318          s => D.Log(s),
319          v => DV.Log(v)
320        );
321      }
322      if (node.Symbol is Exponential) {
323        return FunctionApply(ConvertNode(node.GetSubtree(0)),
324          s => D.Pow(Math.E, s),
325          v => DV.Pow(Math.E, v)
326        );
327      }
328      if (node.Symbol is Square) {
329        return FunctionApply(ConvertNode(node.GetSubtree(0)),
330          s => D.Pow(s, 2),
331          v => DV.Pow(v, 2)
332        );
333      }
334      if (node.Symbol is SquareRoot) {
335        return FunctionApply(ConvertNode(node.GetSubtree(0)),
336          s => D.Sqrt(s),
337          v => DV.Sqrt(v)
338        );
339      }
340      if (node.Symbol is Cube) {
341        return FunctionApply(ConvertNode(node.GetSubtree(0)),
342          s => D.Pow(s, 3),
343          v => DV.Pow(v, 3)
344        );
345      }
346      if (node.Symbol is CubeRoot) {
347        return FunctionApply(ConvertNode(node.GetSubtree(0)),
348          s => D.Sign(s) * D.Pow(D.Abs(s), 1.0 / 3.0),
349          v => DV.PointwiseMultiply(DV.Sign(v), DV.Pow(DV.Abs(v), 1.0 / 3.0))
350        );
351      }
352
353      if (node.Symbol is Sine) {
354        return FunctionApply(ConvertNode(node.GetSubtree(0)),
355          s => D.Sin(s),
356          v => DV.Sin(v)
357        );
358      }
359      if (node.Symbol is Cosine) {
360        return FunctionApply(ConvertNode(node.GetSubtree(0)),
361          s => D.Cos(s),
362          v => DV.Cos(v)
363        );
364      }
365      if (node.Symbol is Tangent) {
366        return FunctionApply(ConvertNode(node.GetSubtree(0)),
367          s => D.Tan(s),
368          v => DV.Tan(v)
369        );
370      }
371      if (node.Symbol is HyperbolicTangent) {
372        return FunctionApply(ConvertNode(node.GetSubtree(0)),
373          s => D.Tanh(s),
374          v => DV.Tanh(v)
375        );
376      }
377
378      if (node.Symbol is Sum) {
379        return AggregateApply(ConvertNode(node.GetSubtree(0)),
380          s => s,
381          v => DV.Sum(v)
382        );
383      }
384      if (node.Symbol is Mean) {
385        return AggregateApply(ConvertNode(node.GetSubtree(0)),
386          s => s,
387          v => DV.Mean(v)
388        );
389      }
390      if (node.Symbol is StandardDeviation) {
391        return AggregateApply(ConvertNode(node.GetSubtree(0)),
392          s => 0,
393          v => DV.StandardDev(v) //TODO: use pop-stdev instead
394        );
395      }
396      if (node.Symbol is Length) {
397        return AggregateApply(ConvertNode(node.GetSubtree(0)),
398          s => 1,
399          v => DV.Sum(v) / DV.Mean(v) // TODO: no length?
400        );
401      }
402      if (node.Symbol is Min) {
403        return AggregateApply(ConvertNode(node.GetSubtree(0)),
404          s => s,
405          v => DV.Min(v)
406        );
407      }
408      if (node.Symbol is Max) {
409        return AggregateApply(ConvertNode(node.GetSubtree(0)),
410          s => s,
411          v => DV.Max(v)
412        );
413      }
414      if (node.Symbol is Variance) {
415        return AggregateApply(ConvertNode(node.GetSubtree(0)),
416          s => s,
417          v => DV.Variance(v)
418        );
419      }
420      //if (node.Symbol is Skewness) {
421      //}
422      //if (node.Symbol is Kurtosis) {
423      //}
424      //if (node.Symbol is EuclideanDistance) {
425      //}
426      //if (node.Symbol is Covariance) {
427      //}
428
429      if (node.Symbol is StartSymbol) {
430        if (addLinearScalingTerms) {
431          // scaling variables α, β are given at the beginning of the parameter vector
432          initialConstants.Add(0.0);
433          initialConstants.Add(1.0);
434          var beta = variables?[variableIdx++] ?? 0.0;
435          var alpha = variables?[variableIdx++] ?? 1.0;
436          var t = ConvertNode(node.GetSubtree(0));
437          if (!t.IsScalar) throw new InvalidOperationException("Must be a scalar result");
438          return new EvaluationResult(t.Scalar * alpha + beta);
439        } else return ConvertNode(node.GetSubtree(0));
440      }
441
442      throw new ConversionException();
443    }
444
445    public static bool IsCompatible(ISymbolicExpressionTree tree) {
446      var containsUnknownSymbol = (
447        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
448        where
449          !(n.Symbol is Variable) &&
450          //!(n.Symbol is BinaryFactorVariable) &&
451          //!(n.Symbol is FactorVariable) &&
452          //!(n.Symbol is LaggedVariable) &&
453          !(n.Symbol is Constant) &&
454          !(n.Symbol is Addition) &&
455          !(n.Symbol is Subtraction) &&
456          !(n.Symbol is Multiplication) &&
457          !(n.Symbol is Division) &&
458          !(n.Symbol is Logarithm) &&
459          !(n.Symbol is Exponential) &&
460          !(n.Symbol is SquareRoot) &&
461          !(n.Symbol is Square) &&
462          !(n.Symbol is Sine) &&
463          !(n.Symbol is Cosine) &&
464          !(n.Symbol is Tangent) &&
465          !(n.Symbol is HyperbolicTangent) &&
466          //!(n.Symbol is Erf) &&
467          //!(n.Symbol is Norm) &&
468          !(n.Symbol is StartSymbol) &&
469          !(n.Symbol is Absolute) &&
470          //!(n.Symbol is AnalyticQuotient) &&
471          !(n.Symbol is Cube) &&
472          !(n.Symbol is CubeRoot) &&
473          !(n.Symbol is Sum) &&
474          !(n.Symbol is Mean) &&
475          !(n.Symbol is StandardDeviation) &&
476          //!(n.Symbol is Length) &&
477          !(n.Symbol is Min) &&
478          !(n.Symbol is Max) &&
479          !(n.Symbol is Variance)
480        //!(n.Symbol is Skewness) &&
481        //!(n.Symbol is Kurtosis) &&
482        //!(n.Symbol is EuclideanDistance) &&
483        //!(n.Symbol is Covariance)
484        select n).Any();
485      return !containsUnknownSymbol;
486    }
487
488    #region exception class
489    [Serializable]
490    public class ConversionException : Exception {
491
492      public ConversionException() {
493      }
494
495      public ConversionException(string message) : base(message) {
496      }
497
498      public ConversionException(string message, Exception inner) : base(message, inner) {
499      }
500
501      protected ConversionException(
502        SerializationInfo info,
503        StreamingContext context) : base(info, context) {
504      }
505    }
506    #endregion
507  }
508}
509
510#endif
Note: See TracBrowser for help on using the repository browser.