Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis.Symbolic.TimeSeriesPrognosis/3.4/SymbolicTimeSeriesPrognosisInterpreter.cs @ 7842

Last change on this file since 7842 was 7120, checked in by gkronber, 13 years ago

#1081 implemented multi-variate symbolic expression tree interpreter for time series prognosis.

File size: 22.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Data;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Parameters;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30using System.Linq;
31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic.TimeSeriesPrognosis {
33  [StorableClass]
34  [Item("SymbolicTimeSeriesPrognosisInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.")]
35  public sealed class SymbolicTimeSeriesPrognosisInterpreter : ParameterizedNamedItem, ISymbolicTimeSeriesPrognosisExpressionTreeInterpreter {
36    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
37    #region private classes
38    private class InterpreterState {
39      private double[] argumentStack;
40      private int argumentStackPointer;
41      private Instruction[] code;
42      private int pc;
43      public int ProgramCounter {
44        get { return pc; }
45        set { pc = value; }
46      }
47      internal InterpreterState(Instruction[] code, int argumentStackSize) {
48        this.code = code;
49        this.pc = 0;
50        if (argumentStackSize > 0) {
51          this.argumentStack = new double[argumentStackSize];
52        }
53        this.argumentStackPointer = 0;
54      }
55
56      internal void Reset() {
57        this.pc = 0;
58        this.argumentStackPointer = 0;
59      }
60
61      internal Instruction NextInstruction() {
62        return code[pc++];
63      }
64      private void Push(double val) {
65        argumentStack[argumentStackPointer++] = val;
66      }
67      private double Pop() {
68        return argumentStack[--argumentStackPointer];
69      }
70
71      internal void CreateStackFrame(double[] argValues) {
72        // push in reverse order to make indexing easier
73        for (int i = argValues.Length - 1; i >= 0; i--) {
74          argumentStack[argumentStackPointer++] = argValues[i];
75        }
76        Push(argValues.Length);
77      }
78
79      internal void RemoveStackFrame() {
80        int size = (int)Pop();
81        argumentStackPointer -= size;
82      }
83
84      internal double GetStackFrameValue(ushort index) {
85        // layout of stack:
86        // [0]   <- argumentStackPointer
87        // [StackFrameSize = N + 1]
88        // [Arg0] <- argumentStackPointer - 2 - 0
89        // [Arg1] <- argumentStackPointer - 2 - 1
90        // [...]
91        // [ArgN] <- argumentStackPointer - 2 - N
92        // <Begin of stack frame>
93        return argumentStack[argumentStackPointer - index - 2];
94      }
95    }
96    private class OpCodes {
97      public const byte Add = 1;
98      public const byte Sub = 2;
99      public const byte Mul = 3;
100      public const byte Div = 4;
101
102      public const byte Sin = 5;
103      public const byte Cos = 6;
104      public const byte Tan = 7;
105
106      public const byte Log = 8;
107      public const byte Exp = 9;
108
109      public const byte IfThenElse = 10;
110
111      public const byte GT = 11;
112      public const byte LT = 12;
113
114      public const byte AND = 13;
115      public const byte OR = 14;
116      public const byte NOT = 15;
117
118
119      public const byte Average = 16;
120
121      public const byte Call = 17;
122
123      public const byte Variable = 18;
124      public const byte LagVariable = 19;
125      public const byte Constant = 20;
126      public const byte Arg = 21;
127
128      public const byte Power = 22;
129      public const byte Root = 23;
130      public const byte TimeLag = 24;
131      public const byte Integral = 25;
132      public const byte Derivative = 26;
133
134      public const byte VariableCondition = 27;
135    }
136    #endregion
137
138    private Dictionary<Type, byte> symbolToOpcode = new Dictionary<Type, byte>() {
139      { typeof(Addition), OpCodes.Add },
140      { typeof(Subtraction), OpCodes.Sub },
141      { typeof(Multiplication), OpCodes.Mul },
142      { typeof(Division), OpCodes.Div },
143      { typeof(Sine), OpCodes.Sin },
144      { typeof(Cosine), OpCodes.Cos },
145      { typeof(Tangent), OpCodes.Tan },
146      { typeof(Logarithm), OpCodes.Log },
147      { typeof(Exponential), OpCodes.Exp },
148      { typeof(IfThenElse), OpCodes.IfThenElse },
149      { typeof(GreaterThan), OpCodes.GT },
150      { typeof(LessThan), OpCodes.LT },
151      { typeof(And), OpCodes.AND },
152      { typeof(Or), OpCodes.OR },
153      { typeof(Not), OpCodes.NOT},
154      { typeof(Average), OpCodes.Average},
155      { typeof(InvokeFunction), OpCodes.Call },
156      { typeof(HeuristicLab.Problems.DataAnalysis.Symbolic.Variable), OpCodes.Variable },
157      { typeof(LaggedVariable), OpCodes.LagVariable },
158      { typeof(Constant), OpCodes.Constant },
159      { typeof(Argument), OpCodes.Arg },
160      { typeof(Power),OpCodes.Power},
161      { typeof(Root),OpCodes.Root},
162      { typeof(TimeLag), OpCodes.TimeLag},
163      { typeof(Integral), OpCodes.Integral},
164      { typeof(Derivative), OpCodes.Derivative},
165      { typeof(VariableCondition),OpCodes.VariableCondition}
166    };
167
168    public override bool CanChangeName {
169      get { return false; }
170    }
171    public override bool CanChangeDescription {
172      get { return false; }
173    }
174
175    #region parameter properties
176    public IValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
177      get { return (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
178    }
179    #endregion
180
181    #region properties
182    public BoolValue CheckExpressionsWithIntervalArithmetic {
183      get { return CheckExpressionsWithIntervalArithmeticParameter.Value; }
184      set { CheckExpressionsWithIntervalArithmeticParameter.Value = value; }
185    }
186
187    [Storable]
188    private readonly string[] targetVariables;
189    #endregion
190
191
192    [StorableConstructor]
193    private SymbolicTimeSeriesPrognosisInterpreter(bool deserializing) : base(deserializing) { }
194    private SymbolicTimeSeriesPrognosisInterpreter(SymbolicTimeSeriesPrognosisInterpreter original, Cloner cloner)
195      : base(original, cloner) {
196      this.targetVariables = original.targetVariables;
197    }
198    public override IDeepCloneable Clone(Cloner cloner) {
199      return new SymbolicTimeSeriesPrognosisInterpreter(this, cloner);
200    }
201
202    public SymbolicTimeSeriesPrognosisInterpreter(string[] targetVariables)
203      : base("SymbolicTimeSeriesPrognosisInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.") {
204      Parameters.Add(new ValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
205      this.targetVariables = targetVariables;
206    }
207
208    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows) {
209      throw new NotSupportedException();
210    }
211
212    // for each row for each target variable one prognosis (=enumerable of future values)
213    public IEnumerable<IEnumerable<IEnumerable<double>>> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows, int horizon) {
214      if (CheckExpressionsWithIntervalArithmetic.Value)
215        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
216      var compiler = new SymbolicExpressionTreeCompiler();
217      Instruction[] code = compiler.Compile(tree, MapSymbolToOpCode);
218      int necessaryArgStackSize = 0;
219      for (int i = 0; i < code.Length; i++) {
220        Instruction instr = code[i];
221        if (instr.opCode == OpCodes.Variable) {
222          var variableTreeNode = instr.dynamicNode as VariableTreeNode;
223          instr.iArg0 = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
224          code[i] = instr;
225        } else if (instr.opCode == OpCodes.LagVariable) {
226          var laggedVariableTreeNode = instr.dynamicNode as LaggedVariableTreeNode;
227          instr.iArg0 = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
228          code[i] = instr;
229        } else if (instr.opCode == OpCodes.VariableCondition) {
230          var variableConditionTreeNode = instr.dynamicNode as VariableConditionTreeNode;
231          instr.iArg0 = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
232        } else if (instr.opCode == OpCodes.Call) {
233          necessaryArgStackSize += instr.nArguments + 1;
234        }
235      }
236      var state = new InterpreterState(code, necessaryArgStackSize);
237
238      int nComponents = tree.Root.GetSubtree(0).SubtreeCount;
239      // produce a n-step forecast for each target variable for all rows
240      var cachedPrognosedValues = new Dictionary<string, double[]>();
241      foreach (var targetVariable in targetVariables)
242        cachedPrognosedValues[targetVariable] = new double[horizon];
243      foreach (var rowEnum in rows) {
244        int row = rowEnum;
245        List<double[]> vProgs = new List<double[]>();
246        foreach (var horizonRow in Enumerable.Range(row, horizon)) {
247          int localRow = horizonRow; // create a local variable for the ref parameter
248          var vPrognosis = from i in Enumerable.Range(0, nComponents)
249                           select Evaluate(dataset, ref localRow, row - 1, state, cachedPrognosedValues);
250
251          var vPrognosisArr = vPrognosis.ToArray();
252          vProgs.Add(vPrognosisArr);
253          // set cachedValues for prognosis of future values
254          for (int i = 0; i < vPrognosisArr.Length; i++)
255            cachedPrognosedValues[targetVariables[i]][horizonRow - row] = vPrognosisArr[i];
256
257          state.Reset();
258        }
259
260        yield return from component in Enumerable.Range(0, nComponents)
261                     select from prognosisStep in Enumerable.Range(0, vProgs.Count)
262                            select vProgs[prognosisStep][component];
263      }
264    }
265
266    private double Evaluate(Dataset dataset, ref int row, int lastObservedRow, InterpreterState state, Dictionary<string, double[]> cachedPrognosedValues) {
267      Instruction currentInstr = state.NextInstruction();
268      switch (currentInstr.opCode) {
269        case OpCodes.Add: {
270            double s = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
271            for (int i = 1; i < currentInstr.nArguments; i++) {
272              s += Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
273            }
274            return s;
275          }
276        case OpCodes.Sub: {
277            double s = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
278            for (int i = 1; i < currentInstr.nArguments; i++) {
279              s -= Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
280            }
281            if (currentInstr.nArguments == 1) s = -s;
282            return s;
283          }
284        case OpCodes.Mul: {
285            double p = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
286            for (int i = 1; i < currentInstr.nArguments; i++) {
287              p *= Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
288            }
289            return p;
290          }
291        case OpCodes.Div: {
292            double p = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
293            for (int i = 1; i < currentInstr.nArguments; i++) {
294              p /= Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
295            }
296            if (currentInstr.nArguments == 1) p = 1.0 / p;
297            return p;
298          }
299        case OpCodes.Average: {
300            double sum = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
301            for (int i = 1; i < currentInstr.nArguments; i++) {
302              sum += Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
303            }
304            return sum / currentInstr.nArguments;
305          }
306        case OpCodes.Cos: {
307            return Math.Cos(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
308          }
309        case OpCodes.Sin: {
310            return Math.Sin(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
311          }
312        case OpCodes.Tan: {
313            return Math.Tan(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
314          }
315        case OpCodes.Power: {
316            double x = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
317            double y = Math.Round(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
318            return Math.Pow(x, y);
319          }
320        case OpCodes.Root: {
321            double x = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
322            double y = Math.Round(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
323            return Math.Pow(x, 1 / y);
324          }
325        case OpCodes.Exp: {
326            return Math.Exp(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
327          }
328        case OpCodes.Log: {
329            return Math.Log(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
330          }
331        case OpCodes.IfThenElse: {
332            double condition = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
333            double result;
334            if (condition > 0.0) {
335              result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues); SkipInstructions(state);
336            } else {
337              SkipInstructions(state); result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
338            }
339            return result;
340          }
341        case OpCodes.AND: {
342            double result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
343            for (int i = 1; i < currentInstr.nArguments; i++) {
344              if (result > 0.0) result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
345              else {
346                SkipInstructions(state);
347              }
348            }
349            return result > 0.0 ? 1.0 : -1.0;
350          }
351        case OpCodes.OR: {
352            double result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
353            for (int i = 1; i < currentInstr.nArguments; i++) {
354              if (result <= 0.0) result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
355              else {
356                SkipInstructions(state);
357              }
358            }
359            return result > 0.0 ? 1.0 : -1.0;
360          }
361        case OpCodes.NOT: {
362            return Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues) > 0.0 ? -1.0 : 1.0;
363          }
364        case OpCodes.GT: {
365            double x = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
366            double y = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
367            if (x > y) return 1.0;
368            else return -1.0;
369          }
370        case OpCodes.LT: {
371            double x = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
372            double y = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
373            if (x < y) return 1.0;
374            else return -1.0;
375          }
376        case OpCodes.TimeLag: {
377            var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
378            row += timeLagTreeNode.Lag;
379            double result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
380            row -= timeLagTreeNode.Lag;
381            return result;
382          }
383        case OpCodes.Integral: {
384            int savedPc = state.ProgramCounter;
385            var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
386            double sum = 0.0;
387            for (int i = 0; i < Math.Abs(timeLagTreeNode.Lag); i++) {
388              row += Math.Sign(timeLagTreeNode.Lag);
389              sum += Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
390              state.ProgramCounter = savedPc;
391            }
392            row -= timeLagTreeNode.Lag;
393            sum += Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
394            return sum;
395          }
396
397        //mkommend: derivate calculation taken from:
398        //http://www.holoborodko.com/pavel/numerical-methods/numerical-derivative/smooth-low-noise-differentiators/
399        //one sided smooth differentiatior, N = 4
400        // y' = 1/8h (f_i + 2f_i-1, -2 f_i-3 - f_i-4)
401        case OpCodes.Derivative: {
402            int savedPc = state.ProgramCounter;
403            double f_0 = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues); row--;
404            state.ProgramCounter = savedPc;
405            double f_1 = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues); row -= 2;
406            state.ProgramCounter = savedPc;
407            double f_3 = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues); row--;
408            state.ProgramCounter = savedPc;
409            double f_4 = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
410            row += 4;
411
412            return (f_0 + 2 * f_1 - 2 * f_3 - f_4) / 8; // h = 1
413          }
414        case OpCodes.Call: {
415            // evaluate sub-trees
416            double[] argValues = new double[currentInstr.nArguments];
417            for (int i = 0; i < currentInstr.nArguments; i++) {
418              argValues[i] = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
419            }
420            // push on argument values on stack
421            state.CreateStackFrame(argValues);
422
423            // save the pc
424            int savedPc = state.ProgramCounter;
425            // set pc to start of function 
426            state.ProgramCounter = (ushort)currentInstr.iArg0;
427            // evaluate the function
428            double v = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
429
430            // delete the stack frame
431            state.RemoveStackFrame();
432
433            // restore the pc => evaluation will continue at point after my subtrees 
434            state.ProgramCounter = savedPc;
435            return v;
436          }
437        case OpCodes.Arg: {
438            return state.GetStackFrameValue((ushort)currentInstr.iArg0);
439          }
440        case OpCodes.Variable: {
441            if (row < 0 || row >= dataset.Rows)
442              return double.NaN;
443            var variableTreeNode = (VariableTreeNode)currentInstr.dynamicNode;
444            if (row <= lastObservedRow) return ((IList<double>)currentInstr.iArg0)[row] * variableTreeNode.Weight;
445            else return cachedPrognosedValues[variableTreeNode.VariableName][row - lastObservedRow - 1] * variableTreeNode.Weight;
446          }
447        case OpCodes.LagVariable: {
448            var laggedVariableTreeNode = (LaggedVariableTreeNode)currentInstr.dynamicNode;
449            int actualRow = row + laggedVariableTreeNode.Lag;
450            if (actualRow < 0 || actualRow >= dataset.Rows)
451              return double.NaN;
452            if (actualRow <= lastObservedRow) return ((IList<double>)currentInstr.iArg0)[actualRow] * laggedVariableTreeNode.Weight;
453            else return cachedPrognosedValues[laggedVariableTreeNode.VariableName][actualRow - lastObservedRow - 1] * laggedVariableTreeNode.Weight;
454          }
455        case OpCodes.Constant: {
456            var constTreeNode = currentInstr.dynamicNode as ConstantTreeNode;
457            return constTreeNode.Value;
458          }
459
460        //mkommend: this symbol uses the logistic function f(x) = 1 / (1 + e^(-alpha * x) )
461        //to determine the relative amounts of the true and false branch see http://en.wikipedia.org/wiki/Logistic_function
462        case OpCodes.VariableCondition: {
463            if (row < 0 || row >= dataset.Rows)
464              return double.NaN;
465            var variableConditionTreeNode = (VariableConditionTreeNode)currentInstr.dynamicNode;
466            double variableValue;
467            if (row <= lastObservedRow)
468              variableValue = ((IList<double>)currentInstr.iArg0)[row];
469            else
470              variableValue = cachedPrognosedValues[variableConditionTreeNode.VariableName][row - lastObservedRow - 1];
471
472            double x = variableValue - variableConditionTreeNode.Threshold;
473            double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
474
475            double trueBranch = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
476            double falseBranch = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
477
478            return trueBranch * p + falseBranch * (1 - p);
479          }
480        default: throw new NotSupportedException();
481      }
482    }
483
484    private byte MapSymbolToOpCode(ISymbolicExpressionTreeNode treeNode) {
485      if (symbolToOpcode.ContainsKey(treeNode.Symbol.GetType()))
486        return symbolToOpcode[treeNode.Symbol.GetType()];
487      else
488        throw new NotSupportedException("Symbol: " + treeNode.Symbol);
489    }
490
491    // skips a whole branch
492    private void SkipInstructions(InterpreterState state) {
493      int i = 1;
494      while (i > 0) {
495        i += state.NextInstruction().nArguments;
496        i--;
497      }
498    }
499  }
500}
Note: See TracBrowser for help on using the repository browser.