Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/SymbolicDataAnalysisExpressionTreeInterpreter.cs @ 7120

Last change on this file since 7120 was 7120, checked in by gkronber, 13 years ago

#1081 implemented multi-variate symbolic expression tree interpreter for time series prognosis.

File size: 22.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
33  [StorableClass]
34  [Item("SymbolicDataAnalysisExpressionTreeInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.")]
35  public sealed class SymbolicDataAnalysisExpressionTreeInterpreter : ParameterizedNamedItem,
36    ISymbolicDataAnalysisExpressionTreeInterpreter, ISymbolicTimeSeriesPrognosisExpressionTreeInterpreter {
37    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
38    #region private classes
39    private class InterpreterState {
40      private double[] argumentStack;
41      private int argumentStackPointer;
42      private Instruction[] code;
43      private int pc;
44      public int ProgramCounter {
45        get { return pc; }
46        set { pc = value; }
47      }
48      internal InterpreterState(Instruction[] code, int argumentStackSize) {
49        this.code = code;
50        this.pc = 0;
51        if (argumentStackSize > 0) {
52          this.argumentStack = new double[argumentStackSize];
53        }
54        this.argumentStackPointer = 0;
55      }
56
57      internal void Reset() {
58        this.pc = 0;
59        this.argumentStackPointer = 0;
60      }
61
62      internal Instruction NextInstruction() {
63        return code[pc++];
64      }
65      private void Push(double val) {
66        argumentStack[argumentStackPointer++] = val;
67      }
68      private double Pop() {
69        return argumentStack[--argumentStackPointer];
70      }
71
72      internal void CreateStackFrame(double[] argValues) {
73        // push in reverse order to make indexing easier
74        for (int i = argValues.Length - 1; i >= 0; i--) {
75          argumentStack[argumentStackPointer++] = argValues[i];
76        }
77        Push(argValues.Length);
78      }
79
80      internal void RemoveStackFrame() {
81        int size = (int)Pop();
82        argumentStackPointer -= size;
83      }
84
85      internal double GetStackFrameValue(ushort index) {
86        // layout of stack:
87        // [0]   <- argumentStackPointer
88        // [StackFrameSize = N + 1]
89        // [Arg0] <- argumentStackPointer - 2 - 0
90        // [Arg1] <- argumentStackPointer - 2 - 1
91        // [...]
92        // [ArgN] <- argumentStackPointer - 2 - N
93        // <Begin of stack frame>
94        return argumentStack[argumentStackPointer - index - 2];
95      }
96    }
97    private class OpCodes {
98      public const byte Add = 1;
99      public const byte Sub = 2;
100      public const byte Mul = 3;
101      public const byte Div = 4;
102
103      public const byte Sin = 5;
104      public const byte Cos = 6;
105      public const byte Tan = 7;
106
107      public const byte Log = 8;
108      public const byte Exp = 9;
109
110      public const byte IfThenElse = 10;
111
112      public const byte GT = 11;
113      public const byte LT = 12;
114
115      public const byte AND = 13;
116      public const byte OR = 14;
117      public const byte NOT = 15;
118
119
120      public const byte Average = 16;
121
122      public const byte Call = 17;
123
124      public const byte Variable = 18;
125      public const byte LagVariable = 19;
126      public const byte Constant = 20;
127      public const byte Arg = 21;
128
129      public const byte Power = 22;
130      public const byte Root = 23;
131      public const byte TimeLag = 24;
132      public const byte Integral = 25;
133      public const byte Derivative = 26;
134
135      public const byte VariableCondition = 27;
136    }
137    #endregion
138
139    private Dictionary<Type, byte> symbolToOpcode = new Dictionary<Type, byte>() {
140      { typeof(Addition), OpCodes.Add },
141      { typeof(Subtraction), OpCodes.Sub },
142      { typeof(Multiplication), OpCodes.Mul },
143      { typeof(Division), OpCodes.Div },
144      { typeof(Sine), OpCodes.Sin },
145      { typeof(Cosine), OpCodes.Cos },
146      { typeof(Tangent), OpCodes.Tan },
147      { typeof(Logarithm), OpCodes.Log },
148      { typeof(Exponential), OpCodes.Exp },
149      { typeof(IfThenElse), OpCodes.IfThenElse },
150      { typeof(GreaterThan), OpCodes.GT },
151      { typeof(LessThan), OpCodes.LT },
152      { typeof(And), OpCodes.AND },
153      { typeof(Or), OpCodes.OR },
154      { typeof(Not), OpCodes.NOT},
155      { typeof(Average), OpCodes.Average},
156      { typeof(InvokeFunction), OpCodes.Call },
157      { typeof(HeuristicLab.Problems.DataAnalysis.Symbolic.Variable), OpCodes.Variable },
158      { typeof(LaggedVariable), OpCodes.LagVariable },
159      { typeof(Constant), OpCodes.Constant },
160      { typeof(Argument), OpCodes.Arg },
161      { typeof(Power),OpCodes.Power},
162      { typeof(Root),OpCodes.Root},
163      { typeof(TimeLag), OpCodes.TimeLag},
164      { typeof(Integral), OpCodes.Integral},
165      { typeof(Derivative), OpCodes.Derivative},
166      { typeof(VariableCondition),OpCodes.VariableCondition}
167    };
168
169    public override bool CanChangeName {
170      get { return false; }
171    }
172    public override bool CanChangeDescription {
173      get { return false; }
174    }
175
176    #region parameter properties
177    public IValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
178      get { return (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
179    }
180    #endregion
181
182    #region properties
183    public BoolValue CheckExpressionsWithIntervalArithmetic {
184      get { return CheckExpressionsWithIntervalArithmeticParameter.Value; }
185      set { CheckExpressionsWithIntervalArithmeticParameter.Value = value; }
186    }
187    #endregion
188
189
190    [StorableConstructor]
191    private SymbolicDataAnalysisExpressionTreeInterpreter(bool deserializing) : base(deserializing) { }
192    private SymbolicDataAnalysisExpressionTreeInterpreter(SymbolicDataAnalysisExpressionTreeInterpreter original, Cloner cloner) : base(original, cloner) { }
193    public override IDeepCloneable Clone(Cloner cloner) {
194      return new SymbolicDataAnalysisExpressionTreeInterpreter(this, cloner);
195    }
196
197    public SymbolicDataAnalysisExpressionTreeInterpreter()
198      : base("SymbolicDataAnalysisExpressionTreeInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.") {
199      Parameters.Add(new ValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
200    }
201
202    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows) {
203      return from prog in GetSymbolicExpressionTreeValues(tree, dataset, new string[] { "#NOTHING#" }, rows, 1)
204             select prog.First().First();
205    }
206
207    // for each row for each target variable one prognosis (=enumerable of future values)
208    public IEnumerable<IEnumerable<IEnumerable<double>>> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, string[] targetVariables, IEnumerable<int> rows, int horizon) {
209      if (CheckExpressionsWithIntervalArithmetic.Value)
210        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
211      var compiler = new SymbolicExpressionTreeCompiler();
212      Instruction[] code = compiler.Compile(tree, MapSymbolToOpCode);
213      int necessaryArgStackSize = 0;
214      for (int i = 0; i < code.Length; i++) {
215        Instruction instr = code[i];
216        if (instr.opCode == OpCodes.Variable) {
217          var variableTreeNode = instr.dynamicNode as VariableTreeNode;
218          instr.iArg0 = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
219          code[i] = instr;
220        } else if (instr.opCode == OpCodes.LagVariable) {
221          var laggedVariableTreeNode = instr.dynamicNode as LaggedVariableTreeNode;
222          instr.iArg0 = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
223          code[i] = instr;
224        } else if (instr.opCode == OpCodes.VariableCondition) {
225          var variableConditionTreeNode = instr.dynamicNode as VariableConditionTreeNode;
226          instr.iArg0 = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
227        } else if (instr.opCode == OpCodes.Call) {
228          necessaryArgStackSize += instr.nArguments + 1;
229        }
230      }
231      var state = new InterpreterState(code, necessaryArgStackSize);
232
233      int nComponents = tree.Root.GetSubtree(0).SubtreeCount;
234      // produce a n-step forecast for each target variable for all rows
235      var cachedPrognosedValues = new Dictionary<string, double[]>();
236      foreach (var targetVariable in targetVariables)
237        cachedPrognosedValues[targetVariable] = new double[horizon];
238      foreach (var rowEnum in rows) {
239        int row = rowEnum;
240        List<double[]> vProgs = new List<double[]>();
241        foreach (var horizonRow in Enumerable.Range(row, horizon)) {
242          int localRow = horizonRow; // create a local variable for the ref parameter
243          var vPrognosis = from i in Enumerable.Range(0, nComponents)
244                           select Evaluate(dataset, ref localRow, row - 1, state, cachedPrognosedValues);
245
246          var vPrognosisArr = vPrognosis.ToArray();
247          vProgs.Add(vPrognosisArr);
248          // set cachedValues for prognosis of future values
249          for (int i = 0; i < vPrognosisArr.Length; i++)
250            cachedPrognosedValues[targetVariables[i]][horizonRow - row] = vPrognosisArr[i];
251
252          state.Reset();
253        }
254
255        yield return from component in Enumerable.Range(0, nComponents)
256                     select from prognosisStep in Enumerable.Range(0, vProgs.Count)
257                            select vProgs[prognosisStep][component];
258      }
259    }
260
261    private double Evaluate(Dataset dataset, ref int row, int lastObservedRow, InterpreterState state, Dictionary<string, double[]> cachedPrognosedValues) {
262      Instruction currentInstr = state.NextInstruction();
263      switch (currentInstr.opCode) {
264        case OpCodes.Add: {
265            double s = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
266            for (int i = 1; i < currentInstr.nArguments; i++) {
267              s += Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
268            }
269            return s;
270          }
271        case OpCodes.Sub: {
272            double s = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
273            for (int i = 1; i < currentInstr.nArguments; i++) {
274              s -= Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
275            }
276            if (currentInstr.nArguments == 1) s = -s;
277            return s;
278          }
279        case OpCodes.Mul: {
280            double p = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
281            for (int i = 1; i < currentInstr.nArguments; i++) {
282              p *= Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
283            }
284            return p;
285          }
286        case OpCodes.Div: {
287            double p = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
288            for (int i = 1; i < currentInstr.nArguments; i++) {
289              p /= Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
290            }
291            if (currentInstr.nArguments == 1) p = 1.0 / p;
292            return p;
293          }
294        case OpCodes.Average: {
295            double sum = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
296            for (int i = 1; i < currentInstr.nArguments; i++) {
297              sum += Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
298            }
299            return sum / currentInstr.nArguments;
300          }
301        case OpCodes.Cos: {
302            return Math.Cos(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
303          }
304        case OpCodes.Sin: {
305            return Math.Sin(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
306          }
307        case OpCodes.Tan: {
308            return Math.Tan(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
309          }
310        case OpCodes.Power: {
311            double x = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
312            double y = Math.Round(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
313            return Math.Pow(x, y);
314          }
315        case OpCodes.Root: {
316            double x = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
317            double y = Math.Round(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
318            return Math.Pow(x, 1 / y);
319          }
320        case OpCodes.Exp: {
321            return Math.Exp(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
322          }
323        case OpCodes.Log: {
324            return Math.Log(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
325          }
326        case OpCodes.IfThenElse: {
327            double condition = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
328            double result;
329            if (condition > 0.0) {
330              result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues); SkipInstructions(state);
331            } else {
332              SkipInstructions(state); result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
333            }
334            return result;
335          }
336        case OpCodes.AND: {
337            double result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
338            for (int i = 1; i < currentInstr.nArguments; i++) {
339              if (result > 0.0) result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
340              else {
341                SkipInstructions(state);
342              }
343            }
344            return result > 0.0 ? 1.0 : -1.0;
345          }
346        case OpCodes.OR: {
347            double result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
348            for (int i = 1; i < currentInstr.nArguments; i++) {
349              if (result <= 0.0) result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
350              else {
351                SkipInstructions(state);
352              }
353            }
354            return result > 0.0 ? 1.0 : -1.0;
355          }
356        case OpCodes.NOT: {
357            return Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues) > 0.0 ? -1.0 : 1.0;
358          }
359        case OpCodes.GT: {
360            double x = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
361            double y = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
362            if (x > y) return 1.0;
363            else return -1.0;
364          }
365        case OpCodes.LT: {
366            double x = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
367            double y = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
368            if (x < y) return 1.0;
369            else return -1.0;
370          }
371        case OpCodes.TimeLag: {
372            var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
373            row += timeLagTreeNode.Lag;
374            double result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
375            row -= timeLagTreeNode.Lag;
376            return result;
377          }
378        case OpCodes.Integral: {
379            int savedPc = state.ProgramCounter;
380            var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
381            double sum = 0.0;
382            for (int i = 0; i < Math.Abs(timeLagTreeNode.Lag); i++) {
383              row += Math.Sign(timeLagTreeNode.Lag);
384              sum += Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
385              state.ProgramCounter = savedPc;
386            }
387            row -= timeLagTreeNode.Lag;
388            sum += Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
389            return sum;
390          }
391
392        //mkommend: derivate calculation taken from:
393        //http://www.holoborodko.com/pavel/numerical-methods/numerical-derivative/smooth-low-noise-differentiators/
394        //one sided smooth differentiatior, N = 4
395        // y' = 1/8h (f_i + 2f_i-1, -2 f_i-3 - f_i-4)
396        case OpCodes.Derivative: {
397            int savedPc = state.ProgramCounter;
398            double f_0 = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues); row--;
399            state.ProgramCounter = savedPc;
400            double f_1 = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues); row -= 2;
401            state.ProgramCounter = savedPc;
402            double f_3 = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues); row--;
403            state.ProgramCounter = savedPc;
404            double f_4 = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
405            row += 4;
406
407            return (f_0 + 2 * f_1 - 2 * f_3 - f_4) / 8; // h = 1
408          }
409        case OpCodes.Call: {
410            // evaluate sub-trees
411            double[] argValues = new double[currentInstr.nArguments];
412            for (int i = 0; i < currentInstr.nArguments; i++) {
413              argValues[i] = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
414            }
415            // push on argument values on stack
416            state.CreateStackFrame(argValues);
417
418            // save the pc
419            int savedPc = state.ProgramCounter;
420            // set pc to start of function 
421            state.ProgramCounter = (ushort)currentInstr.iArg0;
422            // evaluate the function
423            double v = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
424
425            // delete the stack frame
426            state.RemoveStackFrame();
427
428            // restore the pc => evaluation will continue at point after my subtrees 
429            state.ProgramCounter = savedPc;
430            return v;
431          }
432        case OpCodes.Arg: {
433            return state.GetStackFrameValue((ushort)currentInstr.iArg0);
434          }
435        case OpCodes.Variable: {
436            if (row < 0 || row >= dataset.Rows)
437              return double.NaN;
438            var variableTreeNode = (VariableTreeNode)currentInstr.dynamicNode;
439            if (row <= lastObservedRow || !cachedPrognosedValues.ContainsKey(variableTreeNode.VariableName)) return ((IList<double>)currentInstr.iArg0)[row] * variableTreeNode.Weight;
440            else return cachedPrognosedValues[variableTreeNode.VariableName][row - lastObservedRow - 1] * variableTreeNode.Weight;
441          }
442        case OpCodes.LagVariable: {
443            var laggedVariableTreeNode = (LaggedVariableTreeNode)currentInstr.dynamicNode;
444            int actualRow = row + laggedVariableTreeNode.Lag;
445            if (actualRow < 0 || actualRow >= dataset.Rows)
446              return double.NaN;
447            if (actualRow <= lastObservedRow || !cachedPrognosedValues.ContainsKey(laggedVariableTreeNode.VariableName)) return ((IList<double>)currentInstr.iArg0)[actualRow] * laggedVariableTreeNode.Weight;
448            else return cachedPrognosedValues[laggedVariableTreeNode.VariableName][actualRow - lastObservedRow - 1] * laggedVariableTreeNode.Weight;
449          }
450        case OpCodes.Constant: {
451            var constTreeNode = currentInstr.dynamicNode as ConstantTreeNode;
452            return constTreeNode.Value;
453          }
454
455        //mkommend: this symbol uses the logistic function f(x) = 1 / (1 + e^(-alpha * x) )
456        //to determine the relative amounts of the true and false branch see http://en.wikipedia.org/wiki/Logistic_function
457        case OpCodes.VariableCondition: {
458            if (row < 0 || row >= dataset.Rows)
459              return double.NaN;
460            var variableConditionTreeNode = (VariableConditionTreeNode)currentInstr.dynamicNode;
461            double variableValue;
462            if (row <= lastObservedRow || !cachedPrognosedValues.ContainsKey(variableConditionTreeNode.VariableName))
463              variableValue = ((IList<double>)currentInstr.iArg0)[row];
464            else
465              variableValue = cachedPrognosedValues[variableConditionTreeNode.VariableName][row - lastObservedRow - 1];
466
467            double x = variableValue - variableConditionTreeNode.Threshold;
468            double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
469
470            double trueBranch = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
471            double falseBranch = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
472
473            return trueBranch * p + falseBranch * (1 - p);
474          }
475        default: throw new NotSupportedException();
476      }
477    }
478
479    private byte MapSymbolToOpCode(ISymbolicExpressionTreeNode treeNode) {
480      if (symbolToOpcode.ContainsKey(treeNode.Symbol.GetType()))
481        return symbolToOpcode[treeNode.Symbol.GetType()];
482      else
483        throw new NotSupportedException("Symbol: " + treeNode.Symbol);
484    }
485
486    // skips a whole branch
487    private void SkipInstructions(InterpreterState state) {
488      int i = 1;
489      while (i > 0) {
490        i += state.NextInstruction().nArguments;
491        i--;
492      }
493    }
494  }
495}
Note: See TracBrowser for help on using the repository browser.