Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/TreeToDiffSharpConverter.cs @ 17786

Last change on this file since 17786 was 17786, checked in by pfleck, 3 years ago

#3040 Worked in DiffSharp for constant-opt.

File size: 18.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Runtime.Serialization;
26using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
27using DiffSharp.Interop.Float64;
28using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector<double>;
29
30
31namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
32  public class TreeToDiffSharpConverter {
33    public delegate D ParametricFunction(D[] vars, D[] @params, DV[] vectorParams);
34
35    #region helper class
36    public class DataForVariable {
37      public readonly string variableName;
38      public readonly string variableValue; // for factor vars
39      public readonly int lag;
40
41      public DataForVariable(string varName, string varValue, int lag) {
42        this.variableName = varName;
43        this.variableValue = varValue;
44        this.lag = lag;
45      }
46
47      public override bool Equals(object obj) {
48        if (obj is DataForVariable other) {
49          return other.variableName.Equals(this.variableName) &&
50                 other.variableValue.Equals(this.variableValue) &&
51                 other.lag == this.lag;
52        }
53        return false;
54      }
55
56      public override int GetHashCode() {
57        return variableName.GetHashCode() ^ variableValue.GetHashCode() ^ lag;
58      }
59    }
60
61    public class EvaluationResult {
62      public D Scalar { get; }
63      public bool IsScalar => !ReferenceEquals(Scalar, NanScalar);
64
65      public DV Vector { get; }
66      public bool IsVector => !ReferenceEquals(Vector, NaNVector);
67
68      public bool IsNaN => !IsScalar && !IsVector;
69
70      public EvaluationResult(D scalar) {
71        if (scalar == null) throw new ArgumentNullException(nameof(scalar));
72        Scalar = scalar;
73        Vector = NaNVector;
74      }
75      public EvaluationResult(DV vector) {
76        if (vector == null) throw new ArgumentNullException(nameof(vector));
77        Scalar = NanScalar;
78        Vector = vector;
79      }
80      private EvaluationResult() {
81        Scalar = NanScalar;
82        Vector = NaNVector;
83      }
84
85      private static readonly DV NaNVector = new DV(new[] { double.NaN });
86      private static readonly D NanScalar = new D(double.NaN);
87      public static readonly EvaluationResult NaN = new EvaluationResult();
88    }
89    #endregion
90
91    public static D Evaluate(ISymbolicExpressionTree tree,
92      bool makeVariableWeightsVariable, bool addLinearScalingTerms,
93      DV variables,
94      IDictionary<string, D> scalarParameters, IDictionary<string, DV> vectorsParameters) {
95
96      var transformator = new TreeToDiffSharpConverter(
97        variables,
98        scalarParameters, vectorsParameters,
99        makeVariableWeightsVariable, addLinearScalingTerms);
100
101      var result = transformator.ConvertNode(tree.Root.GetSubtree(0));
102      if (!result.IsScalar) throw new InvalidOperationException("Result of evaluation is not a scalar.");
103      return result.Scalar;
104    }
105
106    //public static bool TryConvert(ISymbolicExpressionTree tree, IDataset dataset,
107    //  bool makeVariableWeightsVariable, bool addLinearScalingTerms,
108    //  out double[] initialConstants, out List<DataForVariable> scalarParameters, out List<DataForVariable> vectorParameters,
109    //  out D func) {
110
111    //  var transformator = new TreeToDiffSharpConverter(dataset, makeVariableWeightsVariable, addLinearScalingTerms);
112    //  try {
113    //    D term = transformator.ConvertNode(tree.Root.GetSubtree(0));
114    //    initialConstants = transformator.initialConstants.ToArray();
115    //    var scalarParameterEntries = transformator.scalarParameters.ToArray(); // guarantee same order for keys and values
116    //    var vectorParameterEntries = transformator.vectorParameters.ToArray(); // guarantee same order for keys and values
117    //    scalarParameters = scalarParameterEntries.Select(kvp => kvp.Key).ToList();
118    //    vectorParameters = vectorParameterEntries.Select(kvp => kvp.Key).ToList();
119    //    func = term;
120    //    return true;
121    //  } catch (ConversionException) {
122    //    initialConstants = null;
123    //    scalarParameters = null;
124    //    vectorParameters = null;
125    //    func = null;
126    //  }
127    //  return false;
128    //}
129
130    public static List<double> GetInitialConstants(ISymbolicExpressionTree tree,
131      bool makeVariableWeightsVariable, bool addLinearScalingTerms,
132      IDictionary<string, D> scalarParameters, IDictionary<string, DV> vectorsParameters) {
133
134      var transformator = new TreeToDiffSharpConverter( /*dataset, */
135        null,
136        scalarParameters, vectorsParameters,
137        makeVariableWeightsVariable, addLinearScalingTerms);
138
139      transformator.ConvertNode(tree.Root.GetSubtree(0));
140      return transformator.initialConstants;
141    }
142
143    /*private readonly IDataset dataset;*/
144    private readonly IDictionary<string, D> scalarParameters;
145    private readonly IDictionary<string, DV> vectorsParameters;
146    private readonly bool makeVariableWeightsVariable;
147    private readonly bool addLinearScalingTerms;
148
149    private readonly List<double> initialConstants;
150    private readonly DV variables;
151    private int variableIdx;
152    //private readonly Dictionary<DataForVariable, D> scalarParameters;
153    //private readonly Dictionary<DataForVariable, DV> vectorParameters;
154
155    private TreeToDiffSharpConverter(/*IDataset dataset,*/
156      DV variables,
157      IDictionary<string, D> scalarParameters, IDictionary<string, DV> vectorsParameters,
158      bool makeVariableWeightsVariable, bool addLinearScalingTerms) {
159      /*this.dataset = dataset;*/
160      this.scalarParameters = scalarParameters;
161      this.vectorsParameters = vectorsParameters;
162      this.makeVariableWeightsVariable = makeVariableWeightsVariable;
163      this.addLinearScalingTerms = addLinearScalingTerms;
164
165      initialConstants = new List<double>();
166      this.variables = variables;
167      variableIdx = 0;
168      //scalarParameters = new Dictionary<DataForVariable, D>();
169      //vectorParameters = new Dictionary<DataForVariable, DV>();
170    }
171
172    #region Evaluation helpers
173    private static EvaluationResult ArithmeticApply(EvaluationResult lhs, EvaluationResult rhs,
174      /*Func<DV, DV, (DV, DV)> lengthStrategy,*/
175      Func<D, D, D> ssFunc = null,
176      Func<D, DV, DV> svFunc = null,
177      Func<DV, D, DV> vsFunc = null,
178      Func<DV, DV, DV> vvFunc = null) {
179
180      if (lhs.IsScalar && rhs.IsScalar && ssFunc != null) return new EvaluationResult(ssFunc(lhs.Scalar, rhs.Scalar));
181      if (lhs.IsScalar && rhs.IsVector && svFunc != null) return new EvaluationResult(svFunc(lhs.Scalar, rhs.Vector));
182      if (lhs.IsVector && rhs.IsScalar && vsFunc != null) return new EvaluationResult(vsFunc(lhs.Vector, rhs.Scalar));
183      if (lhs.IsVector && rhs.IsVector && vvFunc != null) return new EvaluationResult(vvFunc(lhs.Vector, rhs.Vector));
184      /* }
185      if (lhs.Vector.Count == rhs.Vector.Count) {
186        return new EvaluationResult(vvFunc(lhs.Vector, rhs.Vector));
187      } else {
188        var (lhsVector, rhsVector) = lengthStrategy(lhs.Vector, rhs.Vector);
189        return new EvaluationResult(vvFunc(lhsVector, rhsVector));
190      }*/
191      return EvaluationResult.NaN;
192    }
193
194    private static EvaluationResult FunctionApply(EvaluationResult val,
195      Func<D, D> sFunc = null,
196      Func<DV, DV> vFunc = null) {
197      if (val.IsScalar && sFunc != null) return new EvaluationResult(sFunc(val.Scalar));
198      if (val.IsVector && vFunc != null) return new EvaluationResult(vFunc(val.Vector));
199      return EvaluationResult.NaN;
200    }
201    private static EvaluationResult AggregateApply(EvaluationResult val,
202      Func<D, D> sFunc = null,
203      Func<DV, D> vFunc = null) {
204      if (val.IsScalar && sFunc != null) return new EvaluationResult(sFunc(val.Scalar));
205      if (val.IsVector && vFunc != null) return new EvaluationResult(vFunc(val.Vector));
206      return EvaluationResult.NaN;
207    }
208    #endregion
209
210    private EvaluationResult ConvertNode(ISymbolicExpressionTreeNode node) {
211
212      if (node.Symbol is Constant) {
213        // assume scalar constant
214        var constant = ((ConstantTreeNode)node).Value;
215        initialConstants.Add(constant);
216
217        var c = variables?[variableIdx++] ?? constant;
218
219        return new EvaluationResult(c);
220      }
221
222      if (node.Symbol is Variable) {
223        var varNode = node as VariableTreeNodeBase;
224        if (scalarParameters.ContainsKey(varNode.VariableName)) {
225          var par = scalarParameters[varNode.VariableName];
226          if (makeVariableWeightsVariable) {
227            var weight = varNode.Weight;
228            initialConstants.Add(weight);
229            var w = variables?[variableIdx++] ?? weight;
230            return new EvaluationResult(w * par);
231          } else {
232            return new EvaluationResult(varNode.Weight * par);
233          }
234        } else if (vectorsParameters.ContainsKey(varNode.VariableName)) {
235          var par = vectorsParameters[varNode.VariableName];
236          if (makeVariableWeightsVariable) {
237            var weight = varNode.Weight;
238            initialConstants.Add(weight);
239            var w = variables?[variableIdx++] ?? weight;
240            return new EvaluationResult(w * par);
241          } else {
242            return new EvaluationResult(varNode.Weight * par);
243          }
244        }
245      }
246      //if (node.Symbol is FactorVariable) {
247      //  var factorVarNode = node as FactorVariableTreeNode;
248      //  var products = new List<D>();
249      //  foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
250      //    var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);
251
252      //    var wVar = new D(factorVarNode.GetValue(variableValue));
253      //    variables.Add(wVar);
254      //    products.Add(wVar * par);
255      //  }
256
257      //  return products.Aggregate((x, y) => x + y);
258      //}
259
260      if (node.Symbol is Addition) {
261        var terms = node.Subtrees.Select(ConvertNode).ToList();
262        return terms.Aggregate((a, b) =>
263          ArithmeticApply(a, b,
264            (s1, s2) => s1 + s2,
265            (s1, v2) => s1 + v2,
266            (v1, s2) => v1 + s2,
267            (v1, v2) => v1 + v2
268          ));
269      }
270      if (node.Symbol is Subtraction) {
271        var terms = node.Subtrees.Select(ConvertNode).ToList();
272        if (terms.Count == 1) return FunctionApply(terms[0],
273            s => -s,
274            v => DV.Neg(v));
275        return terms.Aggregate((a, b) =>
276          ArithmeticApply(a, b,
277            (s1, s2) => s1 - s2,
278            (s1, v2) => s1 - v2,
279            (v1, s2) => v1 - s2,
280            (v1, v2) => v1 - v2
281          ));
282      }
283      if (node.Symbol is Multiplication) {
284        var terms = node.Subtrees.Select(ConvertNode).ToList();
285        return terms.Aggregate((a, b) =>
286          ArithmeticApply(a, b,
287            (s1, s2) => s1 * s2,
288            (s1, v2) => s1 * v2,
289            (v1, s2) => v1 * s2,
290            (v1, v2) => DV.op_DotMultiply(v1, v2)
291          ));
292      }
293      if (node.Symbol is Division) {
294        var terms = node.Subtrees.Select(ConvertNode).ToList();
295        if (terms.Count == 1) return FunctionApply(terms[0],
296          s => 1.0 / s,
297          v => 1.0 / v);
298        return terms.Aggregate((a, b) =>
299          ArithmeticApply(a, b,
300            (s1, s2) => s1 / s2,
301            (s1, v2) => s1 / v2,
302            (v1, s2) => v1 / s2,
303            (v1, v2) => DV.op_DotDivide(v1, v2)
304          ));
305      }
306
307      if (node.Symbol is Absolute) {
308        return FunctionApply(ConvertNode(node.GetSubtree(0)),
309          s => D.Abs(s),
310          v => DV.Abs(v)
311        );
312      }
313
314      if (node.Symbol is Logarithm) {
315        return FunctionApply(ConvertNode(node.GetSubtree(0)),
316          s => D.Log(s),
317          v => DV.Log(v)
318        );
319      }
320      if (node.Symbol is Exponential) {
321        return FunctionApply(ConvertNode(node.GetSubtree(0)),
322          s => D.Pow(Math.E, s),
323          v => DV.Pow(Math.E, v)
324        );
325      }
326      if (node.Symbol is Square) {
327        return FunctionApply(ConvertNode(node.GetSubtree(0)),
328          s => D.Pow(s, 2),
329          v => DV.Pow(v, 2)
330        );
331      }
332      if (node.Symbol is SquareRoot) {
333        return FunctionApply(ConvertNode(node.GetSubtree(0)),
334          s => D.Sqrt(s),
335          v => DV.Sqrt(v)
336        );
337      }
338      if (node.Symbol is Cube) {
339        return FunctionApply(ConvertNode(node.GetSubtree(0)),
340          s => D.Pow(s, 3),
341          v => DV.Pow(v, 3)
342        );
343      }
344      if (node.Symbol is CubeRoot) {
345        return FunctionApply(ConvertNode(node.GetSubtree(0)),
346          s => D.Sign(s) * D.Pow(D.Abs(s), 1.0 / 3.0),
347          v => DV.op_DotMultiply(DV.Sign(v), DV.Pow(DV.Abs(v), 1.0 / 3.0))
348        );
349      }
350
351      if (node.Symbol is Sine) {
352        return FunctionApply(ConvertNode(node.GetSubtree(0)),
353          s => D.Sin(s),
354          v => DV.Sin(v)
355        );
356      }
357      if (node.Symbol is Cosine) {
358        return FunctionApply(ConvertNode(node.GetSubtree(0)),
359          s => D.Cos(s),
360          v => DV.Cos(v)
361        );
362      }
363      if (node.Symbol is Tangent) {
364        return FunctionApply(ConvertNode(node.GetSubtree(0)),
365          s => D.Tan(s),
366          v => DV.Tan(v)
367        );
368      }
369      if (node.Symbol is HyperbolicTangent) {
370        return FunctionApply(ConvertNode(node.GetSubtree(0)),
371          s => D.Tanh(s),
372          v => DV.Tanh(v)
373        );
374      }
375
376      if (node.Symbol is Sum) {
377        return AggregateApply(ConvertNode(node.GetSubtree(0)),
378          s => s,
379          v => DV.Sum(v)
380        );
381      }
382      if (node.Symbol is Mean) {
383        return AggregateApply(ConvertNode(node.GetSubtree(0)),
384          s => s,
385          v => DV.Mean(v)
386        );
387      }
388      if (node.Symbol is StandardDeviation) {
389        return AggregateApply(ConvertNode(node.GetSubtree(0)),
390          s => 0,
391          v => DV.StandardDev(v) //TODO: use pop-stdev instead
392        );
393      }
394      if (node.Symbol is Length) {
395        return AggregateApply(ConvertNode(node.GetSubtree(0)),
396          s => 1,
397          v => DV.Sum(v) / DV.Mean(v) // TODO: no length?
398        );
399      }
400      if (node.Symbol is Min) {
401        return AggregateApply(ConvertNode(node.GetSubtree(0)),
402          s => s,
403          v => DV.Min(v)
404        );
405      }
406      if (node.Symbol is Max) {
407        return AggregateApply(ConvertNode(node.GetSubtree(0)),
408          s => s,
409          v => DV.Max(v)
410        );
411      }
412      if (node.Symbol is Variance) {
413        return AggregateApply(ConvertNode(node.GetSubtree(0)),
414          s => s,
415          v => DV.Variance(v)
416        );
417      }
418      //if (node.Symbol is Skewness) {
419      //}
420      //if (node.Symbol is Kurtosis) {
421      //}
422      //if (node.Symbol is EuclideanDistance) {
423      //}
424      //if (node.Symbol is Covariance) {
425      //}
426
427      if (node.Symbol is StartSymbol) {
428        if (addLinearScalingTerms) {
429          // scaling variables α, β are given at the beginning of the parameter vector
430          initialConstants.Add(0.0);
431          initialConstants.Add(1.0);
432          var beta = variables?[variableIdx++] ?? 0.0;
433          var alpha = variables?[variableIdx++] ?? 1.0;
434          var t = ConvertNode(node.GetSubtree(0));
435          if (!t.IsScalar) throw new InvalidOperationException("Must be a scalar result");
436          return new EvaluationResult(t.Scalar * alpha + beta);
437        } else return ConvertNode(node.GetSubtree(0));
438      }
439
440      throw new ConversionException();
441    }
442
443    public static bool IsCompatible(ISymbolicExpressionTree tree) {
444      var containsUnknownSymbol = (
445        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
446        where
447          !(n.Symbol is Variable) &&
448          //!(n.Symbol is BinaryFactorVariable) &&
449          //!(n.Symbol is FactorVariable) &&
450          //!(n.Symbol is LaggedVariable) &&
451          !(n.Symbol is Constant) &&
452          !(n.Symbol is Addition) &&
453          !(n.Symbol is Subtraction) &&
454          !(n.Symbol is Multiplication) &&
455          !(n.Symbol is Division) &&
456          !(n.Symbol is Logarithm) &&
457          !(n.Symbol is Exponential) &&
458          !(n.Symbol is SquareRoot) &&
459          !(n.Symbol is Square) &&
460          !(n.Symbol is Sine) &&
461          !(n.Symbol is Cosine) &&
462          !(n.Symbol is Tangent) &&
463          !(n.Symbol is HyperbolicTangent) &&
464          //!(n.Symbol is Erf) &&
465          //!(n.Symbol is Norm) &&
466          !(n.Symbol is StartSymbol) &&
467          !(n.Symbol is Absolute) &&
468          //!(n.Symbol is AnalyticQuotient) &&
469          !(n.Symbol is Cube) &&
470          !(n.Symbol is CubeRoot) &&
471          !(n.Symbol is Sum) &&
472          !(n.Symbol is Mean) &&
473          !(n.Symbol is StandardDeviation) &&
474          //!(n.Symbol is Length) &&
475          !(n.Symbol is Min) &&
476          !(n.Symbol is Max) &&
477          !(n.Symbol is Variance)
478        //!(n.Symbol is Skewness) &&
479        //!(n.Symbol is Kurtosis) &&
480        //!(n.Symbol is EuclideanDistance) &&
481        //!(n.Symbol is Covariance)
482        select n).Any();
483      return !containsUnknownSymbol;
484    }
485
486    #region exception class
487    [Serializable]
488    public class ConversionException : Exception {
489
490      public ConversionException() {
491      }
492
493      public ConversionException(string message) : base(message) {
494      }
495
496      public ConversionException(string message, Exception inner) : base(message, inner) {
497      }
498
499      protected ConversionException(
500        SerializationInfo info,
501        StreamingContext context) : base(info, context) {
502      }
503    }
504    #endregion
505  }
506}
Note: See TracBrowser for help on using the repository browser.