Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/SymbolicDataAnalysisExpressionTreeILEmittingInterpreter.cs @ 6755

Last change on this file since 6755 was 6755, checked in by gkronber, 13 years ago

#1480 implemented code to handle root symbols for the il emitting interpreter and fixed code for power symbol.

File size: 22.0 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2011 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.ObjectModel;
24using System.Linq;
25using System.Collections.Generic;
26using System.Reflection;
27using System.Reflection.Emit;
28using HeuristicLab.Common;
29using HeuristicLab.Core;
30using HeuristicLab.Data;
31using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
32using HeuristicLab.Parameters;
33using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
34
35namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
36  [StorableClass]
37  [Item("SymbolicDataAnalysisExpressionTreeILEmittingInterpreter", "Interpreter for symbolic expression trees.")]
38  public sealed class SymbolicDataAnalysisExpressionTreeILEmittingInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
39    private static MethodInfo listGetValue = typeof(IList<double>).GetProperty("Item", new Type[] { typeof(int) }).GetGetMethod();
40    private static MethodInfo cos = typeof(Math).GetMethod("Cos", new Type[] { typeof(double) });
41    private static MethodInfo sin = typeof(Math).GetMethod("Sin", new Type[] { typeof(double) });
42    private static MethodInfo tan = typeof(Math).GetMethod("Tan", new Type[] { typeof(double) });
43    private static MethodInfo exp = typeof(Math).GetMethod("Exp", new Type[] { typeof(double) });
44    private static MethodInfo log = typeof(Math).GetMethod("Log", new Type[] { typeof(double) });
45    private static MethodInfo power = typeof(Math).GetMethod("Pow", new Type[] { typeof(double), typeof(double) });
46    private static MethodInfo round = typeof(Math).GetMethod("Round", new Type[] { typeof(double) });
47
48    internal delegate double CompiledFunction(int sampleIndex, IList<double>[] columns);
49    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
50    #region private classes
51    private class InterpreterState {
52      private double[] argumentStack;
53      private int argumentStackPointer;
54      private Instruction[] code;
55      private int pc;
56      public int ProgramCounter {
57        get { return pc; }
58        set { pc = value; }
59      }
60      internal InterpreterState(Instruction[] code, int argumentStackSize) {
61        this.code = code;
62        this.pc = 0;
63        if (argumentStackSize > 0) {
64          this.argumentStack = new double[argumentStackSize];
65        }
66        this.argumentStackPointer = 0;
67      }
68
69      internal void Reset() {
70        this.pc = 0;
71        this.argumentStackPointer = 0;
72      }
73
74      internal Instruction NextInstruction() {
75        return code[pc++];
76      }
77      private void Push(double val) {
78        argumentStack[argumentStackPointer++] = val;
79      }
80      private double Pop() {
81        return argumentStack[--argumentStackPointer];
82      }
83
84      internal void CreateStackFrame(double[] argValues) {
85        // push in reverse order to make indexing easier
86        for (int i = argValues.Length - 1; i >= 0; i--) {
87          argumentStack[argumentStackPointer++] = argValues[i];
88        }
89        Push(argValues.Length);
90      }
91
92      internal void RemoveStackFrame() {
93        int size = (int)Pop();
94        argumentStackPointer -= size;
95      }
96
97      internal double GetStackFrameValue(ushort index) {
98        // layout of stack:
99        // [0]   <- argumentStackPointer
100        // [StackFrameSize = N + 1]
101        // [Arg0] <- argumentStackPointer - 2 - 0
102        // [Arg1] <- argumentStackPointer - 2 - 1
103        // [...]
104        // [ArgN] <- argumentStackPointer - 2 - N
105        // <Begin of stack frame>
106        return argumentStack[argumentStackPointer - index - 2];
107      }
108    }
109    private class OpCodes {
110      public const byte Add = 1;
111      public const byte Sub = 2;
112      public const byte Mul = 3;
113      public const byte Div = 4;
114
115      public const byte Sin = 5;
116      public const byte Cos = 6;
117      public const byte Tan = 7;
118
119      public const byte Log = 8;
120      public const byte Exp = 9;
121
122      public const byte IfThenElse = 10;
123
124      public const byte GT = 11;
125      public const byte LT = 12;
126
127      public const byte AND = 13;
128      public const byte OR = 14;
129      public const byte NOT = 15;
130
131
132      public const byte Average = 16;
133
134      public const byte Call = 17;
135
136      public const byte Variable = 18;
137      public const byte LagVariable = 19;
138      public const byte Constant = 20;
139      public const byte Arg = 21;
140
141      public const byte Power = 22;
142      public const byte Root = 23;
143      public const byte TimeLag = 24;
144      public const byte Integral = 25;
145      public const byte Derivative = 26;
146
147      public const byte VariableCondition = 27;
148    }
149    #endregion
150
151    private Dictionary<Type, byte> symbolToOpcode = new Dictionary<Type, byte>() {
152      { typeof(Addition), OpCodes.Add },
153      { typeof(Subtraction), OpCodes.Sub },
154      { typeof(Multiplication), OpCodes.Mul },
155      { typeof(Division), OpCodes.Div },
156      { typeof(Sine), OpCodes.Sin },
157      { typeof(Cosine), OpCodes.Cos },
158      { typeof(Tangent), OpCodes.Tan },
159      { typeof(Logarithm), OpCodes.Log },
160      { typeof(Exponential), OpCodes.Exp },
161      { typeof(IfThenElse), OpCodes.IfThenElse },
162      { typeof(GreaterThan), OpCodes.GT },
163      { typeof(LessThan), OpCodes.LT },
164      { typeof(And), OpCodes.AND },
165      { typeof(Or), OpCodes.OR },
166      { typeof(Not), OpCodes.NOT},
167      { typeof(Average), OpCodes.Average},
168      { typeof(InvokeFunction), OpCodes.Call },
169      { typeof(HeuristicLab.Problems.DataAnalysis.Symbolic.Variable), OpCodes.Variable },
170      { typeof(LaggedVariable), OpCodes.LagVariable },
171      { typeof(Constant), OpCodes.Constant },
172      { typeof(Argument), OpCodes.Arg },
173      { typeof(Power),OpCodes.Power},
174      { typeof(Root),OpCodes.Root},
175      { typeof(TimeLag), OpCodes.TimeLag},
176      { typeof(Integral), OpCodes.Integral},
177      { typeof(Derivative), OpCodes.Derivative},
178      { typeof(VariableCondition),OpCodes.VariableCondition}
179    };
180
181    public override bool CanChangeName {
182      get { return false; }
183    }
184    public override bool CanChangeDescription {
185      get { return false; }
186    }
187
188    #region parameter properties
189    public IValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
190      get { return (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
191    }
192    #endregion
193
194    #region properties
195    public BoolValue CheckExpressionsWithIntervalArithmetic {
196      get { return CheckExpressionsWithIntervalArithmeticParameter.Value; }
197      set { CheckExpressionsWithIntervalArithmeticParameter.Value = value; }
198    }
199    #endregion
200
201
202    [StorableConstructor]
203    private SymbolicDataAnalysisExpressionTreeILEmittingInterpreter(bool deserializing) : base(deserializing) { }
204    private SymbolicDataAnalysisExpressionTreeILEmittingInterpreter(SymbolicDataAnalysisExpressionTreeILEmittingInterpreter original, Cloner cloner) : base(original, cloner) { }
205    public override IDeepCloneable Clone(Cloner cloner) {
206      return new SymbolicDataAnalysisExpressionTreeILEmittingInterpreter(this, cloner);
207    }
208
209    public SymbolicDataAnalysisExpressionTreeILEmittingInterpreter()
210      : base("SymbolicDataAnalysisExpressionTreeILEmittingInterpreter", "Interpreter for symbolic expression trees.") {
211      Parameters.Add(new ValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
212    }
213
214    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows) {
215      if (CheckExpressionsWithIntervalArithmetic.Value)
216        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
217      var compiler = new SymbolicExpressionTreeCompiler();
218      Instruction[] code = compiler.Compile(tree, MapSymbolToOpCode);
219      int necessaryArgStackSize = 0;
220
221      Dictionary<string, int> doubleVariableNames = dataset.DoubleVariables.Select((x, i) => new { x, i }).ToDictionary(e => e.x, e => e.i);
222      IList<double>[] columns = (from v in doubleVariableNames.Keys
223                                 select dataset.GetReadOnlyDoubleValues(v))
224                                .ToArray();
225
226      for (int i = 0; i < code.Length; i++) {
227        Instruction instr = code[i];
228        if (instr.opCode == OpCodes.Variable) {
229          var variableTreeNode = instr.dynamicNode as VariableTreeNode;
230          instr.iArg0 = doubleVariableNames[variableTreeNode.VariableName];
231          code[i] = instr;
232        } else if (instr.opCode == OpCodes.LagVariable) {
233          var variableTreeNode = instr.dynamicNode as LaggedVariableTreeNode;
234          instr.iArg0 = doubleVariableNames[variableTreeNode.VariableName];
235          code[i] = instr;
236        } else if (instr.opCode == OpCodes.VariableCondition) {
237          var variableConditionTreeNode = instr.dynamicNode as VariableConditionTreeNode;
238          instr.iArg0 = doubleVariableNames[variableConditionTreeNode.VariableName];
239        } else if (instr.opCode == OpCodes.Call) {
240          necessaryArgStackSize += instr.nArguments + 1;
241        }
242      }
243      var state = new InterpreterState(code, necessaryArgStackSize);
244
245      Type[] methodArgs = { typeof(int), typeof(IList<double>[]) };
246      DynamicMethod testFun = new DynamicMethod("TestFun", typeof(double), methodArgs, typeof(SymbolicDataAnalysisExpressionTreeILEmittingInterpreter).Module);
247
248      ILGenerator il = testFun.GetILGenerator();
249      CompileInstructions(il, state);
250      il.Emit(System.Reflection.Emit.OpCodes.Conv_R8);
251      il.Emit(System.Reflection.Emit.OpCodes.Ret);
252      var function = (CompiledFunction)testFun.CreateDelegate(typeof(CompiledFunction));
253
254      foreach (var row in rows) {
255        yield return function(row, columns);
256      }
257    }
258
259    private void CompileInstructions(ILGenerator il, InterpreterState state) {
260      Instruction currentInstr = state.NextInstruction();
261      int nArgs = currentInstr.nArguments;
262
263      switch (currentInstr.opCode) {
264        case OpCodes.Add: {
265            if (nArgs > 0) {
266              CompileInstructions(il, state);
267            }
268            for (int i = 1; i < nArgs; i++) {
269              CompileInstructions(il, state);
270              il.Emit(System.Reflection.Emit.OpCodes.Add);
271            }
272            return;
273          }
274        case OpCodes.Sub: {
275            if (nArgs == 1) {
276              CompileInstructions(il, state);
277              il.Emit(System.Reflection.Emit.OpCodes.Neg);
278              return;
279            }
280            if (nArgs > 0) {
281              CompileInstructions(il, state);
282            }
283            for (int i = 1; i < nArgs; i++) {
284              CompileInstructions(il, state);
285              il.Emit(System.Reflection.Emit.OpCodes.Sub);
286            }
287            return;
288          }
289        case OpCodes.Mul: {
290            if (nArgs > 0) {
291              CompileInstructions(il, state);
292            }
293            for (int i = 1; i < nArgs; i++) {
294              CompileInstructions(il, state);
295              il.Emit(System.Reflection.Emit.OpCodes.Mul);
296            }
297            return;
298          }
299        case OpCodes.Div: {
300            if (nArgs == 1) {
301              il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 1.0);
302              CompileInstructions(il, state);
303              il.Emit(System.Reflection.Emit.OpCodes.Div);
304              return;
305            }
306            if (nArgs > 0) {
307              CompileInstructions(il, state);
308            }
309            for (int i = 1; i < nArgs; i++) {
310              CompileInstructions(il, state);
311              il.Emit(System.Reflection.Emit.OpCodes.Div);
312            }
313            return;
314          }
315        case OpCodes.Average: {
316            CompileInstructions(il, state);
317            for (int i = 1; i < nArgs; i++) {
318              CompileInstructions(il, state);
319              il.Emit(System.Reflection.Emit.OpCodes.Add);
320            }
321            il.Emit(System.Reflection.Emit.OpCodes.Ldc_I4, nArgs);
322            il.Emit(System.Reflection.Emit.OpCodes.Div);
323            return;
324          }
325        case OpCodes.Cos: {
326            CompileInstructions(il, state);
327            il.Emit(System.Reflection.Emit.OpCodes.Call, cos);
328            return;
329          }
330        case OpCodes.Sin: {
331            CompileInstructions(il, state);
332            il.Emit(System.Reflection.Emit.OpCodes.Call, sin);
333            return;
334          }
335        case OpCodes.Tan: {
336            CompileInstructions(il, state);
337            il.Emit(System.Reflection.Emit.OpCodes.Call, tan);
338            return;
339          }
340        case OpCodes.Power: {
341            CompileInstructions(il, state);
342            CompileInstructions(il, state);
343            il.Emit(System.Reflection.Emit.OpCodes.Call, round);
344            il.Emit(System.Reflection.Emit.OpCodes.Call, power);
345            return;
346          }
347        case OpCodes.Root: {
348            CompileInstructions(il, state);
349            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 1.0); // 1 / round(...)
350            CompileInstructions(il, state);
351            il.Emit(System.Reflection.Emit.OpCodes.Call, round);
352            il.Emit(System.Reflection.Emit.OpCodes.Div);
353            il.Emit(System.Reflection.Emit.OpCodes.Call, power);
354            return;
355          }
356        case OpCodes.Exp: {
357            CompileInstructions(il, state);
358            il.Emit(System.Reflection.Emit.OpCodes.Call, exp);
359            return;
360          }
361        case OpCodes.Log: {
362            CompileInstructions(il, state);
363            il.Emit(System.Reflection.Emit.OpCodes.Call, log);
364            return;
365          }
366        case OpCodes.IfThenElse: {
367            Label end = il.DefineLabel();
368            Label c1 = il.DefineLabel();
369            CompileInstructions(il, state);
370            il.Emit(System.Reflection.Emit.OpCodes.Ldc_I4_0); // > 0
371            il.Emit(System.Reflection.Emit.OpCodes.Cgt);
372            il.Emit(System.Reflection.Emit.OpCodes.Brfalse, c1);
373            CompileInstructions(il, state);
374            il.Emit(System.Reflection.Emit.OpCodes.Br, end);
375            il.MarkLabel(c1);
376            CompileInstructions(il, state);
377            il.MarkLabel(end);
378            return;
379          }
380        case OpCodes.AND: {
381            Label falseBranch = il.DefineLabel();
382            Label end = il.DefineLabel();
383            CompileInstructions(il, state);
384            for (int i = 1; i < nArgs; i++) {
385              il.Emit(System.Reflection.Emit.OpCodes.Ldc_I4_0); // > 0
386              il.Emit(System.Reflection.Emit.OpCodes.Cgt);
387              il.Emit(System.Reflection.Emit.OpCodes.Brfalse, falseBranch);
388              CompileInstructions(il, state);
389            }
390            il.Emit(System.Reflection.Emit.OpCodes.Ldc_I4_0); // > 0
391            il.Emit(System.Reflection.Emit.OpCodes.Cgt);
392            il.Emit(System.Reflection.Emit.OpCodes.Brfalse, falseBranch);
393            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 1.0); // 1
394            il.Emit(System.Reflection.Emit.OpCodes.Br, end);
395            il.MarkLabel(falseBranch);
396            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 1.0); // -1
397            il.Emit(System.Reflection.Emit.OpCodes.Neg);
398            il.MarkLabel(end);
399            return;
400          }
401        case OpCodes.OR: {
402            Label trueBranch = il.DefineLabel();
403            Label end = il.DefineLabel();
404            Label resultBranch = il.DefineLabel();
405            CompileInstructions(il, state);
406            for (int i = 1; i < nArgs; i++) {
407              Label nextArgBranch = il.DefineLabel();
408              // complex definition because of special properties of NaN 
409              il.Emit(System.Reflection.Emit.OpCodes.Dup);
410              il.Emit(System.Reflection.Emit.OpCodes.Ldc_I4_0); // <= 0       
411              il.Emit(System.Reflection.Emit.OpCodes.Ble, nextArgBranch);
412              il.Emit(System.Reflection.Emit.OpCodes.Br, resultBranch);
413              il.MarkLabel(nextArgBranch);
414              il.Emit(System.Reflection.Emit.OpCodes.Pop);
415              CompileInstructions(il, state);
416            }
417            il.MarkLabel(resultBranch);
418            il.Emit(System.Reflection.Emit.OpCodes.Ldc_I4_0); // > 0
419            il.Emit(System.Reflection.Emit.OpCodes.Cgt);
420            il.Emit(System.Reflection.Emit.OpCodes.Brtrue, trueBranch);
421            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 1.0); // -1
422            il.Emit(System.Reflection.Emit.OpCodes.Neg);
423            il.Emit(System.Reflection.Emit.OpCodes.Br, end);
424            il.MarkLabel(trueBranch);
425            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 1.0); // 1
426            il.MarkLabel(end);
427            return;
428          }
429        case OpCodes.NOT: {
430            CompileInstructions(il, state);
431            il.Emit(System.Reflection.Emit.OpCodes.Ldc_I4_0); // > 0
432            il.Emit(System.Reflection.Emit.OpCodes.Cgt);
433            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 2.0); // * 2
434            il.Emit(System.Reflection.Emit.OpCodes.Mul);
435            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 1.0); // - 1
436            il.Emit(System.Reflection.Emit.OpCodes.Sub);
437            il.Emit(System.Reflection.Emit.OpCodes.Neg); // * -1
438            return;
439          }
440        case OpCodes.GT: {
441            CompileInstructions(il, state);
442            CompileInstructions(il, state);
443            il.Emit(System.Reflection.Emit.OpCodes.Cgt); // 1 (>) / 0 (otherwise)
444            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 2.0); // * 2
445            il.Emit(System.Reflection.Emit.OpCodes.Mul);
446            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 1.0); // - 1
447            il.Emit(System.Reflection.Emit.OpCodes.Sub);
448            return;
449          }
450        case OpCodes.LT: {
451            CompileInstructions(il, state);
452            CompileInstructions(il, state);
453            il.Emit(System.Reflection.Emit.OpCodes.Clt);
454            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 2.0); // * 2
455            il.Emit(System.Reflection.Emit.OpCodes.Mul);
456            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, 1.0); // - 1
457            il.Emit(System.Reflection.Emit.OpCodes.Sub);
458            return;
459          }
460        case OpCodes.TimeLag: {
461            throw new NotImplementedException();
462          }
463        case OpCodes.Integral: {
464            throw new NotImplementedException();
465          }
466
467        //mkommend: derivate calculation taken from:
468        //http://www.holoborodko.com/pavel/numerical-methods/numerical-derivative/smooth-low-noise-differentiators/
469        //one sided smooth differentiatior, N = 4
470        // y' = 1/8h (f_i + 2f_i-1, -2 f_i-3 - f_i-4)
471        case OpCodes.Derivative: {
472            throw new NotImplementedException();
473          }
474        case OpCodes.Call: {
475            throw new NotImplementedException();
476          }
477        case OpCodes.Arg: {
478            throw new NotImplementedException();
479          }
480        case OpCodes.Variable: {
481            VariableTreeNode varNode = (VariableTreeNode)currentInstr.dynamicNode;
482            il.Emit(System.Reflection.Emit.OpCodes.Ldarg_1); // load columns array
483            il.Emit(System.Reflection.Emit.OpCodes.Ldc_I4, (int)currentInstr.iArg0); // load correct column of the current variable
484            il.Emit(System.Reflection.Emit.OpCodes.Ldelem_Ref);
485            il.Emit(System.Reflection.Emit.OpCodes.Ldc_I4, 0); // sampleOffset
486            il.Emit(System.Reflection.Emit.OpCodes.Ldarg_0); // sampleIndex
487            il.Emit(System.Reflection.Emit.OpCodes.Add); // row = sampleIndex + sampleOffset
488            il.Emit(System.Reflection.Emit.OpCodes.Call, listGetValue);
489            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, varNode.Weight); // load weight
490            il.Emit(System.Reflection.Emit.OpCodes.Mul);
491            return;
492          }
493        case OpCodes.LagVariable: {
494            throw new NotImplementedException();
495          }
496        case OpCodes.Constant: {
497            ConstantTreeNode constNode = (ConstantTreeNode)currentInstr.dynamicNode;
498            il.Emit(System.Reflection.Emit.OpCodes.Ldc_R8, constNode.Value);
499            return;
500          }
501
502        //mkommend: this symbol uses the logistic function f(x) = 1 / (1 + e^(-alpha * x) )
503        //to determine the relative amounts of the true and false branch see http://en.wikipedia.org/wiki/Logistic_function
504        case OpCodes.VariableCondition: {
505            throw new NotImplementedException();
506          }
507        default: throw new NotSupportedException();
508      }
509    }
510
511    private byte MapSymbolToOpCode(ISymbolicExpressionTreeNode treeNode) {
512      if (symbolToOpcode.ContainsKey(treeNode.Symbol.GetType()))
513        return symbolToOpcode[treeNode.Symbol.GetType()];
514      else
515        throw new NotSupportedException("Symbol: " + treeNode.Symbol);
516    }
517
518    // skips a whole branch
519    private void SkipInstructions(InterpreterState state) {
520      int i = 1;
521      while (i > 0) {
522        i += state.NextInstruction().nArguments;
523        i--;
524      }
525    }
526  }
527}
Note: See TracBrowser for help on using the repository browser.