Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
12/05/11 08:22:36 (12 years ago)
Author:
gkronber
Message:

#1081 implemented multi-variate symbolic expression tree interpreter for time series prognosis.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/SymbolicDataAnalysisExpressionTreeInterpreter.cs

    r6860 r7120  
    2222using System;
    2323using System.Collections.Generic;
     24using System.Linq;
    2425using HeuristicLab.Common;
    2526using HeuristicLab.Core;
     
    3233  [StorableClass]
    3334  [Item("SymbolicDataAnalysisExpressionTreeInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.")]
    34   public sealed class SymbolicDataAnalysisExpressionTreeInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
     35  public sealed class SymbolicDataAnalysisExpressionTreeInterpreter : ParameterizedNamedItem,
     36    ISymbolicDataAnalysisExpressionTreeInterpreter, ISymbolicTimeSeriesPrognosisExpressionTreeInterpreter {
    3537    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
    3638    #region private classes
     
    199201
    200202    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows) {
     203      return from prog in GetSymbolicExpressionTreeValues(tree, dataset, new string[] { "#NOTHING#" }, rows, 1)
     204             select prog.First().First();
     205    }
     206
     207    // for each row for each target variable one prognosis (=enumerable of future values)
     208    public IEnumerable<IEnumerable<IEnumerable<double>>> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, string[] targetVariables, IEnumerable<int> rows, int horizon) {
    201209      if (CheckExpressionsWithIntervalArithmetic.Value)
    202210        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
     
    223231      var state = new InterpreterState(code, necessaryArgStackSize);
    224232
     233      int nComponents = tree.Root.GetSubtree(0).SubtreeCount;
     234      // produce a n-step forecast for each target variable for all rows
     235      var cachedPrognosedValues = new Dictionary<string, double[]>();
     236      foreach (var targetVariable in targetVariables)
     237        cachedPrognosedValues[targetVariable] = new double[horizon];
    225238      foreach (var rowEnum in rows) {
    226239        int row = rowEnum;
    227         state.Reset();
    228         yield return Evaluate(dataset, ref row, state);
    229       }
    230     }
    231 
    232     private double Evaluate(Dataset dataset, ref int row, InterpreterState state) {
     240        List<double[]> vProgs = new List<double[]>();
     241        foreach (var horizonRow in Enumerable.Range(row, horizon)) {
     242          int localRow = horizonRow; // create a local variable for the ref parameter
     243          var vPrognosis = from i in Enumerable.Range(0, nComponents)
     244                           select Evaluate(dataset, ref localRow, row - 1, state, cachedPrognosedValues);
     245
     246          var vPrognosisArr = vPrognosis.ToArray();
     247          vProgs.Add(vPrognosisArr);
     248          // set cachedValues for prognosis of future values
     249          for (int i = 0; i < vPrognosisArr.Length; i++)
     250            cachedPrognosedValues[targetVariables[i]][horizonRow - row] = vPrognosisArr[i];
     251
     252          state.Reset();
     253        }
     254
     255        yield return from component in Enumerable.Range(0, nComponents)
     256                     select from prognosisStep in Enumerable.Range(0, vProgs.Count)
     257                            select vProgs[prognosisStep][component];
     258      }
     259    }
     260
     261    private double Evaluate(Dataset dataset, ref int row, int lastObservedRow, InterpreterState state, Dictionary<string, double[]> cachedPrognosedValues) {
    233262      Instruction currentInstr = state.NextInstruction();
    234263      switch (currentInstr.opCode) {
    235264        case OpCodes.Add: {
    236             double s = Evaluate(dataset, ref row, state);
    237             for (int i = 1; i < currentInstr.nArguments; i++) {
    238               s += Evaluate(dataset, ref row, state);
     265            double s = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
     266            for (int i = 1; i < currentInstr.nArguments; i++) {
     267              s += Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    239268            }
    240269            return s;
    241270          }
    242271        case OpCodes.Sub: {
    243             double s = Evaluate(dataset, ref row, state);
    244             for (int i = 1; i < currentInstr.nArguments; i++) {
    245               s -= Evaluate(dataset, ref row, state);
     272            double s = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
     273            for (int i = 1; i < currentInstr.nArguments; i++) {
     274              s -= Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    246275            }
    247276            if (currentInstr.nArguments == 1) s = -s;
     
    249278          }
    250279        case OpCodes.Mul: {
    251             double p = Evaluate(dataset, ref row, state);
    252             for (int i = 1; i < currentInstr.nArguments; i++) {
    253               p *= Evaluate(dataset, ref row, state);
     280            double p = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
     281            for (int i = 1; i < currentInstr.nArguments; i++) {
     282              p *= Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    254283            }
    255284            return p;
    256285          }
    257286        case OpCodes.Div: {
    258             double p = Evaluate(dataset, ref row, state);
    259             for (int i = 1; i < currentInstr.nArguments; i++) {
    260               p /= Evaluate(dataset, ref row, state);
     287            double p = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
     288            for (int i = 1; i < currentInstr.nArguments; i++) {
     289              p /= Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    261290            }
    262291            if (currentInstr.nArguments == 1) p = 1.0 / p;
     
    264293          }
    265294        case OpCodes.Average: {
    266             double sum = Evaluate(dataset, ref row, state);
    267             for (int i = 1; i < currentInstr.nArguments; i++) {
    268               sum += Evaluate(dataset, ref row, state);
     295            double sum = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
     296            for (int i = 1; i < currentInstr.nArguments; i++) {
     297              sum += Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    269298            }
    270299            return sum / currentInstr.nArguments;
    271300          }
    272301        case OpCodes.Cos: {
    273             return Math.Cos(Evaluate(dataset, ref row, state));
     302            return Math.Cos(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
    274303          }
    275304        case OpCodes.Sin: {
    276             return Math.Sin(Evaluate(dataset, ref row, state));
     305            return Math.Sin(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
    277306          }
    278307        case OpCodes.Tan: {
    279             return Math.Tan(Evaluate(dataset, ref row, state));
     308            return Math.Tan(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
    280309          }
    281310        case OpCodes.Power: {
    282             double x = Evaluate(dataset, ref row, state);
    283             double y = Math.Round(Evaluate(dataset, ref row, state));
     311            double x = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
     312            double y = Math.Round(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
    284313            return Math.Pow(x, y);
    285314          }
    286315        case OpCodes.Root: {
    287             double x = Evaluate(dataset, ref row, state);
    288             double y = Math.Round(Evaluate(dataset, ref row, state));
     316            double x = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
     317            double y = Math.Round(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
    289318            return Math.Pow(x, 1 / y);
    290319          }
    291320        case OpCodes.Exp: {
    292             return Math.Exp(Evaluate(dataset, ref row, state));
     321            return Math.Exp(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
    293322          }
    294323        case OpCodes.Log: {
    295             return Math.Log(Evaluate(dataset, ref row, state));
     324            return Math.Log(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
    296325          }
    297326        case OpCodes.IfThenElse: {
    298             double condition = Evaluate(dataset, ref row, state);
     327            double condition = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    299328            double result;
    300329            if (condition > 0.0) {
    301               result = Evaluate(dataset, ref row, state); SkipInstructions(state);
     330              result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues); SkipInstructions(state);
    302331            } else {
    303               SkipInstructions(state); result = Evaluate(dataset, ref row, state);
     332              SkipInstructions(state); result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    304333            }
    305334            return result;
    306335          }
    307336        case OpCodes.AND: {
    308             double result = Evaluate(dataset, ref row, state);
    309             for (int i = 1; i < currentInstr.nArguments; i++) {
    310               if (result > 0.0) result = Evaluate(dataset, ref row, state);
     337            double result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
     338            for (int i = 1; i < currentInstr.nArguments; i++) {
     339              if (result > 0.0) result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    311340              else {
    312341                SkipInstructions(state);
     
    316345          }
    317346        case OpCodes.OR: {
    318             double result = Evaluate(dataset, ref row, state);
    319             for (int i = 1; i < currentInstr.nArguments; i++) {
    320               if (result <= 0.0) result = Evaluate(dataset, ref row, state);
     347            double result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
     348            for (int i = 1; i < currentInstr.nArguments; i++) {
     349              if (result <= 0.0) result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    321350              else {
    322351                SkipInstructions(state);
     
    326355          }
    327356        case OpCodes.NOT: {
    328             return Evaluate(dataset, ref row, state) > 0.0 ? -1.0 : 1.0;
     357            return Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues) > 0.0 ? -1.0 : 1.0;
    329358          }
    330359        case OpCodes.GT: {
    331             double x = Evaluate(dataset, ref row, state);
    332             double y = Evaluate(dataset, ref row, state);
     360            double x = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
     361            double y = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    333362            if (x > y) return 1.0;
    334363            else return -1.0;
    335364          }
    336365        case OpCodes.LT: {
    337             double x = Evaluate(dataset, ref row, state);
    338             double y = Evaluate(dataset, ref row, state);
     366            double x = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
     367            double y = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    339368            if (x < y) return 1.0;
    340369            else return -1.0;
     
    343372            var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
    344373            row += timeLagTreeNode.Lag;
    345             double result = Evaluate(dataset, ref row, state);
     374            double result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    346375            row -= timeLagTreeNode.Lag;
    347376            return result;
     
    353382            for (int i = 0; i < Math.Abs(timeLagTreeNode.Lag); i++) {
    354383              row += Math.Sign(timeLagTreeNode.Lag);
    355               sum += Evaluate(dataset, ref row, state);
     384              sum += Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    356385              state.ProgramCounter = savedPc;
    357386            }
    358387            row -= timeLagTreeNode.Lag;
    359             sum += Evaluate(dataset, ref row, state);
     388            sum += Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    360389            return sum;
    361390          }
     
    367396        case OpCodes.Derivative: {
    368397            int savedPc = state.ProgramCounter;
    369             double f_0 = Evaluate(dataset, ref row, state); row--;
     398            double f_0 = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues); row--;
    370399            state.ProgramCounter = savedPc;
    371             double f_1 = Evaluate(dataset, ref row, state); row -= 2;
     400            double f_1 = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues); row -= 2;
    372401            state.ProgramCounter = savedPc;
    373             double f_3 = Evaluate(dataset, ref row, state); row--;
     402            double f_3 = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues); row--;
    374403            state.ProgramCounter = savedPc;
    375             double f_4 = Evaluate(dataset, ref row, state);
     404            double f_4 = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    376405            row += 4;
    377406
     
    382411            double[] argValues = new double[currentInstr.nArguments];
    383412            for (int i = 0; i < currentInstr.nArguments; i++) {
    384               argValues[i] = Evaluate(dataset, ref row, state);
     413              argValues[i] = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    385414            }
    386415            // push on argument values on stack
     
    392421            state.ProgramCounter = (ushort)currentInstr.iArg0;
    393422            // evaluate the function
    394             double v = Evaluate(dataset, ref row, state);
     423            double v = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    395424
    396425            // delete the stack frame
     
    408437              return double.NaN;
    409438            var variableTreeNode = (VariableTreeNode)currentInstr.dynamicNode;
    410             return ((IList<double>)currentInstr.iArg0)[row] * variableTreeNode.Weight;
     439            if (row <= lastObservedRow || !cachedPrognosedValues.ContainsKey(variableTreeNode.VariableName)) return ((IList<double>)currentInstr.iArg0)[row] * variableTreeNode.Weight;
     440            else return cachedPrognosedValues[variableTreeNode.VariableName][row - lastObservedRow - 1] * variableTreeNode.Weight;
    411441          }
    412442        case OpCodes.LagVariable: {
     
    415445            if (actualRow < 0 || actualRow >= dataset.Rows)
    416446              return double.NaN;
    417             return ((IList<double>)currentInstr.iArg0)[actualRow] * laggedVariableTreeNode.Weight;
     447            if (actualRow <= lastObservedRow || !cachedPrognosedValues.ContainsKey(laggedVariableTreeNode.VariableName)) return ((IList<double>)currentInstr.iArg0)[actualRow] * laggedVariableTreeNode.Weight;
     448            else return cachedPrognosedValues[laggedVariableTreeNode.VariableName][actualRow - lastObservedRow - 1] * laggedVariableTreeNode.Weight;
    418449          }
    419450        case OpCodes.Constant: {
     
    428459              return double.NaN;
    429460            var variableConditionTreeNode = (VariableConditionTreeNode)currentInstr.dynamicNode;
    430             double variableValue = ((IList<double>)currentInstr.iArg0)[row];
     461            double variableValue;
     462            if (row <= lastObservedRow || !cachedPrognosedValues.ContainsKey(variableConditionTreeNode.VariableName))
     463              variableValue = ((IList<double>)currentInstr.iArg0)[row];
     464            else
     465              variableValue = cachedPrognosedValues[variableConditionTreeNode.VariableName][row - lastObservedRow - 1];
     466
    431467            double x = variableValue - variableConditionTreeNode.Threshold;
    432468            double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
    433469
    434             double trueBranch = Evaluate(dataset, ref row, state);
    435             double falseBranch = Evaluate(dataset, ref row, state);
     470            double trueBranch = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
     471            double falseBranch = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    436472
    437473            return trueBranch * p + falseBranch * (1 - p);
Note: See TracChangeset for help on using the changeset viewer.