Free cookie consent management tool by TermsFeed Policy Generator

source: branches/DataAnalysis/HeuristicLab.Problems.DataAnalysis/3.3/Symbolic/SimpleArithmeticExpressionInterpreter.cs @ 5275

Last change on this file since 5275 was 5275, checked in by gkronber, 13 years ago

Merged changes from trunk to data analysis exploration branch and added fractional distance metric evaluator. #1142

File size: 12.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Compiler;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Symbols;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
33  [StorableClass]
34  [Item("SimpleArithmeticExpressionInterpreter", "Interpreter for arithmetic symbolic expression trees including function calls.")]
35  public sealed class SimpleArithmeticExpressionInterpreter : NamedItem, ISymbolicExpressionTreeInterpreter {
36    private class OpCodes {
37      public const byte Add = 1;
38      public const byte Sub = 2;
39      public const byte Mul = 3;
40      public const byte Div = 4;
41
42      public const byte Sin = 5;
43      public const byte Cos = 6;
44      public const byte Tan = 7;
45
46      public const byte Log = 8;
47      public const byte Exp = 9;
48
49      public const byte IfThenElse = 10;
50
51      public const byte GT = 11;
52      public const byte LT = 12;
53
54      public const byte AND = 13;
55      public const byte OR = 14;
56      public const byte NOT = 15;
57
58
59      public const byte Average = 16;
60
61      public const byte Call = 17;
62
63      public const byte Variable = 18;
64      public const byte LagVariable = 19;
65      public const byte Constant = 20;
66      public const byte Arg = 21;
67    }
68
69    private Dictionary<Type, byte> symbolToOpcode = new Dictionary<Type, byte>() {
70      { typeof(Addition), OpCodes.Add },
71      { typeof(Subtraction), OpCodes.Sub },
72      { typeof(Multiplication), OpCodes.Mul },
73      { typeof(Division), OpCodes.Div },
74      { typeof(Sine), OpCodes.Sin },
75      { typeof(Cosine), OpCodes.Cos },
76      { typeof(Tangent), OpCodes.Tan },
77      { typeof(Logarithm), OpCodes.Log },
78      { typeof(Exponential), OpCodes.Exp },
79      { typeof(IfThenElse), OpCodes.IfThenElse },
80      { typeof(GreaterThan), OpCodes.GT },
81      { typeof(LessThan), OpCodes.LT },
82      { typeof(And), OpCodes.AND },
83      { typeof(Or), OpCodes.OR },
84      { typeof(Not), OpCodes.NOT},
85      { typeof(Average), OpCodes.Average},
86      { typeof(InvokeFunction), OpCodes.Call },
87      { typeof(HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols.Variable), OpCodes.Variable },
88      { typeof(LaggedVariable), OpCodes.LagVariable },
89      { typeof(Constant), OpCodes.Constant },
90      { typeof(Argument), OpCodes.Arg },
91    };
92    private const int ARGUMENT_STACK_SIZE = 1024;
93
94    public override bool CanChangeName {
95      get { return false; }
96    }
97    public override bool CanChangeDescription {
98      get { return false; }
99    }
100
101    [StorableConstructor]
102    private SimpleArithmeticExpressionInterpreter(bool deserializing) : base(deserializing) { }
103    private SimpleArithmeticExpressionInterpreter(SimpleArithmeticExpressionInterpreter original, Cloner cloner) : base(original, cloner) { }
104    public override IDeepCloneable Clone(Cloner cloner) {
105      return new SimpleArithmeticExpressionInterpreter(this, cloner);
106    }
107
108    public SimpleArithmeticExpressionInterpreter()
109      : base() {
110    }
111
112    public IEnumerable<double> GetSymbolicExpressionTreeValues(SymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows) {
113      var compiler = new SymbolicExpressionTreeCompiler();
114      Instruction[] code = compiler.Compile(tree, MapSymbolToOpCode);
115
116      for (int i = 0; i < code.Length; i++) {
117        Instruction instr = code[i];
118        if (instr.opCode == OpCodes.Variable) {
119          var variableTreeNode = instr.dynamicNode as VariableTreeNode;
120          instr.iArg0 = (ushort)dataset.GetVariableIndex(variableTreeNode.VariableName);
121          code[i] = instr;
122        } else if (instr.opCode == OpCodes.LagVariable) {
123          var variableTreeNode = instr.dynamicNode as LaggedVariableTreeNode;
124          instr.iArg0 = (ushort)dataset.GetVariableIndex(variableTreeNode.VariableName);
125          code[i] = instr;
126        }
127      }
128
129      double[] argumentStack = new double[ARGUMENT_STACK_SIZE];
130      foreach (var rowEnum in rows) {
131        int row = rowEnum;
132        int pc = 0;
133        int argStackPointer = 0;
134        yield return Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer);
135      }
136    }
137
138    private double Evaluate(Dataset dataset, ref int row, Instruction[] code, ref int pc, double[] argumentStack, ref int argStackPointer) {
139      Instruction currentInstr = code[pc++];
140      switch (currentInstr.opCode) {
141        case OpCodes.Add: {
142            double s = Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer);
143            for (int i = 1; i < currentInstr.nArguments; i++) {
144              s += Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer);
145            }
146            return s;
147          }
148        case OpCodes.Sub: {
149            double s = Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer);
150            for (int i = 1; i < currentInstr.nArguments; i++) {
151              s -= Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer);
152            }
153            if (currentInstr.nArguments == 1) s = -s;
154            return s;
155          }
156        case OpCodes.Mul: {
157            double p = Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer);
158            for (int i = 1; i < currentInstr.nArguments; i++) {
159              p *= Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer);
160            }
161            return p;
162          }
163        case OpCodes.Div: {
164            double p = Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer);
165            for (int i = 1; i < currentInstr.nArguments; i++) {
166              p /= Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer);
167            }
168            if (currentInstr.nArguments == 1) p = 1.0 / p;
169            return p;
170          }
171        case OpCodes.Average: {
172            double sum = Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer);
173            for (int i = 1; i < currentInstr.nArguments; i++) {
174              sum += Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer);
175            }
176            return sum / currentInstr.nArguments;
177          }
178        case OpCodes.Cos: {
179            return Math.Cos(Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer));
180          }
181        case OpCodes.Sin: {
182            return Math.Sin(Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer));
183          }
184        case OpCodes.Tan: {
185            return Math.Tan(Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer));
186          }
187        case OpCodes.Exp: {
188            return Math.Exp(Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer));
189          }
190        case OpCodes.Log: {
191            return Math.Log(Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer));
192          }
193        case OpCodes.IfThenElse: {
194            double condition = Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer);
195            double result;
196            if (condition > 0.0) {
197              result = Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer); SkipBakedCode(code, ref pc);
198            } else {
199              SkipBakedCode(code, ref pc); result = Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer);
200            }
201            return result;
202          }
203        case OpCodes.AND: {
204            double result = Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer);
205            for (int i = 1; i < currentInstr.nArguments; i++) {
206              if (result <= 0.0) SkipBakedCode(code, ref pc);
207              else {
208                result = Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer);
209              }
210            }
211            return result <= 0.0 ? -1.0 : 1.0;
212          }
213        case OpCodes.OR: {
214            double result = Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer);
215            for (int i = 1; i < currentInstr.nArguments; i++) {
216              if (result > 0.0) SkipBakedCode(code, ref pc);
217              else {
218                result = Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer);
219              }
220            }
221            return result > 0.0 ? 1.0 : -1.0;
222          }
223        case OpCodes.NOT: {
224            return -Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer);
225          }
226        case OpCodes.GT: {
227            double x = Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer);
228            double y = Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer);
229            if (x > y) return 1.0;
230            else return -1.0;
231          }
232        case OpCodes.LT: {
233            double x = Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer);
234            double y = Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer);
235            if (x < y) return 1.0;
236            else return -1.0;
237          }
238        case OpCodes.Call: {
239            // evaluate sub-trees
240            // push on argStack in reverse order
241            for (int i = 0; i < currentInstr.nArguments; i++) {
242              argumentStack[argStackPointer + currentInstr.nArguments - i] = Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer);
243            }
244            argStackPointer += currentInstr.nArguments;
245
246            // save the pc
247            int nextPc = pc;
248            // set pc to start of function 
249            pc = currentInstr.iArg0;
250            // evaluate the function
251            double v = Evaluate(dataset, ref row, code, ref pc, argumentStack, ref argStackPointer);
252
253            // decrease the argument stack pointer by the number of arguments pushed
254            // to set the argStackPointer back to the original location
255            argStackPointer -= currentInstr.nArguments;
256
257            // restore the pc => evaluation will continue at point after my subtrees 
258            pc = nextPc;
259            return v;
260          }
261        case OpCodes.Arg: {
262            return argumentStack[argStackPointer - currentInstr.iArg0];
263          }
264        case OpCodes.Variable: {
265            var variableTreeNode = currentInstr.dynamicNode as VariableTreeNode;
266            return dataset[row, currentInstr.iArg0] * variableTreeNode.Weight;
267          }
268        case OpCodes.LagVariable: {
269            var laggedVariableTreeNode = currentInstr.dynamicNode as LaggedVariableTreeNode;
270            int actualRow = row + laggedVariableTreeNode.Lag;
271            if (actualRow < 0 || actualRow >= dataset.Rows) throw new ArgumentException("Out of range access to dataset row: " + row);
272            return dataset[actualRow, currentInstr.iArg0] * laggedVariableTreeNode.Weight;
273          }
274        case OpCodes.Constant: {
275            var constTreeNode = currentInstr.dynamicNode as ConstantTreeNode;
276            return constTreeNode.Value;
277          }
278        default: throw new NotSupportedException();
279      }
280    }
281
282    private byte MapSymbolToOpCode(SymbolicExpressionTreeNode treeNode) {
283      if (symbolToOpcode.ContainsKey(treeNode.Symbol.GetType()))
284        return symbolToOpcode[treeNode.Symbol.GetType()];
285      else
286        throw new NotSupportedException("Symbol: " + treeNode.Symbol);
287    }
288
289    // skips a whole branch
290    private void SkipBakedCode(Instruction[] code, ref int pc) {
291      int i = 1;
292      while (i > 0) {
293        i += code[pc++].nArguments;
294        i--;
295      }
296    }
297  }
298}
Note: See TracBrowser for help on using the repository browser.