Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeVectorInterpreter.cs @ 17786

Last change on this file since 17786 was 17786, checked in by pfleck, 3 years ago

#3040 Worked in DiffSharp for constant-opt.

File size: 34.0 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Parameters;
30using HEAL.Attic;
31using MathNet.Numerics;
32using MathNet.Numerics.Statistics;
33using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector<double>;
34
35namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
36  [StorableType("DE68A1D9-5AFC-4DDD-AB62-29F3B8FC28E0")]
37  [Item("SymbolicDataAnalysisExpressionTreeVectorInterpreter", "Interpreter for symbolic expression trees including vector arithmetic.")]
38  public class SymbolicDataAnalysisExpressionTreeVectorInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
39    [StorableType("2612504E-AD5F-4AE2-B60E-98A5AB59E164")]
40    public enum Aggregation {
41      Mean,
42      Median,
43      Sum,
44      First,
45      L1Norm,
46      L2Norm,
47      NaN,
48      Exception
49    }
50    public static double Aggregate(Aggregation aggregation, DoubleVector vector) {
51      switch (aggregation) {
52        case Aggregation.Mean: return Statistics.Mean(vector);
53        case Aggregation.Median: return Statistics.Median(vector);
54        case Aggregation.Sum: return vector.Sum();
55        case Aggregation.First: return vector.First();
56        case Aggregation.L1Norm: return vector.L1Norm();
57        case Aggregation.L2Norm: return vector.L2Norm();
58        case Aggregation.NaN: return double.NaN;
59        case Aggregation.Exception: throw new InvalidOperationException("Result of the tree is not a scalar.");
60        default: throw new ArgumentOutOfRangeException(nameof(aggregation), aggregation, null);
61      }
62    }
63
64    [StorableType("73DCBB45-916F-4139-8ADC-57BA610A1B66")]
65    public enum VectorLengthStrategy {
66      ExceptionIfDifferent,
67      FillShorterWithNaN,
68      FillShorterWithNeutralElement,
69      CutLonger,
70      ResampleToLonger,
71      ResampleToShorter,
72      CycleShorter
73    }
74
75    #region Implementation VectorLengthStrategy
76    public static (DoubleVector, DoubleVector) ExceptionIfDifferent(DoubleVector lhs, DoubleVector rhs) {
77      if (lhs.Count != rhs.Count)
78        throw new InvalidOperationException($"Vector Lengths incompatible ({lhs.Count} vs. {rhs.Count}");
79      return (lhs, rhs);
80    }
81
82    public static (DoubleVector, DoubleVector) FillShorter(DoubleVector lhs, DoubleVector rhs, double fillElement) {
83      var targetLength = Math.Max(lhs.Count, rhs.Count);
84
85      DoubleVector PadVector(DoubleVector v) {
86        if (v.Count == targetLength) return v;
87        var p = DoubleVector.Build.Dense(targetLength, fillElement);
88        v.CopySubVectorTo(p, 0, 0, v.Count);
89        return p;
90      }
91
92      return (PadVector(lhs), PadVector(rhs));
93    }
94
95    public static (DoubleVector, DoubleVector) CutLonger(DoubleVector lhs, DoubleVector rhs) {
96      var targetLength = Math.Min(lhs.Count, rhs.Count);
97
98      DoubleVector CutVector(DoubleVector v) {
99        if (v.Count == targetLength) return v;
100        return v.SubVector(0, targetLength);
101      }
102
103      return (CutVector(lhs), CutVector(rhs));
104    }
105
106    private static DoubleVector ResampleToLength(DoubleVector v, int targetLength) {
107      if (v.Count == targetLength) return v;
108
109      var indices = Enumerable.Range(0, v.Count).Select(x => (double)x);
110      var interpolation = Interpolate.Linear(indices, v);
111
112      var resampledIndices = Enumerable.Range(0, targetLength).Select(i => (double)i / targetLength * v.Count);
113      var interpolatedValues = resampledIndices.Select(interpolation.Interpolate);
114
115      return DoubleVector.Build.DenseOfEnumerable(interpolatedValues);
116    }
117    public static (DoubleVector, DoubleVector) ResampleToLonger(DoubleVector lhs, DoubleVector rhs) {
118      var maxLength = Math.Max(lhs.Count, rhs.Count);
119      return (ResampleToLength(lhs, maxLength), ResampleToLength(rhs, maxLength));
120    }
121    public static (DoubleVector, DoubleVector) ResampleToShorter(DoubleVector lhs, DoubleVector rhs) {
122      var minLength = Math.Min(lhs.Count, rhs.Count);
123      return (ResampleToLength(lhs, minLength), ResampleToLength(rhs, minLength));
124    }
125
126    public static (DoubleVector, DoubleVector) CycleShorter(DoubleVector lhs, DoubleVector rhs) {
127      var targetLength = Math.Max(lhs.Count, rhs.Count);
128
129      DoubleVector CycleVector(DoubleVector v) {
130        if (v.Count == targetLength) return v;
131        var cycledValues = Enumerable.Range(0, targetLength).Select(i => v[i % v.Count]);
132        return DoubleVector.Build.DenseOfEnumerable(cycledValues);
133      }
134
135      return (CycleVector(lhs), CycleVector(rhs));
136    }
137    #endregion
138
139    public static (DoubleVector lhs, DoubleVector rhs) ApplyVectorLengthStrategy(VectorLengthStrategy strategy, DoubleVector lhs, DoubleVector rhs,
140      double neutralElement = double.NaN) {
141
142      switch (strategy) {
143        case VectorLengthStrategy.ExceptionIfDifferent: return ExceptionIfDifferent(lhs, rhs);
144        case VectorLengthStrategy.FillShorterWithNaN: return FillShorter(lhs, rhs, double.NaN);
145        case VectorLengthStrategy.FillShorterWithNeutralElement: return FillShorter(lhs, rhs, neutralElement);
146        case VectorLengthStrategy.CutLonger: return CutLonger(lhs, rhs);
147        case VectorLengthStrategy.ResampleToLonger: return ResampleToLonger(lhs, rhs);
148        case VectorLengthStrategy.ResampleToShorter: return ResampleToShorter(lhs, rhs);
149        case VectorLengthStrategy.CycleShorter: return CycleShorter(lhs, rhs);
150        default: throw new ArgumentOutOfRangeException(nameof(strategy), strategy, null);
151      }
152    }
153
154    #region Aggregation Symbols
155    private static Type[] AggregationSymbols = new[] {
156      typeof(Sum), typeof(Mean), typeof(Length), typeof(StandardDeviation), typeof(Variance),
157      typeof(EuclideanDistance), typeof(Covariance)
158    };
159    #endregion
160
161    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
162    private const string FinalAggregationParameterName = "FinalAggregation";
163    private const string DifferentVectorLengthStrategyParameterName = "DifferentVectorLengthStrategy";
164
165    public override bool CanChangeName {
166      get { return false; }
167    }
168
169    public override bool CanChangeDescription {
170      get { return false; }
171    }
172
173    #region parameter properties
174    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
175      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
176    }
177    public IFixedValueParameter<EnumValue<Aggregation>> FinalAggregationParameter {
178      get { return (IFixedValueParameter<EnumValue<Aggregation>>)Parameters[FinalAggregationParameterName]; }
179    }
180    public IFixedValueParameter<EnumValue<VectorLengthStrategy>> DifferentVectorLengthStrategyParameter {
181      get { return (IFixedValueParameter<EnumValue<VectorLengthStrategy>>)Parameters[DifferentVectorLengthStrategyParameterName]; }
182    }
183    #endregion
184
185    #region properties
186    public int EvaluatedSolutions {
187      get { return EvaluatedSolutionsParameter.Value.Value; }
188      set { EvaluatedSolutionsParameter.Value.Value = value; }
189    }
190    public Aggregation FinalAggregation {
191      get { return FinalAggregationParameter.Value.Value; }
192      set { FinalAggregationParameter.Value.Value = value; }
193    }
194    public VectorLengthStrategy DifferentVectorLengthStrategy {
195      get { return DifferentVectorLengthStrategyParameter.Value.Value; }
196      set { DifferentVectorLengthStrategyParameter.Value.Value = value; }
197    }
198    #endregion
199
200    [StorableConstructor]
201    protected SymbolicDataAnalysisExpressionTreeVectorInterpreter(StorableConstructorFlag _) : base(_) { }
202
203    protected SymbolicDataAnalysisExpressionTreeVectorInterpreter(SymbolicDataAnalysisExpressionTreeVectorInterpreter original, Cloner cloner)
204      : base(original, cloner) { }
205
206    public override IDeepCloneable Clone(Cloner cloner) {
207      return new SymbolicDataAnalysisExpressionTreeVectorInterpreter(this, cloner);
208    }
209
210    public SymbolicDataAnalysisExpressionTreeVectorInterpreter()
211      : this("SymbolicDataAnalysisExpressionTreeVectorInterpreter", "Interpreter for symbolic expression trees including vector arithmetic.") {
212    }
213
214    protected SymbolicDataAnalysisExpressionTreeVectorInterpreter(string name, string description)
215      : base(name, description) {
216      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
217      Parameters.Add(new FixedValueParameter<EnumValue<Aggregation>>(FinalAggregationParameterName, "If root node of the expression tree results in a Vector it is aggregated according to this parameter", new EnumValue<Aggregation>(Aggregation.Mean)));
218      Parameters.Add(new FixedValueParameter<EnumValue<VectorLengthStrategy>>(DifferentVectorLengthStrategyParameterName, "", new EnumValue<VectorLengthStrategy>(VectorLengthStrategy.ExceptionIfDifferent)));
219    }
220
221    [StorableHook(HookType.AfterDeserialization)]
222    private void AfterDeserialization() {
223      if (!Parameters.ContainsKey(FinalAggregationParameterName)) {
224        Parameters.Add(new FixedValueParameter<EnumValue<Aggregation>>(FinalAggregationParameterName, "If root node of the expression tree results in a Vector it is aggregated according to this parameter", new EnumValue<Aggregation>(Aggregation.Mean)));
225      }
226      if (!Parameters.ContainsKey(DifferentVectorLengthStrategyParameterName)) {
227        Parameters.Add(new FixedValueParameter<EnumValue<VectorLengthStrategy>>(DifferentVectorLengthStrategyParameterName, "", new EnumValue<VectorLengthStrategy>(VectorLengthStrategy.ExceptionIfDifferent)));
228      }
229    }
230
231    #region IStatefulItem
232    public void InitializeState() {
233      EvaluatedSolutions = 0;
234    }
235
236    public void ClearState() { }
237    #endregion
238
239    private readonly object syncRoot = new object();
240    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
241      lock (syncRoot) {
242        EvaluatedSolutions++; // increment the evaluated solutions counter
243      }
244      var state = PrepareInterpreterState(tree, dataset);
245
246      foreach (var rowEnum in rows) {
247        int row = rowEnum;
248        var result = Evaluate(dataset, ref row, state);
249        if (result.IsScalar)
250          yield return result.Scalar;
251        else if (result.IsVector) {
252          yield return Aggregate(FinalAggregation, result.Vector);
253        } else
254          yield return double.NaN;
255        state.Reset();
256      }
257    }
258
259    public IEnumerable<Dictionary<ISymbolicExpressionTreeNode, EvaluationResult>> GetIntermediateNodeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
260      var state = PrepareInterpreterState(tree, dataset);
261
262      foreach (var rowEnum in rows) {
263        int row = rowEnum;
264        var traceDict = new Dictionary<ISymbolicExpressionTreeNode, EvaluationResult>();
265        var result = Evaluate(dataset, ref row, state, traceDict);
266        traceDict.Add(tree.Root.GetSubtree(0), result); // Add StartSymbol
267        yield return traceDict;
268        state.Reset();
269      }
270    }
271
272    private static InterpreterState PrepareInterpreterState(ISymbolicExpressionTree tree, IDataset dataset) {
273      Instruction[] code = SymbolicExpressionTreeCompiler.Compile(tree, OpCodes.MapSymbolToOpCode);
274      int necessaryArgStackSize = 0;
275      foreach (Instruction instr in code) {
276        if (instr.opCode == OpCodes.Variable) {
277          var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
278          if (dataset.VariableHasType<double>(variableTreeNode.VariableName))
279            instr.data = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
280          else if (dataset.VariableHasType<DoubleVector>(variableTreeNode.VariableName))
281            instr.data = dataset.GetReadOnlyDoubleVectorValues(variableTreeNode.VariableName);
282          else throw new NotSupportedException($"Type of variable {variableTreeNode.VariableName} is not supported.");
283        } else if (instr.opCode == OpCodes.FactorVariable) {
284          var factorTreeNode = instr.dynamicNode as FactorVariableTreeNode;
285          instr.data = dataset.GetReadOnlyStringValues(factorTreeNode.VariableName);
286        } else if (instr.opCode == OpCodes.BinaryFactorVariable) {
287          var factorTreeNode = instr.dynamicNode as BinaryFactorVariableTreeNode;
288          instr.data = dataset.GetReadOnlyStringValues(factorTreeNode.VariableName);
289        } else if (instr.opCode == OpCodes.LagVariable) {
290          var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
291          instr.data = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
292        } else if (instr.opCode == OpCodes.VariableCondition) {
293          var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
294          instr.data = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
295        } else if (instr.opCode == OpCodes.Call) {
296          necessaryArgStackSize += instr.nArguments + 1;
297        }
298      }
299      return new InterpreterState(code, necessaryArgStackSize);
300    }
301
302
303    public struct EvaluationResult {
304      public double Scalar { get; }
305      public bool IsScalar => !double.IsNaN(Scalar);
306
307      public DoubleVector Vector { get; }
308      public bool IsVector => !(Vector.Count == 1 && double.IsNaN(Vector[0]));
309
310      public bool IsNaN => !IsScalar && !IsVector;
311
312      public EvaluationResult(double scalar) {
313        Scalar = scalar;
314        Vector = NaNVector;
315      }
316      public EvaluationResult(DoubleVector vector) {
317        if (vector == null) throw new ArgumentNullException(nameof(vector));
318        Vector = vector;
319        Scalar = double.NaN;
320      }
321
322      public override string ToString() {
323        if (IsScalar) return Scalar.ToString();
324        if (IsVector) return Vector.ToVectorString();
325        return "NaN";
326      }
327
328      private static readonly DoubleVector NaNVector = DoubleVector.Build.Dense(1, double.NaN);
329      public static readonly EvaluationResult NaN = new EvaluationResult(double.NaN);
330    }
331
332    private static EvaluationResult ArithmeticApply(EvaluationResult lhs, EvaluationResult rhs,
333      Func<DoubleVector, DoubleVector, (DoubleVector, DoubleVector)> lengthStrategy,
334      Func<double, double, double> ssFunc = null,
335      Func<double, DoubleVector, DoubleVector> svFunc = null,
336      Func<DoubleVector, double, DoubleVector> vsFunc = null,
337      Func<DoubleVector, DoubleVector, DoubleVector> vvFunc = null) {
338
339      if (lhs.IsScalar && rhs.IsScalar && ssFunc != null) return new EvaluationResult(ssFunc(lhs.Scalar, rhs.Scalar));
340      if (lhs.IsScalar && rhs.IsVector && svFunc != null) return new EvaluationResult(svFunc(lhs.Scalar, rhs.Vector));
341      if (lhs.IsVector && rhs.IsScalar && vsFunc != null) return new EvaluationResult(vsFunc(lhs.Vector, rhs.Scalar));
342      if (lhs.IsVector && rhs.IsVector && vvFunc != null) {
343        if (lhs.Vector.Count == rhs.Vector.Count) {
344          return new EvaluationResult(vvFunc(lhs.Vector, rhs.Vector));
345        } else {
346          var (lhsVector, rhsVector) = lengthStrategy(lhs.Vector, rhs.Vector);
347          return new EvaluationResult(vvFunc(lhsVector, rhsVector));
348        }
349      }
350      return EvaluationResult.NaN;
351    }
352
353    private static EvaluationResult FunctionApply(EvaluationResult val,
354      Func<double, double> sFunc = null,
355      Func<DoubleVector, DoubleVector> vFunc = null) {
356      if (val.IsScalar && sFunc != null) return new EvaluationResult(sFunc(val.Scalar));
357      if (val.IsVector && vFunc != null) return new EvaluationResult(vFunc(val.Vector));
358      return EvaluationResult.NaN;
359    }
360    private static EvaluationResult AggregateApply(EvaluationResult val,
361      Func<double, double> sFunc = null,
362      Func<DoubleVector, double> vFunc = null) {
363      if (val.IsScalar && sFunc != null) return new EvaluationResult(sFunc(val.Scalar));
364      if (val.IsVector && vFunc != null) return new EvaluationResult(vFunc(val.Vector));
365      return EvaluationResult.NaN;
366    }
367
368    private static EvaluationResult WindowedAggregateApply(EvaluationResult val, WindowedSymbolTreeNode node,
369      Func<double, double> sFunc = null,
370      Func<DoubleVector, double> vFunc = null) {
371
372      var offset = node.Offset;
373      var length = node.Length;
374
375      DoubleVector SubVector(DoubleVector v) {
376        int index = (int)(offset * v.Count);
377        int count = (int)(length * (v.Count - index));
378        return v.SubVector(index, count);
379      };
380
381      if (val.IsScalar && sFunc != null) return new EvaluationResult(sFunc(val.Scalar));
382      if (val.IsVector && vFunc != null) return new EvaluationResult(vFunc(SubVector(val.Vector)));
383      return EvaluationResult.NaN;
384    }
385    private static EvaluationResult WindowedFunctionApply(EvaluationResult val, IWindowedSymbolTreeNode node,
386      Func<double, double> sFunc = null,
387      Func<DoubleVector, DoubleVector> vFunc = null) {
388      var offset = node.Offset;
389      var length = node.Length;
390
391      DoubleVector SubVector(DoubleVector v) {
392        int index = (int)(offset * v.Count);
393        int count = (int)(length * (v.Count - index));
394        return v.SubVector(index, count);
395      };
396
397      if (val.IsScalar && sFunc != null) return new EvaluationResult(sFunc(val.Scalar));
398      if (val.IsVector && vFunc != null) return new EvaluationResult(vFunc(SubVector(val.Vector)));
399      return EvaluationResult.NaN;
400    }
401
402    private static EvaluationResult AggregateMultipleApply(EvaluationResult lhs, EvaluationResult rhs,
403      Func<DoubleVector, DoubleVector, (DoubleVector, DoubleVector)> lengthStrategy,
404      Func<double, double, double> ssFunc = null,
405      Func<double, DoubleVector, double> svFunc = null,
406      Func<DoubleVector, double, double> vsFunc = null,
407      Func<DoubleVector, DoubleVector, double> vvFunc = null) {
408      if (lhs.IsScalar && rhs.IsScalar && ssFunc != null) return new EvaluationResult(ssFunc(lhs.Scalar, rhs.Scalar));
409      if (lhs.IsScalar && rhs.IsVector && svFunc != null) return new EvaluationResult(svFunc(lhs.Scalar, rhs.Vector));
410      if (lhs.IsVector && rhs.IsScalar && vsFunc != null) return new EvaluationResult(vsFunc(lhs.Vector, rhs.Scalar));
411      if (lhs.IsVector && rhs.IsVector && vvFunc != null) {
412        if (lhs.Vector.Count == rhs.Vector.Count) {
413          return new EvaluationResult(vvFunc(lhs.Vector, rhs.Vector));
414        } else {
415          var (lhsVector, rhsVector) = lengthStrategy(lhs.Vector, rhs.Vector);
416          return new EvaluationResult(vvFunc(lhsVector, rhsVector));
417        }
418      }
419      return EvaluationResult.NaN;
420    }
421
422    public virtual Type GetNodeType(ISymbolicExpressionTreeNode node) {
423      if (node.DataType != null)
424        return node.DataType;
425
426      if (AggregationSymbols.Contains(node.Symbol.GetType()))
427        return typeof(double);
428
429      var argumentTypes = node.Subtrees.Select(GetNodeType);
430      if (argumentTypes.Any(t => t == typeof(DoubleVector)))
431        return typeof(DoubleVector);
432
433      return typeof(double);
434    }
435
436
437    public virtual EvaluationResult Evaluate(IDataset dataset, ref int row, InterpreterState state,
438      IDictionary<ISymbolicExpressionTreeNode, EvaluationResult> traceDict = null) {
439
440      void TraceEvaluation(Instruction instr, EvaluationResult result) {
441        traceDict?.Add(instr.dynamicNode, result);
442      }
443
444      Instruction currentInstr = state.NextInstruction();
445      switch (currentInstr.opCode) {
446        case OpCodes.Add: {
447            var cur = Evaluate(dataset, ref row, state, traceDict);
448            for (int i = 1; i < currentInstr.nArguments; i++) {
449              var op = Evaluate(dataset, ref row, state, traceDict);
450              cur = ArithmeticApply(cur, op,
451                (lhs, rhs) => ApplyVectorLengthStrategy(DifferentVectorLengthStrategy, lhs, rhs, 0.0),
452                (s1, s2) => s1 + s2,
453                (s1, v2) => s1 + v2,
454                (v1, s2) => v1 + s2,
455                (v1, v2) => v1 + v2);
456            }
457            TraceEvaluation(currentInstr, cur);
458            return cur;
459          }
460        case OpCodes.Sub: {
461            var cur = Evaluate(dataset, ref row, state, traceDict);
462            for (int i = 1; i < currentInstr.nArguments; i++) {
463              var op = Evaluate(dataset, ref row, state, traceDict);
464              cur = ArithmeticApply(cur, op,
465                (lhs, rhs) => ApplyVectorLengthStrategy(DifferentVectorLengthStrategy, lhs, rhs, 0.0),
466                (s1, s2) => s1 - s2,
467                (s1, v2) => s1 - v2,
468                (v1, s2) => v1 - s2,
469                (v1, v2) => v1 - v2);
470            }
471            if (currentInstr.nArguments == 1)
472              cur = FunctionApply(cur,
473                s => -s,
474                v => -v);
475            TraceEvaluation(currentInstr, cur);
476            return cur;
477          }
478        case OpCodes.Mul: {
479            var cur = Evaluate(dataset, ref row, state, traceDict);
480            for (int i = 1; i < currentInstr.nArguments; i++) {
481              var op = Evaluate(dataset, ref row, state, traceDict);
482              cur = ArithmeticApply(cur, op,
483                (lhs, rhs) => ApplyVectorLengthStrategy(DifferentVectorLengthStrategy, lhs, rhs, 1.0),
484                (s1, s2) => s1 * s2,
485                (s1, v2) => s1 * v2,
486                (v1, s2) => v1 * s2,
487                (v1, v2) => v1.PointwiseMultiply(v2));
488            }
489            TraceEvaluation(currentInstr, cur);
490            return cur;
491          }
492        case OpCodes.Div: {
493            var cur = Evaluate(dataset, ref row, state, traceDict);
494            for (int i = 1; i < currentInstr.nArguments; i++) {
495              var op = Evaluate(dataset, ref row, state, traceDict);
496              cur = ArithmeticApply(cur, op,
497                (lhs, rhs) => ApplyVectorLengthStrategy(DifferentVectorLengthStrategy, lhs, rhs, 1.0),
498                (s1, s2) => s1 / s2,
499                (s1, v2) => s1 / v2,
500                (v1, s2) => v1 / s2,
501                (v1, v2) => v1 / v2);
502            }
503            if (currentInstr.nArguments == 1)
504              cur = FunctionApply(cur,
505                s => 1 / s,
506                v => 1 / v);
507            TraceEvaluation(currentInstr, cur);
508            return cur;
509          }
510        case OpCodes.Absolute: {
511            var cur = Evaluate(dataset, ref row, state, traceDict);
512            cur = FunctionApply(cur, Math.Abs, DoubleVector.Abs);
513            TraceEvaluation(currentInstr, cur);
514            return cur;
515          }
516        case OpCodes.Tanh: {
517            var cur = Evaluate(dataset, ref row, state, traceDict);
518            cur = FunctionApply(cur, Math.Tanh, DoubleVector.Tanh);
519            TraceEvaluation(currentInstr, cur);
520            return cur;
521          }
522        case OpCodes.Cos: {
523            var cur = Evaluate(dataset, ref row, state, traceDict);
524            cur = FunctionApply(cur, Math.Cos, DoubleVector.Cos);
525            TraceEvaluation(currentInstr, cur);
526            return cur;
527          }
528        case OpCodes.Sin: {
529            var cur = Evaluate(dataset, ref row, state, traceDict);
530            cur = FunctionApply(cur, Math.Sin, DoubleVector.Sin);
531            TraceEvaluation(currentInstr, cur);
532            return cur;
533          }
534        case OpCodes.Tan: {
535            var cur = Evaluate(dataset, ref row, state, traceDict);
536            cur = FunctionApply(cur, Math.Tan, DoubleVector.Tan);
537            TraceEvaluation(currentInstr, cur);
538            return cur;
539          }
540        case OpCodes.Square: {
541            var cur = Evaluate(dataset, ref row, state, traceDict);
542            cur = FunctionApply(cur,
543              s => Math.Pow(s, 2),
544              v => v.PointwisePower(2));
545            TraceEvaluation(currentInstr, cur);
546            return cur;
547          }
548        case OpCodes.Cube: {
549            var cur = Evaluate(dataset, ref row, state, traceDict);
550            cur = FunctionApply(cur,
551              s => Math.Pow(s, 3),
552              v => v.PointwisePower(3));
553            TraceEvaluation(currentInstr, cur);
554            return cur;
555          }
556        case OpCodes.Power: {
557            var x = Evaluate(dataset, ref row, state, traceDict);
558            var y = Evaluate(dataset, ref row, state, traceDict);
559            var cur = ArithmeticApply(x, y,
560              (lhs, rhs) => lhs.Count < rhs.Count
561                ? CutLonger(lhs, rhs)
562                : ApplyVectorLengthStrategy(DifferentVectorLengthStrategy, lhs, rhs, 1.0),
563              (s1, s2) => Math.Pow(s1, Math.Round(s2)),
564              (s1, v2) => DoubleVector.Build.Dense(v2.Count, s1).PointwisePower(DoubleVector.Round(v2)),
565              (v1, s2) => v1.PointwisePower(Math.Round(s2)),
566              (v1, v2) => v1.PointwisePower(DoubleVector.Round(v2)));
567            TraceEvaluation(currentInstr, cur);
568            return cur;
569          }
570        case OpCodes.SquareRoot: {
571            var cur = Evaluate(dataset, ref row, state, traceDict);
572            cur = FunctionApply(cur,
573              s => Math.Sqrt(s),
574              v => DoubleVector.Sqrt(v));
575            TraceEvaluation(currentInstr, cur);
576            return cur;
577          }
578        case OpCodes.CubeRoot: {
579            var cur = Evaluate(dataset, ref row, state, traceDict);
580            cur = FunctionApply(cur,
581              s => s < 0 ? -Math.Pow(-s, 1.0 / 3.0) : Math.Pow(s, 1.0 / 3.0),
582              v => v.Map(s => s < 0 ? -Math.Pow(-s, 1.0 / 3.0) : Math.Pow(s, 1.0 / 3.0)));
583            TraceEvaluation(currentInstr, cur);
584            return cur;
585          }
586        case OpCodes.Root: {
587            var x = Evaluate(dataset, ref row, state, traceDict);
588            var y = Evaluate(dataset, ref row, state, traceDict);
589            var cur = ArithmeticApply(x, y,
590              (lhs, rhs) => lhs.Count < rhs.Count
591                ? CutLonger(lhs, rhs)
592                : ApplyVectorLengthStrategy(DifferentVectorLengthStrategy, lhs, rhs, 1.0),
593              (s1, s2) => Math.Pow(s1, 1.0 / Math.Round(s2)),
594              (s1, v2) => DoubleVector.Build.Dense(v2.Count, s1).PointwisePower(1.0 / DoubleVector.Round(v2)),
595              (v1, s2) => v1.PointwisePower(1.0 / Math.Round(s2)),
596              (v1, v2) => v1.PointwisePower(1.0 / DoubleVector.Round(v2)));
597            TraceEvaluation(currentInstr, cur);
598            return cur;
599          }
600        case OpCodes.Exp: {
601            var cur = Evaluate(dataset, ref row, state, traceDict);
602            cur = FunctionApply(cur,
603              s => Math.Exp(s),
604              v => DoubleVector.Exp(v));
605            TraceEvaluation(currentInstr, cur);
606            return cur;
607          }
608        case OpCodes.Log: {
609            var cur = Evaluate(dataset, ref row, state, traceDict);
610            cur = FunctionApply(cur,
611              s => Math.Log(s),
612              v => DoubleVector.Log(v));
613            TraceEvaluation(currentInstr, cur);
614            return cur;
615          }
616        case OpCodes.Sum: {
617            var cur = Evaluate(dataset, ref row, state, traceDict);
618            cur = AggregateApply(cur,
619              s => s,
620              v => v.Sum());
621            TraceEvaluation(currentInstr, cur);
622            return cur;
623          }
624        case OpCodes.Mean: {
625            var cur = Evaluate(dataset, ref row, state, traceDict);
626            cur = AggregateApply(cur,
627              s => s,
628              v => Statistics.Mean(v));
629            TraceEvaluation(currentInstr, cur);
630            return cur;
631          }
632        case OpCodes.StandardDeviation: {
633            var cur = Evaluate(dataset, ref row, state, traceDict);
634            cur = AggregateApply(cur,
635              s => 0,
636              v => Statistics.PopulationStandardDeviation(v));
637            TraceEvaluation(currentInstr, cur);
638            return cur;
639          }
640        case OpCodes.Length: {
641            var cur = Evaluate(dataset, ref row, state, traceDict);
642            cur = AggregateApply(cur,
643              s => 1,
644              v => v.Count);
645            TraceEvaluation(currentInstr, cur);
646            return cur;
647          }
648        case OpCodes.Min: {
649            var cur = Evaluate(dataset, ref row, state, traceDict);
650            cur = AggregateApply(cur,
651              s => s,
652              v => Statistics.Minimum(v));
653            TraceEvaluation(currentInstr, cur);
654            return cur;
655          }
656        case OpCodes.Max: {
657            var cur = Evaluate(dataset, ref row, state, traceDict);
658            cur = AggregateApply(cur,
659              s => s,
660              v => Statistics.Maximum(v));
661            TraceEvaluation(currentInstr, cur);
662            return cur;
663          }
664        case OpCodes.Variance: {
665            var cur = Evaluate(dataset, ref row, state, traceDict);
666            cur = AggregateApply(cur,
667              s => 0,
668              v => Statistics.PopulationVariance(v));
669            TraceEvaluation(currentInstr, cur);
670            return cur;
671          }
672        case OpCodes.Skewness: {
673            var cur = Evaluate(dataset, ref row, state, traceDict);
674            cur = AggregateApply(cur,
675              s => double.NaN,
676              v => Statistics.PopulationSkewness(v));
677            TraceEvaluation(currentInstr, cur);
678            return cur;
679          }
680        case OpCodes.Kurtosis: {
681            var cur = Evaluate(dataset, ref row, state, traceDict);
682            cur = AggregateApply(cur,
683              s => double.NaN,
684              v => Statistics.PopulationKurtosis(v));
685            TraceEvaluation(currentInstr, cur);
686            return cur;
687          }
688        case OpCodes.EuclideanDistance: {
689            var x1 = Evaluate(dataset, ref row, state, traceDict);
690            var x2 = Evaluate(dataset, ref row, state, traceDict);
691            var cur = AggregateMultipleApply(x1, x2,
692              (lhs, rhs) => ApplyVectorLengthStrategy(DifferentVectorLengthStrategy, lhs, rhs, 0.0),
693              (s1, s2) => s1 - s2,
694              (s1, v2) => Math.Sqrt((s1 - v2).PointwisePower(2).Sum()),
695              (v1, s2) => Math.Sqrt((v1 - s2).PointwisePower(2).Sum()),
696              (v1, v2) => Math.Sqrt((v1 - v2).PointwisePower(2).Sum()));
697            TraceEvaluation(currentInstr, cur);
698            return cur;
699          }
700        case OpCodes.Covariance: {
701            var x1 = Evaluate(dataset, ref row, state, traceDict);
702            var x2 = Evaluate(dataset, ref row, state, traceDict);
703            var cur = AggregateMultipleApply(x1, x2,
704              (lhs, rhs) => ApplyVectorLengthStrategy(DifferentVectorLengthStrategy, lhs, rhs, 0.0),
705              (s1, s2) => 0,
706              (s1, v2) => 0,
707              (v1, s2) => 0,
708              (v1, v2) => Statistics.PopulationCovariance(v1, v2));
709            TraceEvaluation(currentInstr, cur);
710            return cur;
711          }
712        case OpCodes.SubVector: {
713            var cur = Evaluate(dataset, ref row, state, traceDict);
714            return WindowedFunctionApply(cur, (WindowedSymbolTreeNode)currentInstr.dynamicNode,
715              s => s,
716              v => v);
717          }
718        case OpCodes.Variable: {
719            if (row < 0 || row >= dataset.Rows) return EvaluationResult.NaN;
720            var variableTreeNode = (VariableTreeNode)currentInstr.dynamicNode;
721            if (currentInstr.data is IList<double> doubleList) {
722              var cur = new EvaluationResult(doubleList[row] * variableTreeNode.Weight);
723              TraceEvaluation(currentInstr, cur);
724              return cur;
725            }
726            if (currentInstr.data is IList<DoubleVector> doubleVectorList) {
727              var cur = new EvaluationResult(doubleVectorList[row] * variableTreeNode.Weight);
728              TraceEvaluation(currentInstr, cur);
729              return cur;
730            }
731            throw new NotSupportedException($"Unsupported type of variable: {currentInstr.data.GetType().GetPrettyName()}");
732          }
733        case OpCodes.BinaryFactorVariable: {
734            if (row < 0 || row >= dataset.Rows) return EvaluationResult.NaN;
735            var factorVarTreeNode = currentInstr.dynamicNode as BinaryFactorVariableTreeNode;
736            var cur = new EvaluationResult(((IList<string>)currentInstr.data)[row] == factorVarTreeNode.VariableValue ? factorVarTreeNode.Weight : 0);
737            TraceEvaluation(currentInstr, cur);
738            return cur;
739          }
740        case OpCodes.FactorVariable: {
741            if (row < 0 || row >= dataset.Rows) return EvaluationResult.NaN;
742            var factorVarTreeNode = currentInstr.dynamicNode as FactorVariableTreeNode;
743            var cur = new EvaluationResult(factorVarTreeNode.GetValue(((IList<string>)currentInstr.data)[row]));
744            TraceEvaluation(currentInstr, cur);
745            return cur;
746          }
747        case OpCodes.Constant: {
748            var constTreeNode = (ConstantTreeNode)currentInstr.dynamicNode;
749            var cur = new EvaluationResult(constTreeNode.Value);
750            TraceEvaluation(currentInstr, cur);
751            return cur;
752          }
753
754        default:
755          throw new NotSupportedException($"Unsupported OpCode: {currentInstr.opCode}");
756      }
757    }
758  }
759}
Note: See TracBrowser for help on using the repository browser.