Free cookie consent management tool by TermsFeed Policy Generator

source: branches/DataAnalysis SolutionEnsembles/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/SymbolicDataAnalysisExpressionTreeInterpreter.cs @ 5991

Last change on this file since 5991 was 5809, checked in by mkommend, 14 years ago

#1418: Reintegrated branch into trunk.

File size: 18.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Data;
29using HeuristicLab.Parameters;
30
31namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
32  [StorableClass]
33  [Item("SymbolicDataAnalysisExpressionTreeInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.")]
34  public sealed class SymbolicDataAnalysisExpressionTreeInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
35    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
36    #region private classes
37    private class InterpreterState {
38      private const int ARGUMENT_STACK_SIZE = 1024;
39      private double[] argumentStack;
40      private int argumentStackPointer;
41      private Instruction[] code;
42      private int pc;
43      public int ProgramCounter {
44        get { return pc; }
45        set { pc = value; }
46      }
47      internal InterpreterState(Instruction[] code) {
48        this.code = code;
49        this.pc = 0;
50        this.argumentStack = new double[ARGUMENT_STACK_SIZE];
51        this.argumentStackPointer = 0;
52      }
53
54      internal void Reset() {
55        this.pc = 0;
56        this.argumentStackPointer = 0;
57      }
58
59      internal Instruction NextInstruction() {
60        return code[pc++];
61      }
62      private void Push(double val) {
63        argumentStack[argumentStackPointer++] = val;
64      }
65      private double Pop() {
66        return argumentStack[--argumentStackPointer];
67      }
68
69      internal void CreateStackFrame(double[] argValues) {
70        // push in reverse order to make indexing easier
71        for (int i = argValues.Length - 1; i >= 0; i--) {
72          argumentStack[argumentStackPointer++] = argValues[i];
73        }
74        Push(argValues.Length);
75      }
76
77      internal void RemoveStackFrame() {
78        int size = (int)Pop();
79        argumentStackPointer -= size;
80      }
81
82      internal double GetStackFrameValue(ushort index) {
83        // layout of stack:
84        // [0]   <- argumentStackPointer
85        // [StackFrameSize = N + 1]
86        // [Arg0] <- argumentStackPointer - 2 - 0
87        // [Arg1] <- argumentStackPointer - 2 - 1
88        // [...]
89        // [ArgN] <- argumentStackPointer - 2 - N
90        // <Begin of stack frame>
91        return argumentStack[argumentStackPointer - index - 2];
92      }
93    }
94    private class OpCodes {
95      public const byte Add = 1;
96      public const byte Sub = 2;
97      public const byte Mul = 3;
98      public const byte Div = 4;
99
100      public const byte Sin = 5;
101      public const byte Cos = 6;
102      public const byte Tan = 7;
103
104      public const byte Log = 8;
105      public const byte Exp = 9;
106
107      public const byte IfThenElse = 10;
108
109      public const byte GT = 11;
110      public const byte LT = 12;
111
112      public const byte AND = 13;
113      public const byte OR = 14;
114      public const byte NOT = 15;
115
116
117      public const byte Average = 16;
118
119      public const byte Call = 17;
120
121      public const byte Variable = 18;
122      public const byte LagVariable = 19;
123      public const byte Constant = 20;
124      public const byte Arg = 21;
125
126      public const byte Power = 22;
127      public const byte Root = 23;
128      public const byte TimeLag = 24;
129      public const byte Integral = 25;
130      public const byte Derivative = 26;
131
132      public const byte VariableCondition = 27;
133    }
134    #endregion
135
136    private Dictionary<Type, byte> symbolToOpcode = new Dictionary<Type, byte>() {
137      { typeof(Addition), OpCodes.Add },
138      { typeof(Subtraction), OpCodes.Sub },
139      { typeof(Multiplication), OpCodes.Mul },
140      { typeof(Division), OpCodes.Div },
141      { typeof(Sine), OpCodes.Sin },
142      { typeof(Cosine), OpCodes.Cos },
143      { typeof(Tangent), OpCodes.Tan },
144      { typeof(Logarithm), OpCodes.Log },
145      { typeof(Exponential), OpCodes.Exp },
146      { typeof(IfThenElse), OpCodes.IfThenElse },
147      { typeof(GreaterThan), OpCodes.GT },
148      { typeof(LessThan), OpCodes.LT },
149      { typeof(And), OpCodes.AND },
150      { typeof(Or), OpCodes.OR },
151      { typeof(Not), OpCodes.NOT},
152      { typeof(Average), OpCodes.Average},
153      { typeof(InvokeFunction), OpCodes.Call },
154      { typeof(HeuristicLab.Problems.DataAnalysis.Symbolic.Variable), OpCodes.Variable },
155      { typeof(LaggedVariable), OpCodes.LagVariable },
156      { typeof(Constant), OpCodes.Constant },
157      { typeof(Argument), OpCodes.Arg },
158      { typeof(Power),OpCodes.Power},
159      { typeof(Root),OpCodes.Root},
160      { typeof(TimeLag), OpCodes.TimeLag},
161      { typeof(Integral), OpCodes.Integral},
162      { typeof(Derivative), OpCodes.Derivative},
163      { typeof(VariableCondition),OpCodes.VariableCondition}
164    };
165
166    public override bool CanChangeName {
167      get { return false; }
168    }
169    public override bool CanChangeDescription {
170      get { return false; }
171    }
172
173    #region parameter properties
174    public IValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
175      get { return (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
176    }
177    #endregion
178
179    #region properties
180    public BoolValue CheckExpressionsWithIntervalArithmetic {
181      get { return CheckExpressionsWithIntervalArithmeticParameter.Value; }
182      set { CheckExpressionsWithIntervalArithmeticParameter.Value = value; }
183    }
184    #endregion
185
186
187    [StorableConstructor]
188    private SymbolicDataAnalysisExpressionTreeInterpreter(bool deserializing) : base(deserializing) { }
189    private SymbolicDataAnalysisExpressionTreeInterpreter(SymbolicDataAnalysisExpressionTreeInterpreter original, Cloner cloner) : base(original, cloner) { }
190    public override IDeepCloneable Clone(Cloner cloner) {
191      return new SymbolicDataAnalysisExpressionTreeInterpreter(this, cloner);
192    }
193
194    public SymbolicDataAnalysisExpressionTreeInterpreter()
195      : base("SymbolicDataAnalysisExpressionTreeInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.") {
196      Parameters.Add(new ValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
197    }
198
199    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows) {
200      if (CheckExpressionsWithIntervalArithmetic.Value)
201        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
202      var compiler = new SymbolicExpressionTreeCompiler();
203      Instruction[] code = compiler.Compile(tree, MapSymbolToOpCode);
204
205      for (int i = 0; i < code.Length; i++) {
206        Instruction instr = code[i];
207        if (instr.opCode == OpCodes.Variable) {
208          var variableTreeNode = instr.dynamicNode as VariableTreeNode;
209          instr.iArg0 = (ushort)dataset.GetVariableIndex(variableTreeNode.VariableName);
210          code[i] = instr;
211        } else if (instr.opCode == OpCodes.LagVariable) {
212          var variableTreeNode = instr.dynamicNode as LaggedVariableTreeNode;
213          instr.iArg0 = (ushort)dataset.GetVariableIndex(variableTreeNode.VariableName);
214          code[i] = instr;
215        } else if (instr.opCode == OpCodes.VariableCondition) {
216          var variableConditionTreeNode = instr.dynamicNode as VariableConditionTreeNode;
217          instr.iArg0 = (ushort)dataset.GetVariableIndex(variableConditionTreeNode.VariableName);
218        }
219      }
220      var state = new InterpreterState(code);
221
222      foreach (var rowEnum in rows) {
223        int row = rowEnum;
224        state.Reset();
225        yield return Evaluate(dataset, ref row, state);
226      }
227    }
228
229    private double Evaluate(Dataset dataset, ref int row, InterpreterState state) {
230      Instruction currentInstr = state.NextInstruction();
231      switch (currentInstr.opCode) {
232        case OpCodes.Add: {
233            double s = Evaluate(dataset, ref row, state);
234            for (int i = 1; i < currentInstr.nArguments; i++) {
235              s += Evaluate(dataset, ref row, state);
236            }
237            return s;
238          }
239        case OpCodes.Sub: {
240            double s = Evaluate(dataset, ref row, state);
241            for (int i = 1; i < currentInstr.nArguments; i++) {
242              s -= Evaluate(dataset, ref row, state);
243            }
244            if (currentInstr.nArguments == 1) s = -s;
245            return s;
246          }
247        case OpCodes.Mul: {
248            double p = Evaluate(dataset, ref row, state);
249            for (int i = 1; i < currentInstr.nArguments; i++) {
250              p *= Evaluate(dataset, ref row, state);
251            }
252            return p;
253          }
254        case OpCodes.Div: {
255            double p = Evaluate(dataset, ref row, state);
256            for (int i = 1; i < currentInstr.nArguments; i++) {
257              p /= Evaluate(dataset, ref row, state);
258            }
259            if (currentInstr.nArguments == 1) p = 1.0 / p;
260            return p;
261          }
262        case OpCodes.Average: {
263            double sum = Evaluate(dataset, ref row, state);
264            for (int i = 1; i < currentInstr.nArguments; i++) {
265              sum += Evaluate(dataset, ref row, state);
266            }
267            return sum / currentInstr.nArguments;
268          }
269        case OpCodes.Cos: {
270            return Math.Cos(Evaluate(dataset, ref row, state));
271          }
272        case OpCodes.Sin: {
273            return Math.Sin(Evaluate(dataset, ref row, state));
274          }
275        case OpCodes.Tan: {
276            return Math.Tan(Evaluate(dataset, ref row, state));
277          }
278        case OpCodes.Power: {
279            double x = Evaluate(dataset, ref row, state);
280            double y = Math.Round(Evaluate(dataset, ref row, state));
281            return Math.Pow(x, y);
282          }
283        case OpCodes.Root: {
284            double x = Evaluate(dataset, ref row, state);
285            double y = Math.Round(Evaluate(dataset, ref row, state));
286            return Math.Pow(x, 1 / y);
287          }
288        case OpCodes.Exp: {
289            return Math.Exp(Evaluate(dataset, ref row, state));
290          }
291        case OpCodes.Log: {
292            return Math.Log(Evaluate(dataset, ref row, state));
293          }
294        case OpCodes.IfThenElse: {
295            double condition = Evaluate(dataset, ref row, state);
296            double result;
297            if (condition > 0.0) {
298              result = Evaluate(dataset, ref row, state); SkipInstructions(state);
299            } else {
300              SkipInstructions(state); result = Evaluate(dataset, ref row, state);
301            }
302            return result;
303          }
304        case OpCodes.AND: {
305            double result = Evaluate(dataset, ref row, state);
306            for (int i = 1; i < currentInstr.nArguments; i++) {
307              if (result <= 0.0) SkipInstructions(state);
308              else {
309                result = Evaluate(dataset, ref row, state);
310              }
311            }
312            return result <= 0.0 ? -1.0 : 1.0;
313          }
314        case OpCodes.OR: {
315            double result = Evaluate(dataset, ref row, state);
316            for (int i = 1; i < currentInstr.nArguments; i++) {
317              if (result > 0.0) SkipInstructions(state);
318              else {
319                result = Evaluate(dataset, ref row, state);
320              }
321            }
322            return result > 0.0 ? 1.0 : -1.0;
323          }
324        case OpCodes.NOT: {
325            return Evaluate(dataset, ref row, state) > 0.0 ? -1.0 : 1.0;
326          }
327        case OpCodes.GT: {
328            double x = Evaluate(dataset, ref row, state);
329            double y = Evaluate(dataset, ref row, state);
330            if (x > y) return 1.0;
331            else return -1.0;
332          }
333        case OpCodes.LT: {
334            double x = Evaluate(dataset, ref row, state);
335            double y = Evaluate(dataset, ref row, state);
336            if (x < y) return 1.0;
337            else return -1.0;
338          }
339        case OpCodes.TimeLag: {
340            var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
341            if (row + timeLagTreeNode.Lag < 0 || row + timeLagTreeNode.Lag >= dataset.Rows)
342              return double.NaN;
343
344            row += timeLagTreeNode.Lag;
345            double result = Evaluate(dataset, ref row, state);
346            row -= timeLagTreeNode.Lag;
347            return result;
348          }
349        case OpCodes.Integral: {
350            int savedPc = state.ProgramCounter;
351            var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
352            if (row + timeLagTreeNode.Lag < 0 || row + timeLagTreeNode.Lag >= dataset.Rows)
353              return double.NaN;
354            double sum = 0.0;
355            for (int i = 0; i < Math.Abs(timeLagTreeNode.Lag); i++) {
356              row += Math.Sign(timeLagTreeNode.Lag);
357              sum += Evaluate(dataset, ref row, state);
358              state.ProgramCounter = savedPc;
359            }
360            row -= timeLagTreeNode.Lag;
361            sum += Evaluate(dataset, ref row, state);
362            return sum;
363          }
364
365        //mkommend: derivate calculation taken from:
366        //http://www.holoborodko.com/pavel/numerical-methods/numerical-derivative/smooth-low-noise-differentiators/
367        //one sided smooth differentiatior, N = 4
368        // y' = 1/8h (f_i + 2f_i-1, -2 f_i-3 - f_i-4)
369        case OpCodes.Derivative: {
370            if (row - 4 < 0) return double.NaN;
371            int savedPc = state.ProgramCounter;
372            double f_0 = Evaluate(dataset, ref row, state); ; row--;
373            state.ProgramCounter = savedPc;
374            double f_1 = Evaluate(dataset, ref row, state); ; row -= 2;
375            state.ProgramCounter = savedPc;
376            double f_3 = Evaluate(dataset, ref row, state); ; row--;
377            state.ProgramCounter = savedPc;
378            double f_4 = Evaluate(dataset, ref row, state); ;
379            row += 4;
380
381            return (f_0 + 2 * f_1 - 2 * f_3 - f_4) / 8; // h = 1
382          }
383        case OpCodes.Call: {
384            // evaluate sub-trees
385            double[] argValues = new double[currentInstr.nArguments];
386            for (int i = 0; i < currentInstr.nArguments; i++) {
387              argValues[i] = Evaluate(dataset, ref row, state);
388            }
389            // push on argument values on stack
390            state.CreateStackFrame(argValues);
391
392            // save the pc
393            int savedPc = state.ProgramCounter;
394            // set pc to start of function 
395            state.ProgramCounter = currentInstr.iArg0;
396            // evaluate the function
397            double v = Evaluate(dataset, ref row, state);
398
399            // delete the stack frame
400            state.RemoveStackFrame();
401
402            // restore the pc => evaluation will continue at point after my subtrees 
403            state.ProgramCounter = savedPc;
404            return v;
405          }
406        case OpCodes.Arg: {
407            return state.GetStackFrameValue(currentInstr.iArg0);
408          }
409        case OpCodes.Variable: {
410            var variableTreeNode = currentInstr.dynamicNode as VariableTreeNode;
411            return dataset[row, currentInstr.iArg0] * variableTreeNode.Weight;
412          }
413        case OpCodes.LagVariable: {
414            var laggedVariableTreeNode = currentInstr.dynamicNode as LaggedVariableTreeNode;
415            int actualRow = row + laggedVariableTreeNode.Lag;
416            if (actualRow < 0 || actualRow >= dataset.Rows) throw new ArgumentException("Out of range access to dataset row: " + row);
417            return dataset[actualRow, currentInstr.iArg0] * laggedVariableTreeNode.Weight;
418          }
419        case OpCodes.Constant: {
420            var constTreeNode = currentInstr.dynamicNode as ConstantTreeNode;
421            return constTreeNode.Value;
422          }
423
424        //mkommend: this symbol uses the logistic function f(x) = 1 / (1 + e^(-alpha * x) )
425        //to determine the relative amounts of the true and false branch see http://en.wikipedia.org/wiki/Logistic_function
426        case OpCodes.VariableCondition: {
427            var variableConditionTreeNode = (VariableConditionTreeNode)currentInstr.dynamicNode;
428            double variableValue = dataset[row, currentInstr.iArg0];
429            double x = variableValue - variableConditionTreeNode.Threshold;
430            double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
431
432            double trueBranch = Evaluate(dataset, ref row, state);
433            double falseBranch = Evaluate(dataset, ref row, state);
434
435            return trueBranch * p + falseBranch * (1 - p);
436          }
437        default: throw new NotSupportedException();
438      }
439    }
440
441    private byte MapSymbolToOpCode(ISymbolicExpressionTreeNode treeNode) {
442      if (symbolToOpcode.ContainsKey(treeNode.Symbol.GetType()))
443        return symbolToOpcode[treeNode.Symbol.GetType()];
444      else
445        throw new NotSupportedException("Symbol: " + treeNode.Symbol);
446    }
447
448    // skips a whole branch
449    private void SkipInstructions(InterpreterState state) {
450      int i = 1;
451      while (i > 0) {
452        i += state.NextInstruction().nArguments;
453        i--;
454      }
455    }
456  }
457}
Note: See TracBrowser for help on using the repository browser.