Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/Symbolic/SimpleArithmeticExpressionInterpreter.cs @ 6646

Last change on this file since 6646 was 5467, checked in by gkronber, 14 years ago

#1325: Merged r5060 from branch into trunk.

File size: 17.0 KB
RevLine 
[3253]1#region License Information
2/* HeuristicLab
[5445]3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[3253]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
[4068]23using System.Collections.Generic;
[4722]24using HeuristicLab.Common;
[3253]25using HeuristicLab.Core;
26using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
[4068]27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Compiler;
[3462]28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Symbols;
[4068]29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
[3373]30using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
[3253]31
[3373]32namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
[3253]33  [StorableClass]
[3462]34  [Item("SimpleArithmeticExpressionInterpreter", "Interpreter for arithmetic symbolic expression trees including function calls.")]
[4722]35  public sealed class SimpleArithmeticExpressionInterpreter : NamedItem, ISymbolicExpressionTreeInterpreter {
[5396]36    private class InterpreterState {
37      private const int ARGUMENT_STACK_SIZE = 1024;
38      private double[] argumentStack;
39      private int argumentStackPointer;
40      private Instruction[] code;
41      private int pc;
42      public int ProgramCounter {
43        get { return pc; }
44        set { pc = value; }
45      }
46      internal InterpreterState(Instruction[] code) {
47        this.code = code;
48        this.pc = 0;
49        this.argumentStack = new double[ARGUMENT_STACK_SIZE];
50        this.argumentStackPointer = 0;
51      }
52
53      internal void Reset() {
54        this.pc = 0;
55        this.argumentStackPointer = 0;
56      }
57
58      internal Instruction NextInstruction() {
59        return code[pc++];
60      }
61      private void Push(double val) {
62        argumentStack[argumentStackPointer++] = val;
63      }
64      private double Pop() {
65        return argumentStack[--argumentStackPointer];
66      }
67
68      internal void CreateStackFrame(double[] argValues) {
69        // push in reverse order to make indexing easier
70        for (int i = argValues.Length - 1; i >= 0; i--) {
71          argumentStack[argumentStackPointer++] = argValues[i];
72        }
73        Push(argValues.Length);
74      }
75
76      internal void RemoveStackFrame() {
77        int size = (int)Pop();
78        argumentStackPointer -= size;
79      }
80
81      internal double GetStackFrameValue(ushort index) {
82        // layout of stack:
83        // [0]   <- argumentStackPointer
84        // [StackFrameSize = N + 1]
85        // [Arg0] <- argumentStackPointer - 2 - 0
86        // [Arg1] <- argumentStackPointer - 2 - 1
87        // [...]
88        // [ArgN] <- argumentStackPointer - 2 - N
89        // <Begin of stack frame>
90        return argumentStack[argumentStackPointer - index - 2];
91      }
92    }
93
[3462]94    private class OpCodes {
95      public const byte Add = 1;
96      public const byte Sub = 2;
97      public const byte Mul = 3;
98      public const byte Div = 4;
[3841]99
100      public const byte Sin = 5;
101      public const byte Cos = 6;
102      public const byte Tan = 7;
103
104      public const byte Log = 8;
105      public const byte Exp = 9;
106
107      public const byte IfThenElse = 10;
108
109      public const byte GT = 11;
110      public const byte LT = 12;
111
112      public const byte AND = 13;
113      public const byte OR = 14;
114      public const byte NOT = 15;
115
116
117      public const byte Average = 16;
118
119      public const byte Call = 17;
120
121      public const byte Variable = 18;
122      public const byte LagVariable = 19;
123      public const byte Constant = 20;
124      public const byte Arg = 21;
[5288]125
126      public const byte Power = 22;
[5384]127      public const byte Root = 23;
128      public const byte TimeLag = 24;
129      public const byte Integral = 25;
130      public const byte Derivative = 26;
[5467]131
132      public const byte VariableCondition = 27;
[3462]133    }
134
[3841]135    private Dictionary<Type, byte> symbolToOpcode = new Dictionary<Type, byte>() {
136      { typeof(Addition), OpCodes.Add },
137      { typeof(Subtraction), OpCodes.Sub },
138      { typeof(Multiplication), OpCodes.Mul },
139      { typeof(Division), OpCodes.Div },
140      { typeof(Sine), OpCodes.Sin },
141      { typeof(Cosine), OpCodes.Cos },
142      { typeof(Tangent), OpCodes.Tan },
143      { typeof(Logarithm), OpCodes.Log },
144      { typeof(Exponential), OpCodes.Exp },
145      { typeof(IfThenElse), OpCodes.IfThenElse },
146      { typeof(GreaterThan), OpCodes.GT },
147      { typeof(LessThan), OpCodes.LT },
148      { typeof(And), OpCodes.AND },
149      { typeof(Or), OpCodes.OR },
150      { typeof(Not), OpCodes.NOT},
151      { typeof(Average), OpCodes.Average},
152      { typeof(InvokeFunction), OpCodes.Call },
153      { typeof(HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols.Variable), OpCodes.Variable },
154      { typeof(LaggedVariable), OpCodes.LagVariable },
155      { typeof(Constant), OpCodes.Constant },
156      { typeof(Argument), OpCodes.Arg },
[5288]157      { typeof(Power),OpCodes.Power},
[5384]158      { typeof(Root),OpCodes.Root},
[5373]159      { typeof(TimeLag), OpCodes.TimeLag},
160      { typeof(Integral), OpCodes.Integral},
[5384]161      { typeof(Derivative), OpCodes.Derivative},
[5467]162      { typeof(VariableCondition),OpCodes.VariableCondition}
[3841]163    };
[3513]164
[5396]165
[3545]166    public override bool CanChangeName {
167      get { return false; }
168    }
169    public override bool CanChangeDescription {
170      get { return false; }
171    }
172
[4722]173    [StorableConstructor]
174    private SimpleArithmeticExpressionInterpreter(bool deserializing) : base(deserializing) { }
175    private SimpleArithmeticExpressionInterpreter(SimpleArithmeticExpressionInterpreter original, Cloner cloner) : base(original, cloner) { }
176    public override IDeepCloneable Clone(Cloner cloner) {
177      return new SimpleArithmeticExpressionInterpreter(this, cloner);
178    }
179
[3513]180    public SimpleArithmeticExpressionInterpreter()
181      : base() {
182    }
183
[3462]184    public IEnumerable<double> GetSymbolicExpressionTreeValues(SymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows) {
[3294]185      var compiler = new SymbolicExpressionTreeCompiler();
[5223]186      Instruction[] code = compiler.Compile(tree, MapSymbolToOpCode);
187
188      for (int i = 0; i < code.Length; i++) {
189        Instruction instr = code[i];
190        if (instr.opCode == OpCodes.Variable) {
191          var variableTreeNode = instr.dynamicNode as VariableTreeNode;
192          instr.iArg0 = (ushort)dataset.GetVariableIndex(variableTreeNode.VariableName);
193          code[i] = instr;
194        } else if (instr.opCode == OpCodes.LagVariable) {
195          var variableTreeNode = instr.dynamicNode as LaggedVariableTreeNode;
196          instr.iArg0 = (ushort)dataset.GetVariableIndex(variableTreeNode.VariableName);
197          code[i] = instr;
[5467]198        } else if (instr.opCode == OpCodes.VariableCondition) {
199          var variableConditionTreeNode = instr.dynamicNode as VariableConditionTreeNode;
200          instr.iArg0 = (ushort)dataset.GetVariableIndex(variableConditionTreeNode.VariableName);
[5223]201        }
[3253]202      }
[5396]203      var state = new InterpreterState(code);
[3253]204
[5223]205      foreach (var rowEnum in rows) {
206        int row = rowEnum;
[5396]207        state.Reset();
208        yield return Evaluate(dataset, ref row, state);
[5467]209      }
[3462]210    }
211
[5396]212    private double Evaluate(Dataset dataset, ref int row, InterpreterState state) {
213      Instruction currentInstr = state.NextInstruction();
[3462]214      switch (currentInstr.opCode) {
215        case OpCodes.Add: {
[5396]216            double s = Evaluate(dataset, ref row, state);
[3996]217            for (int i = 1; i < currentInstr.nArguments; i++) {
[5396]218              s += Evaluate(dataset, ref row, state);
[3294]219            }
220            return s;
221          }
[3462]222        case OpCodes.Sub: {
[5396]223            double s = Evaluate(dataset, ref row, state);
[3294]224            for (int i = 1; i < currentInstr.nArguments; i++) {
[5396]225              s -= Evaluate(dataset, ref row, state);
[3294]226            }
[3733]227            if (currentInstr.nArguments == 1) s = -s;
[3294]228            return s;
229          }
[3462]230        case OpCodes.Mul: {
[5396]231            double p = Evaluate(dataset, ref row, state);
[3294]232            for (int i = 1; i < currentInstr.nArguments; i++) {
[5396]233              p *= Evaluate(dataset, ref row, state);
[3294]234            }
235            return p;
236          }
[3462]237        case OpCodes.Div: {
[5396]238            double p = Evaluate(dataset, ref row, state);
[3294]239            for (int i = 1; i < currentInstr.nArguments; i++) {
[5396]240              p /= Evaluate(dataset, ref row, state);
[3294]241            }
[3733]242            if (currentInstr.nArguments == 1) p = 1.0 / p;
[3294]243            return p;
244          }
[3841]245        case OpCodes.Average: {
[5396]246            double sum = Evaluate(dataset, ref row, state);
[3841]247            for (int i = 1; i < currentInstr.nArguments; i++) {
[5396]248              sum += Evaluate(dataset, ref row, state);
[3841]249            }
250            return sum / currentInstr.nArguments;
251          }
252        case OpCodes.Cos: {
[5396]253            return Math.Cos(Evaluate(dataset, ref row, state));
[3841]254          }
255        case OpCodes.Sin: {
[5396]256            return Math.Sin(Evaluate(dataset, ref row, state));
[3841]257          }
258        case OpCodes.Tan: {
[5396]259            return Math.Tan(Evaluate(dataset, ref row, state));
[3841]260          }
[5288]261        case OpCodes.Power: {
[5396]262            double x = Evaluate(dataset, ref row, state);
263            double y = Math.Round(Evaluate(dataset, ref row, state));
[5288]264            return Math.Pow(x, y);
265          }
[5384]266        case OpCodes.Root: {
[5396]267            double x = Evaluate(dataset, ref row, state);
268            double y = Math.Round(Evaluate(dataset, ref row, state));
[5384]269            return Math.Pow(x, 1 / y);
270          }
[3841]271        case OpCodes.Exp: {
[5396]272            return Math.Exp(Evaluate(dataset, ref row, state));
[3841]273          }
274        case OpCodes.Log: {
[5396]275            return Math.Log(Evaluate(dataset, ref row, state));
[3841]276          }
277        case OpCodes.IfThenElse: {
[5396]278            double condition = Evaluate(dataset, ref row, state);
[3841]279            double result;
280            if (condition > 0.0) {
[5396]281              result = Evaluate(dataset, ref row, state); SkipInstructions(state);
[3841]282            } else {
[5396]283              SkipInstructions(state); result = Evaluate(dataset, ref row, state);
[3841]284            }
285            return result;
286          }
287        case OpCodes.AND: {
[5396]288            double result = Evaluate(dataset, ref row, state);
[3841]289            for (int i = 1; i < currentInstr.nArguments; i++) {
[5396]290              if (result <= 0.0) SkipInstructions(state);
[3841]291              else {
[5396]292                result = Evaluate(dataset, ref row, state);
[3841]293              }
294            }
295            return result <= 0.0 ? -1.0 : 1.0;
296          }
297        case OpCodes.OR: {
[5396]298            double result = Evaluate(dataset, ref row, state);
[3841]299            for (int i = 1; i < currentInstr.nArguments; i++) {
[5396]300              if (result > 0.0) SkipInstructions(state);
[3841]301              else {
[5396]302                result = Evaluate(dataset, ref row, state);
[3841]303              }
304            }
305            return result > 0.0 ? 1.0 : -1.0;
306          }
307        case OpCodes.NOT: {
[5396]308            return -Evaluate(dataset, ref row, state);
[3841]309          }
310        case OpCodes.GT: {
[5396]311            double x = Evaluate(dataset, ref row, state);
312            double y = Evaluate(dataset, ref row, state);
[3841]313            if (x > y) return 1.0;
314            else return -1.0;
315          }
316        case OpCodes.LT: {
[5396]317            double x = Evaluate(dataset, ref row, state);
318            double y = Evaluate(dataset, ref row, state);
[3841]319            if (x < y) return 1.0;
320            else return -1.0;
321          }
[5373]322        case OpCodes.TimeLag: {
323            var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
324            if (row + timeLagTreeNode.Lag < 0 || row + timeLagTreeNode.Lag >= dataset.Rows)
325              return double.NaN;
326
327            row += timeLagTreeNode.Lag;
[5396]328            double result = Evaluate(dataset, ref row, state);
[5373]329            row -= timeLagTreeNode.Lag;
330            return result;
331          }
332        case OpCodes.Integral: {
[5396]333            int savedPc = state.ProgramCounter;
[5373]334            var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
335            if (row + timeLagTreeNode.Lag < 0 || row + timeLagTreeNode.Lag >= dataset.Rows)
336              return double.NaN;
337            double sum = 0.0;
338            for (int i = 0; i < Math.Abs(timeLagTreeNode.Lag); i++) {
339              row += Math.Sign(timeLagTreeNode.Lag);
[5396]340              sum += Evaluate(dataset, ref row, state);
341              state.ProgramCounter = savedPc;
[5373]342            }
343            row -= timeLagTreeNode.Lag;
[5396]344            sum += Evaluate(dataset, ref row, state);
[5373]345            return sum;
346          }
347
348        //mkommend: derivate calculation taken from:
349        //http://www.holoborodko.com/pavel/numerical-methods/numerical-derivative/smooth-low-noise-differentiators/
350        //one sided smooth differentiatior, N = 4
351        // y' = 1/8h (f_i + 2f_i-1, -2 f_i-3 - f_i-4)
352        case OpCodes.Derivative: {
353            if (row - 4 < 0) return double.NaN;
[5396]354            int savedPc = state.ProgramCounter;
355            double f_0 = Evaluate(dataset, ref row, state); ; row--;
356            state.ProgramCounter = savedPc;
357            double f_1 = Evaluate(dataset, ref row, state); ; row -= 2;
358            state.ProgramCounter = savedPc;
359            double f_3 = Evaluate(dataset, ref row, state); ; row--;
360            state.ProgramCounter = savedPc;
361            double f_4 = Evaluate(dataset, ref row, state); ;
[5373]362            row += 4;
363
364            return (f_0 + 2 * f_1 - 2 * f_3 - f_4) / 8; // h = 1
365          }
[3462]366        case OpCodes.Call: {
[3409]367            // evaluate sub-trees
[5396]368            double[] argValues = new double[currentInstr.nArguments];
[3409]369            for (int i = 0; i < currentInstr.nArguments; i++) {
[5396]370              argValues[i] = Evaluate(dataset, ref row, state);
[3409]371            }
[5396]372            // push on argument values on stack
373            state.CreateStackFrame(argValues);
[3491]374
[3409]375            // save the pc
[5396]376            int savedPc = state.ProgramCounter;
[3409]377            // set pc to start of function 
[5396]378            state.ProgramCounter = currentInstr.iArg0;
[3409]379            // evaluate the function
[5396]380            double v = Evaluate(dataset, ref row, state);
[3491]381
[5396]382            // delete the stack frame
383            state.RemoveStackFrame();
[3491]384
[3409]385            // restore the pc => evaluation will continue at point after my subtrees 
[5396]386            state.ProgramCounter = savedPc;
[3409]387            return v;
388          }
[3462]389        case OpCodes.Arg: {
[5396]390            return state.GetStackFrameValue(currentInstr.iArg0);
[3409]391          }
[3462]392        case OpCodes.Variable: {
[3373]393            var variableTreeNode = currentInstr.dynamicNode as VariableTreeNode;
[3462]394            return dataset[row, currentInstr.iArg0] * variableTreeNode.Weight;
395          }
[3841]396        case OpCodes.LagVariable: {
[5223]397            var laggedVariableTreeNode = currentInstr.dynamicNode as LaggedVariableTreeNode;
398            int actualRow = row + laggedVariableTreeNode.Lag;
[3841]399            if (actualRow < 0 || actualRow >= dataset.Rows) throw new ArgumentException("Out of range access to dataset row: " + row);
[5223]400            return dataset[actualRow, currentInstr.iArg0] * laggedVariableTreeNode.Weight;
[3841]401          }
[3462]402        case OpCodes.Constant: {
[3373]403            var constTreeNode = currentInstr.dynamicNode as ConstantTreeNode;
[3462]404            return constTreeNode.Value;
[3294]405          }
[5467]406
407        //mkommend: this symbol uses the logistic function f(x) = 1 / (1 + e^(-alpha * x) )
408        //to determine the relative amounts of the true and false branch see http://en.wikipedia.org/wiki/Logistic_function
409        case OpCodes.VariableCondition: {
410            var variableConditionTreeNode = (VariableConditionTreeNode)currentInstr.dynamicNode;
411            double variableValue = dataset[row, currentInstr.iArg0];
412            double x = variableValue - variableConditionTreeNode.Threshold;
413            double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
414
415            double trueBranch = Evaluate(dataset, ref row, state);
416            double falseBranch = Evaluate(dataset, ref row, state);
417
418            return trueBranch * p + falseBranch * (1 - p);
419          }
[3294]420        default: throw new NotSupportedException();
[3253]421      }
422    }
[3841]423
[5223]424    private byte MapSymbolToOpCode(SymbolicExpressionTreeNode treeNode) {
425      if (symbolToOpcode.ContainsKey(treeNode.Symbol.GetType()))
426        return symbolToOpcode[treeNode.Symbol.GetType()];
427      else
428        throw new NotSupportedException("Symbol: " + treeNode.Symbol);
429    }
430
[3841]431    // skips a whole branch
[5396]432    private void SkipInstructions(InterpreterState state) {
[3841]433      int i = 1;
434      while (i > 0) {
[5396]435        i += state.NextInstruction().nArguments;
[3841]436        i--;
437      }
438    }
[3253]439  }
440}
Note: See TracBrowser for help on using the repository browser.