Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/SymbolicDataAnalysisExpressionTreeInterpreter.cs @ 5894

Last change on this file since 5894 was 5894, checked in by gkronber, 13 years ago

#1453: Added an ErrorState property to online evaluators to indicate if the result value is valid or if there has been an error in the calculation. Adapted all classes that use one of the online evaluators to check this property.

File size: 18.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
27using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
28using HeuristicLab.Data;
29using HeuristicLab.Parameters;
30
31namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
32  [StorableClass]
33  [Item("SymbolicDataAnalysisExpressionTreeInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.")]
34  public sealed class SymbolicDataAnalysisExpressionTreeInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
35    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
36    #region private classes
37    private class InterpreterState {
38      private const int ARGUMENT_STACK_SIZE = 1024;
39      private double[] argumentStack;
40      private int argumentStackPointer;
41      private Instruction[] code;
42      private int pc;
43      public int ProgramCounter {
44        get { return pc; }
45        set { pc = value; }
46      }
47      internal InterpreterState(Instruction[] code) {
48        this.code = code;
49        this.pc = 0;
50        this.argumentStack = new double[ARGUMENT_STACK_SIZE];
51        this.argumentStackPointer = 0;
52      }
53
54      internal void Reset() {
55        this.pc = 0;
56        this.argumentStackPointer = 0;
57      }
58
59      internal Instruction NextInstruction() {
60        return code[pc++];
61      }
62      private void Push(double val) {
63        argumentStack[argumentStackPointer++] = val;
64      }
65      private double Pop() {
66        return argumentStack[--argumentStackPointer];
67      }
68
69      internal void CreateStackFrame(double[] argValues) {
70        // push in reverse order to make indexing easier
71        for (int i = argValues.Length - 1; i >= 0; i--) {
72          argumentStack[argumentStackPointer++] = argValues[i];
73        }
74        Push(argValues.Length);
75      }
76
77      internal void RemoveStackFrame() {
78        int size = (int)Pop();
79        argumentStackPointer -= size;
80      }
81
82      internal double GetStackFrameValue(ushort index) {
83        // layout of stack:
84        // [0]   <- argumentStackPointer
85        // [StackFrameSize = N + 1]
86        // [Arg0] <- argumentStackPointer - 2 - 0
87        // [Arg1] <- argumentStackPointer - 2 - 1
88        // [...]
89        // [ArgN] <- argumentStackPointer - 2 - N
90        // <Begin of stack frame>
91        return argumentStack[argumentStackPointer - index - 2];
92      }
93    }
94    private class OpCodes {
95      public const byte Add = 1;
96      public const byte Sub = 2;
97      public const byte Mul = 3;
98      public const byte Div = 4;
99
100      public const byte Sin = 5;
101      public const byte Cos = 6;
102      public const byte Tan = 7;
103
104      public const byte Log = 8;
105      public const byte Exp = 9;
106
107      public const byte IfThenElse = 10;
108
109      public const byte GT = 11;
110      public const byte LT = 12;
111
112      public const byte AND = 13;
113      public const byte OR = 14;
114      public const byte NOT = 15;
115
116
117      public const byte Average = 16;
118
119      public const byte Call = 17;
120
121      public const byte Variable = 18;
122      public const byte LagVariable = 19;
123      public const byte Constant = 20;
124      public const byte Arg = 21;
125
126      public const byte Power = 22;
127      public const byte Root = 23;
128      public const byte TimeLag = 24;
129      public const byte Integral = 25;
130      public const byte Derivative = 26;
131
132      public const byte VariableCondition = 27;
133    }
134    #endregion
135
136    private Dictionary<Type, byte> symbolToOpcode = new Dictionary<Type, byte>() {
137      { typeof(Addition), OpCodes.Add },
138      { typeof(Subtraction), OpCodes.Sub },
139      { typeof(Multiplication), OpCodes.Mul },
140      { typeof(Division), OpCodes.Div },
141      { typeof(Sine), OpCodes.Sin },
142      { typeof(Cosine), OpCodes.Cos },
143      { typeof(Tangent), OpCodes.Tan },
144      { typeof(Logarithm), OpCodes.Log },
145      { typeof(Exponential), OpCodes.Exp },
146      { typeof(IfThenElse), OpCodes.IfThenElse },
147      { typeof(GreaterThan), OpCodes.GT },
148      { typeof(LessThan), OpCodes.LT },
149      { typeof(And), OpCodes.AND },
150      { typeof(Or), OpCodes.OR },
151      { typeof(Not), OpCodes.NOT},
152      { typeof(Average), OpCodes.Average},
153      { typeof(InvokeFunction), OpCodes.Call },
154      { typeof(HeuristicLab.Problems.DataAnalysis.Symbolic.Variable), OpCodes.Variable },
155      { typeof(LaggedVariable), OpCodes.LagVariable },
156      { typeof(Constant), OpCodes.Constant },
157      { typeof(Argument), OpCodes.Arg },
158      { typeof(Power),OpCodes.Power},
159      { typeof(Root),OpCodes.Root},
160      { typeof(TimeLag), OpCodes.TimeLag},
161      { typeof(Integral), OpCodes.Integral},
162      { typeof(Derivative), OpCodes.Derivative},
163      { typeof(VariableCondition),OpCodes.VariableCondition}
164    };
165
166    public override bool CanChangeName {
167      get { return false; }
168    }
169    public override bool CanChangeDescription {
170      get { return false; }
171    }
172
173    #region parameter properties
174    public IValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
175      get { return (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
176    }
177    #endregion
178
179    #region properties
180    public BoolValue CheckExpressionsWithIntervalArithmetic {
181      get { return CheckExpressionsWithIntervalArithmeticParameter.Value; }
182      set { CheckExpressionsWithIntervalArithmeticParameter.Value = value; }
183    }
184    #endregion
185
186
187    [StorableConstructor]
188    private SymbolicDataAnalysisExpressionTreeInterpreter(bool deserializing) : base(deserializing) { }
189    private SymbolicDataAnalysisExpressionTreeInterpreter(SymbolicDataAnalysisExpressionTreeInterpreter original, Cloner cloner) : base(original, cloner) { }
190    public override IDeepCloneable Clone(Cloner cloner) {
191      return new SymbolicDataAnalysisExpressionTreeInterpreter(this, cloner);
192    }
193
194    public SymbolicDataAnalysisExpressionTreeInterpreter()
195      : base("SymbolicDataAnalysisExpressionTreeInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.") {
196      Parameters.Add(new ValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
197    }
198
199    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows) {
200      if (CheckExpressionsWithIntervalArithmetic.Value)
201        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
202      var compiler = new SymbolicExpressionTreeCompiler();
203      Instruction[] code = compiler.Compile(tree, MapSymbolToOpCode);
204
205      for (int i = 0; i < code.Length; i++) {
206        Instruction instr = code[i];
207        if (instr.opCode == OpCodes.Variable) {
208          var variableTreeNode = instr.dynamicNode as VariableTreeNode;
209          instr.iArg0 = (ushort)dataset.GetVariableIndex(variableTreeNode.VariableName);
210          instr.dArg0 = variableTreeNode.Weight;
211        } else if (instr.opCode == OpCodes.LagVariable) {
212          var variableTreeNode = instr.dynamicNode as LaggedVariableTreeNode;
213          instr.iArg0 = (ushort)dataset.GetVariableIndex(variableTreeNode.VariableName);
214          instr.dArg0 = variableTreeNode.Weight;
215        } else if (instr.opCode == OpCodes.VariableCondition) {
216          var variableConditionTreeNode = instr.dynamicNode as VariableConditionTreeNode;
217          instr.iArg0 = (ushort)dataset.GetVariableIndex(variableConditionTreeNode.VariableName);
218          instr.dArg0 = variableConditionTreeNode.Threshold;
219        } else if (instr.opCode == OpCodes.Constant) {
220          var constTreeNode = instr.dynamicNode as ConstantTreeNode;
221          instr.dArg0 = constTreeNode.Value;
222        }
223      }
224      var state = new InterpreterState(code);
225
226      foreach (var rowEnum in rows) {
227        int row = rowEnum;
228        state.Reset();
229        yield return Evaluate(dataset, ref row, state);
230      }
231    }
232
233    private double Evaluate(Dataset dataset, ref int row, InterpreterState state) {
234      Instruction currentInstr = state.NextInstruction();
235      switch (currentInstr.opCode) {
236        case OpCodes.Add: {
237            double s = Evaluate(dataset, ref row, state);
238            for (int i = 1; i < currentInstr.nArguments; i++) {
239              s += Evaluate(dataset, ref row, state);
240            }
241            return s;
242          }
243        case OpCodes.Sub: {
244            double s = Evaluate(dataset, ref row, state);
245            for (int i = 1; i < currentInstr.nArguments; i++) {
246              s -= Evaluate(dataset, ref row, state);
247            }
248            if (currentInstr.nArguments == 1) s = -s;
249            return s;
250          }
251        case OpCodes.Mul: {
252            double p = Evaluate(dataset, ref row, state);
253            for (int i = 1; i < currentInstr.nArguments; i++) {
254              p *= Evaluate(dataset, ref row, state);
255            }
256            return p;
257          }
258        case OpCodes.Div: {
259            double p = Evaluate(dataset, ref row, state);
260            for (int i = 1; i < currentInstr.nArguments; i++) {
261              p /= Evaluate(dataset, ref row, state);
262            }
263            if (currentInstr.nArguments == 1) p = 1.0 / p;
264            return p;
265          }
266        case OpCodes.Average: {
267            double sum = Evaluate(dataset, ref row, state);
268            for (int i = 1; i < currentInstr.nArguments; i++) {
269              sum += Evaluate(dataset, ref row, state);
270            }
271            return sum / currentInstr.nArguments;
272          }
273        case OpCodes.Cos: {
274            return Math.Cos(Evaluate(dataset, ref row, state));
275          }
276        case OpCodes.Sin: {
277            return Math.Sin(Evaluate(dataset, ref row, state));
278          }
279        case OpCodes.Tan: {
280            return Math.Tan(Evaluate(dataset, ref row, state));
281          }
282        case OpCodes.Power: {
283            double x = Evaluate(dataset, ref row, state);
284            double y = Math.Round(Evaluate(dataset, ref row, state));
285            return Math.Pow(x, y);
286          }
287        case OpCodes.Root: {
288            double x = Evaluate(dataset, ref row, state);
289            double y = Math.Round(Evaluate(dataset, ref row, state));
290            return Math.Pow(x, 1 / y);
291          }
292        case OpCodes.Exp: {
293            return Math.Exp(Evaluate(dataset, ref row, state));
294          }
295        case OpCodes.Log: {
296            return Math.Log(Evaluate(dataset, ref row, state));
297          }
298        case OpCodes.IfThenElse: {
299            double condition = Evaluate(dataset, ref row, state);
300            double result;
301            if (condition > 0.0) {
302              result = Evaluate(dataset, ref row, state); SkipInstructions(state);
303            } else {
304              SkipInstructions(state); result = Evaluate(dataset, ref row, state);
305            }
306            return result;
307          }
308        case OpCodes.AND: {
309            double result = Evaluate(dataset, ref row, state);
310            for (int i = 1; i < currentInstr.nArguments; i++) {
311              if (result <= 0.0) SkipInstructions(state);
312              else {
313                result = Evaluate(dataset, ref row, state);
314              }
315            }
316            return result <= 0.0 ? -1.0 : 1.0;
317          }
318        case OpCodes.OR: {
319            double result = Evaluate(dataset, ref row, state);
320            for (int i = 1; i < currentInstr.nArguments; i++) {
321              if (result > 0.0) SkipInstructions(state);
322              else {
323                result = Evaluate(dataset, ref row, state);
324              }
325            }
326            return result > 0.0 ? 1.0 : -1.0;
327          }
328        case OpCodes.NOT: {
329            return Evaluate(dataset, ref row, state) > 0.0 ? -1.0 : 1.0;
330          }
331        case OpCodes.GT: {
332            double x = Evaluate(dataset, ref row, state);
333            double y = Evaluate(dataset, ref row, state);
334            if (x > y) return 1.0;
335            else return -1.0;
336          }
337        case OpCodes.LT: {
338            double x = Evaluate(dataset, ref row, state);
339            double y = Evaluate(dataset, ref row, state);
340            if (x < y) return 1.0;
341            else return -1.0;
342          }
343        case OpCodes.TimeLag: {
344            var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
345            if (row + timeLagTreeNode.Lag < 0 || row + timeLagTreeNode.Lag >= dataset.Rows)
346              return double.NaN;
347
348            row += timeLagTreeNode.Lag;
349            double result = Evaluate(dataset, ref row, state);
350            row -= timeLagTreeNode.Lag;
351            return result;
352          }
353        case OpCodes.Integral: {
354            int savedPc = state.ProgramCounter;
355            var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
356            if (row + timeLagTreeNode.Lag < 0 || row + timeLagTreeNode.Lag >= dataset.Rows)
357              return double.NaN;
358            double sum = 0.0;
359            for (int i = 0; i < Math.Abs(timeLagTreeNode.Lag); i++) {
360              row += Math.Sign(timeLagTreeNode.Lag);
361              sum += Evaluate(dataset, ref row, state);
362              state.ProgramCounter = savedPc;
363            }
364            row -= timeLagTreeNode.Lag;
365            sum += Evaluate(dataset, ref row, state);
366            return sum;
367          }
368
369        //mkommend: derivate calculation taken from:
370        //http://www.holoborodko.com/pavel/numerical-methods/numerical-derivative/smooth-low-noise-differentiators/
371        //one sided smooth differentiatior, N = 4
372        // y' = 1/8h (f_i + 2f_i-1, -2 f_i-3 - f_i-4)
373        case OpCodes.Derivative: {
374            if (row - 4 < 0) return double.NaN;
375            int savedPc = state.ProgramCounter;
376            double f_0 = Evaluate(dataset, ref row, state); ; row--;
377            state.ProgramCounter = savedPc;
378            double f_1 = Evaluate(dataset, ref row, state); ; row -= 2;
379            state.ProgramCounter = savedPc;
380            double f_3 = Evaluate(dataset, ref row, state); ; row--;
381            state.ProgramCounter = savedPc;
382            double f_4 = Evaluate(dataset, ref row, state); ;
383            row += 4;
384
385            return (f_0 + 2 * f_1 - 2 * f_3 - f_4) / 8; // h = 1
386          }
387        case OpCodes.Call: {
388            // evaluate sub-trees
389            double[] argValues = new double[currentInstr.nArguments];
390            for (int i = 0; i < currentInstr.nArguments; i++) {
391              argValues[i] = Evaluate(dataset, ref row, state);
392            }
393            // push on argument values on stack
394            state.CreateStackFrame(argValues);
395
396            // save the pc
397            int savedPc = state.ProgramCounter;
398            // set pc to start of function 
399            state.ProgramCounter = currentInstr.iArg0;
400            // evaluate the function
401            double v = Evaluate(dataset, ref row, state);
402
403            // delete the stack frame
404            state.RemoveStackFrame();
405
406            // restore the pc => evaluation will continue at point after my subtrees 
407            state.ProgramCounter = savedPc;
408            return v;
409          }
410        case OpCodes.Arg: {
411            return state.GetStackFrameValue(currentInstr.iArg0);
412          }
413        case OpCodes.Variable: {
414            return dataset[row, currentInstr.iArg0] * currentInstr.dArg0;
415          }
416        case OpCodes.LagVariable: {
417            var laggedVariableTreeNode = currentInstr.dynamicNode as LaggedVariableTreeNode;
418            int actualRow = row + laggedVariableTreeNode.Lag;
419            if (actualRow < 0 || actualRow >= dataset.Rows) throw new ArgumentException("Out of range access to dataset row: " + row);
420            return dataset[actualRow, currentInstr.iArg0] * currentInstr.dArg0;
421          }
422        case OpCodes.Constant: {
423            return currentInstr.dArg0;
424          }
425
426        //mkommend: this symbol uses the logistic function f(x) = 1 / (1 + e^(-alpha * x) )
427        //to determine the relative amounts of the true and false branch see http://en.wikipedia.org/wiki/Logistic_function
428        case OpCodes.VariableCondition: {
429            var variableConditionTreeNode = (VariableConditionTreeNode)currentInstr.dynamicNode;
430            double variableValue = dataset[row, currentInstr.iArg0];
431            double x = variableValue - currentInstr.dArg0;
432            double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
433
434            double trueBranch = Evaluate(dataset, ref row, state);
435            double falseBranch = Evaluate(dataset, ref row, state);
436
437            return trueBranch * p + falseBranch * (1 - p);
438          }
439        default: throw new NotSupportedException();
440      }
441    }
442
443    private byte MapSymbolToOpCode(ISymbolicExpressionTreeNode treeNode) {
444      if (symbolToOpcode.ContainsKey(treeNode.Symbol.GetType()))
445        return symbolToOpcode[treeNode.Symbol.GetType()];
446      else
447        throw new NotSupportedException("Symbol: " + treeNode.Symbol);
448    }
449
450    // skips a whole branch
451    private void SkipInstructions(InterpreterState state) {
452      int i = 1;
453      while (i > 0) {
454        i += state.NextInstruction().nArguments;
455        i--;
456      }
457    }
458  }
459}
Note: See TracBrowser for help on using the repository browser.