Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeInterpreter.cs @ 17448

Last change on this file since 17448 was 17448, checked in by pfleck, 4 years ago

#3040 Replaced own Vector with MathNet.Numerics Vector.

File size: 26.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector<double>;
23
24using System;
25using System.Collections.Generic;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Parameters;
31using HEAL.Attic;
32using MathNet.Numerics.Statistics;
33
34namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
35  [StorableType("FB94F333-B32A-44FB-A561-CBDE76693D20")]
36  [Item("SymbolicDataAnalysisExpressionTreeInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.")]
37  public class SymbolicDataAnalysisExpressionTreeInterpreter : ParameterizedNamedItem,
38    ISymbolicDataAnalysisExpressionTreeInterpreter {
39    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
40    private const string CheckExpressionsWithIntervalArithmeticParameterDescription = "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.";
41    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
42
43    public override bool CanChangeName {
44      get { return false; }
45    }
46
47    public override bool CanChangeDescription {
48      get { return false; }
49    }
50
51    #region parameter properties
52    public IFixedValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
53      get { return (IFixedValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
54    }
55
56    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
57      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
58    }
59    #endregion
60
61    #region properties
62    public bool CheckExpressionsWithIntervalArithmetic {
63      get { return CheckExpressionsWithIntervalArithmeticParameter.Value.Value; }
64      set { CheckExpressionsWithIntervalArithmeticParameter.Value.Value = value; }
65    }
66
67    public int EvaluatedSolutions {
68      get { return EvaluatedSolutionsParameter.Value.Value; }
69      set { EvaluatedSolutionsParameter.Value.Value = value; }
70    }
71    #endregion
72
73    [StorableConstructor]
74    protected SymbolicDataAnalysisExpressionTreeInterpreter(StorableConstructorFlag _) : base(_) { }
75
76    protected SymbolicDataAnalysisExpressionTreeInterpreter(SymbolicDataAnalysisExpressionTreeInterpreter original,
77      Cloner cloner)
78      : base(original, cloner) { }
79
80    public override IDeepCloneable Clone(Cloner cloner) {
81      return new SymbolicDataAnalysisExpressionTreeInterpreter(this, cloner);
82    }
83
84    public SymbolicDataAnalysisExpressionTreeInterpreter()
85      : base("SymbolicDataAnalysisExpressionTreeInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.") {
86      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
87      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
88    }
89
90    protected SymbolicDataAnalysisExpressionTreeInterpreter(string name, string description)
91      : base(name, description) {
92      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
93      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
94    }
95
96    [StorableHook(HookType.AfterDeserialization)]
97    private void AfterDeserialization() {
98      var evaluatedSolutions = new IntValue(0);
99      var checkExpressionsWithIntervalArithmetic = new BoolValue(false);
100      if (Parameters.ContainsKey(EvaluatedSolutionsParameterName)) {
101        var evaluatedSolutionsParameter = (IValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName];
102        evaluatedSolutions = evaluatedSolutionsParameter.Value;
103        Parameters.Remove(EvaluatedSolutionsParameterName);
104      }
105      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", evaluatedSolutions));
106      if (Parameters.ContainsKey(CheckExpressionsWithIntervalArithmeticParameterName)) {
107        var checkExpressionsWithIntervalArithmeticParameter = (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName];
108        Parameters.Remove(CheckExpressionsWithIntervalArithmeticParameterName);
109        checkExpressionsWithIntervalArithmetic = checkExpressionsWithIntervalArithmeticParameter.Value;
110      }
111      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, CheckExpressionsWithIntervalArithmeticParameterDescription, checkExpressionsWithIntervalArithmetic));
112    }
113
114    #region IStatefulItem
115    public void InitializeState() {
116      EvaluatedSolutions = 0;
117    }
118
119    public void ClearState() { }
120    #endregion
121
122    private readonly object syncRoot = new object();
123    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset,
124      IEnumerable<int> rows) {
125      if (CheckExpressionsWithIntervalArithmetic) {
126        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
127      }
128
129      lock (syncRoot) {
130        EvaluatedSolutions++; // increment the evaluated solutions counter
131      }
132      var state = PrepareInterpreterState(tree, dataset);
133
134      foreach (var rowEnum in rows) {
135        int row = rowEnum;
136        yield return Evaluate(dataset, ref row, state);
137        state.Reset();
138      }
139    }
140
141    private static InterpreterState PrepareInterpreterState(ISymbolicExpressionTree tree, IDataset dataset) {
142      Instruction[] code = SymbolicExpressionTreeCompiler.Compile(tree, OpCodes.MapSymbolToOpCode);
143      int necessaryArgStackSize = 0;
144      foreach (Instruction instr in code) {
145        if (instr.opCode == OpCodes.Variable) {
146          var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
147          instr.data = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
148        } else if (instr.opCode == OpCodes.FactorVariable) {
149          var factorTreeNode = instr.dynamicNode as FactorVariableTreeNode;
150          instr.data = dataset.GetReadOnlyStringValues(factorTreeNode.VariableName);
151        } else if (instr.opCode == OpCodes.BinaryFactorVariable) {
152          var factorTreeNode = instr.dynamicNode as BinaryFactorVariableTreeNode;
153          instr.data = dataset.GetReadOnlyStringValues(factorTreeNode.VariableName);
154        } else if (instr.opCode == OpCodes.VectorVariable) {
155          var vectorVariableTreeNode = (VectorVariableTreeNode)instr.dynamicNode;
156          instr.data = dataset.GetReadOnlyDoubleVectorValues(vectorVariableTreeNode.VariableName);
157        } else if (instr.opCode == OpCodes.LagVariable) {
158          var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
159          instr.data = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
160        } else if (instr.opCode == OpCodes.VariableCondition) {
161          var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
162          instr.data = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
163        } else if (instr.opCode == OpCodes.Call) {
164          necessaryArgStackSize += instr.nArguments + 1;
165        }
166      }
167      return new InterpreterState(code, necessaryArgStackSize);
168    }
169
170    public virtual double Evaluate(IDataset dataset, ref int row, InterpreterState state) {
171      Instruction currentInstr = state.NextInstruction();
172      switch (currentInstr.opCode) {
173        case OpCodes.Add: {
174            double s = Evaluate(dataset, ref row, state);
175            for (int i = 1; i < currentInstr.nArguments; i++) {
176              s += Evaluate(dataset, ref row, state);
177            }
178            return s;
179          }
180        case OpCodes.Sub: {
181            double s = Evaluate(dataset, ref row, state);
182            for (int i = 1; i < currentInstr.nArguments; i++) {
183              s -= Evaluate(dataset, ref row, state);
184            }
185            if (currentInstr.nArguments == 1) { s = -s; }
186            return s;
187          }
188        case OpCodes.Mul: {
189            double p = Evaluate(dataset, ref row, state);
190            for (int i = 1; i < currentInstr.nArguments; i++) {
191              p *= Evaluate(dataset, ref row, state);
192            }
193            return p;
194          }
195        case OpCodes.Div: {
196            double p = Evaluate(dataset, ref row, state);
197            for (int i = 1; i < currentInstr.nArguments; i++) {
198              p /= Evaluate(dataset, ref row, state);
199            }
200            if (currentInstr.nArguments == 1) { p = 1.0 / p; }
201            return p;
202          }
203        case OpCodes.Average: {
204            double sum = Evaluate(dataset, ref row, state);
205            for (int i = 1; i < currentInstr.nArguments; i++) {
206              sum += Evaluate(dataset, ref row, state);
207            }
208            return sum / currentInstr.nArguments;
209          }
210        case OpCodes.Absolute: {
211            return Math.Abs(Evaluate(dataset, ref row, state));
212          }
213        case OpCodes.Tanh: {
214            return Math.Tanh(Evaluate(dataset, ref row, state));
215          }
216        case OpCodes.Cos: {
217            return Math.Cos(Evaluate(dataset, ref row, state));
218          }
219        case OpCodes.Sin: {
220            return Math.Sin(Evaluate(dataset, ref row, state));
221          }
222        case OpCodes.Tan: {
223            return Math.Tan(Evaluate(dataset, ref row, state));
224          }
225        case OpCodes.Square: {
226            return Math.Pow(Evaluate(dataset, ref row, state), 2);
227          }
228        case OpCodes.Cube: {
229            return Math.Pow(Evaluate(dataset, ref row, state), 3);
230          }
231        case OpCodes.Power: {
232            double x = Evaluate(dataset, ref row, state);
233            double y = Math.Round(Evaluate(dataset, ref row, state));
234            return Math.Pow(x, y);
235          }
236        case OpCodes.SquareRoot: {
237            return Math.Sqrt(Evaluate(dataset, ref row, state));
238          }
239        case OpCodes.CubeRoot: {
240            var arg = Evaluate(dataset, ref row, state);
241            return arg < 0 ? -Math.Pow(-arg, 1.0 / 3.0) : Math.Pow(arg, 1.0 / 3.0);
242          }
243        case OpCodes.Root: {
244            double x = Evaluate(dataset, ref row, state);
245            double y = Math.Round(Evaluate(dataset, ref row, state));
246            return Math.Pow(x, 1 / y);
247          }
248        case OpCodes.Exp: {
249            return Math.Exp(Evaluate(dataset, ref row, state));
250          }
251        case OpCodes.Log: {
252            return Math.Log(Evaluate(dataset, ref row, state));
253          }
254        case OpCodes.Gamma: {
255            var x = Evaluate(dataset, ref row, state);
256            if (double.IsNaN(x)) { return double.NaN; } else { return alglib.gammafunction(x); }
257          }
258        case OpCodes.Psi: {
259            var x = Evaluate(dataset, ref row, state);
260            if (double.IsNaN(x)) return double.NaN;
261            else if (x <= 0 && (Math.Floor(x) - x).IsAlmost(0)) return double.NaN;
262            return alglib.psi(x);
263          }
264        case OpCodes.Dawson: {
265            var x = Evaluate(dataset, ref row, state);
266            if (double.IsNaN(x)) { return double.NaN; }
267            return alglib.dawsonintegral(x);
268          }
269        case OpCodes.ExponentialIntegralEi: {
270            var x = Evaluate(dataset, ref row, state);
271            if (double.IsNaN(x)) { return double.NaN; }
272            return alglib.exponentialintegralei(x);
273          }
274        case OpCodes.SineIntegral: {
275            double si, ci;
276            var x = Evaluate(dataset, ref row, state);
277            if (double.IsNaN(x)) return double.NaN;
278            else {
279              alglib.sinecosineintegrals(x, out si, out ci);
280              return si;
281            }
282          }
283        case OpCodes.CosineIntegral: {
284            double si, ci;
285            var x = Evaluate(dataset, ref row, state);
286            if (double.IsNaN(x)) return double.NaN;
287            else {
288              alglib.sinecosineintegrals(x, out si, out ci);
289              return ci;
290            }
291          }
292        case OpCodes.HyperbolicSineIntegral: {
293            double shi, chi;
294            var x = Evaluate(dataset, ref row, state);
295            if (double.IsNaN(x)) return double.NaN;
296            else {
297              alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
298              return shi;
299            }
300          }
301        case OpCodes.HyperbolicCosineIntegral: {
302            double shi, chi;
303            var x = Evaluate(dataset, ref row, state);
304            if (double.IsNaN(x)) return double.NaN;
305            else {
306              alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
307              return chi;
308            }
309          }
310        case OpCodes.FresnelCosineIntegral: {
311            double c = 0, s = 0;
312            var x = Evaluate(dataset, ref row, state);
313            if (double.IsNaN(x)) return double.NaN;
314            else {
315              alglib.fresnelintegral(x, ref c, ref s);
316              return c;
317            }
318          }
319        case OpCodes.FresnelSineIntegral: {
320            double c = 0, s = 0;
321            var x = Evaluate(dataset, ref row, state);
322            if (double.IsNaN(x)) return double.NaN;
323            else {
324              alglib.fresnelintegral(x, ref c, ref s);
325              return s;
326            }
327          }
328        case OpCodes.AiryA: {
329            double ai, aip, bi, bip;
330            var x = Evaluate(dataset, ref row, state);
331            if (double.IsNaN(x)) return double.NaN;
332            else {
333              alglib.airy(x, out ai, out aip, out bi, out bip);
334              return ai;
335            }
336          }
337        case OpCodes.AiryB: {
338            double ai, aip, bi, bip;
339            var x = Evaluate(dataset, ref row, state);
340            if (double.IsNaN(x)) return double.NaN;
341            else {
342              alglib.airy(x, out ai, out aip, out bi, out bip);
343              return bi;
344            }
345          }
346        case OpCodes.Norm: {
347            var x = Evaluate(dataset, ref row, state);
348            if (double.IsNaN(x)) return double.NaN;
349            else return alglib.normaldistribution(x);
350          }
351        case OpCodes.Erf: {
352            var x = Evaluate(dataset, ref row, state);
353            if (double.IsNaN(x)) return double.NaN;
354            else return alglib.errorfunction(x);
355          }
356        case OpCodes.Bessel: {
357            var x = Evaluate(dataset, ref row, state);
358            if (double.IsNaN(x)) return double.NaN;
359            else return alglib.besseli0(x);
360          }
361
362        case OpCodes.AnalyticQuotient: {
363            var x1 = Evaluate(dataset, ref row, state);
364            var x2 = Evaluate(dataset, ref row, state);
365            return x1 / Math.Pow(1 + x2 * x2, 0.5);
366          }
367        case OpCodes.IfThenElse: {
368            double condition = Evaluate(dataset, ref row, state);
369            double result;
370            if (condition > 0.0) {
371              result = Evaluate(dataset, ref row, state); state.SkipInstructions();
372            } else {
373              state.SkipInstructions(); result = Evaluate(dataset, ref row, state);
374            }
375            return result;
376          }
377        case OpCodes.AND: {
378            double result = Evaluate(dataset, ref row, state);
379            for (int i = 1; i < currentInstr.nArguments; i++) {
380              if (result > 0.0) result = Evaluate(dataset, ref row, state);
381              else {
382                state.SkipInstructions();
383              }
384            }
385            return result > 0.0 ? 1.0 : -1.0;
386          }
387        case OpCodes.OR: {
388            double result = Evaluate(dataset, ref row, state);
389            for (int i = 1; i < currentInstr.nArguments; i++) {
390              if (result <= 0.0) result = Evaluate(dataset, ref row, state);
391              else {
392                state.SkipInstructions();
393              }
394            }
395            return result > 0.0 ? 1.0 : -1.0;
396          }
397        case OpCodes.NOT: {
398            return Evaluate(dataset, ref row, state) > 0.0 ? -1.0 : 1.0;
399          }
400        case OpCodes.XOR: {
401            //mkommend: XOR on multiple inputs is defined as true if the number of positive signals is odd
402            // this is equal to a consecutive execution of binary XOR operations.
403            int positiveSignals = 0;
404            for (int i = 0; i < currentInstr.nArguments; i++) {
405              if (Evaluate(dataset, ref row, state) > 0.0) { positiveSignals++; }
406            }
407            return positiveSignals % 2 != 0 ? 1.0 : -1.0;
408          }
409        case OpCodes.GT: {
410            double x = Evaluate(dataset, ref row, state);
411            double y = Evaluate(dataset, ref row, state);
412            if (x > y) { return 1.0; } else { return -1.0; }
413          }
414        case OpCodes.LT: {
415            double x = Evaluate(dataset, ref row, state);
416            double y = Evaluate(dataset, ref row, state);
417            if (x < y) { return 1.0; } else { return -1.0; }
418          }
419        case OpCodes.TimeLag: {
420            var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
421            row += timeLagTreeNode.Lag;
422            double result = Evaluate(dataset, ref row, state);
423            row -= timeLagTreeNode.Lag;
424            return result;
425          }
426        case OpCodes.Integral: {
427            int savedPc = state.ProgramCounter;
428            var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
429            double sum = 0.0;
430            for (int i = 0; i < Math.Abs(timeLagTreeNode.Lag); i++) {
431              row += Math.Sign(timeLagTreeNode.Lag);
432              sum += Evaluate(dataset, ref row, state);
433              state.ProgramCounter = savedPc;
434            }
435            row -= timeLagTreeNode.Lag;
436            sum += Evaluate(dataset, ref row, state);
437            return sum;
438          }
439
440        //mkommend: derivate calculation taken from:
441        //http://www.holoborodko.com/pavel/numerical-methods/numerical-derivative/smooth-low-noise-differentiators/
442        //one sided smooth differentiatior, N = 4
443        // y' = 1/8h (f_i + 2f_i-1, -2 f_i-3 - f_i-4)
444        case OpCodes.Derivative: {
445            int savedPc = state.ProgramCounter;
446            double f_0 = Evaluate(dataset, ref row, state); row--;
447            state.ProgramCounter = savedPc;
448            double f_1 = Evaluate(dataset, ref row, state); row -= 2;
449            state.ProgramCounter = savedPc;
450            double f_3 = Evaluate(dataset, ref row, state); row--;
451            state.ProgramCounter = savedPc;
452            double f_4 = Evaluate(dataset, ref row, state);
453            row += 4;
454
455            return (f_0 + 2 * f_1 - 2 * f_3 - f_4) / 8; // h = 1
456          }
457        case OpCodes.Call: {
458            // evaluate sub-trees
459            double[] argValues = new double[currentInstr.nArguments];
460            for (int i = 0; i < currentInstr.nArguments; i++) {
461              argValues[i] = Evaluate(dataset, ref row, state);
462            }
463            // push on argument values on stack
464            state.CreateStackFrame(argValues);
465
466            // save the pc
467            int savedPc = state.ProgramCounter;
468            // set pc to start of function 
469            state.ProgramCounter = (ushort)currentInstr.data;
470            // evaluate the function
471            double v = Evaluate(dataset, ref row, state);
472
473            // delete the stack frame
474            state.RemoveStackFrame();
475
476            // restore the pc => evaluation will continue at point after my subtrees 
477            state.ProgramCounter = savedPc;
478            return v;
479          }
480        case OpCodes.Arg: {
481            return state.GetStackFrameValue((ushort)currentInstr.data);
482          }
483        case OpCodes.Variable: {
484            if (row < 0 || row >= dataset.Rows) return double.NaN;
485            var variableTreeNode = (VariableTreeNode)currentInstr.dynamicNode;
486            return ((IList<double>)currentInstr.data)[row] * variableTreeNode.Weight;
487          }
488        case OpCodes.BinaryFactorVariable: {
489            if (row < 0 || row >= dataset.Rows) return double.NaN;
490            var factorVarTreeNode = currentInstr.dynamicNode as BinaryFactorVariableTreeNode;
491            return ((IList<string>)currentInstr.data)[row] == factorVarTreeNode.VariableValue ? factorVarTreeNode.Weight : 0;
492          }
493        case OpCodes.FactorVariable: {
494            if (row < 0 || row >= dataset.Rows) return double.NaN;
495            var factorVarTreeNode = currentInstr.dynamicNode as FactorVariableTreeNode;
496            return factorVarTreeNode.GetValue(((IList<string>)currentInstr.data)[row]);
497          }
498        case OpCodes.LagVariable: {
499            var laggedVariableTreeNode = (LaggedVariableTreeNode)currentInstr.dynamicNode;
500            int actualRow = row + laggedVariableTreeNode.Lag;
501            if (actualRow < 0 || actualRow >= dataset.Rows) { return double.NaN; }
502            return ((IList<double>)currentInstr.data)[actualRow] * laggedVariableTreeNode.Weight;
503          }
504        case OpCodes.Constant: {
505            var constTreeNode = (ConstantTreeNode)currentInstr.dynamicNode;
506            return constTreeNode.Value;
507          }
508
509        //mkommend: this symbol uses the logistic function f(x) = 1 / (1 + e^(-alpha * x) )
510        //to determine the relative amounts of the true and false branch see http://en.wikipedia.org/wiki/Logistic_function
511        case OpCodes.VariableCondition: {
512            if (row < 0 || row >= dataset.Rows) return double.NaN;
513            var variableConditionTreeNode = (VariableConditionTreeNode)currentInstr.dynamicNode;
514            if (!variableConditionTreeNode.Symbol.IgnoreSlope) {
515              double variableValue = ((IList<double>)currentInstr.data)[row];
516              double x = variableValue - variableConditionTreeNode.Threshold;
517              double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
518
519              double trueBranch = Evaluate(dataset, ref row, state);
520              double falseBranch = Evaluate(dataset, ref row, state);
521
522              return trueBranch * p + falseBranch * (1 - p);
523            } else {
524              // strict threshold
525              double variableValue = ((IList<double>)currentInstr.data)[row];
526              if (variableValue <= variableConditionTreeNode.Threshold) {
527                var left = Evaluate(dataset, ref row, state);
528                state.SkipInstructions();
529                return left;
530              } else {
531                state.SkipInstructions();
532                return Evaluate(dataset, ref row, state);
533              }
534            }
535          }
536
537        case OpCodes.VectorSum: {
538            DoubleVector v = VectorEvaluate(dataset, ref row, state);
539            return v.Sum();
540          }
541        case OpCodes.VectorMean: {
542            DoubleVector v = VectorEvaluate(dataset, ref row, state);
543            return v.Mean();
544          }
545
546        default:
547          throw new NotSupportedException();
548      }
549    }
550
551    public virtual DoubleVector VectorEvaluate(IDataset dataset, ref int row, InterpreterState state) {
552      Instruction currentInstr = state.NextInstruction();
553      switch (currentInstr.opCode) {
554        case OpCodes.VectorAdd: {
555            DoubleVector s = VectorEvaluate(dataset, ref row, state);
556            for (int i = 1; i < currentInstr.nArguments; i++) {
557              s += VectorEvaluate(dataset, ref row, state);
558            }
559            return s;
560          }
561        case OpCodes.VectorSub: {
562            DoubleVector s = VectorEvaluate(dataset, ref row, state);
563            for (int i = 1; i < currentInstr.nArguments; i++) {
564              s -= VectorEvaluate(dataset, ref row, state);
565            }
566            return s;
567          }
568        case OpCodes.VectorMul: {
569            DoubleVector s = VectorEvaluate(dataset, ref row, state);
570            for (int i = 1; i < currentInstr.nArguments; i++) {
571              s = s.PointwiseMultiply(VectorEvaluate(dataset, ref row, state));
572            }
573            return s;
574          }
575        case OpCodes.VectorDiv: {
576            DoubleVector s = VectorEvaluate(dataset, ref row, state);
577            for (int i = 1; i < currentInstr.nArguments; i++) {
578              s /= VectorEvaluate(dataset, ref row, state);
579            }
580            return s;
581          }
582
583        case OpCodes.VectorVariable: {
584            if (row < 0 || row >= dataset.Rows) return DoubleVector.Build.Dense(new[] { double.NaN });
585            var vectorVarTreeNode = currentInstr.dynamicNode as VectorVariableTreeNode;
586            return ((IList<DoubleVector>)currentInstr.data)[row] * vectorVarTreeNode.Weight;
587          }
588        default:
589          throw new NotSupportedException();
590      }
591    }
592  }
593}
Note: See TracBrowser for help on using the repository browser.