Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
06/12/12 10:31:56 (12 years ago)
Author:
mkommend
Message:

#1081: Improved performance of time series prognosis.

File:
1 copied

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis.Symbolic.TimeSeriesPrognosis/3.4/SymbolicTimeSeriesPrognosisExpressionTreeInterpreter.cs

    r7949 r7989  
    2222using System;
    2323using System.Collections.Generic;
     24using System.Linq;
    2425using HeuristicLab.Common;
    2526using HeuristicLab.Core;
     
    2829using HeuristicLab.Parameters;
    2930using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    30 using System.Linq;
    3131
    3232namespace HeuristicLab.Problems.DataAnalysis.Symbolic.TimeSeriesPrognosis {
    3333  [StorableClass]
    3434  [Item("SymbolicTimeSeriesPrognosisInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.")]
    35   public sealed class SymbolicTimeSeriesPrognosisInterpreter : ParameterizedNamedItem, ISymbolicTimeSeriesPrognosisExpressionTreeInterpreter {
    36     private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
    37     #region private classes
    38     private class InterpreterState {
    39       private double[] argumentStack;
    40       private int argumentStackPointer;
    41       private Instruction[] code;
    42       private int pc;
    43       public int ProgramCounter {
    44         get { return pc; }
    45         set { pc = value; }
    46       }
    47       internal InterpreterState(Instruction[] code, int argumentStackSize) {
    48         this.code = code;
    49         this.pc = 0;
    50         if (argumentStackSize > 0) {
    51           this.argumentStack = new double[argumentStackSize];
     35  public sealed class SymbolicTimeSeriesPrognosisExpressionTreeInterpreter : SymbolicDataAnalysisExpressionTreeInterpreter, ISymbolicTimeSeriesPrognosisExpressionTreeInterpreter {
     36    private const string TargetVariableParameterName = "TargetVariable";
     37
     38    public IFixedValueParameter<StringValue> TargetVariableParameter {
     39      get { return (IFixedValueParameter<StringValue>)Parameters[TargetVariableParameterName]; }
     40    }
     41
     42    public string TargetVariable {
     43      get { return TargetVariableParameter.Value.Value; }
     44      set { TargetVariableParameter.Value.Value = value; }
     45    }
     46
     47    [ThreadStatic]
     48    private static double[] targetVariableCache;
     49    [ThreadStatic]
     50    private static List<int> invalidateCacheIndexes;
     51
     52    [StorableConstructor]
     53    private SymbolicTimeSeriesPrognosisExpressionTreeInterpreter(bool deserializing) : base(deserializing) { }
     54    private SymbolicTimeSeriesPrognosisExpressionTreeInterpreter(SymbolicTimeSeriesPrognosisExpressionTreeInterpreter original, Cloner cloner) : base(original, cloner) { }
     55    public override IDeepCloneable Clone(Cloner cloner) {
     56      return new SymbolicTimeSeriesPrognosisExpressionTreeInterpreter(this, cloner);
     57    }
     58
     59    public SymbolicTimeSeriesPrognosisExpressionTreeInterpreter()
     60      : base("SymbolicTimeSeriesPrognosisInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.") {
     61      Parameters.Add(new FixedValueParameter<StringValue>(TargetVariableParameterName));
     62      TargetVariableParameter.Hidden = true;
     63    }
     64
     65    // for each row several (=#horizon) future predictions
     66    public IEnumerable<IEnumerable<double>> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows, int horizon) {
     67      return GetSymbolicExpressionTreeValues(tree, dataset, rows, rows.Select(row => horizon));
     68    }
     69
     70    public IEnumerable<IEnumerable<double>> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows, IEnumerable<int> horizons) {
     71      if (CheckExpressionsWithIntervalArithmetic.Value)
     72        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
     73      if (targetVariableCache == null || targetVariableCache.GetLength(0) < dataset.Rows)
     74        targetVariableCache = dataset.GetDoubleValues(TargetVariable).ToArray();
     75      if (invalidateCacheIndexes == null)
     76        invalidateCacheIndexes = new List<int>(10);
     77
     78      string targetVariable = TargetVariable;
     79      EvaluatedSolutions.Value++; // increment the evaluated solutions counter
     80      var state = PrepareInterpreterState(tree, dataset, targetVariableCache);
     81      var rowsEnumerator = rows.GetEnumerator();
     82      var horizonsEnumerator = horizons.GetEnumerator();
     83
     84      // produce a n-step forecast for all rows
     85      while (rowsEnumerator.MoveNext() & horizonsEnumerator.MoveNext()) {
     86        int row = rowsEnumerator.Current;
     87        int horizon = horizonsEnumerator.Current;
     88
     89        double[] vProgs = new double[horizon];
     90        for (int i = 0; i < horizon; i++) {
     91          int localRow = i + row; // create a local variable for the ref parameter
     92          vProgs[i] = Evaluate(dataset, ref localRow, state);
     93          targetVariableCache[localRow] = vProgs[i];
     94          invalidateCacheIndexes.Add(localRow);
     95          state.Reset();
    5296        }
    53         this.argumentStackPointer = 0;
     97
     98        yield return vProgs;
     99
     100        int j = 0;
     101        foreach (var targetValue in dataset.GetDoubleValues(TargetVariable, invalidateCacheIndexes)) {
     102          targetVariableCache[invalidateCacheIndexes[j]] = targetValue;
     103          j++;
     104        }
     105        invalidateCacheIndexes.Clear();
    54106      }
    55107
    56       internal void Reset() {
    57         this.pc = 0;
    58         this.argumentStackPointer = 0;
    59       }
    60 
    61       internal Instruction NextInstruction() {
    62         return code[pc++];
    63       }
    64       private void Push(double val) {
    65         argumentStack[argumentStackPointer++] = val;
    66       }
    67       private double Pop() {
    68         return argumentStack[--argumentStackPointer];
    69       }
    70 
    71       internal void CreateStackFrame(double[] argValues) {
    72         // push in reverse order to make indexing easier
    73         for (int i = argValues.Length - 1; i >= 0; i--) {
    74           argumentStack[argumentStackPointer++] = argValues[i];
    75         }
    76         Push(argValues.Length);
    77       }
    78 
    79       internal void RemoveStackFrame() {
    80         int size = (int)Pop();
    81         argumentStackPointer -= size;
    82       }
    83 
    84       internal double GetStackFrameValue(ushort index) {
    85         // layout of stack:
    86         // [0]   <- argumentStackPointer
    87         // [StackFrameSize = N + 1]
    88         // [Arg0] <- argumentStackPointer - 2 - 0
    89         // [Arg1] <- argumentStackPointer - 2 - 1
    90         // [...]
    91         // [ArgN] <- argumentStackPointer - 2 - N
    92         // <Begin of stack frame>
    93         return argumentStack[argumentStackPointer - index - 2];
    94       }
    95     }
    96     private class OpCodes {
    97       public const byte Add = 1;
    98       public const byte Sub = 2;
    99       public const byte Mul = 3;
    100       public const byte Div = 4;
    101 
    102       public const byte Sin = 5;
    103       public const byte Cos = 6;
    104       public const byte Tan = 7;
    105 
    106       public const byte Log = 8;
    107       public const byte Exp = 9;
    108 
    109       public const byte IfThenElse = 10;
    110 
    111       public const byte GT = 11;
    112       public const byte LT = 12;
    113 
    114       public const byte AND = 13;
    115       public const byte OR = 14;
    116       public const byte NOT = 15;
    117 
    118 
    119       public const byte Average = 16;
    120 
    121       public const byte Call = 17;
    122 
    123       public const byte Variable = 18;
    124       public const byte LagVariable = 19;
    125       public const byte Constant = 20;
    126       public const byte Arg = 21;
    127 
    128       public const byte Power = 22;
    129       public const byte Root = 23;
    130       public const byte TimeLag = 24;
    131       public const byte Integral = 25;
    132       public const byte Derivative = 26;
    133 
    134       public const byte VariableCondition = 27;
    135     }
    136     #endregion
    137 
    138     private Dictionary<Type, byte> symbolToOpcode = new Dictionary<Type, byte>() {
    139       { typeof(Addition), OpCodes.Add },
    140       { typeof(Subtraction), OpCodes.Sub },
    141       { typeof(Multiplication), OpCodes.Mul },
    142       { typeof(Division), OpCodes.Div },
    143       { typeof(Sine), OpCodes.Sin },
    144       { typeof(Cosine), OpCodes.Cos },
    145       { typeof(Tangent), OpCodes.Tan },
    146       { typeof(Logarithm), OpCodes.Log },
    147       { typeof(Exponential), OpCodes.Exp },
    148       { typeof(IfThenElse), OpCodes.IfThenElse },
    149       { typeof(GreaterThan), OpCodes.GT },
    150       { typeof(LessThan), OpCodes.LT },
    151       { typeof(And), OpCodes.AND },
    152       { typeof(Or), OpCodes.OR },
    153       { typeof(Not), OpCodes.NOT},
    154       { typeof(Average), OpCodes.Average},
    155       { typeof(InvokeFunction), OpCodes.Call },
    156       { typeof(HeuristicLab.Problems.DataAnalysis.Symbolic.Variable), OpCodes.Variable },
    157       { typeof(LaggedVariable), OpCodes.LagVariable },
    158       { typeof(Constant), OpCodes.Constant },
    159       { typeof(Argument), OpCodes.Arg },
    160       { typeof(Power),OpCodes.Power},
    161       { typeof(Root),OpCodes.Root},
    162       { typeof(TimeLag), OpCodes.TimeLag},
    163       { typeof(Integral), OpCodes.Integral},
    164       { typeof(Derivative), OpCodes.Derivative},
    165       { typeof(VariableCondition),OpCodes.VariableCondition}
    166     };
    167 
    168     public override bool CanChangeName {
    169       get { return false; }
    170     }
    171     public override bool CanChangeDescription {
    172       get { return false; }
     108      if (rowsEnumerator.MoveNext() || horizonsEnumerator.MoveNext())
     109        throw new ArgumentException("Number of elements in rows and horizon enumerations doesn't match.");
    173110    }
    174111
    175     #region parameter properties
    176     public IValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
    177       get { return (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
    178     }
    179     #endregion
    180 
    181     #region properties
    182     public BoolValue CheckExpressionsWithIntervalArithmetic {
    183       get { return CheckExpressionsWithIntervalArithmeticParameter.Value; }
    184       set { CheckExpressionsWithIntervalArithmeticParameter.Value = value; }
    185     }
    186 
    187     [Storable]
    188     private readonly string[] targetVariables;
    189     #endregion
    190 
    191 
    192     [StorableConstructor]
    193     private SymbolicTimeSeriesPrognosisInterpreter(bool deserializing) : base(deserializing) { }
    194     private SymbolicTimeSeriesPrognosisInterpreter(SymbolicTimeSeriesPrognosisInterpreter original, Cloner cloner)
    195       : base(original, cloner) {
    196       this.targetVariables = original.targetVariables;
    197     }
    198     public override IDeepCloneable Clone(Cloner cloner) {
    199       return new SymbolicTimeSeriesPrognosisInterpreter(this, cloner);
    200     }
    201 
    202     public SymbolicTimeSeriesPrognosisInterpreter(string[] targetVariables)
    203       : base("SymbolicTimeSeriesPrognosisInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.") {
    204       Parameters.Add(new ValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
    205       this.targetVariables = targetVariables;
    206     }
    207 
    208     public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows) {
    209       throw new NotSupportedException();
    210     }
    211 
    212     // for each row for each target variable one prognosis (=enumerable of future values)
    213     public IEnumerable<IEnumerable<IEnumerable<double>>> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows, int horizon) {
    214       if (CheckExpressionsWithIntervalArithmetic.Value)
    215         throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
    216       var compiler = new SymbolicExpressionTreeCompiler();
    217       Instruction[] code = compiler.Compile(tree, MapSymbolToOpCode);
     112    private InterpreterState PrepareInterpreterState(ISymbolicExpressionTree tree, Dataset dataset, double[] targetVariableCache) {
     113      Instruction[] code = SymbolicExpressionTreeCompiler.Compile(tree, OpCodes.MapSymbolToOpCode);
    218114      int necessaryArgStackSize = 0;
    219       for (int i = 0; i < code.Length; i++) {
    220         Instruction instr = code[i];
     115      foreach (Instruction instr in code) {
    221116        if (instr.opCode == OpCodes.Variable) {
    222           var variableTreeNode = instr.dynamicNode as VariableTreeNode;
    223           instr.iArg0 = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
    224           code[i] = instr;
     117          var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
     118          if (variableTreeNode.VariableName == TargetVariable)
     119            instr.iArg0 = targetVariableCache;
     120          else
     121            instr.iArg0 = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
    225122        } else if (instr.opCode == OpCodes.LagVariable) {
    226           var laggedVariableTreeNode = instr.dynamicNode as LaggedVariableTreeNode;
     123          var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
    227124          instr.iArg0 = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
    228           code[i] = instr;
    229125        } else if (instr.opCode == OpCodes.VariableCondition) {
    230           var variableConditionTreeNode = instr.dynamicNode as VariableConditionTreeNode;
     126          var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
    231127          instr.iArg0 = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
    232128        } else if (instr.opCode == OpCodes.Call) {
     
    234130        }
    235131      }
    236       var state = new InterpreterState(code, necessaryArgStackSize);
    237132
    238       int nComponents = tree.Root.GetSubtree(0).SubtreeCount;
    239       // produce a n-step forecast for each target variable for all rows
    240       var cachedPrognosedValues = new Dictionary<string, double[]>();
    241       foreach (var targetVariable in targetVariables)
    242         cachedPrognosedValues[targetVariable] = new double[horizon];
    243       foreach (var rowEnum in rows) {
    244         int row = rowEnum;
    245         List<double[]> vProgs = new List<double[]>();
    246         foreach (var horizonRow in Enumerable.Range(row, horizon)) {
    247           int localRow = horizonRow; // create a local variable for the ref parameter
    248           var vPrognosis = from i in Enumerable.Range(0, nComponents)
    249                            select Evaluate(dataset, ref localRow, row - 1, state, cachedPrognosedValues);
    250 
    251           var vPrognosisArr = vPrognosis.ToArray();
    252           vProgs.Add(vPrognosisArr);
    253           // set cachedValues for prognosis of future values
    254           for (int i = 0; i < vPrognosisArr.Length; i++)
    255             cachedPrognosedValues[targetVariables[i]][horizonRow - row] = vPrognosisArr[i];
    256 
    257           state.Reset();
    258         }
    259 
    260         yield return from component in Enumerable.Range(0, nComponents)
    261                      select from prognosisStep in Enumerable.Range(0, vProgs.Count)
    262                             select vProgs[prognosisStep][component];
    263       }
    264     }
    265 
    266     private double Evaluate(Dataset dataset, ref int row, int lastObservedRow, InterpreterState state, Dictionary<string, double[]> cachedPrognosedValues) {
    267       Instruction currentInstr = state.NextInstruction();
    268       switch (currentInstr.opCode) {
    269         case OpCodes.Add: {
    270             double s = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    271             for (int i = 1; i < currentInstr.nArguments; i++) {
    272               s += Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    273             }
    274             return s;
    275           }
    276         case OpCodes.Sub: {
    277             double s = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    278             for (int i = 1; i < currentInstr.nArguments; i++) {
    279               s -= Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    280             }
    281             if (currentInstr.nArguments == 1) s = -s;
    282             return s;
    283           }
    284         case OpCodes.Mul: {
    285             double p = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    286             for (int i = 1; i < currentInstr.nArguments; i++) {
    287               p *= Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    288             }
    289             return p;
    290           }
    291         case OpCodes.Div: {
    292             double p = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    293             for (int i = 1; i < currentInstr.nArguments; i++) {
    294               p /= Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    295             }
    296             if (currentInstr.nArguments == 1) p = 1.0 / p;
    297             return p;
    298           }
    299         case OpCodes.Average: {
    300             double sum = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    301             for (int i = 1; i < currentInstr.nArguments; i++) {
    302               sum += Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    303             }
    304             return sum / currentInstr.nArguments;
    305           }
    306         case OpCodes.Cos: {
    307             return Math.Cos(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
    308           }
    309         case OpCodes.Sin: {
    310             return Math.Sin(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
    311           }
    312         case OpCodes.Tan: {
    313             return Math.Tan(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
    314           }
    315         case OpCodes.Power: {
    316             double x = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    317             double y = Math.Round(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
    318             return Math.Pow(x, y);
    319           }
    320         case OpCodes.Root: {
    321             double x = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    322             double y = Math.Round(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
    323             return Math.Pow(x, 1 / y);
    324           }
    325         case OpCodes.Exp: {
    326             return Math.Exp(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
    327           }
    328         case OpCodes.Log: {
    329             return Math.Log(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
    330           }
    331         case OpCodes.IfThenElse: {
    332             double condition = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    333             double result;
    334             if (condition > 0.0) {
    335               result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues); SkipInstructions(state);
    336             } else {
    337               SkipInstructions(state); result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    338             }
    339             return result;
    340           }
    341         case OpCodes.AND: {
    342             double result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    343             for (int i = 1; i < currentInstr.nArguments; i++) {
    344               if (result > 0.0) result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    345               else {
    346                 SkipInstructions(state);
    347               }
    348             }
    349             return result > 0.0 ? 1.0 : -1.0;
    350           }
    351         case OpCodes.OR: {
    352             double result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    353             for (int i = 1; i < currentInstr.nArguments; i++) {
    354               if (result <= 0.0) result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    355               else {
    356                 SkipInstructions(state);
    357               }
    358             }
    359             return result > 0.0 ? 1.0 : -1.0;
    360           }
    361         case OpCodes.NOT: {
    362             return Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues) > 0.0 ? -1.0 : 1.0;
    363           }
    364         case OpCodes.GT: {
    365             double x = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    366             double y = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    367             if (x > y) return 1.0;
    368             else return -1.0;
    369           }
    370         case OpCodes.LT: {
    371             double x = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    372             double y = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    373             if (x < y) return 1.0;
    374             else return -1.0;
    375           }
    376         case OpCodes.TimeLag: {
    377             var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
    378             row += timeLagTreeNode.Lag;
    379             double result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    380             row -= timeLagTreeNode.Lag;
    381             return result;
    382           }
    383         case OpCodes.Integral: {
    384             int savedPc = state.ProgramCounter;
    385             var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
    386             double sum = 0.0;
    387             for (int i = 0; i < Math.Abs(timeLagTreeNode.Lag); i++) {
    388               row += Math.Sign(timeLagTreeNode.Lag);
    389               sum += Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    390               state.ProgramCounter = savedPc;
    391             }
    392             row -= timeLagTreeNode.Lag;
    393             sum += Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    394             return sum;
    395           }
    396 
    397         //mkommend: derivate calculation taken from:
    398         //http://www.holoborodko.com/pavel/numerical-methods/numerical-derivative/smooth-low-noise-differentiators/
    399         //one sided smooth differentiatior, N = 4
    400         // y' = 1/8h (f_i + 2f_i-1, -2 f_i-3 - f_i-4)
    401         case OpCodes.Derivative: {
    402             int savedPc = state.ProgramCounter;
    403             double f_0 = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues); row--;
    404             state.ProgramCounter = savedPc;
    405             double f_1 = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues); row -= 2;
    406             state.ProgramCounter = savedPc;
    407             double f_3 = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues); row--;
    408             state.ProgramCounter = savedPc;
    409             double f_4 = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    410             row += 4;
    411 
    412             return (f_0 + 2 * f_1 - 2 * f_3 - f_4) / 8; // h = 1
    413           }
    414         case OpCodes.Call: {
    415             // evaluate sub-trees
    416             double[] argValues = new double[currentInstr.nArguments];
    417             for (int i = 0; i < currentInstr.nArguments; i++) {
    418               argValues[i] = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    419             }
    420             // push on argument values on stack
    421             state.CreateStackFrame(argValues);
    422 
    423             // save the pc
    424             int savedPc = state.ProgramCounter;
    425             // set pc to start of function 
    426             state.ProgramCounter = (ushort)currentInstr.iArg0;
    427             // evaluate the function
    428             double v = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    429 
    430             // delete the stack frame
    431             state.RemoveStackFrame();
    432 
    433             // restore the pc => evaluation will continue at point after my subtrees 
    434             state.ProgramCounter = savedPc;
    435             return v;
    436           }
    437         case OpCodes.Arg: {
    438             return state.GetStackFrameValue((ushort)currentInstr.iArg0);
    439           }
    440         case OpCodes.Variable: {
    441             if (row < 0 || row >= dataset.Rows)
    442               return double.NaN;
    443             var variableTreeNode = (VariableTreeNode)currentInstr.dynamicNode;
    444             if (row <= lastObservedRow) return ((IList<double>)currentInstr.iArg0)[row] * variableTreeNode.Weight;
    445             else return cachedPrognosedValues[variableTreeNode.VariableName][row - lastObservedRow - 1] * variableTreeNode.Weight;
    446           }
    447         case OpCodes.LagVariable: {
    448             var laggedVariableTreeNode = (LaggedVariableTreeNode)currentInstr.dynamicNode;
    449             int actualRow = row + laggedVariableTreeNode.Lag;
    450             if (actualRow < 0 || actualRow >= dataset.Rows)
    451               return double.NaN;
    452             if (actualRow <= lastObservedRow) return ((IList<double>)currentInstr.iArg0)[actualRow] * laggedVariableTreeNode.Weight;
    453             else return cachedPrognosedValues[laggedVariableTreeNode.VariableName][actualRow - lastObservedRow - 1] * laggedVariableTreeNode.Weight;
    454           }
    455         case OpCodes.Constant: {
    456             var constTreeNode = currentInstr.dynamicNode as ConstantTreeNode;
    457             return constTreeNode.Value;
    458           }
    459 
    460         //mkommend: this symbol uses the logistic function f(x) = 1 / (1 + e^(-alpha * x) )
    461         //to determine the relative amounts of the true and false branch see http://en.wikipedia.org/wiki/Logistic_function
    462         case OpCodes.VariableCondition: {
    463             if (row < 0 || row >= dataset.Rows)
    464               return double.NaN;
    465             var variableConditionTreeNode = (VariableConditionTreeNode)currentInstr.dynamicNode;
    466             double variableValue;
    467             if (row <= lastObservedRow)
    468               variableValue = ((IList<double>)currentInstr.iArg0)[row];
    469             else
    470               variableValue = cachedPrognosedValues[variableConditionTreeNode.VariableName][row - lastObservedRow - 1];
    471 
    472             double x = variableValue - variableConditionTreeNode.Threshold;
    473             double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
    474 
    475             double trueBranch = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    476             double falseBranch = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    477 
    478             return trueBranch * p + falseBranch * (1 - p);
    479           }
    480         default: throw new NotSupportedException();
    481       }
    482     }
    483 
    484     private byte MapSymbolToOpCode(ISymbolicExpressionTreeNode treeNode) {
    485       if (symbolToOpcode.ContainsKey(treeNode.Symbol.GetType()))
    486         return symbolToOpcode[treeNode.Symbol.GetType()];
    487       else
    488         throw new NotSupportedException("Symbol: " + treeNode.Symbol);
    489     }
    490 
    491     // skips a whole branch
    492     private void SkipInstructions(InterpreterState state) {
    493       int i = 1;
    494       while (i > 0) {
    495         i += state.NextInstruction().nArguments;
    496         i--;
    497       }
     133      return new InterpreterState(code, necessaryArgStackSize);
    498134    }
    499135  }
Note: See TracChangeset for help on using the changeset viewer.