1 | #region License Information
|
---|
2 | /* HeuristicLab
|
---|
3 | * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
|
---|
4 | *
|
---|
5 | * This file is part of HeuristicLab.
|
---|
6 | *
|
---|
7 | * HeuristicLab is free software: you can redistribute it and/or modify
|
---|
8 | * it under the terms of the GNU General Public License as published by
|
---|
9 | * the Free Software Foundation, either version 3 of the License, or
|
---|
10 | * (at your option) any later version.
|
---|
11 | *
|
---|
12 | * HeuristicLab is distributed in the hope that it will be useful,
|
---|
13 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
|
---|
14 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
---|
15 | * GNU General Public License for more details.
|
---|
16 | *
|
---|
17 | * You should have received a copy of the GNU General Public License
|
---|
18 | * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
|
---|
19 | */
|
---|
20 | #endregion
|
---|
21 |
|
---|
22 | using System;
|
---|
23 | using System.Collections.Generic;
|
---|
24 | using System.Linq;
|
---|
25 | using HeuristicLab.Common;
|
---|
26 | using HeuristicLab.Core;
|
---|
27 | using HeuristicLab.Data;
|
---|
28 | using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
|
---|
29 | using HeuristicLab.Parameters;
|
---|
30 | using HEAL.Attic;
|
---|
31 | using MathNet.Numerics.Statistics;
|
---|
32 |
|
---|
33 | using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector<double>;
|
---|
34 |
|
---|
35 | namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
|
---|
36 | [StorableType("DE68A1D9-5AFC-4DDD-AB62-29F3B8FC28E0")]
|
---|
37 | [Item("SymbolicDataAnalysisExpressionTreeVectorInterpreter", "Interpreter for symbolic expression trees including vector arithmetic.")]
|
---|
38 | public class SymbolicDataAnalysisExpressionTreeVectorInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
|
---|
39 | [StorableType("2612504E-AD5F-4AE2-B60E-98A5AB59E164")]
|
---|
40 | public enum Aggregation {
|
---|
41 | Mean,
|
---|
42 | Median,
|
---|
43 | Sum,
|
---|
44 | First,
|
---|
45 | NaN,
|
---|
46 | Exception
|
---|
47 | }
|
---|
48 |
|
---|
49 | private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
|
---|
50 | private const string FinalAggregationParameterName = "FinalAggregation";
|
---|
51 |
|
---|
52 | public override bool CanChangeName {
|
---|
53 | get { return false; }
|
---|
54 | }
|
---|
55 |
|
---|
56 | public override bool CanChangeDescription {
|
---|
57 | get { return false; }
|
---|
58 | }
|
---|
59 |
|
---|
60 | #region parameter properties
|
---|
61 | public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
|
---|
62 | get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
|
---|
63 | }
|
---|
64 | public IFixedValueParameter<EnumValue<Aggregation>> FinalAggregationParameter {
|
---|
65 | get { return (IFixedValueParameter<EnumValue<Aggregation>>)Parameters[FinalAggregationParameterName]; }
|
---|
66 | }
|
---|
67 | #endregion
|
---|
68 |
|
---|
69 | #region properties
|
---|
70 | public int EvaluatedSolutions {
|
---|
71 | get { return EvaluatedSolutionsParameter.Value.Value; }
|
---|
72 | set { EvaluatedSolutionsParameter.Value.Value = value; }
|
---|
73 | }
|
---|
74 | public Aggregation FinalAggregation {
|
---|
75 | get { return FinalAggregationParameter.Value.Value; }
|
---|
76 | set { FinalAggregationParameter.Value.Value = value; }
|
---|
77 | }
|
---|
78 | #endregion
|
---|
79 |
|
---|
80 | [StorableConstructor]
|
---|
81 | protected SymbolicDataAnalysisExpressionTreeVectorInterpreter(StorableConstructorFlag _) : base(_) { }
|
---|
82 |
|
---|
83 | protected SymbolicDataAnalysisExpressionTreeVectorInterpreter(SymbolicDataAnalysisExpressionTreeVectorInterpreter original, Cloner cloner)
|
---|
84 | : base(original, cloner) { }
|
---|
85 |
|
---|
86 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
87 | return new SymbolicDataAnalysisExpressionTreeVectorInterpreter(this, cloner);
|
---|
88 | }
|
---|
89 |
|
---|
90 | public SymbolicDataAnalysisExpressionTreeVectorInterpreter()
|
---|
91 | : this("SymbolicDataAnalysisExpressionTreeVectorInterpreter", "Interpreter for symbolic expression trees including vector arithmetic.") {
|
---|
92 | }
|
---|
93 |
|
---|
94 | protected SymbolicDataAnalysisExpressionTreeVectorInterpreter(string name, string description)
|
---|
95 | : base(name, description) {
|
---|
96 | Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
|
---|
97 | Parameters.Add(new FixedValueParameter<EnumValue<Aggregation>>(FinalAggregationParameterName, "If root node of the expression tree results in a Vector it is aggregated according to this parameter", new EnumValue<Aggregation>(Aggregation.Mean)));
|
---|
98 | }
|
---|
99 |
|
---|
100 | [StorableHook(HookType.AfterDeserialization)]
|
---|
101 | private void AfterDeserialization() {
|
---|
102 | if (!Parameters.ContainsKey(FinalAggregationParameterName)) {
|
---|
103 | Parameters.Add(new FixedValueParameter<EnumValue<Aggregation>>(FinalAggregationParameterName, "If root node of the expression tree results in a Vector it is aggregated according to this parameter", new EnumValue<Aggregation>(Aggregation.Mean)));
|
---|
104 | }
|
---|
105 | }
|
---|
106 |
|
---|
107 | #region IStatefulItem
|
---|
108 | public void InitializeState() {
|
---|
109 | EvaluatedSolutions = 0;
|
---|
110 | }
|
---|
111 |
|
---|
112 | public void ClearState() { }
|
---|
113 | #endregion
|
---|
114 |
|
---|
115 | private readonly object syncRoot = new object();
|
---|
116 | public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
|
---|
117 | lock (syncRoot) {
|
---|
118 | EvaluatedSolutions++; // increment the evaluated solutions counter
|
---|
119 | }
|
---|
120 | var state = PrepareInterpreterState(tree, dataset);
|
---|
121 |
|
---|
122 | foreach (var rowEnum in rows) {
|
---|
123 | int row = rowEnum;
|
---|
124 | var result = Evaluate(dataset, ref row, state);
|
---|
125 | if (result.IsScalar)
|
---|
126 | yield return result.Scalar;
|
---|
127 | else if (result.IsVector) {
|
---|
128 | if (FinalAggregation == Aggregation.Mean) yield return result.Vector.Mean();
|
---|
129 | else if (FinalAggregation == Aggregation.Median) yield return Statistics.Median(result.Vector);
|
---|
130 | else if (FinalAggregation == Aggregation.Sum) yield return result.Vector.Sum();
|
---|
131 | else if (FinalAggregation == Aggregation.First) yield return result.Vector.First();
|
---|
132 | else if (FinalAggregation == Aggregation.Exception) throw new InvalidOperationException("Result of the tree is not a scalar.");
|
---|
133 | else yield return double.NaN;
|
---|
134 | } else
|
---|
135 | yield return double.NaN;
|
---|
136 | state.Reset();
|
---|
137 | }
|
---|
138 | }
|
---|
139 |
|
---|
140 | private static InterpreterState PrepareInterpreterState(ISymbolicExpressionTree tree, IDataset dataset) {
|
---|
141 | Instruction[] code = SymbolicExpressionTreeCompiler.Compile(tree, OpCodes.MapSymbolToOpCode);
|
---|
142 | int necessaryArgStackSize = 0;
|
---|
143 | foreach (Instruction instr in code) {
|
---|
144 | if (instr.opCode == OpCodes.Variable) {
|
---|
145 | var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
|
---|
146 | if (dataset.VariableHasType<double>(variableTreeNode.VariableName))
|
---|
147 | instr.data = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
|
---|
148 | else if (dataset.VariableHasType<DoubleVector>(variableTreeNode.VariableName))
|
---|
149 | instr.data = dataset.GetReadOnlyDoubleVectorValues(variableTreeNode.VariableName);
|
---|
150 | else throw new NotSupportedException($"Type of variable {variableTreeNode.VariableName} is not supported.");
|
---|
151 | } else if (instr.opCode == OpCodes.FactorVariable) {
|
---|
152 | var factorTreeNode = instr.dynamicNode as FactorVariableTreeNode;
|
---|
153 | instr.data = dataset.GetReadOnlyStringValues(factorTreeNode.VariableName);
|
---|
154 | } else if (instr.opCode == OpCodes.BinaryFactorVariable) {
|
---|
155 | var factorTreeNode = instr.dynamicNode as BinaryFactorVariableTreeNode;
|
---|
156 | instr.data = dataset.GetReadOnlyStringValues(factorTreeNode.VariableName);
|
---|
157 | } else if (instr.opCode == OpCodes.LagVariable) {
|
---|
158 | var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
|
---|
159 | instr.data = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
|
---|
160 | } else if (instr.opCode == OpCodes.VariableCondition) {
|
---|
161 | var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
|
---|
162 | instr.data = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
|
---|
163 | } else if (instr.opCode == OpCodes.Call) {
|
---|
164 | necessaryArgStackSize += instr.nArguments + 1;
|
---|
165 | }
|
---|
166 | }
|
---|
167 | return new InterpreterState(code, necessaryArgStackSize);
|
---|
168 | }
|
---|
169 |
|
---|
170 |
|
---|
171 | public struct EvaluationResult {
|
---|
172 | public double Scalar { get; }
|
---|
173 | public bool IsScalar => !double.IsNaN(Scalar);
|
---|
174 |
|
---|
175 | public DoubleVector Vector { get; }
|
---|
176 | public bool IsVector => !(Vector.Count == 1 && double.IsNaN(Vector[0]));
|
---|
177 |
|
---|
178 | public bool IsNaN => !IsScalar && !IsVector;
|
---|
179 |
|
---|
180 | public EvaluationResult(double scalar) {
|
---|
181 | Scalar = scalar;
|
---|
182 | Vector = NaNVector;
|
---|
183 | }
|
---|
184 | public EvaluationResult(DoubleVector vector) {
|
---|
185 | if (vector == null) throw new ArgumentNullException(nameof(vector));
|
---|
186 | Vector = vector;
|
---|
187 | Scalar = double.NaN;
|
---|
188 | }
|
---|
189 |
|
---|
190 | public override string ToString() {
|
---|
191 | if (IsScalar) return Scalar.ToString();
|
---|
192 | if (IsVector) return Vector.ToVectorString();
|
---|
193 | return "NaN";
|
---|
194 | }
|
---|
195 |
|
---|
196 | private static readonly DoubleVector NaNVector = DoubleVector.Build.Dense(1, double.NaN);
|
---|
197 | public static readonly EvaluationResult NaN = new EvaluationResult(double.NaN);
|
---|
198 | }
|
---|
199 |
|
---|
200 | private static EvaluationResult ArithmeticApply(EvaluationResult lhs, EvaluationResult rhs,
|
---|
201 | Func<double, double, double> ssFunc = null,
|
---|
202 | Func<double, DoubleVector, DoubleVector> svFunc = null,
|
---|
203 | Func<DoubleVector, double, DoubleVector> vsFunc = null,
|
---|
204 | Func<DoubleVector, DoubleVector, DoubleVector> vvFunc = null) {
|
---|
205 | if (lhs.IsScalar && rhs.IsScalar && ssFunc != null) return new EvaluationResult(ssFunc(lhs.Scalar, rhs.Scalar));
|
---|
206 | if (lhs.IsScalar && rhs.IsVector && svFunc != null) return new EvaluationResult(svFunc(lhs.Scalar, rhs.Vector));
|
---|
207 | if (lhs.IsVector && rhs.IsScalar && vsFunc != null) return new EvaluationResult(vsFunc(lhs.Vector, rhs.Scalar));
|
---|
208 | if (lhs.IsVector && rhs.IsVector && vvFunc != null) return new EvaluationResult(vvFunc(lhs.Vector, rhs.Vector));
|
---|
209 | return EvaluationResult.NaN;
|
---|
210 | }
|
---|
211 |
|
---|
212 | private static EvaluationResult FunctionApply(EvaluationResult val,
|
---|
213 | Func<double, double> sFunc = null,
|
---|
214 | Func<DoubleVector, DoubleVector> vFunc = null) {
|
---|
215 | if (val.IsScalar && sFunc != null) return new EvaluationResult(sFunc(val.Scalar));
|
---|
216 | if (val.IsVector && vFunc != null) return new EvaluationResult(vFunc(val.Vector));
|
---|
217 | return EvaluationResult.NaN;
|
---|
218 | }
|
---|
219 | private static EvaluationResult AggregateApply(EvaluationResult val,
|
---|
220 | Func<double, double> sFunc = null,
|
---|
221 | Func<DoubleVector, double> vFunc = null) {
|
---|
222 | if (val.IsScalar && sFunc != null) return new EvaluationResult(sFunc(val.Scalar));
|
---|
223 | if (val.IsVector && vFunc != null) return new EvaluationResult(vFunc(val.Vector));
|
---|
224 | return EvaluationResult.NaN;
|
---|
225 | }
|
---|
226 |
|
---|
227 | private static EvaluationResult AggregateApply(EvaluationResult val, WindowedSymbolTreeNode node,
|
---|
228 | Func<double, double> sFunc = null,
|
---|
229 | Func<DoubleVector, double> vFunc = null) {
|
---|
230 |
|
---|
231 | var offset = node.Offset;
|
---|
232 | var length = node.Length;
|
---|
233 |
|
---|
234 | DoubleVector SubVector(DoubleVector v) {
|
---|
235 | int index = (int)(offset * v.Count);
|
---|
236 | int count = (int)(length * (v.Count - index));
|
---|
237 | return v.SubVector(index, count);
|
---|
238 | };
|
---|
239 |
|
---|
240 | if (val.IsScalar && sFunc != null) return new EvaluationResult(sFunc(val.Scalar));
|
---|
241 | if (val.IsVector && vFunc != null) return new EvaluationResult(vFunc(SubVector(val.Vector)));
|
---|
242 | return EvaluationResult.NaN;
|
---|
243 | }
|
---|
244 | private static EvaluationResult AggregateMultipleApply(EvaluationResult lhs, EvaluationResult rhs,
|
---|
245 | Func<double, double, double> ssFunc = null,
|
---|
246 | Func<double, DoubleVector, double> svFunc = null,
|
---|
247 | Func<DoubleVector, double, double> vsFunc = null,
|
---|
248 | Func<DoubleVector, DoubleVector, double> vvFunc = null) {
|
---|
249 | if (lhs.IsScalar && rhs.IsScalar && ssFunc != null) return new EvaluationResult(ssFunc(lhs.Scalar, rhs.Scalar));
|
---|
250 | if (lhs.IsScalar && rhs.IsVector && svFunc != null) return new EvaluationResult(svFunc(lhs.Scalar, rhs.Vector));
|
---|
251 | if (lhs.IsVector && rhs.IsScalar && vsFunc != null) return new EvaluationResult(vsFunc(lhs.Vector, rhs.Scalar));
|
---|
252 | if (lhs.IsVector && rhs.IsVector && vvFunc != null) return new EvaluationResult(vvFunc(lhs.Vector, rhs.Vector));
|
---|
253 | return EvaluationResult.NaN;
|
---|
254 | }
|
---|
255 |
|
---|
256 | public virtual IEnumerable<EvaluationResult> EvaluateNode(ISymbolicExpressionTreeNode node, IDataset dataset, IEnumerable<int> rows) {
|
---|
257 | //lock (syncRoot) {
|
---|
258 | // EvaluatedSolutions++; // increment the evaluated solutions counter
|
---|
259 | //}
|
---|
260 |
|
---|
261 | var startNode = new StartSymbol().CreateTreeNode();
|
---|
262 | startNode.AddSubtree(node);
|
---|
263 | var programNode = new ProgramRootSymbol().CreateTreeNode();
|
---|
264 | programNode.AddSubtree(startNode);
|
---|
265 | var tree = new SymbolicExpressionTree(programNode);
|
---|
266 |
|
---|
267 | var state = PrepareInterpreterState(tree, dataset);
|
---|
268 |
|
---|
269 | foreach (var rowEnum in rows) {
|
---|
270 | int row = rowEnum;
|
---|
271 | var result = Evaluate(dataset, ref row, state);
|
---|
272 | yield return result;
|
---|
273 | state.Reset();
|
---|
274 | }
|
---|
275 | }
|
---|
276 |
|
---|
277 | public virtual EvaluationResult Evaluate(IDataset dataset, ref int row, InterpreterState state) {
|
---|
278 | Instruction currentInstr = state.NextInstruction();
|
---|
279 | switch (currentInstr.opCode) {
|
---|
280 | case OpCodes.Add: {
|
---|
281 | var cur = Evaluate(dataset, ref row, state);
|
---|
282 | for (int i = 1; i < currentInstr.nArguments; i++) {
|
---|
283 | var op = Evaluate(dataset, ref row, state);
|
---|
284 | cur = ArithmeticApply(cur, op,
|
---|
285 | (s1, s2) => s1 + s2,
|
---|
286 | (s1, v2) => s1 + v2,
|
---|
287 | (v1, s2) => v1 + s2,
|
---|
288 | (v1, v2) => v1 + v2);
|
---|
289 | }
|
---|
290 | return cur;
|
---|
291 | }
|
---|
292 | case OpCodes.Sub: {
|
---|
293 | var cur = Evaluate(dataset, ref row, state);
|
---|
294 | for (int i = 1; i < currentInstr.nArguments; i++) {
|
---|
295 | var op = Evaluate(dataset, ref row, state);
|
---|
296 | cur = ArithmeticApply(cur, op,
|
---|
297 | (s1, s2) => s1 - s2,
|
---|
298 | (s1, v2) => s1 - v2,
|
---|
299 | (v1, s2) => v1 - s2,
|
---|
300 | (v1, v2) => v1 - v2);
|
---|
301 | }
|
---|
302 | return cur;
|
---|
303 | }
|
---|
304 | case OpCodes.Mul: {
|
---|
305 | var cur = Evaluate(dataset, ref row, state);
|
---|
306 | for (int i = 1; i < currentInstr.nArguments; i++) {
|
---|
307 | var op = Evaluate(dataset, ref row, state);
|
---|
308 | cur = ArithmeticApply(cur, op,
|
---|
309 | (s1, s2) => s1 * s2,
|
---|
310 | (s1, v2) => s1 * v2,
|
---|
311 | (v1, s2) => v1 * s2,
|
---|
312 | (v1, v2) => v1.PointwiseMultiply(v2));
|
---|
313 | }
|
---|
314 | return cur;
|
---|
315 | }
|
---|
316 | case OpCodes.Div: {
|
---|
317 | var cur = Evaluate(dataset, ref row, state);
|
---|
318 | for (int i = 1; i < currentInstr.nArguments; i++) {
|
---|
319 | var op = Evaluate(dataset, ref row, state);
|
---|
320 | cur = ArithmeticApply(cur, op,
|
---|
321 | (s1, s2) => s1 / s2,
|
---|
322 | (s1, v2) => s1 / v2,
|
---|
323 | (v1, s2) => v1 / s2,
|
---|
324 | (v1, v2) => v1 / v2);
|
---|
325 | }
|
---|
326 | return cur;
|
---|
327 | }
|
---|
328 | case OpCodes.Absolute: {
|
---|
329 | var cur = Evaluate(dataset, ref row, state);
|
---|
330 | return FunctionApply(cur, Math.Abs, DoubleVector.Abs);
|
---|
331 | }
|
---|
332 | case OpCodes.Tanh: {
|
---|
333 | var cur = Evaluate(dataset, ref row, state);
|
---|
334 | return FunctionApply(cur, Math.Tanh, DoubleVector.Tanh);
|
---|
335 | }
|
---|
336 | case OpCodes.Cos: {
|
---|
337 | var cur = Evaluate(dataset, ref row, state);
|
---|
338 | return FunctionApply(cur, Math.Cos, DoubleVector.Cos);
|
---|
339 | }
|
---|
340 | case OpCodes.Sin: {
|
---|
341 | var cur = Evaluate(dataset, ref row, state);
|
---|
342 | return FunctionApply(cur, Math.Sin, DoubleVector.Sin);
|
---|
343 | }
|
---|
344 | case OpCodes.Tan: {
|
---|
345 | var cur = Evaluate(dataset, ref row, state);
|
---|
346 | return FunctionApply(cur, Math.Tan, DoubleVector.Tan);
|
---|
347 | }
|
---|
348 | case OpCodes.Square: {
|
---|
349 | var cur = Evaluate(dataset, ref row, state);
|
---|
350 | return FunctionApply(cur,
|
---|
351 | s => Math.Pow(s, 2),
|
---|
352 | v => v.PointwisePower(2));
|
---|
353 | }
|
---|
354 | case OpCodes.Cube: {
|
---|
355 | var cur = Evaluate(dataset, ref row, state);
|
---|
356 | return FunctionApply(cur,
|
---|
357 | s => Math.Pow(s, 3),
|
---|
358 | v => v.PointwisePower(3));
|
---|
359 | }
|
---|
360 | case OpCodes.Power: {
|
---|
361 | var x = Evaluate(dataset, ref row, state);
|
---|
362 | var y = Evaluate(dataset, ref row, state);
|
---|
363 | return ArithmeticApply(x, y,
|
---|
364 | (s1, s2) => Math.Pow(s1, Math.Round(s2)),
|
---|
365 | (s1, v2) => DoubleVector.Build.Dense(v2.Count, s1).PointwisePower(DoubleVector.Round(v2)),
|
---|
366 | (v1, s2) => v1.PointwisePower(Math.Round(s2)),
|
---|
367 | (v1, v2) => v1.PointwisePower(DoubleVector.Round(v2)));
|
---|
368 | }
|
---|
369 | case OpCodes.SquareRoot: {
|
---|
370 | var cur = Evaluate(dataset, ref row, state);
|
---|
371 | return FunctionApply(cur,
|
---|
372 | s => Math.Sqrt(s),
|
---|
373 | v => DoubleVector.Sqrt(v));
|
---|
374 | }
|
---|
375 | case OpCodes.CubeRoot: {
|
---|
376 | var cur = Evaluate(dataset, ref row, state);
|
---|
377 | return FunctionApply(cur,
|
---|
378 | s => s < 0 ? -Math.Pow(-s, 1.0 / 3.0) : Math.Pow(s, 1.0 / 3.0),
|
---|
379 | v => v.Map(s => s < 0 ? -Math.Pow(-s, 1.0 / 3.0) : Math.Pow(s, 1.0 / 3.0)));
|
---|
380 | }
|
---|
381 | case OpCodes.Root: {
|
---|
382 | var x = Evaluate(dataset, ref row, state);
|
---|
383 | var y = Evaluate(dataset, ref row, state);
|
---|
384 | return ArithmeticApply(x, y,
|
---|
385 | (s1, s2) => Math.Pow(s1, 1.0 / Math.Round(s2)),
|
---|
386 | (s1, v2) => DoubleVector.Build.Dense(v2.Count, s1).PointwisePower(1.0 / DoubleVector.Round(v2)),
|
---|
387 | (v1, s2) => v1.PointwisePower(1.0 / Math.Round(s2)),
|
---|
388 | (v1, v2) => v1.PointwisePower(1.0 / DoubleVector.Round(v2)));
|
---|
389 | }
|
---|
390 | case OpCodes.Exp: {
|
---|
391 | var cur = Evaluate(dataset, ref row, state);
|
---|
392 | return FunctionApply(cur,
|
---|
393 | s => Math.Exp(s),
|
---|
394 | v => DoubleVector.Exp(v));
|
---|
395 | }
|
---|
396 | case OpCodes.Log: {
|
---|
397 | var cur = Evaluate(dataset, ref row, state);
|
---|
398 | return FunctionApply(cur,
|
---|
399 | s => Math.Log(s),
|
---|
400 | v => DoubleVector.Log(v));
|
---|
401 | }
|
---|
402 | case OpCodes.Sum: {
|
---|
403 | var cur = Evaluate(dataset, ref row, state);
|
---|
404 | return AggregateApply(cur, (WindowedSymbolTreeNode)currentInstr.dynamicNode,
|
---|
405 | s => s,
|
---|
406 | v => v.Sum());
|
---|
407 | }
|
---|
408 | case OpCodes.Mean: {
|
---|
409 | var cur = Evaluate(dataset, ref row, state);
|
---|
410 | return AggregateApply(cur,
|
---|
411 | s => s,
|
---|
412 | v => v.Mean());
|
---|
413 | }
|
---|
414 | case OpCodes.StandardDeviation: {
|
---|
415 | var cur = Evaluate(dataset, ref row, state);
|
---|
416 | return AggregateApply(cur,
|
---|
417 | s => 0,
|
---|
418 | v => Statistics.PopulationStandardDeviation(v));
|
---|
419 | }
|
---|
420 | case OpCodes.Length: {
|
---|
421 | var cur = Evaluate(dataset, ref row, state);
|
---|
422 | return AggregateApply(cur,
|
---|
423 | s => 1,
|
---|
424 | v => v.Count);
|
---|
425 | }
|
---|
426 | case OpCodes.Min: {
|
---|
427 | var cur = Evaluate(dataset, ref row, state);
|
---|
428 | return AggregateApply(cur,
|
---|
429 | s => s,
|
---|
430 | v => Statistics.Minimum(v));
|
---|
431 | }
|
---|
432 | case OpCodes.Max: {
|
---|
433 | var cur = Evaluate(dataset, ref row, state);
|
---|
434 | return AggregateApply(cur,
|
---|
435 | s => s,
|
---|
436 | v => Statistics.Maximum(v));
|
---|
437 | }
|
---|
438 | case OpCodes.Variance: {
|
---|
439 | var cur = Evaluate(dataset, ref row, state);
|
---|
440 | return AggregateApply(cur,
|
---|
441 | s => 0,
|
---|
442 | v => Statistics.PopulationVariance(v));
|
---|
443 | }
|
---|
444 | case OpCodes.Skewness: {
|
---|
445 | var cur = Evaluate(dataset, ref row, state);
|
---|
446 | return AggregateApply(cur,
|
---|
447 | s => double.NaN,
|
---|
448 | v => Statistics.PopulationSkewness(v));
|
---|
449 | }
|
---|
450 | case OpCodes.Kurtosis: {
|
---|
451 | var cur = Evaluate(dataset, ref row, state);
|
---|
452 | return AggregateApply(cur,
|
---|
453 | s => double.NaN,
|
---|
454 | v => Statistics.PopulationKurtosis(v));
|
---|
455 | }
|
---|
456 | case OpCodes.EuclideanDistance: {
|
---|
457 | var x1 = Evaluate(dataset, ref row, state);
|
---|
458 | var x2 = Evaluate(dataset, ref row, state);
|
---|
459 | return AggregateMultipleApply(x1, x2,
|
---|
460 | //(s1, s2) => s1 - s2,
|
---|
461 | //(s1, v2) => Math.Sqrt((s1 - v2).PointwisePower(2).Sum()),
|
---|
462 | //(v1, s2) => Math.Sqrt((v1 - s2).PointwisePower(2).Sum()),
|
---|
463 | vvFunc: (v1, v2) => v1.Count == v2.Count ? Math.Sqrt((v1 - v2).PointwisePower(2).Sum()) : double.NaN);
|
---|
464 | }
|
---|
465 | case OpCodes.Covariance: {
|
---|
466 | var x1 = Evaluate(dataset, ref row, state);
|
---|
467 | var x2 = Evaluate(dataset, ref row, state);
|
---|
468 | return AggregateMultipleApply(x1, x2,
|
---|
469 | //(s1, s2) => 0,
|
---|
470 | //(s1, v2) => 0,
|
---|
471 | //(v1, s2) => 0,
|
---|
472 | vvFunc: (v1, v2) => v1.Count == v2.Count ? Statistics.PopulationCovariance(v1, v2) : double.NaN);
|
---|
473 | }
|
---|
474 | case OpCodes.Variable: {
|
---|
475 | if (row < 0 || row >= dataset.Rows) return EvaluationResult.NaN;
|
---|
476 | var variableTreeNode = (VariableTreeNode)currentInstr.dynamicNode;
|
---|
477 | if (currentInstr.data is IList<double> doubleList)
|
---|
478 | return new EvaluationResult(doubleList[row] * variableTreeNode.Weight);
|
---|
479 | if (currentInstr.data is IList<DoubleVector> doubleVectorList)
|
---|
480 | return new EvaluationResult(doubleVectorList[row] * variableTreeNode.Weight);
|
---|
481 | throw new NotSupportedException($"Unsupported type of variable: {currentInstr.data.GetType().GetPrettyName()}");
|
---|
482 | }
|
---|
483 | case OpCodes.BinaryFactorVariable: {
|
---|
484 | if (row < 0 || row >= dataset.Rows) return EvaluationResult.NaN;
|
---|
485 | var factorVarTreeNode = currentInstr.dynamicNode as BinaryFactorVariableTreeNode;
|
---|
486 | return new EvaluationResult(((IList<string>)currentInstr.data)[row] == factorVarTreeNode.VariableValue ? factorVarTreeNode.Weight : 0);
|
---|
487 | }
|
---|
488 | case OpCodes.FactorVariable: {
|
---|
489 | if (row < 0 || row >= dataset.Rows) return EvaluationResult.NaN;
|
---|
490 | var factorVarTreeNode = currentInstr.dynamicNode as FactorVariableTreeNode;
|
---|
491 | return new EvaluationResult(factorVarTreeNode.GetValue(((IList<string>)currentInstr.data)[row]));
|
---|
492 | }
|
---|
493 | case OpCodes.Constant: {
|
---|
494 | var constTreeNode = (ConstantTreeNode)currentInstr.dynamicNode;
|
---|
495 | return new EvaluationResult(constTreeNode.Value);
|
---|
496 | }
|
---|
497 |
|
---|
498 | default:
|
---|
499 | throw new NotSupportedException($"Unsupported OpCode: {currentInstr.opCode}");
|
---|
500 | }
|
---|
501 | }
|
---|
502 | }
|
---|
503 | } |
---|