Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
02/28/20 10:31:57 (4 years ago)
Author:
pfleck
Message:

#3040 Added separate Interpreter for vector that reuse the existing symbols instead of creating explicit vector symbols.

File:
1 copied

Legend:

Unmodified
Added
Removed
  • branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeVectorInterpreter.cs

    r17448 r17455  
    2020#endregion
    2121
    22 using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector<double>;
    23 
    2422using System;
    2523using System.Collections.Generic;
     24using HeuristicLab.Analysis;
    2625using HeuristicLab.Common;
    2726using HeuristicLab.Core;
     
    3029using HeuristicLab.Parameters;
    3130using HEAL.Attic;
     31using MathNet.Numerics.LinearAlgebra;
    3232using MathNet.Numerics.Statistics;
    3333
     34using DoubleVector = MathNet.Numerics.LinearAlgebra.Vector<double>;
     35
    3436namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
    35   [StorableType("FB94F333-B32A-44FB-A561-CBDE76693D20")]
    36   [Item("SymbolicDataAnalysisExpressionTreeInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.")]
    37   public class SymbolicDataAnalysisExpressionTreeInterpreter : ParameterizedNamedItem,
    38     ISymbolicDataAnalysisExpressionTreeInterpreter {
    39     private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
    40     private const string CheckExpressionsWithIntervalArithmeticParameterDescription = "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.";
     37  [StorableType("DE68A1D9-5AFC-4DDD-AB62-29F3B8FC28E0")]
     38  [Item("SymbolicDataAnalysisExpressionTreeVectorInterpreter", "Interpreter for symbolic expression trees including vector arithmetic.")]
     39  public class SymbolicDataAnalysisExpressionTreeVectorInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
     40
    4141    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
    4242
     
    5050
    5151    #region parameter properties
    52     public IFixedValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
    53       get { return (IFixedValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
    54     }
    55 
    5652    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
    5753      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
     
    6056
    6157    #region properties
    62     public bool CheckExpressionsWithIntervalArithmetic {
    63       get { return CheckExpressionsWithIntervalArithmeticParameter.Value.Value; }
    64       set { CheckExpressionsWithIntervalArithmeticParameter.Value.Value = value; }
    65     }
    66 
    6758    public int EvaluatedSolutions {
    6859      get { return EvaluatedSolutionsParameter.Value.Value; }
     
    7263
    7364    [StorableConstructor]
    74     protected SymbolicDataAnalysisExpressionTreeInterpreter(StorableConstructorFlag _) : base(_) { }
    75 
    76     protected SymbolicDataAnalysisExpressionTreeInterpreter(SymbolicDataAnalysisExpressionTreeInterpreter original,
    77       Cloner cloner)
     65    protected SymbolicDataAnalysisExpressionTreeVectorInterpreter(StorableConstructorFlag _) : base(_) { }
     66
     67    protected SymbolicDataAnalysisExpressionTreeVectorInterpreter(SymbolicDataAnalysisExpressionTreeVectorInterpreter original, Cloner cloner)
    7868      : base(original, cloner) { }
    7969
    8070    public override IDeepCloneable Clone(Cloner cloner) {
    81       return new SymbolicDataAnalysisExpressionTreeInterpreter(this, cloner);
    82     }
    83 
    84     public SymbolicDataAnalysisExpressionTreeInterpreter()
    85       : base("SymbolicDataAnalysisExpressionTreeInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.") {
    86       Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
     71      return new SymbolicDataAnalysisExpressionTreeVectorInterpreter(this, cloner);
     72    }
     73
     74    public SymbolicDataAnalysisExpressionTreeVectorInterpreter()
     75      : base("SymbolicDataAnalysisExpressionTreeVectorInterpreter", "Interpreter for symbolic expression trees including vector arithmetic.") {
    8776      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
    8877    }
    8978
    90     protected SymbolicDataAnalysisExpressionTreeInterpreter(string name, string description)
     79    protected SymbolicDataAnalysisExpressionTreeVectorInterpreter(string name, string description)
    9180      : base(name, description) {
    92       Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
    9381      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
    9482    }
     
    9684    [StorableHook(HookType.AfterDeserialization)]
    9785    private void AfterDeserialization() {
    98       var evaluatedSolutions = new IntValue(0);
    99       var checkExpressionsWithIntervalArithmetic = new BoolValue(false);
    100       if (Parameters.ContainsKey(EvaluatedSolutionsParameterName)) {
    101         var evaluatedSolutionsParameter = (IValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName];
    102         evaluatedSolutions = evaluatedSolutionsParameter.Value;
    103         Parameters.Remove(EvaluatedSolutionsParameterName);
    104       }
    105       Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", evaluatedSolutions));
    106       if (Parameters.ContainsKey(CheckExpressionsWithIntervalArithmeticParameterName)) {
    107         var checkExpressionsWithIntervalArithmeticParameter = (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName];
    108         Parameters.Remove(CheckExpressionsWithIntervalArithmeticParameterName);
    109         checkExpressionsWithIntervalArithmetic = checkExpressionsWithIntervalArithmeticParameter.Value;
    110       }
    111       Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, CheckExpressionsWithIntervalArithmeticParameterDescription, checkExpressionsWithIntervalArithmetic));
     86
    11287    }
    11388
     
    12196
    12297    private readonly object syncRoot = new object();
    123     public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset,
    124       IEnumerable<int> rows) {
    125       if (CheckExpressionsWithIntervalArithmetic) {
    126         throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
    127       }
    128 
     98    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
    12999      lock (syncRoot) {
    130100        EvaluatedSolutions++; // increment the evaluated solutions counter
     
    134104      foreach (var rowEnum in rows) {
    135105        int row = rowEnum;
    136         yield return Evaluate(dataset, ref row, state);
     106        var result = Evaluate(dataset, ref row, state);
     107        if (!result.IsScalar)
     108          throw new InvalidOperationException("Result of the tree is not a scalar.");
     109        yield return result.Scalar;
    137110        state.Reset();
    138111      }
     
    152125          var factorTreeNode = instr.dynamicNode as BinaryFactorVariableTreeNode;
    153126          instr.data = dataset.GetReadOnlyStringValues(factorTreeNode.VariableName);
    154         } else if (instr.opCode == OpCodes.VectorVariable) {
    155           var vectorVariableTreeNode = (VectorVariableTreeNode)instr.dynamicNode;
    156           instr.data = dataset.GetReadOnlyDoubleVectorValues(vectorVariableTreeNode.VariableName);
    157127        } else if (instr.opCode == OpCodes.LagVariable) {
    158128          var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
     
    168138    }
    169139
    170     public virtual double Evaluate(IDataset dataset, ref int row, InterpreterState state) {
     140
     141    public struct EvaluationResult {
     142      public double Scalar { get; }
     143      public bool IsScalar => !double.IsNaN(Scalar);
     144
     145      public DoubleVector Vector { get; }
     146      public bool IsVector => !(Vector.Count == 1 && double.IsNaN(Vector[0]));
     147
     148      public bool IsNaN => !IsScalar && !IsVector;
     149
     150      public EvaluationResult(double scalar) {
     151        Scalar = scalar;
     152        Vector = NaNVector;
     153      }
     154      public EvaluationResult(DoubleVector vector) {
     155        Vector = vector;
     156        Scalar = double.NaN;
     157      }
     158
     159      public override string ToString() {
     160        if (IsScalar) return Scalar.ToString();
     161        if (IsVector) return Vector.ToVectorString();
     162        return "NaN";
     163      }
     164
     165      public static readonly EvaluationResult NaN = new EvaluationResult(double.NaN);
     166      private static readonly DoubleVector NaNVector = DoubleVector.Build.Dense(1, double.NaN);
     167    }
     168
     169    private static EvaluationResult ArithmeticApply(EvaluationResult lhs, EvaluationResult rhs,
     170      Func<double, double, double> ssFunc = null,
     171      Func<double, DoubleVector, DoubleVector> svFunc = null,
     172      Func<DoubleVector, double, DoubleVector> vsFunc = null,
     173      Func<DoubleVector, DoubleVector, DoubleVector> vvFunc = null) {
     174      if (lhs.IsScalar && rhs.IsScalar && ssFunc != null) return new EvaluationResult(ssFunc(lhs.Scalar, rhs.Scalar));
     175      if (lhs.IsScalar && rhs.IsVector && svFunc != null) return new EvaluationResult(svFunc(lhs.Scalar, rhs.Vector));
     176      if (lhs.IsVector && rhs.IsScalar && vsFunc != null) return new EvaluationResult(vsFunc(lhs.Vector, rhs.Scalar));
     177      if (lhs.IsVector && rhs.IsVector && vvFunc != null) return new EvaluationResult(vvFunc(lhs.Vector, rhs.Vector));
     178      throw new NotSupportedException($"Unsupported combination of argument types: ({lhs}) / ({rhs})");
     179    }
     180
     181    private static EvaluationResult FunctionApply(EvaluationResult val,
     182      Func<double, double> sFunc = null,
     183      Func<DoubleVector, DoubleVector> vFunc = null) {
     184      if (val.IsScalar && sFunc != null) return new EvaluationResult(sFunc(val.Scalar));
     185      if (val.IsVector && vFunc != null) return new EvaluationResult(vFunc(val.Vector));
     186      throw new NotSupportedException($"Unsupported argument type ({val})");
     187    }
     188
     189    public virtual EvaluationResult Evaluate(IDataset dataset, ref int row, InterpreterState state) {
    171190      Instruction currentInstr = state.NextInstruction();
    172191      switch (currentInstr.opCode) {
    173192        case OpCodes.Add: {
    174             double s = Evaluate(dataset, ref row, state);
     193            var cur = Evaluate(dataset, ref row, state);
    175194            for (int i = 1; i < currentInstr.nArguments; i++) {
    176               s += Evaluate(dataset, ref row, state);
     195              var op = Evaluate(dataset, ref row, state);
     196              cur = ArithmeticApply(cur, op,
     197                (s1, s2) => s1 + s2,
     198                (s1, v2) => s1 + v2,
     199                (v1, s2) => v1 + s2,
     200                (v1, v2) => v1 + v2);
    177201            }
    178             return s;
     202            return cur;
    179203          }
    180204        case OpCodes.Sub: {
    181             double s = Evaluate(dataset, ref row, state);
     205            var cur = Evaluate(dataset, ref row, state);
    182206            for (int i = 1; i < currentInstr.nArguments; i++) {
    183               s -= Evaluate(dataset, ref row, state);
     207              var op = Evaluate(dataset, ref row, state);
     208              cur = ArithmeticApply(cur, op,
     209                (s1, s2) => s1 - s2,
     210                (s1, v2) => s1 - v2,
     211                (v1, s2) => v1 - s2,
     212                (v1, v2) => v1 - v2);
    184213            }
    185             if (currentInstr.nArguments == 1) { s = -s; }
    186             return s;
     214            return cur;
    187215          }
    188216        case OpCodes.Mul: {
    189             double p = Evaluate(dataset, ref row, state);
     217            var cur = Evaluate(dataset, ref row, state);
    190218            for (int i = 1; i < currentInstr.nArguments; i++) {
    191               p *= Evaluate(dataset, ref row, state);
     219              var op = Evaluate(dataset, ref row, state);
     220              cur = ArithmeticApply(cur, op,
     221                (s1, s2) => s1 * s2,
     222                (s1, v2) => s1 * v2,
     223                (v1, s2) => v1 * s2,
     224                (v1, v2) => v1.PointwiseMultiply(v2));
    192225            }
    193             return p;
     226            return cur;
    194227          }
    195228        case OpCodes.Div: {
    196             double p = Evaluate(dataset, ref row, state);
     229            var cur = Evaluate(dataset, ref row, state);
    197230            for (int i = 1; i < currentInstr.nArguments; i++) {
    198               p /= Evaluate(dataset, ref row, state);
     231              var op = Evaluate(dataset, ref row, state);
     232              cur = ArithmeticApply(cur, op,
     233                (s1, s2) => s1 / s2,
     234                (s1, v2) => s1 / v2,
     235                (v1, s2) => v1 / s2,
     236                (v1, v2) => v1 / v2);
    199237            }
    200             if (currentInstr.nArguments == 1) { p = 1.0 / p; }
    201             return p;
    202           }
    203         case OpCodes.Average: {
    204             double sum = Evaluate(dataset, ref row, state);
    205             for (int i = 1; i < currentInstr.nArguments; i++) {
    206               sum += Evaluate(dataset, ref row, state);
    207             }
    208             return sum / currentInstr.nArguments;
     238            return cur;
    209239          }
    210240        case OpCodes.Absolute: {
    211             return Math.Abs(Evaluate(dataset, ref row, state));
     241            var cur = Evaluate(dataset, ref row, state);
     242            return FunctionApply(cur, Math.Abs, DoubleVector.Abs);
    212243          }
    213244        case OpCodes.Tanh: {
    214             return Math.Tanh(Evaluate(dataset, ref row, state));
     245            var cur = Evaluate(dataset, ref row, state);
     246            return FunctionApply(cur, Math.Tanh, DoubleVector.Tanh);
    215247          }
    216248        case OpCodes.Cos: {
    217             return Math.Cos(Evaluate(dataset, ref row, state));
     249            var cur = Evaluate(dataset, ref row, state);
     250            return FunctionApply(cur, Math.Cos, DoubleVector.Cos);
    218251          }
    219252        case OpCodes.Sin: {
    220             return Math.Sin(Evaluate(dataset, ref row, state));
     253            var cur = Evaluate(dataset, ref row, state);
     254            return FunctionApply(cur, Math.Sin, DoubleVector.Sin);
    221255          }
    222256        case OpCodes.Tan: {
    223             return Math.Tan(Evaluate(dataset, ref row, state));
     257            var cur = Evaluate(dataset, ref row, state);
     258            return FunctionApply(cur, Math.Tan, DoubleVector.Tan);
    224259          }
    225260        case OpCodes.Square: {
    226             return Math.Pow(Evaluate(dataset, ref row, state), 2);
     261            var cur = Evaluate(dataset, ref row, state);
     262            return FunctionApply(cur,
     263              s => Math.Pow(s, 2),
     264              v => v.PointwisePower(2));
    227265          }
    228266        case OpCodes.Cube: {
    229             return Math.Pow(Evaluate(dataset, ref row, state), 3);
     267            var cur = Evaluate(dataset, ref row, state);
     268            return FunctionApply(cur,
     269              s => Math.Pow(s, 3),
     270              v => v.PointwisePower(3));
    230271          }
    231272        case OpCodes.Power: {
    232             double x = Evaluate(dataset, ref row, state);
    233             double y = Math.Round(Evaluate(dataset, ref row, state));
    234             return Math.Pow(x, y);
     273            var x = Evaluate(dataset, ref row, state);
     274            var y = Evaluate(dataset, ref row, state);
     275            return ArithmeticApply(x, y,
     276              (s1, s2) => Math.Pow(s1, Math.Round(s2)),
     277              (s1, v2) => DoubleVector.Build.Dense(v2.Count, s1).PointwisePower(DoubleVector.Round(v2)),
     278              (v1, s2) => v1.PointwisePower(Math.Round(s2)),
     279              (v1, v2) => v1.PointwisePower(DoubleVector.Round(v2)));
    235280          }
    236281        case OpCodes.SquareRoot: {
    237             return Math.Sqrt(Evaluate(dataset, ref row, state));
     282            var cur = Evaluate(dataset, ref row, state);
     283            return FunctionApply(cur,
     284              s => Math.Sqrt(s),
     285              v => DoubleVector.Sqrt(v));
    238286          }
    239287        case OpCodes.CubeRoot: {
    240             var arg = Evaluate(dataset, ref row, state);
    241             return arg < 0 ? -Math.Pow(-arg, 1.0 / 3.0) : Math.Pow(arg, 1.0 / 3.0);
     288            var cur = Evaluate(dataset, ref row, state);
     289            return FunctionApply(cur,
     290              s => s < 0 ? -Math.Pow(-s, 1.0 / 3.0) : Math.Pow(s, 1.0 / 3.0),
     291              v => v.Map(s => s < 0 ? -Math.Pow(-s, 1.0 / 3.0) : Math.Pow(s, 1.0 / 3.0)));
    242292          }
    243293        case OpCodes.Root: {
    244             double x = Evaluate(dataset, ref row, state);
    245             double y = Math.Round(Evaluate(dataset, ref row, state));
    246             return Math.Pow(x, 1 / y);
     294            var x = Evaluate(dataset, ref row, state);
     295            var y = Evaluate(dataset, ref row, state);
     296            return ArithmeticApply(x, y,
     297              (s1, s2) => Math.Pow(s1, 1.0 / Math.Round(s2)),
     298              (s1, v2) => DoubleVector.Build.Dense(v2.Count, s1).PointwisePower(1.0 / DoubleVector.Round(v2)),
     299              (v1, s2) => v1.PointwisePower(1.0 / Math.Round(s2)),
     300              (v1, v2) => v1.PointwisePower(1.0 / DoubleVector.Round(v2)));
    247301          }
    248302        case OpCodes.Exp: {
    249             return Math.Exp(Evaluate(dataset, ref row, state));
     303            var cur = Evaluate(dataset, ref row, state);
     304            return FunctionApply(cur,
     305              s => Math.Exp(s),
     306              v => DoubleVector.Exp(v));
    250307          }
    251308        case OpCodes.Log: {
    252             return Math.Log(Evaluate(dataset, ref row, state));
    253           }
    254         case OpCodes.Gamma: {
    255             var x = Evaluate(dataset, ref row, state);
    256             if (double.IsNaN(x)) { return double.NaN; } else { return alglib.gammafunction(x); }
    257           }
    258         case OpCodes.Psi: {
    259             var x = Evaluate(dataset, ref row, state);
    260             if (double.IsNaN(x)) return double.NaN;
    261             else if (x <= 0 && (Math.Floor(x) - x).IsAlmost(0)) return double.NaN;
    262             return alglib.psi(x);
    263           }
    264         case OpCodes.Dawson: {
    265             var x = Evaluate(dataset, ref row, state);
    266             if (double.IsNaN(x)) { return double.NaN; }
    267             return alglib.dawsonintegral(x);
    268           }
    269         case OpCodes.ExponentialIntegralEi: {
    270             var x = Evaluate(dataset, ref row, state);
    271             if (double.IsNaN(x)) { return double.NaN; }
    272             return alglib.exponentialintegralei(x);
    273           }
    274         case OpCodes.SineIntegral: {
    275             double si, ci;
    276             var x = Evaluate(dataset, ref row, state);
    277             if (double.IsNaN(x)) return double.NaN;
    278             else {
    279               alglib.sinecosineintegrals(x, out si, out ci);
    280               return si;
    281             }
    282           }
    283         case OpCodes.CosineIntegral: {
    284             double si, ci;
    285             var x = Evaluate(dataset, ref row, state);
    286             if (double.IsNaN(x)) return double.NaN;
    287             else {
    288               alglib.sinecosineintegrals(x, out si, out ci);
    289               return ci;
    290             }
    291           }
    292         case OpCodes.HyperbolicSineIntegral: {
    293             double shi, chi;
    294             var x = Evaluate(dataset, ref row, state);
    295             if (double.IsNaN(x)) return double.NaN;
    296             else {
    297               alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
    298               return shi;
    299             }
    300           }
    301         case OpCodes.HyperbolicCosineIntegral: {
    302             double shi, chi;
    303             var x = Evaluate(dataset, ref row, state);
    304             if (double.IsNaN(x)) return double.NaN;
    305             else {
    306               alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
    307               return chi;
    308             }
    309           }
    310         case OpCodes.FresnelCosineIntegral: {
    311             double c = 0, s = 0;
    312             var x = Evaluate(dataset, ref row, state);
    313             if (double.IsNaN(x)) return double.NaN;
    314             else {
    315               alglib.fresnelintegral(x, ref c, ref s);
    316               return c;
    317             }
    318           }
    319         case OpCodes.FresnelSineIntegral: {
    320             double c = 0, s = 0;
    321             var x = Evaluate(dataset, ref row, state);
    322             if (double.IsNaN(x)) return double.NaN;
    323             else {
    324               alglib.fresnelintegral(x, ref c, ref s);
    325               return s;
    326             }
    327           }
    328         case OpCodes.AiryA: {
    329             double ai, aip, bi, bip;
    330             var x = Evaluate(dataset, ref row, state);
    331             if (double.IsNaN(x)) return double.NaN;
    332             else {
    333               alglib.airy(x, out ai, out aip, out bi, out bip);
    334               return ai;
    335             }
    336           }
    337         case OpCodes.AiryB: {
    338             double ai, aip, bi, bip;
    339             var x = Evaluate(dataset, ref row, state);
    340             if (double.IsNaN(x)) return double.NaN;
    341             else {
    342               alglib.airy(x, out ai, out aip, out bi, out bip);
    343               return bi;
    344             }
    345           }
    346         case OpCodes.Norm: {
    347             var x = Evaluate(dataset, ref row, state);
    348             if (double.IsNaN(x)) return double.NaN;
    349             else return alglib.normaldistribution(x);
    350           }
    351         case OpCodes.Erf: {
    352             var x = Evaluate(dataset, ref row, state);
    353             if (double.IsNaN(x)) return double.NaN;
    354             else return alglib.errorfunction(x);
    355           }
    356         case OpCodes.Bessel: {
    357             var x = Evaluate(dataset, ref row, state);
    358             if (double.IsNaN(x)) return double.NaN;
    359             else return alglib.besseli0(x);
    360           }
    361 
    362         case OpCodes.AnalyticQuotient: {
    363             var x1 = Evaluate(dataset, ref row, state);
    364             var x2 = Evaluate(dataset, ref row, state);
    365             return x1 / Math.Pow(1 + x2 * x2, 0.5);
    366           }
    367         case OpCodes.IfThenElse: {
    368             double condition = Evaluate(dataset, ref row, state);
    369             double result;
    370             if (condition > 0.0) {
    371               result = Evaluate(dataset, ref row, state); state.SkipInstructions();
    372             } else {
    373               state.SkipInstructions(); result = Evaluate(dataset, ref row, state);
    374             }
    375             return result;
    376           }
    377         case OpCodes.AND: {
    378             double result = Evaluate(dataset, ref row, state);
    379             for (int i = 1; i < currentInstr.nArguments; i++) {
    380               if (result > 0.0) result = Evaluate(dataset, ref row, state);
    381               else {
    382                 state.SkipInstructions();
    383               }
    384             }
    385             return result > 0.0 ? 1.0 : -1.0;
    386           }
    387         case OpCodes.OR: {
    388             double result = Evaluate(dataset, ref row, state);
    389             for (int i = 1; i < currentInstr.nArguments; i++) {
    390               if (result <= 0.0) result = Evaluate(dataset, ref row, state);
    391               else {
    392                 state.SkipInstructions();
    393               }
    394             }
    395             return result > 0.0 ? 1.0 : -1.0;
    396           }
    397         case OpCodes.NOT: {
    398             return Evaluate(dataset, ref row, state) > 0.0 ? -1.0 : 1.0;
    399           }
    400         case OpCodes.XOR: {
    401             //mkommend: XOR on multiple inputs is defined as true if the number of positive signals is odd
    402             // this is equal to a consecutive execution of binary XOR operations.
    403             int positiveSignals = 0;
    404             for (int i = 0; i < currentInstr.nArguments; i++) {
    405               if (Evaluate(dataset, ref row, state) > 0.0) { positiveSignals++; }
    406             }
    407             return positiveSignals % 2 != 0 ? 1.0 : -1.0;
    408           }
    409         case OpCodes.GT: {
    410             double x = Evaluate(dataset, ref row, state);
    411             double y = Evaluate(dataset, ref row, state);
    412             if (x > y) { return 1.0; } else { return -1.0; }
    413           }
    414         case OpCodes.LT: {
    415             double x = Evaluate(dataset, ref row, state);
    416             double y = Evaluate(dataset, ref row, state);
    417             if (x < y) { return 1.0; } else { return -1.0; }
    418           }
    419         case OpCodes.TimeLag: {
    420             var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
    421             row += timeLagTreeNode.Lag;
    422             double result = Evaluate(dataset, ref row, state);
    423             row -= timeLagTreeNode.Lag;
    424             return result;
    425           }
    426         case OpCodes.Integral: {
    427             int savedPc = state.ProgramCounter;
    428             var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
    429             double sum = 0.0;
    430             for (int i = 0; i < Math.Abs(timeLagTreeNode.Lag); i++) {
    431               row += Math.Sign(timeLagTreeNode.Lag);
    432               sum += Evaluate(dataset, ref row, state);
    433               state.ProgramCounter = savedPc;
    434             }
    435             row -= timeLagTreeNode.Lag;
    436             sum += Evaluate(dataset, ref row, state);
    437             return sum;
    438           }
    439 
    440         //mkommend: derivate calculation taken from:
    441         //http://www.holoborodko.com/pavel/numerical-methods/numerical-derivative/smooth-low-noise-differentiators/
    442         //one sided smooth differentiatior, N = 4
    443         // y' = 1/8h (f_i + 2f_i-1, -2 f_i-3 - f_i-4)
    444         case OpCodes.Derivative: {
    445             int savedPc = state.ProgramCounter;
    446             double f_0 = Evaluate(dataset, ref row, state); row--;
    447             state.ProgramCounter = savedPc;
    448             double f_1 = Evaluate(dataset, ref row, state); row -= 2;
    449             state.ProgramCounter = savedPc;
    450             double f_3 = Evaluate(dataset, ref row, state); row--;
    451             state.ProgramCounter = savedPc;
    452             double f_4 = Evaluate(dataset, ref row, state);
    453             row += 4;
    454 
    455             return (f_0 + 2 * f_1 - 2 * f_3 - f_4) / 8; // h = 1
    456           }
    457         case OpCodes.Call: {
    458             // evaluate sub-trees
    459             double[] argValues = new double[currentInstr.nArguments];
    460             for (int i = 0; i < currentInstr.nArguments; i++) {
    461               argValues[i] = Evaluate(dataset, ref row, state);
    462             }
    463             // push on argument values on stack
    464             state.CreateStackFrame(argValues);
    465 
    466             // save the pc
    467             int savedPc = state.ProgramCounter;
    468             // set pc to start of function 
    469             state.ProgramCounter = (ushort)currentInstr.data;
    470             // evaluate the function
    471             double v = Evaluate(dataset, ref row, state);
    472 
    473             // delete the stack frame
    474             state.RemoveStackFrame();
    475 
    476             // restore the pc => evaluation will continue at point after my subtrees 
    477             state.ProgramCounter = savedPc;
    478             return v;
    479           }
    480         case OpCodes.Arg: {
    481             return state.GetStackFrameValue((ushort)currentInstr.data);
     309            var cur = Evaluate(dataset, ref row, state);
     310            return FunctionApply(cur,
     311              s => Math.Log(s),
     312              v => DoubleVector.Log(v));
    482313          }
    483314        case OpCodes.Variable: {
    484             if (row < 0 || row >= dataset.Rows) return double.NaN;
     315            if (row < 0 || row >= dataset.Rows) return EvaluationResult.NaN;
    485316            var variableTreeNode = (VariableTreeNode)currentInstr.dynamicNode;
    486             return ((IList<double>)currentInstr.data)[row] * variableTreeNode.Weight;
     317            if (currentInstr.data is IList<double> doubleList)
     318              return new EvaluationResult(doubleList[row] * variableTreeNode.Weight);
     319            if (currentInstr.data is IList<DoubleVector> doubleVectorList)
     320              return new EvaluationResult(doubleVectorList[row] * variableTreeNode.Weight);
     321            throw new NotSupportedException($"Unsupported type of variable: {currentInstr.data.GetType().GetPrettyName()}");
    487322          }
    488323        case OpCodes.BinaryFactorVariable: {
    489             if (row < 0 || row >= dataset.Rows) return double.NaN;
     324            if (row < 0 || row >= dataset.Rows) return EvaluationResult.NaN;
    490325            var factorVarTreeNode = currentInstr.dynamicNode as BinaryFactorVariableTreeNode;
    491             return ((IList<string>)currentInstr.data)[row] == factorVarTreeNode.VariableValue ? factorVarTreeNode.Weight : 0;
     326            return new EvaluationResult(((IList<string>)currentInstr.data)[row] == factorVarTreeNode.VariableValue ? factorVarTreeNode.Weight : 0);
    492327          }
    493328        case OpCodes.FactorVariable: {
    494             if (row < 0 || row >= dataset.Rows) return double.NaN;
     329            if (row < 0 || row >= dataset.Rows) return EvaluationResult.NaN;
    495330            var factorVarTreeNode = currentInstr.dynamicNode as FactorVariableTreeNode;
    496             return factorVarTreeNode.GetValue(((IList<string>)currentInstr.data)[row]);
    497           }
    498         case OpCodes.LagVariable: {
    499             var laggedVariableTreeNode = (LaggedVariableTreeNode)currentInstr.dynamicNode;
    500             int actualRow = row + laggedVariableTreeNode.Lag;
    501             if (actualRow < 0 || actualRow >= dataset.Rows) { return double.NaN; }
    502             return ((IList<double>)currentInstr.data)[actualRow] * laggedVariableTreeNode.Weight;
     331            return new EvaluationResult(factorVarTreeNode.GetValue(((IList<string>)currentInstr.data)[row]));
    503332          }
    504333        case OpCodes.Constant: {
    505334            var constTreeNode = (ConstantTreeNode)currentInstr.dynamicNode;
    506             return constTreeNode.Value;
    507           }
    508 
    509         //mkommend: this symbol uses the logistic function f(x) = 1 / (1 + e^(-alpha * x) )
    510         //to determine the relative amounts of the true and false branch see http://en.wikipedia.org/wiki/Logistic_function
    511         case OpCodes.VariableCondition: {
    512             if (row < 0 || row >= dataset.Rows) return double.NaN;
    513             var variableConditionTreeNode = (VariableConditionTreeNode)currentInstr.dynamicNode;
    514             if (!variableConditionTreeNode.Symbol.IgnoreSlope) {
    515               double variableValue = ((IList<double>)currentInstr.data)[row];
    516               double x = variableValue - variableConditionTreeNode.Threshold;
    517               double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
    518 
    519               double trueBranch = Evaluate(dataset, ref row, state);
    520               double falseBranch = Evaluate(dataset, ref row, state);
    521 
    522               return trueBranch * p + falseBranch * (1 - p);
    523             } else {
    524               // strict threshold
    525               double variableValue = ((IList<double>)currentInstr.data)[row];
    526               if (variableValue <= variableConditionTreeNode.Threshold) {
    527                 var left = Evaluate(dataset, ref row, state);
    528                 state.SkipInstructions();
    529                 return left;
    530               } else {
    531                 state.SkipInstructions();
    532                 return Evaluate(dataset, ref row, state);
    533               }
    534             }
    535           }
    536 
    537         case OpCodes.VectorSum: {
    538             DoubleVector v = VectorEvaluate(dataset, ref row, state);
    539             return v.Sum();
    540           }
    541         case OpCodes.VectorMean: {
    542             DoubleVector v = VectorEvaluate(dataset, ref row, state);
    543             return v.Mean();
     335            return new EvaluationResult(constTreeNode.Value);
    544336          }
    545337
    546338        default:
    547           throw new NotSupportedException();
    548       }
    549     }
    550 
    551     public virtual DoubleVector VectorEvaluate(IDataset dataset, ref int row, InterpreterState state) {
    552       Instruction currentInstr = state.NextInstruction();
    553       switch (currentInstr.opCode) {
    554         case OpCodes.VectorAdd: {
    555             DoubleVector s = VectorEvaluate(dataset, ref row, state);
    556             for (int i = 1; i < currentInstr.nArguments; i++) {
    557               s += VectorEvaluate(dataset, ref row, state);
    558             }
    559             return s;
    560           }
    561         case OpCodes.VectorSub: {
    562             DoubleVector s = VectorEvaluate(dataset, ref row, state);
    563             for (int i = 1; i < currentInstr.nArguments; i++) {
    564               s -= VectorEvaluate(dataset, ref row, state);
    565             }
    566             return s;
    567           }
    568         case OpCodes.VectorMul: {
    569             DoubleVector s = VectorEvaluate(dataset, ref row, state);
    570             for (int i = 1; i < currentInstr.nArguments; i++) {
    571               s = s.PointwiseMultiply(VectorEvaluate(dataset, ref row, state));
    572             }
    573             return s;
    574           }
    575         case OpCodes.VectorDiv: {
    576             DoubleVector s = VectorEvaluate(dataset, ref row, state);
    577             for (int i = 1; i < currentInstr.nArguments; i++) {
    578               s /= VectorEvaluate(dataset, ref row, state);
    579             }
    580             return s;
    581           }
    582 
    583         case OpCodes.VectorVariable: {
    584             if (row < 0 || row >= dataset.Rows) return DoubleVector.Build.Dense(new[] { double.NaN });
    585             var vectorVarTreeNode = currentInstr.dynamicNode as VectorVariableTreeNode;
    586             return ((IList<DoubleVector>)currentInstr.data)[row] * vectorVarTreeNode.Weight;
    587           }
    588         default:
    589           throw new NotSupportedException();
     339          throw new NotSupportedException($"Unsupported OpCode: {currentInstr.opCode}");
    590340      }
    591341    }
Note: See TracChangeset for help on using the changeset viewer.