Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeVectorInterpreter.cs @ 17721

Last change on this file since 17721 was 17721, checked in by pfleck, 4 years ago

#3040 First draft of different-vector-length strategies (cut, fill, resample, cycle, ...)

File size: 29.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Parameters;
30using HEAL.Attic;
31using MathNet.Numerics;
32using MathNet.Numerics.Statistics;
33using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector<double>;
34
35namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
36  [StorableType("DE68A1D9-5AFC-4DDD-AB62-29F3B8FC28E0")]
37  [Item("SymbolicDataAnalysisExpressionTreeVectorInterpreter", "Interpreter for symbolic expression trees including vector arithmetic.")]
38  public class SymbolicDataAnalysisExpressionTreeVectorInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
39    [StorableType("2612504E-AD5F-4AE2-B60E-98A5AB59E164")]
40    public enum Aggregation {
41      Mean,
42      Median,
43      Sum,
44      First,
45      L1Norm,
46      L2Norm,
47      NaN,
48      Exception
49    }
50    public static double Aggregate(Aggregation aggregation, DoubleVector vector) {
51      switch (aggregation) {
52        case Aggregation.Mean: return Statistics.Mean(vector);
53        case Aggregation.Median: return Statistics.Median(vector);
54        case Aggregation.Sum: return vector.Sum();
55        case Aggregation.First: return vector.First();
56        case Aggregation.L1Norm: return vector.L1Norm();
57        case Aggregation.L2Norm: return vector.L2Norm();
58        case Aggregation.NaN: return double.NaN;
59        case Aggregation.Exception: throw new InvalidOperationException("Result of the tree is not a scalar.");
60        default: throw new ArgumentOutOfRangeException(nameof(aggregation), aggregation, null);
61      }
62    }
63
64    [StorableType("73DCBB45-916F-4139-8ADC-57BA610A1B66")]
65    public enum VectorLengthStrategy {
66      ExceptionIfDifferent,
67      FillShorterWithNaN,
68      FillShorterWithNeutralElement,
69      CutLonger,
70      ResampleToLonger,
71      ResampleToShorter,
72      CycleShorter
73    }
74
75    #region Implementation VectorLengthStrategy
76    public static (DoubleVector, DoubleVector) ExceptionIfDifferent(DoubleVector lhs, DoubleVector rhs) {
77      if (lhs.Count != rhs.Count)
78        throw new InvalidOperationException($"Vector Lengths incompatible ({lhs.Count} vs. {rhs.Count}");
79      return (lhs, rhs);
80    }
81
82    public static (DoubleVector, DoubleVector) FillShorter(DoubleVector lhs, DoubleVector rhs, double fillElement) {
83      var targetLength = Math.Max(lhs.Count, rhs.Count);
84
85      DoubleVector PadVector(DoubleVector v) {
86        if (v.Count == targetLength) return v;
87        var p = DoubleVector.Build.Dense(targetLength, fillElement);
88        v.CopySubVectorTo(p, 0, 0, v.Count);
89        return p;
90      }
91
92      return (PadVector(lhs), PadVector(rhs));
93    }
94
95    public static (DoubleVector, DoubleVector) CutLonger(DoubleVector lhs, DoubleVector rhs) {
96      var targetLength = Math.Min(lhs.Count, rhs.Count);
97
98      DoubleVector CutVector(DoubleVector v) {
99        if (v.Count == targetLength) return v;
100        return v.SubVector(0, targetLength);
101      }
102
103      return (CutVector(lhs), CutVector(rhs));
104    }
105
106    private static DoubleVector ResampleToLength(DoubleVector v, int targetLength) {
107      if (v.Count == targetLength) return v;
108
109      var indices = Enumerable.Range(0, v.Count).Select(x => (double)x);
110      var interpolation = Interpolate.Linear(indices, v);
111
112      var resampledIndices = Enumerable.Range(0, targetLength).Select(i => (double)i / targetLength * v.Count);
113      var interpolatedValues = resampledIndices.Select(interpolation.Interpolate);
114
115      return DoubleVector.Build.DenseOfEnumerable(interpolatedValues);
116    }
117    public static (DoubleVector, DoubleVector) ResampleToLonger(DoubleVector lhs, DoubleVector rhs) {
118      var maxLength = Math.Max(lhs.Count, rhs.Count);
119      return (ResampleToLength(lhs, maxLength), ResampleToLength(rhs, maxLength));
120    }
121    public static (DoubleVector, DoubleVector) ResampleToShorter(DoubleVector lhs, DoubleVector rhs) {
122      var minLength = Math.Min(lhs.Count, rhs.Count);
123      return (ResampleToLength(lhs, minLength), ResampleToLength(rhs, minLength));
124    }
125
126    public static (DoubleVector, DoubleVector) CycleShorter(DoubleVector lhs, DoubleVector rhs) {
127      var targetLength = Math.Max(lhs.Count, rhs.Count);
128
129      DoubleVector CycleVector(DoubleVector v) {
130        if (v.Count == targetLength) return v;
131        var cycledValues = Enumerable.Range(0, targetLength).Select(i => v[i % v.Count]);
132        return DoubleVector.Build.DenseOfEnumerable(cycledValues);
133      }
134
135      return (CycleVector(lhs), CycleVector(rhs));
136    }
137    #endregion
138
139    public static (DoubleVector lhs, DoubleVector rhs) ApplyVectorLengthStrategy(VectorLengthStrategy strategy, DoubleVector lhs, DoubleVector rhs,
140      double neutralElement = double.NaN) {
141
142      switch (strategy) {
143        case VectorLengthStrategy.ExceptionIfDifferent: return ExceptionIfDifferent(lhs, rhs);
144        case VectorLengthStrategy.FillShorterWithNaN: return FillShorter(lhs, rhs, double.NaN);
145        case VectorLengthStrategy.FillShorterWithNeutralElement: return FillShorter(lhs, rhs, neutralElement);
146        case VectorLengthStrategy.CutLonger: return CutLonger(lhs, rhs);
147        case VectorLengthStrategy.ResampleToLonger: return ResampleToLonger(lhs, rhs);
148        case VectorLengthStrategy.ResampleToShorter: return ResampleToShorter(lhs, rhs);
149        case VectorLengthStrategy.CycleShorter: return CycleShorter(lhs, rhs);
150        default: throw new ArgumentOutOfRangeException(nameof(strategy), strategy, null);
151      }
152    }
153
154    #region Aggregation Symbols
155    private static Type[] AggregationSymbols = new[] {
156      typeof(Sum), typeof(Mean), typeof(Length), typeof(StandardDeviation), typeof(Variance),
157      typeof(EuclideanDistance), typeof(Covariance)
158    };
159    #endregion
160
161    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
162    private const string FinalAggregationParameterName = "FinalAggregation";
163    private const string DifferentVectorLengthStrategyParameterName = "DifferentVectorLengthStrategy";
164
165    public override bool CanChangeName {
166      get { return false; }
167    }
168
169    public override bool CanChangeDescription {
170      get { return false; }
171    }
172
173    #region parameter properties
174    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
175      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
176    }
177    public IFixedValueParameter<EnumValue<Aggregation>> FinalAggregationParameter {
178      get { return (IFixedValueParameter<EnumValue<Aggregation>>)Parameters[FinalAggregationParameterName]; }
179    }
180    public IFixedValueParameter<EnumValue<VectorLengthStrategy>> DifferentVectorLengthStrategyParameter {
181      get { return (IFixedValueParameter<EnumValue<VectorLengthStrategy>>)Parameters[DifferentVectorLengthStrategyParameterName]; }
182    }
183    #endregion
184
185    #region properties
186    public int EvaluatedSolutions {
187      get { return EvaluatedSolutionsParameter.Value.Value; }
188      set { EvaluatedSolutionsParameter.Value.Value = value; }
189    }
190    public Aggregation FinalAggregation {
191      get { return FinalAggregationParameter.Value.Value; }
192      set { FinalAggregationParameter.Value.Value = value; }
193    }
194    public VectorLengthStrategy DifferentVectorLengthStrategy {
195      get { return DifferentVectorLengthStrategyParameter.Value.Value; }
196      set { DifferentVectorLengthStrategyParameter.Value.Value = value; }
197    }
198    #endregion
199
200    [StorableConstructor]
201    protected SymbolicDataAnalysisExpressionTreeVectorInterpreter(StorableConstructorFlag _) : base(_) { }
202
203    protected SymbolicDataAnalysisExpressionTreeVectorInterpreter(SymbolicDataAnalysisExpressionTreeVectorInterpreter original, Cloner cloner)
204      : base(original, cloner) { }
205
206    public override IDeepCloneable Clone(Cloner cloner) {
207      return new SymbolicDataAnalysisExpressionTreeVectorInterpreter(this, cloner);
208    }
209
210    public SymbolicDataAnalysisExpressionTreeVectorInterpreter()
211      : this("SymbolicDataAnalysisExpressionTreeVectorInterpreter", "Interpreter for symbolic expression trees including vector arithmetic.") {
212    }
213
214    protected SymbolicDataAnalysisExpressionTreeVectorInterpreter(string name, string description)
215      : base(name, description) {
216      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
217      Parameters.Add(new FixedValueParameter<EnumValue<Aggregation>>(FinalAggregationParameterName, "If root node of the expression tree results in a Vector it is aggregated according to this parameter", new EnumValue<Aggregation>(Aggregation.Mean)));
218      Parameters.Add(new FixedValueParameter<EnumValue<VectorLengthStrategy>>(DifferentVectorLengthStrategyParameterName, "", new EnumValue<VectorLengthStrategy>(VectorLengthStrategy.ExceptionIfDifferent)));
219    }
220
221    [StorableHook(HookType.AfterDeserialization)]
222    private void AfterDeserialization() {
223      if (!Parameters.ContainsKey(FinalAggregationParameterName)) {
224        Parameters.Add(new FixedValueParameter<EnumValue<Aggregation>>(FinalAggregationParameterName, "If root node of the expression tree results in a Vector it is aggregated according to this parameter", new EnumValue<Aggregation>(Aggregation.Mean)));
225      }
226      if (!Parameters.ContainsKey(DifferentVectorLengthStrategyParameterName)) {
227        Parameters.Add(new FixedValueParameter<EnumValue<VectorLengthStrategy>>(DifferentVectorLengthStrategyParameterName, "", new EnumValue<VectorLengthStrategy>(VectorLengthStrategy.ExceptionIfDifferent)));
228      }
229    }
230
231    #region IStatefulItem
232    public void InitializeState() {
233      EvaluatedSolutions = 0;
234    }
235
236    public void ClearState() { }
237    #endregion
238
239    private readonly object syncRoot = new object();
240    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
241      lock (syncRoot) {
242        EvaluatedSolutions++; // increment the evaluated solutions counter
243      }
244      var state = PrepareInterpreterState(tree, dataset);
245
246      foreach (var rowEnum in rows) {
247        int row = rowEnum;
248        var result = Evaluate(dataset, ref row, state);
249        if (result.IsScalar)
250          yield return result.Scalar;
251        else if (result.IsVector) {
252          yield return Aggregate(FinalAggregation, result.Vector);
253        } else
254          yield return double.NaN;
255        state.Reset();
256      }
257    }
258
259    private static InterpreterState PrepareInterpreterState(ISymbolicExpressionTree tree, IDataset dataset) {
260      Instruction[] code = SymbolicExpressionTreeCompiler.Compile(tree, OpCodes.MapSymbolToOpCode);
261      int necessaryArgStackSize = 0;
262      foreach (Instruction instr in code) {
263        if (instr.opCode == OpCodes.Variable) {
264          var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
265          if (dataset.VariableHasType<double>(variableTreeNode.VariableName))
266            instr.data = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
267          else if (dataset.VariableHasType<DoubleVector>(variableTreeNode.VariableName))
268            instr.data = dataset.GetReadOnlyDoubleVectorValues(variableTreeNode.VariableName);
269          else throw new NotSupportedException($"Type of variable {variableTreeNode.VariableName} is not supported.");
270        } else if (instr.opCode == OpCodes.FactorVariable) {
271          var factorTreeNode = instr.dynamicNode as FactorVariableTreeNode;
272          instr.data = dataset.GetReadOnlyStringValues(factorTreeNode.VariableName);
273        } else if (instr.opCode == OpCodes.BinaryFactorVariable) {
274          var factorTreeNode = instr.dynamicNode as BinaryFactorVariableTreeNode;
275          instr.data = dataset.GetReadOnlyStringValues(factorTreeNode.VariableName);
276        } else if (instr.opCode == OpCodes.LagVariable) {
277          var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
278          instr.data = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
279        } else if (instr.opCode == OpCodes.VariableCondition) {
280          var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
281          instr.data = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
282        } else if (instr.opCode == OpCodes.Call) {
283          necessaryArgStackSize += instr.nArguments + 1;
284        }
285      }
286      return new InterpreterState(code, necessaryArgStackSize);
287    }
288
289
290    public struct EvaluationResult {
291      public double Scalar { get; }
292      public bool IsScalar => !double.IsNaN(Scalar);
293
294      public DoubleVector Vector { get; }
295      public bool IsVector => !(Vector.Count == 1 && double.IsNaN(Vector[0]));
296
297      public bool IsNaN => !IsScalar && !IsVector;
298
299      public EvaluationResult(double scalar) {
300        Scalar = scalar;
301        Vector = NaNVector;
302      }
303      public EvaluationResult(DoubleVector vector) {
304        if (vector == null) throw new ArgumentNullException(nameof(vector));
305        Vector = vector;
306        Scalar = double.NaN;
307      }
308
309      public override string ToString() {
310        if (IsScalar) return Scalar.ToString();
311        if (IsVector) return Vector.ToVectorString();
312        return "NaN";
313      }
314
315      private static readonly DoubleVector NaNVector = DoubleVector.Build.Dense(1, double.NaN);
316      public static readonly EvaluationResult NaN = new EvaluationResult(double.NaN);
317    }
318
319    private static EvaluationResult ArithmeticApply(EvaluationResult lhs, EvaluationResult rhs,
320      Func<DoubleVector, DoubleVector, (DoubleVector, DoubleVector)> lengthStrategy,
321      Func<double, double, double> ssFunc = null,
322      Func<double, DoubleVector, DoubleVector> svFunc = null,
323      Func<DoubleVector, double, DoubleVector> vsFunc = null,
324      Func<DoubleVector, DoubleVector, DoubleVector> vvFunc = null) {
325
326      if (lhs.IsScalar && rhs.IsScalar && ssFunc != null) return new EvaluationResult(ssFunc(lhs.Scalar, rhs.Scalar));
327      if (lhs.IsScalar && rhs.IsVector && svFunc != null) return new EvaluationResult(svFunc(lhs.Scalar, rhs.Vector));
328      if (lhs.IsVector && rhs.IsScalar && vsFunc != null) return new EvaluationResult(vsFunc(lhs.Vector, rhs.Scalar));
329      if (lhs.IsVector && rhs.IsVector && vvFunc != null) {
330        if (lhs.Vector.Count == rhs.Vector.Count) {
331          return new EvaluationResult(vvFunc(lhs.Vector, rhs.Vector));
332        } else {
333          var (lhsVector, rhsVector) = lengthStrategy(lhs.Vector, rhs.Vector);
334          return new EvaluationResult(vvFunc(lhsVector, rhsVector));
335        }
336      }
337      return EvaluationResult.NaN;
338    }
339
340    private static EvaluationResult FunctionApply(EvaluationResult val,
341      Func<double, double> sFunc = null,
342      Func<DoubleVector, DoubleVector> vFunc = null) {
343      if (val.IsScalar && sFunc != null) return new EvaluationResult(sFunc(val.Scalar));
344      if (val.IsVector && vFunc != null) return new EvaluationResult(vFunc(val.Vector));
345      return EvaluationResult.NaN;
346    }
347    private static EvaluationResult AggregateApply(EvaluationResult val,
348      Func<double, double> sFunc = null,
349      Func<DoubleVector, double> vFunc = null) {
350      if (val.IsScalar && sFunc != null) return new EvaluationResult(sFunc(val.Scalar));
351      if (val.IsVector && vFunc != null) return new EvaluationResult(vFunc(val.Vector));
352      return EvaluationResult.NaN;
353    }
354
355    private static EvaluationResult AggregateApply(EvaluationResult val, WindowedSymbolTreeNode node,
356      Func<double, double> sFunc = null,
357      Func<DoubleVector, double> vFunc = null) {
358
359      var offset = node.Offset;
360      var length = node.Length;
361
362      DoubleVector SubVector(DoubleVector v) {
363        int index = (int)(offset * v.Count);
364        int count = (int)(length * (v.Count - index));
365        return v.SubVector(index, count);
366      };
367
368      if (val.IsScalar && sFunc != null) return new EvaluationResult(sFunc(val.Scalar));
369      if (val.IsVector && vFunc != null) return new EvaluationResult(vFunc(SubVector(val.Vector)));
370      return EvaluationResult.NaN;
371    }
372    private static EvaluationResult AggregateMultipleApply(EvaluationResult lhs, EvaluationResult rhs,
373      Func<DoubleVector, DoubleVector, (DoubleVector, DoubleVector)> lengthStrategy,
374      Func<double, double, double> ssFunc = null,
375      Func<double, DoubleVector, double> svFunc = null,
376      Func<DoubleVector, double, double> vsFunc = null,
377      Func<DoubleVector, DoubleVector, double> vvFunc = null) {
378      if (lhs.IsScalar && rhs.IsScalar && ssFunc != null) return new EvaluationResult(ssFunc(lhs.Scalar, rhs.Scalar));
379      if (lhs.IsScalar && rhs.IsVector && svFunc != null) return new EvaluationResult(svFunc(lhs.Scalar, rhs.Vector));
380      if (lhs.IsVector && rhs.IsScalar && vsFunc != null) return new EvaluationResult(vsFunc(lhs.Vector, rhs.Scalar));
381      if (lhs.IsVector && rhs.IsVector && vvFunc != null) {
382        if (lhs.Vector.Count == rhs.Vector.Count) {
383          return new EvaluationResult(vvFunc(lhs.Vector, rhs.Vector));
384        } else {
385          var (lhsVector, rhsVector) = lengthStrategy(lhs.Vector, rhs.Vector);
386          return new EvaluationResult(vvFunc(lhsVector, rhsVector));
387        }
388      }
389      return EvaluationResult.NaN;
390    }
391
392    public virtual Type GetNodeType(ISymbolicExpressionTreeNode node) {
393      if (node.DataType != null)
394        return node.DataType;
395
396      if (AggregationSymbols.Contains(node.Symbol.GetType()))
397        return typeof(double);
398
399      var argumentTypes = node.Subtrees.Select(GetNodeType);
400      if (argumentTypes.Any(t => t == typeof(DoubleVector)))
401        return typeof(DoubleVector);
402
403      return typeof(double);
404    }
405
406    public virtual EvaluationResult Evaluate(IDataset dataset, ref int row, InterpreterState state) {
407      Instruction currentInstr = state.NextInstruction();
408      switch (currentInstr.opCode) {
409        case OpCodes.Add: {
410            var cur = Evaluate(dataset, ref row, state);
411            for (int i = 1; i < currentInstr.nArguments; i++) {
412              var op = Evaluate(dataset, ref row, state);
413              cur = ArithmeticApply(cur, op,
414                (lhs, rhs) => ApplyVectorLengthStrategy(DifferentVectorLengthStrategy, lhs, rhs, 0.0),
415                (s1, s2) => s1 + s2,
416                (s1, v2) => s1 + v2,
417                (v1, s2) => v1 + s2,
418                (v1, v2) => v1 + v2);
419            }
420            return cur;
421          }
422        case OpCodes.Sub: {
423            var cur = Evaluate(dataset, ref row, state);
424            for (int i = 1; i < currentInstr.nArguments; i++) {
425              var op = Evaluate(dataset, ref row, state);
426              cur = ArithmeticApply(cur, op,
427                (lhs, rhs) => ApplyVectorLengthStrategy(DifferentVectorLengthStrategy, lhs, rhs, 0.0),
428                (s1, s2) => s1 - s2,
429                (s1, v2) => s1 - v2,
430                (v1, s2) => v1 - s2,
431                (v1, v2) => v1 - v2);
432            }
433            return cur;
434          }
435        case OpCodes.Mul: {
436            var cur = Evaluate(dataset, ref row, state);
437            for (int i = 1; i < currentInstr.nArguments; i++) {
438              var op = Evaluate(dataset, ref row, state);
439              cur = ArithmeticApply(cur, op,
440                (lhs, rhs) => ApplyVectorLengthStrategy(DifferentVectorLengthStrategy, lhs, rhs, 1.0),
441                (s1, s2) => s1 * s2,
442                (s1, v2) => s1 * v2,
443                (v1, s2) => v1 * s2,
444                (v1, v2) => v1.PointwiseMultiply(v2));
445            }
446            return cur;
447          }
448        case OpCodes.Div: {
449            var cur = Evaluate(dataset, ref row, state);
450            for (int i = 1; i < currentInstr.nArguments; i++) {
451              var op = Evaluate(dataset, ref row, state);
452              cur = ArithmeticApply(cur, op,
453                (lhs, rhs) => ApplyVectorLengthStrategy(DifferentVectorLengthStrategy, lhs, rhs, 1.0),
454                (s1, s2) => s1 / s2,
455                (s1, v2) => s1 / v2,
456                (v1, s2) => v1 / s2,
457                (v1, v2) => v1 / v2);
458            }
459            return cur;
460          }
461        case OpCodes.Absolute: {
462            var cur = Evaluate(dataset, ref row, state);
463            return FunctionApply(cur, Math.Abs, DoubleVector.Abs);
464          }
465        case OpCodes.Tanh: {
466            var cur = Evaluate(dataset, ref row, state);
467            return FunctionApply(cur, Math.Tanh, DoubleVector.Tanh);
468          }
469        case OpCodes.Cos: {
470            var cur = Evaluate(dataset, ref row, state);
471            return FunctionApply(cur, Math.Cos, DoubleVector.Cos);
472          }
473        case OpCodes.Sin: {
474            var cur = Evaluate(dataset, ref row, state);
475            return FunctionApply(cur, Math.Sin, DoubleVector.Sin);
476          }
477        case OpCodes.Tan: {
478            var cur = Evaluate(dataset, ref row, state);
479            return FunctionApply(cur, Math.Tan, DoubleVector.Tan);
480          }
481        case OpCodes.Square: {
482            var cur = Evaluate(dataset, ref row, state);
483            return FunctionApply(cur,
484              s => Math.Pow(s, 2),
485              v => v.PointwisePower(2));
486          }
487        case OpCodes.Cube: {
488            var cur = Evaluate(dataset, ref row, state);
489            return FunctionApply(cur,
490              s => Math.Pow(s, 3),
491              v => v.PointwisePower(3));
492          }
493        case OpCodes.Power: {
494            var x = Evaluate(dataset, ref row, state);
495            var y = Evaluate(dataset, ref row, state);
496            return ArithmeticApply(x, y,
497              (lhs, rhs) => lhs.Count < rhs.Count
498                ? CutLonger(lhs, rhs)
499                : ApplyVectorLengthStrategy(DifferentVectorLengthStrategy, lhs, rhs, 1.0),
500              (s1, s2) => Math.Pow(s1, Math.Round(s2)),
501              (s1, v2) => DoubleVector.Build.Dense(v2.Count, s1).PointwisePower(DoubleVector.Round(v2)),
502              (v1, s2) => v1.PointwisePower(Math.Round(s2)),
503              (v1, v2) => v1.PointwisePower(DoubleVector.Round(v2)));
504          }
505        case OpCodes.SquareRoot: {
506            var cur = Evaluate(dataset, ref row, state);
507            return FunctionApply(cur,
508              s => Math.Sqrt(s),
509              v => DoubleVector.Sqrt(v));
510          }
511        case OpCodes.CubeRoot: {
512            var cur = Evaluate(dataset, ref row, state);
513            return FunctionApply(cur,
514              s => s < 0 ? -Math.Pow(-s, 1.0 / 3.0) : Math.Pow(s, 1.0 / 3.0),
515              v => v.Map(s => s < 0 ? -Math.Pow(-s, 1.0 / 3.0) : Math.Pow(s, 1.0 / 3.0)));
516          }
517        case OpCodes.Root: {
518            var x = Evaluate(dataset, ref row, state);
519            var y = Evaluate(dataset, ref row, state);
520            return ArithmeticApply(x, y,
521              (lhs, rhs) => lhs.Count < rhs.Count
522                ? CutLonger(lhs, rhs)
523                : ApplyVectorLengthStrategy(DifferentVectorLengthStrategy, lhs, rhs, 1.0),
524              (s1, s2) => Math.Pow(s1, 1.0 / Math.Round(s2)),
525              (s1, v2) => DoubleVector.Build.Dense(v2.Count, s1).PointwisePower(1.0 / DoubleVector.Round(v2)),
526              (v1, s2) => v1.PointwisePower(1.0 / Math.Round(s2)),
527              (v1, v2) => v1.PointwisePower(1.0 / DoubleVector.Round(v2)));
528          }
529        case OpCodes.Exp: {
530            var cur = Evaluate(dataset, ref row, state);
531            return FunctionApply(cur,
532              s => Math.Exp(s),
533              v => DoubleVector.Exp(v));
534          }
535        case OpCodes.Log: {
536            var cur = Evaluate(dataset, ref row, state);
537            return FunctionApply(cur,
538              s => Math.Log(s),
539              v => DoubleVector.Log(v));
540          }
541        case OpCodes.Sum: {
542            var cur = Evaluate(dataset, ref row, state);
543            return AggregateApply(cur, (WindowedSymbolTreeNode)currentInstr.dynamicNode,
544              s => s,
545              v => v.Sum());
546          }
547        case OpCodes.Mean: {
548            var cur = Evaluate(dataset, ref row, state);
549            return AggregateApply(cur,
550              s => s,
551              v => Statistics.Mean(v));
552          }
553        case OpCodes.StandardDeviation: {
554            var cur = Evaluate(dataset, ref row, state);
555            return AggregateApply(cur,
556              s => 0,
557              v => Statistics.PopulationStandardDeviation(v));
558          }
559        case OpCodes.Length: {
560            var cur = Evaluate(dataset, ref row, state);
561            return AggregateApply(cur,
562              s => 1,
563              v => v.Count);
564          }
565        case OpCodes.Min: {
566            var cur = Evaluate(dataset, ref row, state);
567            return AggregateApply(cur,
568              s => s,
569              v => Statistics.Minimum(v));
570          }
571        case OpCodes.Max: {
572            var cur = Evaluate(dataset, ref row, state);
573            return AggregateApply(cur,
574              s => s,
575              v => Statistics.Maximum(v));
576          }
577        case OpCodes.Variance: {
578            var cur = Evaluate(dataset, ref row, state);
579            return AggregateApply(cur,
580              s => 0,
581              v => Statistics.PopulationVariance(v));
582          }
583        case OpCodes.Skewness: {
584            var cur = Evaluate(dataset, ref row, state);
585            return AggregateApply(cur,
586              s => double.NaN,
587              v => Statistics.PopulationSkewness(v));
588          }
589        case OpCodes.Kurtosis: {
590            var cur = Evaluate(dataset, ref row, state);
591            return AggregateApply(cur,
592              s => double.NaN,
593              v => Statistics.PopulationKurtosis(v));
594          }
595        case OpCodes.EuclideanDistance: {
596            var x1 = Evaluate(dataset, ref row, state);
597            var x2 = Evaluate(dataset, ref row, state);
598            return AggregateMultipleApply(x1, x2,
599              (lhs, rhs) => ApplyVectorLengthStrategy(DifferentVectorLengthStrategy, lhs, rhs, 0.0),
600              (s1, s2) => s1 - s2,
601              (s1, v2) => Math.Sqrt((s1 - v2).PointwisePower(2).Sum()),
602              (v1, s2) => Math.Sqrt((v1 - s2).PointwisePower(2).Sum()),
603              (v1, v2) => Math.Sqrt((v1 - v2).PointwisePower(2).Sum()));
604          }
605        case OpCodes.Covariance: {
606            var x1 = Evaluate(dataset, ref row, state);
607            var x2 = Evaluate(dataset, ref row, state);
608            return AggregateMultipleApply(x1, x2,
609              (lhs, rhs) => ApplyVectorLengthStrategy(DifferentVectorLengthStrategy, lhs, rhs, 0.0),
610              (s1, s2) => 0,
611              (s1, v2) => 0,
612              (v1, s2) => 0,
613              (v1, v2) => Statistics.PopulationCovariance(v1, v2));
614          }
615        case OpCodes.Variable: {
616            if (row < 0 || row >= dataset.Rows) return EvaluationResult.NaN;
617            var variableTreeNode = (VariableTreeNode)currentInstr.dynamicNode;
618            if (currentInstr.data is IList<double> doubleList)
619              return new EvaluationResult(doubleList[row] * variableTreeNode.Weight);
620            if (currentInstr.data is IList<DoubleVector> doubleVectorList)
621              return new EvaluationResult(doubleVectorList[row] * variableTreeNode.Weight);
622            throw new NotSupportedException($"Unsupported type of variable: {currentInstr.data.GetType().GetPrettyName()}");
623          }
624        case OpCodes.BinaryFactorVariable: {
625            if (row < 0 || row >= dataset.Rows) return EvaluationResult.NaN;
626            var factorVarTreeNode = currentInstr.dynamicNode as BinaryFactorVariableTreeNode;
627            return new EvaluationResult(((IList<string>)currentInstr.data)[row] == factorVarTreeNode.VariableValue ? factorVarTreeNode.Weight : 0);
628          }
629        case OpCodes.FactorVariable: {
630            if (row < 0 || row >= dataset.Rows) return EvaluationResult.NaN;
631            var factorVarTreeNode = currentInstr.dynamicNode as FactorVariableTreeNode;
632            return new EvaluationResult(factorVarTreeNode.GetValue(((IList<string>)currentInstr.data)[row]));
633          }
634        case OpCodes.Constant: {
635            var constTreeNode = (ConstantTreeNode)currentInstr.dynamicNode;
636            return new EvaluationResult(constTreeNode.Value);
637          }
638
639        default:
640          throw new NotSupportedException($"Unsupported OpCode: {currentInstr.opCode}");
641      }
642    }
643  }
644}
Note: See TracBrowser for help on using the repository browser.