Free cookie consent management tool by TermsFeed Policy Generator

source: branches/DataAnalysis Refactoring/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/SymbolicDataAnalysisExpressionTreeInterpreter.cs @ 5618

Last change on this file since 5618 was 5574, checked in by gkronber, 14 years ago

#1418 Added test projects for data-analysis and symbolic data-analysis plugin. Moved grammars to version 3.4. Fixed bugs in interpretation and simplification of 'not' symbols.

File size: 16.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28
29namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
30  [StorableClass]
31  [Item("SymbolicDataAnalysisExpressionTreeInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.")]
32  public sealed class SymbolicDataAnalysisExpressionTreeInterpreter : NamedItem, ISymbolicDataAnalysisTreeInterpreter {
33    private class InterpreterState {
34      private const int ARGUMENT_STACK_SIZE = 1024;
35      private double[] argumentStack;
36      private int argumentStackPointer;
37      private Instruction[] code;
38      private int pc;
39      public int ProgramCounter {
40        get { return pc; }
41        set { pc = value; }
42      }
43      internal InterpreterState(Instruction[] code) {
44        this.code = code;
45        this.pc = 0;
46        this.argumentStack = new double[ARGUMENT_STACK_SIZE];
47        this.argumentStackPointer = 0;
48      }
49
50      internal void Reset() {
51        this.pc = 0;
52        this.argumentStackPointer = 0;
53      }
54
55      internal Instruction NextInstruction() {
56        return code[pc++];
57      }
58      private void Push(double val) {
59        argumentStack[argumentStackPointer++] = val;
60      }
61      private double Pop() {
62        return argumentStack[--argumentStackPointer];
63      }
64
65      internal void CreateStackFrame(double[] argValues) {
66        // push in reverse order to make indexing easier
67        for (int i = argValues.Length - 1; i >= 0; i--) {
68          argumentStack[argumentStackPointer++] = argValues[i];
69        }
70        Push(argValues.Length);
71      }
72
73      internal void RemoveStackFrame() {
74        int size = (int)Pop();
75        argumentStackPointer -= size;
76      }
77
78      internal double GetStackFrameValue(ushort index) {
79        // layout of stack:
80        // [0]   <- argumentStackPointer
81        // [StackFrameSize = N + 1]
82        // [Arg0] <- argumentStackPointer - 2 - 0
83        // [Arg1] <- argumentStackPointer - 2 - 1
84        // [...]
85        // [ArgN] <- argumentStackPointer - 2 - N
86        // <Begin of stack frame>
87        return argumentStack[argumentStackPointer - index - 2];
88      }
89    }
90
91    private class OpCodes {
92      public const byte Add = 1;
93      public const byte Sub = 2;
94      public const byte Mul = 3;
95      public const byte Div = 4;
96
97      public const byte Sin = 5;
98      public const byte Cos = 6;
99      public const byte Tan = 7;
100
101      public const byte Log = 8;
102      public const byte Exp = 9;
103
104      public const byte IfThenElse = 10;
105
106      public const byte GT = 11;
107      public const byte LT = 12;
108
109      public const byte AND = 13;
110      public const byte OR = 14;
111      public const byte NOT = 15;
112
113
114      public const byte Average = 16;
115
116      public const byte Call = 17;
117
118      public const byte Variable = 18;
119      public const byte LagVariable = 19;
120      public const byte Constant = 20;
121      public const byte Arg = 21;
122
123      public const byte Power = 22;
124      public const byte Root = 23;
125      public const byte TimeLag = 24;
126      public const byte Integral = 25;
127      public const byte Derivative = 26;
128
129      public const byte VariableCondition = 27;
130    }
131
132    private Dictionary<Type, byte> symbolToOpcode = new Dictionary<Type, byte>() {
133      { typeof(Addition), OpCodes.Add },
134      { typeof(Subtraction), OpCodes.Sub },
135      { typeof(Multiplication), OpCodes.Mul },
136      { typeof(Division), OpCodes.Div },
137      { typeof(Sine), OpCodes.Sin },
138      { typeof(Cosine), OpCodes.Cos },
139      { typeof(Tangent), OpCodes.Tan },
140      { typeof(Logarithm), OpCodes.Log },
141      { typeof(Exponential), OpCodes.Exp },
142      { typeof(IfThenElse), OpCodes.IfThenElse },
143      { typeof(GreaterThan), OpCodes.GT },
144      { typeof(LessThan), OpCodes.LT },
145      { typeof(And), OpCodes.AND },
146      { typeof(Or), OpCodes.OR },
147      { typeof(Not), OpCodes.NOT},
148      { typeof(Average), OpCodes.Average},
149      { typeof(InvokeFunction), OpCodes.Call },
150      { typeof(HeuristicLab.Problems.DataAnalysis.Symbolic.Variable), OpCodes.Variable },
151      { typeof(LaggedVariable), OpCodes.LagVariable },
152      { typeof(Constant), OpCodes.Constant },
153      { typeof(Argument), OpCodes.Arg },
154      { typeof(Power),OpCodes.Power},
155      { typeof(Root),OpCodes.Root},
156      { typeof(TimeLag), OpCodes.TimeLag},
157      { typeof(Integral), OpCodes.Integral},
158      { typeof(Derivative), OpCodes.Derivative},
159      { typeof(VariableCondition),OpCodes.VariableCondition}
160    };
161
162
163    public override bool CanChangeName {
164      get { return false; }
165    }
166    public override bool CanChangeDescription {
167      get { return false; }
168    }
169
170    [StorableConstructor]
171    private SymbolicDataAnalysisExpressionTreeInterpreter(bool deserializing) : base(deserializing) { }
172    private SymbolicDataAnalysisExpressionTreeInterpreter(SymbolicDataAnalysisExpressionTreeInterpreter original, Cloner cloner) : base(original, cloner) { }
173    public override IDeepCloneable Clone(Cloner cloner) {
174      return new SymbolicDataAnalysisExpressionTreeInterpreter(this, cloner);
175    }
176
177    public SymbolicDataAnalysisExpressionTreeInterpreter()
178      : base() {
179    }
180
181    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows) {
182      var compiler = new SymbolicExpressionTreeCompiler();
183      Instruction[] code = compiler.Compile(tree, MapSymbolToOpCode);
184
185      for (int i = 0; i < code.Length; i++) {
186        Instruction instr = code[i];
187        if (instr.opCode == OpCodes.Variable) {
188          var variableTreeNode = instr.dynamicNode as VariableTreeNode;
189          instr.iArg0 = (ushort)dataset.GetVariableIndex(variableTreeNode.VariableName);
190          code[i] = instr;
191        } else if (instr.opCode == OpCodes.LagVariable) {
192          var variableTreeNode = instr.dynamicNode as LaggedVariableTreeNode;
193          instr.iArg0 = (ushort)dataset.GetVariableIndex(variableTreeNode.VariableName);
194          code[i] = instr;
195        } else if (instr.opCode == OpCodes.VariableCondition) {
196          var variableConditionTreeNode = instr.dynamicNode as VariableConditionTreeNode;
197          instr.iArg0 = (ushort)dataset.GetVariableIndex(variableConditionTreeNode.VariableName);
198        }
199      }
200      var state = new InterpreterState(code);
201
202      foreach (var rowEnum in rows) {
203        int row = rowEnum;
204        state.Reset();
205        yield return Evaluate(dataset, ref row, state);
206      }
207    }
208
209    private double Evaluate(Dataset dataset, ref int row, InterpreterState state) {
210      Instruction currentInstr = state.NextInstruction();
211      switch (currentInstr.opCode) {
212        case OpCodes.Add: {
213            double s = Evaluate(dataset, ref row, state);
214            for (int i = 1; i < currentInstr.nArguments; i++) {
215              s += Evaluate(dataset, ref row, state);
216            }
217            return s;
218          }
219        case OpCodes.Sub: {
220            double s = Evaluate(dataset, ref row, state);
221            for (int i = 1; i < currentInstr.nArguments; i++) {
222              s -= Evaluate(dataset, ref row, state);
223            }
224            if (currentInstr.nArguments == 1) s = -s;
225            return s;
226          }
227        case OpCodes.Mul: {
228            double p = Evaluate(dataset, ref row, state);
229            for (int i = 1; i < currentInstr.nArguments; i++) {
230              p *= Evaluate(dataset, ref row, state);
231            }
232            return p;
233          }
234        case OpCodes.Div: {
235            double p = Evaluate(dataset, ref row, state);
236            for (int i = 1; i < currentInstr.nArguments; i++) {
237              p /= Evaluate(dataset, ref row, state);
238            }
239            if (currentInstr.nArguments == 1) p = 1.0 / p;
240            return p;
241          }
242        case OpCodes.Average: {
243            double sum = Evaluate(dataset, ref row, state);
244            for (int i = 1; i < currentInstr.nArguments; i++) {
245              sum += Evaluate(dataset, ref row, state);
246            }
247            return sum / currentInstr.nArguments;
248          }
249        case OpCodes.Cos: {
250            return Math.Cos(Evaluate(dataset, ref row, state));
251          }
252        case OpCodes.Sin: {
253            return Math.Sin(Evaluate(dataset, ref row, state));
254          }
255        case OpCodes.Tan: {
256            return Math.Tan(Evaluate(dataset, ref row, state));
257          }
258        case OpCodes.Power: {
259            double x = Evaluate(dataset, ref row, state);
260            double y = Math.Round(Evaluate(dataset, ref row, state));
261            return Math.Pow(x, y);
262          }
263        case OpCodes.Root: {
264            double x = Evaluate(dataset, ref row, state);
265            double y = Math.Round(Evaluate(dataset, ref row, state));
266            return Math.Pow(x, 1 / y);
267          }
268        case OpCodes.Exp: {
269            return Math.Exp(Evaluate(dataset, ref row, state));
270          }
271        case OpCodes.Log: {
272            return Math.Log(Evaluate(dataset, ref row, state));
273          }
274        case OpCodes.IfThenElse: {
275            double condition = Evaluate(dataset, ref row, state);
276            double result;
277            if (condition > 0.0) {
278              result = Evaluate(dataset, ref row, state); SkipInstructions(state);
279            } else {
280              SkipInstructions(state); result = Evaluate(dataset, ref row, state);
281            }
282            return result;
283          }
284        case OpCodes.AND: {
285            double result = Evaluate(dataset, ref row, state);
286            for (int i = 1; i < currentInstr.nArguments; i++) {
287              if (result <= 0.0) SkipInstructions(state);
288              else {
289                result = Evaluate(dataset, ref row, state);
290              }
291            }
292            return result <= 0.0 ? -1.0 : 1.0;
293          }
294        case OpCodes.OR: {
295            double result = Evaluate(dataset, ref row, state);
296            for (int i = 1; i < currentInstr.nArguments; i++) {
297              if (result > 0.0) SkipInstructions(state);
298              else {
299                result = Evaluate(dataset, ref row, state);
300              }
301            }
302            return result > 0.0 ? 1.0 : -1.0;
303          }
304        case OpCodes.NOT: {
305            return Evaluate(dataset, ref row, state) > 0.0 ? -1.0 : 1.0;
306          }
307        case OpCodes.GT: {
308            double x = Evaluate(dataset, ref row, state);
309            double y = Evaluate(dataset, ref row, state);
310            if (x > y) return 1.0;
311            else return -1.0;
312          }
313        case OpCodes.LT: {
314            double x = Evaluate(dataset, ref row, state);
315            double y = Evaluate(dataset, ref row, state);
316            if (x < y) return 1.0;
317            else return -1.0;
318          }
319        case OpCodes.TimeLag: {
320            var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
321            if (row + timeLagTreeNode.Lag < 0 || row + timeLagTreeNode.Lag >= dataset.Rows)
322              return double.NaN;
323
324            row += timeLagTreeNode.Lag;
325            double result = Evaluate(dataset, ref row, state);
326            row -= timeLagTreeNode.Lag;
327            return result;
328          }
329        case OpCodes.Integral: {
330            int savedPc = state.ProgramCounter;
331            var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
332            if (row + timeLagTreeNode.Lag < 0 || row + timeLagTreeNode.Lag >= dataset.Rows)
333              return double.NaN;
334            double sum = 0.0;
335            for (int i = 0; i < Math.Abs(timeLagTreeNode.Lag); i++) {
336              row += Math.Sign(timeLagTreeNode.Lag);
337              sum += Evaluate(dataset, ref row, state);
338              state.ProgramCounter = savedPc;
339            }
340            row -= timeLagTreeNode.Lag;
341            sum += Evaluate(dataset, ref row, state);
342            return sum;
343          }
344
345        //mkommend: derivate calculation taken from:
346        //http://www.holoborodko.com/pavel/numerical-methods/numerical-derivative/smooth-low-noise-differentiators/
347        //one sided smooth differentiatior, N = 4
348        // y' = 1/8h (f_i + 2f_i-1, -2 f_i-3 - f_i-4)
349        case OpCodes.Derivative: {
350            if (row - 4 < 0) return double.NaN;
351            int savedPc = state.ProgramCounter;
352            double f_0 = Evaluate(dataset, ref row, state); ; row--;
353            state.ProgramCounter = savedPc;
354            double f_1 = Evaluate(dataset, ref row, state); ; row -= 2;
355            state.ProgramCounter = savedPc;
356            double f_3 = Evaluate(dataset, ref row, state); ; row--;
357            state.ProgramCounter = savedPc;
358            double f_4 = Evaluate(dataset, ref row, state); ;
359            row += 4;
360
361            return (f_0 + 2 * f_1 - 2 * f_3 - f_4) / 8; // h = 1
362          }
363        case OpCodes.Call: {
364            // evaluate sub-trees
365            double[] argValues = new double[currentInstr.nArguments];
366            for (int i = 0; i < currentInstr.nArguments; i++) {
367              argValues[i] = Evaluate(dataset, ref row, state);
368            }
369            // push on argument values on stack
370            state.CreateStackFrame(argValues);
371
372            // save the pc
373            int savedPc = state.ProgramCounter;
374            // set pc to start of function 
375            state.ProgramCounter = currentInstr.iArg0;
376            // evaluate the function
377            double v = Evaluate(dataset, ref row, state);
378
379            // delete the stack frame
380            state.RemoveStackFrame();
381
382            // restore the pc => evaluation will continue at point after my subtrees 
383            state.ProgramCounter = savedPc;
384            return v;
385          }
386        case OpCodes.Arg: {
387            return state.GetStackFrameValue(currentInstr.iArg0);
388          }
389        case OpCodes.Variable: {
390            var variableTreeNode = currentInstr.dynamicNode as VariableTreeNode;
391            return dataset[row, currentInstr.iArg0] * variableTreeNode.Weight;
392          }
393        case OpCodes.LagVariable: {
394            var laggedVariableTreeNode = currentInstr.dynamicNode as LaggedVariableTreeNode;
395            int actualRow = row + laggedVariableTreeNode.Lag;
396            if (actualRow < 0 || actualRow >= dataset.Rows) throw new ArgumentException("Out of range access to dataset row: " + row);
397            return dataset[actualRow, currentInstr.iArg0] * laggedVariableTreeNode.Weight;
398          }
399        case OpCodes.Constant: {
400            var constTreeNode = currentInstr.dynamicNode as ConstantTreeNode;
401            return constTreeNode.Value;
402          }
403
404        //mkommend: this symbol uses the logistic function f(x) = 1 / (1 + e^(-alpha * x) )
405        //to determine the relative amounts of the true and false branch see http://en.wikipedia.org/wiki/Logistic_function
406        case OpCodes.VariableCondition: {
407            var variableConditionTreeNode = (VariableConditionTreeNode)currentInstr.dynamicNode;
408            double variableValue = dataset[row, currentInstr.iArg0];
409            double x = variableValue - variableConditionTreeNode.Threshold;
410            double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
411
412            double trueBranch = Evaluate(dataset, ref row, state);
413            double falseBranch = Evaluate(dataset, ref row, state);
414
415            return trueBranch * p + falseBranch * (1 - p);
416          }
417        default: throw new NotSupportedException();
418      }
419    }
420
421    private byte MapSymbolToOpCode(ISymbolicExpressionTreeNode treeNode) {
422      if (symbolToOpcode.ContainsKey(treeNode.Symbol.GetType()))
423        return symbolToOpcode[treeNode.Symbol.GetType()];
424      else
425        throw new NotSupportedException("Symbol: " + treeNode.Symbol);
426    }
427
428    // skips a whole branch
429    private void SkipInstructions(InterpreterState state) {
430      int i = 1;
431      while (i > 0) {
432        i += state.NextInstruction().nArguments;
433        i--;
434      }
435    }
436  }
437}
Note: See TracBrowser for help on using the repository browser.