Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/SymbolicDataAnalysisExpressionTreeILEmittingInterpreter.cs @ 6740

Last change on this file since 6740 was 6740, checked in by mkommend, 13 years ago

#1597, #1609, #1640:

  • Corrected TableFileParser to handle empty rows correctly.
  • Refactored DataSet to store values in List<List> instead of a two-dimensional array.
  • Enable importing and storing string and datetime values.
  • Changed data access methods in dataset and adapted all concerning classes.
  • Changed interpreter to store the variable values for all rows during the compilation step.
File size: 21.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Reflection;
25using System.Reflection.Emit;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32
33namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
34  [StorableClass]
35  [Item("SymbolicDataAnalysisExpressionTreeILEmittingInterpreter", "Interpreter for symbolic expression trees.")]
36  public sealed class SymbolicDataAnalysisExpressionTreeILEmittingInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
37    private static MethodInfo datasetGetValue = typeof(ThresholdCalculator).Assembly.GetType("HeuristicLab.Problems.DataAnalysis.Dataset").GetProperty("Item", new Type[] { typeof(int), typeof(int) }).GetGetMethod();
38    private static MethodInfo cos = typeof(Math).GetMethod("Cos", new Type[] { typeof(double) });
39    private static MethodInfo sin = typeof(Math).GetMethod("Sin", new Type[] { typeof(double) });
40    private static MethodInfo tan = typeof(Math).GetMethod("Tan", new Type[] { typeof(double) });
41    private static MethodInfo exp = typeof(Math).GetMethod("Exp", new Type[] { typeof(double) });
42    private static MethodInfo log = typeof(Math).GetMethod("Log", new Type[] { typeof(double) });
43    private static MethodInfo sign = typeof(Math).GetMethod("Sign", new Type[] { typeof(double) });
44    private static MethodInfo power = typeof(Math).GetMethod("Pow", new Type[] { typeof(double), typeof(double) });
45    private static MethodInfo sqrt = typeof(Math).GetMethod("Sqrt", new Type[] { typeof(double) });
46    private static MethodInfo isNan = typeof(double).GetMethod("IsNaN", new Type[] { typeof(double) });
47    private static MethodInfo abs = typeof(Math).GetMethod("Abs", new Type[] { typeof(double) });
48    private const double EPSILON = 1.0E-6;
49
50    internal delegate double CompiledFunction(Dataset dataset, int sampleIndex);
51    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
52    #region private classes
53    private class InterpreterState {
54      private double[] argumentStack;
55      private int argumentStackPointer;
56      private Instruction[] code;
57      private int pc;
58      public int ProgramCounter {
59        get { return pc; }
60        set { pc = value; }
61      }
62      internal InterpreterState(Instruction[] code, int argumentStackSize) {
63        this.code = code;
64        this.pc = 0;
65        if (argumentStackSize > 0) {
66          this.argumentStack = new double[argumentStackSize];
67        }
68        this.argumentStackPointer = 0;
69      }
70
71      internal void Reset() {
72        this.pc = 0;
73        this.argumentStackPointer = 0;
74      }
75
76      internal Instruction NextInstruction() {
77        return code[pc++];
78      }
79      private void Push(double val) {
80        argumentStack[argumentStackPointer++] = val;
81      }
82      private double Pop() {
83        return argumentStack[--argumentStackPointer];
84      }
85
86      internal void CreateStackFrame(double[] argValues) {
87        // push in reverse order to make indexing easier
88        for (int i = argValues.Length - 1; i >= 0; i--) {
89          argumentStack[argumentStackPointer++] = argValues[i];
90        }
91        Push(argValues.Length);
92      }
93
94      internal void RemoveStackFrame() {
95        int size = (int)Pop();
96        argumentStackPointer -= size;
97      }
98
99      internal double GetStackFrameValue(ushort index) {
100        // layout of stack:
101        // [0]   <- argumentStackPointer
102        // [StackFrameSize = N + 1]
103        // [Arg0] <- argumentStackPointer - 2 - 0
104        // [Arg1] <- argumentStackPointer - 2 - 1
105        // [...]
106        // [ArgN] <- argumentStackPointer - 2 - N
107        // <Begin of stack frame>
108        return argumentStack[argumentStackPointer - index - 2];
109      }
110    }
111    private class OpCodes {
112      public const byte Add = 1;
113      public const byte Sub = 2;
114      public const byte Mul = 3;
115      public const byte Div = 4;
116
117      public const byte Sin = 5;
118      public const byte Cos = 6;
119      public const byte Tan = 7;
120
121      public const byte Log = 8;
122      public const byte Exp = 9;
123
124      public const byte IfThenElse = 10;
125
126      public const byte GT = 11;
127      public const byte LT = 12;
128
129      public const byte AND = 13;
130      public const byte OR = 14;
131      public const byte NOT = 15;
132
133
134      public const byte Average = 16;
135
136      public const byte Call = 17;
137
138      public const byte Variable = 18;
139      public const byte LagVariable = 19;
140      public const byte Constant = 20;
141      public const byte Arg = 21;
142
143      public const byte Power = 22;
144      public const byte Root = 23;
145      public const byte TimeLag = 24;
146      public const byte Integral = 25;
147      public const byte Derivative = 26;
148
149      public const byte VariableCondition = 27;
150    }
151    #endregion
152
153    private Dictionary<Type, byte> symbolToOpcode = new Dictionary<Type, byte>() {
154      { typeof(Addition), OpCodes.Add },
155      { typeof(Subtraction), OpCodes.Sub },
156      { typeof(Multiplication), OpCodes.Mul },
157      { typeof(Division), OpCodes.Div },
158      { typeof(Sine), OpCodes.Sin },
159      { typeof(Cosine), OpCodes.Cos },
160      { typeof(Tangent), OpCodes.Tan },
161      { typeof(Logarithm), OpCodes.Log },
162      { typeof(Exponential), OpCodes.Exp },
163      { typeof(IfThenElse), OpCodes.IfThenElse },
164      { typeof(GreaterThan), OpCodes.GT },
165      { typeof(LessThan), OpCodes.LT },
166      { typeof(And), OpCodes.AND },
167      { typeof(Or), OpCodes.OR },
168      { typeof(Not), OpCodes.NOT},
169      { typeof(Average), OpCodes.Average},
170      { typeof(InvokeFunction), OpCodes.Call },
171      { typeof(HeuristicLab.Problems.DataAnalysis.Symbolic.Variable), OpCodes.Variable },
172      { typeof(LaggedVariable), OpCodes.LagVariable },
173      { typeof(Constant), OpCodes.Constant },
174      { typeof(Argument), OpCodes.Arg },
175      { typeof(Power),OpCodes.Power},
176      { typeof(Root),OpCodes.Root},
177      { typeof(TimeLag), OpCodes.TimeLag},
178      { typeof(Integral), OpCodes.Integral},
179      { typeof(Derivative), OpCodes.Derivative},
180      { typeof(VariableCondition),OpCodes.VariableCondition}
181    };
182
183    public override bool CanChangeName {
184      get { return false; }
185    }
186    public override bool CanChangeDescription {
187      get { return false; }
188    }
189
190    #region parameter properties
191    public IValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
192      get { return (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
193    }
194    #endregion
195
196    #region properties
197    public BoolValue CheckExpressionsWithIntervalArithmetic {
198      get { return CheckExpressionsWithIntervalArithmeticParameter.Value; }
199      set { CheckExpressionsWithIntervalArithmeticParameter.Value = value; }
200    }
201    #endregion
202
203
204    [StorableConstructor]
205    private SymbolicDataAnalysisExpressionTreeILEmittingInterpreter(bool deserializing) : base(deserializing) { }
206    private SymbolicDataAnalysisExpressionTreeILEmittingInterpreter(SymbolicDataAnalysisExpressionTreeILEmittingInterpreter original, Cloner cloner) : base(original, cloner) { }
207    public override IDeepCloneable Clone(Cloner cloner) {
208      return new SymbolicDataAnalysisExpressionTreeILEmittingInterpreter(this, cloner);
209    }
210
211    public SymbolicDataAnalysisExpressionTreeILEmittingInterpreter()
212      : base("SymbolicDataAnalysisExpressionTreeILEmittingInterpreter", "Interpreter for symbolic expression trees.") {
213      Parameters.Add(new ValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
214    }
215
216    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows) {
217      if (CheckExpressionsWithIntervalArithmetic.Value)
218        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
219      var compiler = new SymbolicExpressionTreeCompiler();
220      Instruction[] code = compiler.Compile(tree, MapSymbolToOpCode);
221      int necessaryArgStackSize = 0;
222      for (int i = 0; i < code.Length; i++) {
223        Instruction instr = code[i];
224        if (instr.opCode == OpCodes.Variable) {
225          var variableTreeNode = instr.dynamicNode as VariableTreeNode;
226          instr.iArg0 = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
227          code[i] = instr;
228        } else if (instr.opCode == OpCodes.LagVariable) {
229          var variableTreeNode = instr.dynamicNode as LaggedVariableTreeNode;
230          instr.iArg0 = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
231          code[i] = instr;
232        } else if (instr.opCode == OpCodes.VariableCondition) {
233          var variableConditionTreeNode = instr.dynamicNode as VariableConditionTreeNode;
234          instr.iArg0 = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
235        } else if (instr.opCode == OpCodes.Call) {
236          necessaryArgStackSize += instr.nArguments + 1;
237        }
238      }
239      var state = new InterpreterState(code, necessaryArgStackSize);
240
241      Type[] methodArgs = { typeof(Dataset), typeof(int) };
242      DynamicMethod testFun = new DynamicMethod("TestFun", typeof(double), methodArgs, typeof(SymbolicDataAnalysisExpressionTreeILEmittingInterpreter).Module);
243
244      ILGenerator il = testFun.GetILGenerator();
245      CompileInstructions(il, state);
246      il.Emit(System.Reflection.Emit.OpCodes.Conv_R8);
247      il.Emit(System.Reflection.Emit.OpCodes.Ret);
248      var function = (CompiledFunction)testFun.CreateDelegate(typeof(CompiledFunction));
249
250      foreach (var row in rows) {
251        yield return function(dataset, row);
252      }
253    }
254
255    private void CompileInstructions(ILGenerator il, InterpreterState state) {
256      Instruction currentInstr = state.NextInstruction();
257      int nArgs = currentInstr.nArguments;
258
259      switch (currentInstr.opCode) {
260        case OpCodes.Add: {
261            if (nArgs > 0) {
262              CompileInstructions(il, state);
263            }
264            for (int i = 1; i < nArgs; i++) {
265              CompileInstructions(il, state);
266              il.Emit(System.Reflection.Emit.OpCodes.Add);
267            }
268            return;
269          }
270        case OpCodes.Sub: {
271            if (nArgs == 1) {
272              CompileInstructions(il, state);
273              il.Emit(System.Reflection.Emit.OpCodes.Neg);
274              return;
275            }
276            if (nArgs > 0) {
277              CompileInstructions(il, state);
278            }
279            for (int i = 1; i < nArgs; i++) {
280              CompileInstructions(il, state);
281              il.Emit(System.Reflection.Emit.OpCodes.Sub);
282            }
283            return;
284          }
285        case OpCodes.Mul: {
286            if (nArgs > 0) {
287              CompileInstructions(il, state);
288            }
289            for (int i = 1; i < nArgs; i++) {
290              CompileInstructions(il, state);
291              il.Emit(System.Reflection.Emit.OpCodes.Mul);
292            }
293            return;
294          }
295        case OpCodes.Div: {
296            if (nArgs == 1) {
297              il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 1.0);
298              CompileInstructions(il, state);
299              il.Emit(System.Reflection.Emit.OpCodes.Div);
300              return;
301            }
302            if (nArgs > 0) {
303              CompileInstructions(il, state);
304            }
305            for (int i = 1; i < nArgs; i++) {
306              CompileInstructions(il, state);
307              il.Emit(System.Reflection.Emit.OpCodes.Div);
308            }
309            return;
310          }
311        case OpCodes.Average: {
312            CompileInstructions(il, state);
313            for (int i = 1; i < nArgs; i++) {
314              CompileInstructions(il, state);
315              il.Emit(System.Reflection.Emit.OpCodes.Add);
316            }
317            il.Emit(System.Reflection.Emit.OpCodes.Ldc_I4, nArgs);
318            il.Emit(System.Reflection.Emit.OpCodes.Div);
319            return;
320          }
321        case OpCodes.Cos: {
322            CompileInstructions(il, state);
323            il.Emit(System.Reflection.Emit.OpCodes.Call, cos);
324            return;
325          }
326        case OpCodes.Sin: {
327            CompileInstructions(il, state);
328            il.Emit(System.Reflection.Emit.OpCodes.Call, sin);
329            return;
330          }
331        case OpCodes.Tan: {
332            CompileInstructions(il, state);
333            il.Emit(System.Reflection.Emit.OpCodes.Call, tan);
334            return;
335          }
336        case OpCodes.Power: {
337            CompileInstructions(il, state);
338            CompileInstructions(il, state);
339            il.Emit(System.Reflection.Emit.OpCodes.Call, power);
340            return;
341          }
342        case OpCodes.Root: {
343            throw new NotImplementedException();
344          }
345        case OpCodes.Exp: {
346            CompileInstructions(il, state);
347            il.Emit(System.Reflection.Emit.OpCodes.Call, exp);
348            return;
349          }
350        case OpCodes.Log: {
351            CompileInstructions(il, state);
352            il.Emit(System.Reflection.Emit.OpCodes.Call, log);
353            return;
354          }
355        case OpCodes.IfThenElse: {
356            Label end = il.DefineLabel();
357            Label c1 = il.DefineLabel();
358            CompileInstructions(il, state);
359            il.Emit(System.Reflection.Emit.OpCodes.Ldc_I4_0); // > 0
360            il.Emit(System.Reflection.Emit.OpCodes.Cgt);
361            il.Emit(System.Reflection.Emit.OpCodes.Brfalse, c1);
362            CompileInstructions(il, state);
363            il.Emit(System.Reflection.Emit.OpCodes.Br, end);
364            il.MarkLabel(c1);
365            CompileInstructions(il, state);
366            il.MarkLabel(end);
367            return;
368          }
369        case OpCodes.AND: {
370            Label falseBranch = il.DefineLabel();
371            Label end = il.DefineLabel();
372            CompileInstructions(il, state);
373            for (int i = 1; i < nArgs; i++) {
374              il.Emit(System.Reflection.Emit.OpCodes.Ldc_I4_0); // > 0
375              il.Emit(System.Reflection.Emit.OpCodes.Cgt);
376              il.Emit(System.Reflection.Emit.OpCodes.Brfalse, falseBranch);
377              CompileInstructions(il, state);
378            }
379            il.Emit(System.Reflection.Emit.OpCodes.Ldc_I4_0); // > 0
380            il.Emit(System.Reflection.Emit.OpCodes.Cgt);
381            il.Emit(System.Reflection.Emit.OpCodes.Brfalse, falseBranch);
382            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 1.0); // 1
383            il.Emit(System.Reflection.Emit.OpCodes.Br, end);
384            il.MarkLabel(falseBranch);
385            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 1.0); // -1
386            il.Emit(System.Reflection.Emit.OpCodes.Neg);
387            il.MarkLabel(end);
388            return;
389          }
390        case OpCodes.OR: {
391            Label trueBranch = il.DefineLabel();
392            Label end = il.DefineLabel();
393            Label resultBranch = il.DefineLabel();
394            CompileInstructions(il, state);
395            for (int i = 1; i < nArgs; i++) {
396              Label nextArgBranch = il.DefineLabel();
397              // complex definition because of special properties of NaN 
398              il.Emit(System.Reflection.Emit.OpCodes.Dup);
399              il.Emit(System.Reflection.Emit.OpCodes.Ldc_I4_0); // <= 0       
400              il.Emit(System.Reflection.Emit.OpCodes.Ble, nextArgBranch);
401              il.Emit(System.Reflection.Emit.OpCodes.Br, resultBranch);
402              il.MarkLabel(nextArgBranch);
403              il.Emit(System.Reflection.Emit.OpCodes.Pop);
404              CompileInstructions(il, state);
405            }
406            il.MarkLabel(resultBranch);
407            il.Emit(System.Reflection.Emit.OpCodes.Ldc_I4_0); // > 0
408            il.Emit(System.Reflection.Emit.OpCodes.Cgt);
409            il.Emit(System.Reflection.Emit.OpCodes.Brtrue, trueBranch);
410            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 1.0); // -1
411            il.Emit(System.Reflection.Emit.OpCodes.Neg);
412            il.Emit(System.Reflection.Emit.OpCodes.Br, end);
413            il.MarkLabel(trueBranch);
414            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 1.0); // 1
415            il.MarkLabel(end);
416            return;
417          }
418        case OpCodes.NOT: {
419            CompileInstructions(il, state);
420            il.Emit(System.Reflection.Emit.OpCodes.Ldc_I4_0); // > 0
421            il.Emit(System.Reflection.Emit.OpCodes.Cgt);
422            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 2.0); // * 2
423            il.Emit(System.Reflection.Emit.OpCodes.Mul);
424            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 1.0); // - 1
425            il.Emit(System.Reflection.Emit.OpCodes.Sub);
426            il.Emit(System.Reflection.Emit.OpCodes.Neg); // * -1
427            return;
428          }
429        case OpCodes.GT: {
430            CompileInstructions(il, state);
431            CompileInstructions(il, state);
432            il.Emit(System.Reflection.Emit.OpCodes.Cgt); // 1 (>) / 0 (otherwise)
433            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 2.0); // * 2
434            il.Emit(System.Reflection.Emit.OpCodes.Mul);
435            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 1.0); // - 1
436            il.Emit(System.Reflection.Emit.OpCodes.Sub);
437            return;
438          }
439        case OpCodes.LT: {
440            CompileInstructions(il, state);
441            CompileInstructions(il, state);
442            il.Emit(System.Reflection.Emit.OpCodes.Clt);
443            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 2.0); // * 2
444            il.Emit(System.Reflection.Emit.OpCodes.Mul);
445            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 1.0); // - 1
446            il.Emit(System.Reflection.Emit.OpCodes.Sub);
447            return;
448          }
449        case OpCodes.TimeLag: {
450            throw new NotImplementedException();
451          }
452        case OpCodes.Integral: {
453            throw new NotImplementedException();
454          }
455
456        //mkommend: derivate calculation taken from:
457        //http://www.holoborodko.com/pavel/numerical-methods/numerical-derivative/smooth-low-noise-differentiators/
458        //one sided smooth differentiatior, N = 4
459        // y' = 1/8h (f_i + 2f_i-1, -2 f_i-3 - f_i-4)
460        case OpCodes.Derivative: {
461            throw new NotImplementedException();
462          }
463        case OpCodes.Call: {
464            throw new NotImplementedException();
465          }
466        case OpCodes.Arg: {
467            throw new NotImplementedException();
468          }
469        case OpCodes.Variable: {
470            //VariableTreeNode varNode = (VariableTreeNode)currentInstr.dynamicNode;
471            //il.Emit(System.Reflection.Emit.OpCodes.Ldarg_0); // load dataset
472            //il.Emit(System.Reflection.Emit.OpCodes.Ldc_I4, 0); // sampleOffset
473            //il.Emit(System.Reflection.Emit.OpCodes.Ldarg_1); // sampleIndex
474            //il.Emit(System.Reflection.Emit.OpCodes.Add); // row = sampleIndex + sampleOffset
475            //il.Emit(System.Reflection.Emit.OpCodes.Ldc_I4, currentInstr.iArg0); // load var
476            //il.Emit(System.Reflection.Emit.OpCodes.Call, datasetGetValue); // dataset.GetValue
477            //il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, varNode.Weight); // load weight
478            //il.Emit(System.Reflection.Emit.OpCodes.Mul);
479            return;
480          }
481        case OpCodes.LagVariable: {
482            throw new NotImplementedException();
483          }
484        case OpCodes.Constant: {
485            ConstantTreeNode constNode = (ConstantTreeNode)currentInstr.dynamicNode;
486            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, constNode.Value);
487            return;
488          }
489
490        //mkommend: this symbol uses the logistic function f(x) = 1 / (1 + e^(-alpha * x) )
491        //to determine the relative amounts of the true and false branch see http://en.wikipedia.org/wiki/Logistic_function
492        case OpCodes.VariableCondition: {
493            throw new NotImplementedException();
494          }
495        default: throw new NotSupportedException();
496      }
497    }
498
499    private byte MapSymbolToOpCode(ISymbolicExpressionTreeNode treeNode) {
500      if (symbolToOpcode.ContainsKey(treeNode.Symbol.GetType()))
501        return symbolToOpcode[treeNode.Symbol.GetType()];
502      else
503        throw new NotSupportedException("Symbol: " + treeNode.Symbol);
504    }
505
506    // skips a whole branch
507    private void SkipInstructions(InterpreterState state) {
508      int i = 1;
509      while (i > 0) {
510        i += state.NextInstruction().nArguments;
511        i--;
512      }
513    }
514  }
515}
Note: See TracBrowser for help on using the repository browser.