Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/SymbolicDataAnalysisExpressionTreeILEmittingInterpreter.cs @ 6741

Last change on this file since 6741 was 6741, checked in by gkronber, 13 years ago

#1640 adapted IL emitting interpreter to work with previous (r6740) changes of the dataset.

File size: 21.5 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.ObjectModel;
24using System.Linq;
25using System.Collections.Generic;
26using System.Reflection;
27using System.Reflection.Emit;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Data;
31using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
32using HeuristicLab.Parameters;
33using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
34
35namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
36  [StorableClass]
37  [Item("SymbolicDataAnalysisExpressionTreeILEmittingInterpreter", "Interpreter for symbolic expression trees.")]
38  public sealed class SymbolicDataAnalysisExpressionTreeILEmittingInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
39    private static MethodInfo listGetValue = typeof(IList<double>).GetProperty("Item", new Type[] { typeof(int) }).GetGetMethod();
40    private static MethodInfo cos = typeof(Math).GetMethod("Cos", new Type[] { typeof(double) });
41    private static MethodInfo sin = typeof(Math).GetMethod("Sin", new Type[] { typeof(double) });
42    private static MethodInfo tan = typeof(Math).GetMethod("Tan", new Type[] { typeof(double) });
43    private static MethodInfo exp = typeof(Math).GetMethod("Exp", new Type[] { typeof(double) });
44    private static MethodInfo log = typeof(Math).GetMethod("Log", new Type[] { typeof(double) });
45    private static MethodInfo power = typeof(Math).GetMethod("Pow", new Type[] { typeof(double), typeof(double) });
46
47    internal delegate double CompiledFunction(int sampleIndex, IList<double>[] columns);
48    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
49    #region private classes
50    private class InterpreterState {
51      private double[] argumentStack;
52      private int argumentStackPointer;
53      private Instruction[] code;
54      private int pc;
55      public int ProgramCounter {
56        get { return pc; }
57        set { pc = value; }
58      }
59      internal InterpreterState(Instruction[] code, int argumentStackSize) {
60        this.code = code;
61        this.pc = 0;
62        if (argumentStackSize > 0) {
63          this.argumentStack = new double[argumentStackSize];
64        }
65        this.argumentStackPointer = 0;
66      }
67
68      internal void Reset() {
69        this.pc = 0;
70        this.argumentStackPointer = 0;
71      }
72
73      internal Instruction NextInstruction() {
74        return code[pc++];
75      }
76      private void Push(double val) {
77        argumentStack[argumentStackPointer++] = val;
78      }
79      private double Pop() {
80        return argumentStack[--argumentStackPointer];
81      }
82
83      internal void CreateStackFrame(double[] argValues) {
84        // push in reverse order to make indexing easier
85        for (int i = argValues.Length - 1; i >= 0; i--) {
86          argumentStack[argumentStackPointer++] = argValues[i];
87        }
88        Push(argValues.Length);
89      }
90
91      internal void RemoveStackFrame() {
92        int size = (int)Pop();
93        argumentStackPointer -= size;
94      }
95
96      internal double GetStackFrameValue(ushort index) {
97        // layout of stack:
98        // [0]   <- argumentStackPointer
99        // [StackFrameSize = N + 1]
100        // [Arg0] <- argumentStackPointer - 2 - 0
101        // [Arg1] <- argumentStackPointer - 2 - 1
102        // [...]
103        // [ArgN] <- argumentStackPointer - 2 - N
104        // <Begin of stack frame>
105        return argumentStack[argumentStackPointer - index - 2];
106      }
107    }
108    private class OpCodes {
109      public const byte Add = 1;
110      public const byte Sub = 2;
111      public const byte Mul = 3;
112      public const byte Div = 4;
113
114      public const byte Sin = 5;
115      public const byte Cos = 6;
116      public const byte Tan = 7;
117
118      public const byte Log = 8;
119      public const byte Exp = 9;
120
121      public const byte IfThenElse = 10;
122
123      public const byte GT = 11;
124      public const byte LT = 12;
125
126      public const byte AND = 13;
127      public const byte OR = 14;
128      public const byte NOT = 15;
129
130
131      public const byte Average = 16;
132
133      public const byte Call = 17;
134
135      public const byte Variable = 18;
136      public const byte LagVariable = 19;
137      public const byte Constant = 20;
138      public const byte Arg = 21;
139
140      public const byte Power = 22;
141      public const byte Root = 23;
142      public const byte TimeLag = 24;
143      public const byte Integral = 25;
144      public const byte Derivative = 26;
145
146      public const byte VariableCondition = 27;
147    }
148    #endregion
149
150    private Dictionary<Type, byte> symbolToOpcode = new Dictionary<Type, byte>() {
151      { typeof(Addition), OpCodes.Add },
152      { typeof(Subtraction), OpCodes.Sub },
153      { typeof(Multiplication), OpCodes.Mul },
154      { typeof(Division), OpCodes.Div },
155      { typeof(Sine), OpCodes.Sin },
156      { typeof(Cosine), OpCodes.Cos },
157      { typeof(Tangent), OpCodes.Tan },
158      { typeof(Logarithm), OpCodes.Log },
159      { typeof(Exponential), OpCodes.Exp },
160      { typeof(IfThenElse), OpCodes.IfThenElse },
161      { typeof(GreaterThan), OpCodes.GT },
162      { typeof(LessThan), OpCodes.LT },
163      { typeof(And), OpCodes.AND },
164      { typeof(Or), OpCodes.OR },
165      { typeof(Not), OpCodes.NOT},
166      { typeof(Average), OpCodes.Average},
167      { typeof(InvokeFunction), OpCodes.Call },
168      { typeof(HeuristicLab.Problems.DataAnalysis.Symbolic.Variable), OpCodes.Variable },
169      { typeof(LaggedVariable), OpCodes.LagVariable },
170      { typeof(Constant), OpCodes.Constant },
171      { typeof(Argument), OpCodes.Arg },
172      { typeof(Power),OpCodes.Power},
173      { typeof(Root),OpCodes.Root},
174      { typeof(TimeLag), OpCodes.TimeLag},
175      { typeof(Integral), OpCodes.Integral},
176      { typeof(Derivative), OpCodes.Derivative},
177      { typeof(VariableCondition),OpCodes.VariableCondition}
178    };
179
180    public override bool CanChangeName {
181      get { return false; }
182    }
183    public override bool CanChangeDescription {
184      get { return false; }
185    }
186
187    #region parameter properties
188    public IValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
189      get { return (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
190    }
191    #endregion
192
193    #region properties
194    public BoolValue CheckExpressionsWithIntervalArithmetic {
195      get { return CheckExpressionsWithIntervalArithmeticParameter.Value; }
196      set { CheckExpressionsWithIntervalArithmeticParameter.Value = value; }
197    }
198    #endregion
199
200
201    [StorableConstructor]
202    private SymbolicDataAnalysisExpressionTreeILEmittingInterpreter(bool deserializing) : base(deserializing) { }
203    private SymbolicDataAnalysisExpressionTreeILEmittingInterpreter(SymbolicDataAnalysisExpressionTreeILEmittingInterpreter original, Cloner cloner) : base(original, cloner) { }
204    public override IDeepCloneable Clone(Cloner cloner) {
205      return new SymbolicDataAnalysisExpressionTreeILEmittingInterpreter(this, cloner);
206    }
207
208    public SymbolicDataAnalysisExpressionTreeILEmittingInterpreter()
209      : base("SymbolicDataAnalysisExpressionTreeILEmittingInterpreter", "Interpreter for symbolic expression trees.") {
210      Parameters.Add(new ValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
211    }
212
213    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows) {
214      if (CheckExpressionsWithIntervalArithmetic.Value)
215        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
216      var compiler = new SymbolicExpressionTreeCompiler();
217      Instruction[] code = compiler.Compile(tree, MapSymbolToOpCode);
218      int necessaryArgStackSize = 0;
219
220      Dictionary<string, int> doubleVariableNames = dataset.DoubleVariables.Select((x, i) => new { x, i }).ToDictionary(e => e.x, e => e.i);
221      IList<double>[] columns = (from v in doubleVariableNames.Keys
222                                 select dataset.GetReadOnlyDoubleValues(v))
223                                .ToArray();
224
225      for (int i = 0; i < code.Length; i++) {
226        Instruction instr = code[i];
227        if (instr.opCode == OpCodes.Variable) {
228          var variableTreeNode = instr.dynamicNode as VariableTreeNode;
229          instr.iArg0 = doubleVariableNames[variableTreeNode.VariableName];
230          code[i] = instr;
231        } else if (instr.opCode == OpCodes.LagVariable) {
232          var variableTreeNode = instr.dynamicNode as LaggedVariableTreeNode;
233          instr.iArg0 = doubleVariableNames[variableTreeNode.VariableName];
234          code[i] = instr;
235        } else if (instr.opCode == OpCodes.VariableCondition) {
236          var variableConditionTreeNode = instr.dynamicNode as VariableConditionTreeNode;
237          instr.iArg0 = doubleVariableNames[variableConditionTreeNode.VariableName];
238        } else if (instr.opCode == OpCodes.Call) {
239          necessaryArgStackSize += instr.nArguments + 1;
240        }
241      }
242      var state = new InterpreterState(code, necessaryArgStackSize);
243
244      Type[] methodArgs = { typeof(int), typeof(IList<double>[]) };
245      DynamicMethod testFun = new DynamicMethod("TestFun", typeof(double), methodArgs, typeof(SymbolicDataAnalysisExpressionTreeILEmittingInterpreter).Module);
246
247      ILGenerator il = testFun.GetILGenerator();
248      CompileInstructions(il, state);
249      il.Emit(System.Reflection.Emit.OpCodes.Conv_R8);
250      il.Emit(System.Reflection.Emit.OpCodes.Ret);
251      var function = (CompiledFunction)testFun.CreateDelegate(typeof(CompiledFunction));
252
253      foreach (var row in rows) {
254        yield return function(row, columns);
255      }
256    }
257
258    private void CompileInstructions(ILGenerator il, InterpreterState state) {
259      Instruction currentInstr = state.NextInstruction();
260      int nArgs = currentInstr.nArguments;
261
262      switch (currentInstr.opCode) {
263        case OpCodes.Add: {
264            if (nArgs > 0) {
265              CompileInstructions(il, state);
266            }
267            for (int i = 1; i < nArgs; i++) {
268              CompileInstructions(il, state);
269              il.Emit(System.Reflection.Emit.OpCodes.Add);
270            }
271            return;
272          }
273        case OpCodes.Sub: {
274            if (nArgs == 1) {
275              CompileInstructions(il, state);
276              il.Emit(System.Reflection.Emit.OpCodes.Neg);
277              return;
278            }
279            if (nArgs > 0) {
280              CompileInstructions(il, state);
281            }
282            for (int i = 1; i < nArgs; i++) {
283              CompileInstructions(il, state);
284              il.Emit(System.Reflection.Emit.OpCodes.Sub);
285            }
286            return;
287          }
288        case OpCodes.Mul: {
289            if (nArgs > 0) {
290              CompileInstructions(il, state);
291            }
292            for (int i = 1; i < nArgs; i++) {
293              CompileInstructions(il, state);
294              il.Emit(System.Reflection.Emit.OpCodes.Mul);
295            }
296            return;
297          }
298        case OpCodes.Div: {
299            if (nArgs == 1) {
300              il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 1.0);
301              CompileInstructions(il, state);
302              il.Emit(System.Reflection.Emit.OpCodes.Div);
303              return;
304            }
305            if (nArgs > 0) {
306              CompileInstructions(il, state);
307            }
308            for (int i = 1; i < nArgs; i++) {
309              CompileInstructions(il, state);
310              il.Emit(System.Reflection.Emit.OpCodes.Div);
311            }
312            return;
313          }
314        case OpCodes.Average: {
315            CompileInstructions(il, state);
316            for (int i = 1; i < nArgs; i++) {
317              CompileInstructions(il, state);
318              il.Emit(System.Reflection.Emit.OpCodes.Add);
319            }
320            il.Emit(System.Reflection.Emit.OpCodes.Ldc_I4, nArgs);
321            il.Emit(System.Reflection.Emit.OpCodes.Div);
322            return;
323          }
324        case OpCodes.Cos: {
325            CompileInstructions(il, state);
326            il.Emit(System.Reflection.Emit.OpCodes.Call, cos);
327            return;
328          }
329        case OpCodes.Sin: {
330            CompileInstructions(il, state);
331            il.Emit(System.Reflection.Emit.OpCodes.Call, sin);
332            return;
333          }
334        case OpCodes.Tan: {
335            CompileInstructions(il, state);
336            il.Emit(System.Reflection.Emit.OpCodes.Call, tan);
337            return;
338          }
339        case OpCodes.Power: {
340            CompileInstructions(il, state);
341            CompileInstructions(il, state);
342            il.Emit(System.Reflection.Emit.OpCodes.Call, power);
343            return;
344          }
345        case OpCodes.Root: {
346            throw new NotImplementedException();
347          }
348        case OpCodes.Exp: {
349            CompileInstructions(il, state);
350            il.Emit(System.Reflection.Emit.OpCodes.Call, exp);
351            return;
352          }
353        case OpCodes.Log: {
354            CompileInstructions(il, state);
355            il.Emit(System.Reflection.Emit.OpCodes.Call, log);
356            return;
357          }
358        case OpCodes.IfThenElse: {
359            Label end = il.DefineLabel();
360            Label c1 = il.DefineLabel();
361            CompileInstructions(il, state);
362            il.Emit(System.Reflection.Emit.OpCodes.Ldc_I4_0); // > 0
363            il.Emit(System.Reflection.Emit.OpCodes.Cgt);
364            il.Emit(System.Reflection.Emit.OpCodes.Brfalse, c1);
365            CompileInstructions(il, state);
366            il.Emit(System.Reflection.Emit.OpCodes.Br, end);
367            il.MarkLabel(c1);
368            CompileInstructions(il, state);
369            il.MarkLabel(end);
370            return;
371          }
372        case OpCodes.AND: {
373            Label falseBranch = il.DefineLabel();
374            Label end = il.DefineLabel();
375            CompileInstructions(il, state);
376            for (int i = 1; i < nArgs; i++) {
377              il.Emit(System.Reflection.Emit.OpCodes.Ldc_I4_0); // > 0
378              il.Emit(System.Reflection.Emit.OpCodes.Cgt);
379              il.Emit(System.Reflection.Emit.OpCodes.Brfalse, falseBranch);
380              CompileInstructions(il, state);
381            }
382            il.Emit(System.Reflection.Emit.OpCodes.Ldc_I4_0); // > 0
383            il.Emit(System.Reflection.Emit.OpCodes.Cgt);
384            il.Emit(System.Reflection.Emit.OpCodes.Brfalse, falseBranch);
385            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 1.0); // 1
386            il.Emit(System.Reflection.Emit.OpCodes.Br, end);
387            il.MarkLabel(falseBranch);
388            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 1.0); // -1
389            il.Emit(System.Reflection.Emit.OpCodes.Neg);
390            il.MarkLabel(end);
391            return;
392          }
393        case OpCodes.OR: {
394            Label trueBranch = il.DefineLabel();
395            Label end = il.DefineLabel();
396            Label resultBranch = il.DefineLabel();
397            CompileInstructions(il, state);
398            for (int i = 1; i < nArgs; i++) {
399              Label nextArgBranch = il.DefineLabel();
400              // complex definition because of special properties of NaN 
401              il.Emit(System.Reflection.Emit.OpCodes.Dup);
402              il.Emit(System.Reflection.Emit.OpCodes.Ldc_I4_0); // <= 0       
403              il.Emit(System.Reflection.Emit.OpCodes.Ble, nextArgBranch);
404              il.Emit(System.Reflection.Emit.OpCodes.Br, resultBranch);
405              il.MarkLabel(nextArgBranch);
406              il.Emit(System.Reflection.Emit.OpCodes.Pop);
407              CompileInstructions(il, state);
408            }
409            il.MarkLabel(resultBranch);
410            il.Emit(System.Reflection.Emit.OpCodes.Ldc_I4_0); // > 0
411            il.Emit(System.Reflection.Emit.OpCodes.Cgt);
412            il.Emit(System.Reflection.Emit.OpCodes.Brtrue, trueBranch);
413            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 1.0); // -1
414            il.Emit(System.Reflection.Emit.OpCodes.Neg);
415            il.Emit(System.Reflection.Emit.OpCodes.Br, end);
416            il.MarkLabel(trueBranch);
417            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 1.0); // 1
418            il.MarkLabel(end);
419            return;
420          }
421        case OpCodes.NOT: {
422            CompileInstructions(il, state);
423            il.Emit(System.Reflection.Emit.OpCodes.Ldc_I4_0); // > 0
424            il.Emit(System.Reflection.Emit.OpCodes.Cgt);
425            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 2.0); // * 2
426            il.Emit(System.Reflection.Emit.OpCodes.Mul);
427            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 1.0); // - 1
428            il.Emit(System.Reflection.Emit.OpCodes.Sub);
429            il.Emit(System.Reflection.Emit.OpCodes.Neg); // * -1
430            return;
431          }
432        case OpCodes.GT: {
433            CompileInstructions(il, state);
434            CompileInstructions(il, state);
435            il.Emit(System.Reflection.Emit.OpCodes.Cgt); // 1 (>) / 0 (otherwise)
436            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 2.0); // * 2
437            il.Emit(System.Reflection.Emit.OpCodes.Mul);
438            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 1.0); // - 1
439            il.Emit(System.Reflection.Emit.OpCodes.Sub);
440            return;
441          }
442        case OpCodes.LT: {
443            CompileInstructions(il, state);
444            CompileInstructions(il, state);
445            il.Emit(System.Reflection.Emit.OpCodes.Clt);
446            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 2.0); // * 2
447            il.Emit(System.Reflection.Emit.OpCodes.Mul);
448            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 1.0); // - 1
449            il.Emit(System.Reflection.Emit.OpCodes.Sub);
450            return;
451          }
452        case OpCodes.TimeLag: {
453            throw new NotImplementedException();
454          }
455        case OpCodes.Integral: {
456            throw new NotImplementedException();
457          }
458
459        //mkommend: derivate calculation taken from:
460        //http://www.holoborodko.com/pavel/numerical-methods/numerical-derivative/smooth-low-noise-differentiators/
461        //one sided smooth differentiatior, N = 4
462        // y' = 1/8h (f_i + 2f_i-1, -2 f_i-3 - f_i-4)
463        case OpCodes.Derivative: {
464            throw new NotImplementedException();
465          }
466        case OpCodes.Call: {
467            throw new NotImplementedException();
468          }
469        case OpCodes.Arg: {
470            throw new NotImplementedException();
471          }
472        case OpCodes.Variable: {
473            VariableTreeNode varNode = (VariableTreeNode)currentInstr.dynamicNode;
474            il.Emit(System.Reflection.Emit.OpCodes.Ldarg_1); // load columns array
475            il.Emit(System.Reflection.Emit.OpCodes.Ldc_I4, (int)currentInstr.iArg0); // load correct column of the current variable
476            il.Emit(System.Reflection.Emit.OpCodes.Ldelem_Ref);
477            il.Emit(System.Reflection.Emit.OpCodes.Ldc_I4, 0); // sampleOffset
478            il.Emit(System.Reflection.Emit.OpCodes.Ldarg_0); // sampleIndex
479            il.Emit(System.Reflection.Emit.OpCodes.Add); // row = sampleIndex + sampleOffset
480            il.Emit(System.Reflection.Emit.OpCodes.Call, listGetValue);
481            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, varNode.Weight); // load weight
482            il.Emit(System.Reflection.Emit.OpCodes.Mul);
483            return;
484          }
485        case OpCodes.LagVariable: {
486            throw new NotImplementedException();
487          }
488        case OpCodes.Constant: {
489            ConstantTreeNode constNode = (ConstantTreeNode)currentInstr.dynamicNode;
490            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, constNode.Value);
491            return;
492          }
493
494        //mkommend: this symbol uses the logistic function f(x) = 1 / (1 + e^(-alpha * x) )
495        //to determine the relative amounts of the true and false branch see http://en.wikipedia.org/wiki/Logistic_function
496        case OpCodes.VariableCondition: {
497            throw new NotImplementedException();
498          }
499        default: throw new NotSupportedException();
500      }
501    }
502
503    private byte MapSymbolToOpCode(ISymbolicExpressionTreeNode treeNode) {
504      if (symbolToOpcode.ContainsKey(treeNode.Symbol.GetType()))
505        return symbolToOpcode[treeNode.Symbol.GetType()];
506      else
507        throw new NotSupportedException("Symbol: " + treeNode.Symbol);
508    }
509
510    // skips a whole branch
511    private void SkipInstructions(InterpreterState state) {
512      int i = 1;
513      while (i > 0) {
514        i += state.NextInstruction().nArguments;
515        i--;
516      }
517    }
518  }
519}
Note: See TracBrowser for help on using the repository browser.