Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/Symbolic/SimpleArithmeticExpressionInterpreter.cs @ 3921

Last change on this file since 3921 was 3841, checked in by gkronber, 14 years ago

Extended set of available functions for symbolic regression and added test cases for the extended function set. #1013

File size: 10.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using System.Collections.Generic;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Symbols;
29using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
30using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Compiler;
31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
33  [StorableClass]
34  [Item("SimpleArithmeticExpressionInterpreter", "Interpreter for arithmetic symbolic expression trees including function calls.")]
35  // not thread safe!
36  public class SimpleArithmeticExpressionInterpreter : NamedItem, ISymbolicExpressionTreeInterpreter {
37    private class OpCodes {
38      public const byte Add = 1;
39      public const byte Sub = 2;
40      public const byte Mul = 3;
41      public const byte Div = 4;
42
43      public const byte Sin = 5;
44      public const byte Cos = 6;
45      public const byte Tan = 7;
46
47      public const byte Log = 8;
48      public const byte Exp = 9;
49
50      public const byte IfThenElse = 10;
51
52      public const byte GT = 11;
53      public const byte LT = 12;
54
55      public const byte AND = 13;
56      public const byte OR = 14;
57      public const byte NOT = 15;
58
59
60      public const byte Average = 16;
61
62      public const byte Call = 17;
63
64      public const byte Variable = 18;
65      public const byte LagVariable = 19;
66      public const byte Constant = 20;
67      public const byte Arg = 21;
68    }
69
70    private Dictionary<Type, byte> symbolToOpcode = new Dictionary<Type, byte>() {
71      { typeof(Addition), OpCodes.Add },
72      { typeof(Subtraction), OpCodes.Sub },
73      { typeof(Multiplication), OpCodes.Mul },
74      { typeof(Division), OpCodes.Div },
75      { typeof(Sine), OpCodes.Sin },
76      { typeof(Cosine), OpCodes.Cos },
77      { typeof(Tangent), OpCodes.Tan },
78      { typeof(Logarithm), OpCodes.Log },
79      { typeof(Exponential), OpCodes.Exp },
80      { typeof(IfThenElse), OpCodes.IfThenElse },
81      { typeof(GreaterThan), OpCodes.GT },
82      { typeof(LessThan), OpCodes.LT },
83      { typeof(And), OpCodes.AND },
84      { typeof(Or), OpCodes.OR },
85      { typeof(Not), OpCodes.NOT},
86      { typeof(Average), OpCodes.Average},
87      { typeof(InvokeFunction), OpCodes.Call },
88      { typeof(HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols.Variable), OpCodes.Variable },
89      { typeof(LaggedVariable), OpCodes.LagVariable },
90      { typeof(Constant), OpCodes.Constant },
91      { typeof(Argument), OpCodes.Arg },
92    };
93    private const int ARGUMENT_STACK_SIZE = 1024;
94
95    private Dataset dataset;
96    private int row;
97    private Instruction[] code;
98    private int pc;
99
100    public override bool CanChangeName {
101      get { return false; }
102    }
103    public override bool CanChangeDescription {
104      get { return false; }
105    }
106
107    public SimpleArithmeticExpressionInterpreter()
108      : base() {
109    }
110
111    public IEnumerable<double> GetSymbolicExpressionTreeValues(SymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows) {
112      this.dataset = dataset;
113      var compiler = new SymbolicExpressionTreeCompiler();
114      compiler.AddInstructionPostProcessingHook(PostProcessInstruction);
115      code = compiler.Compile(tree, MapSymbolToOpCode);
116      foreach (var row in rows) {
117        this.row = row;
118        pc = 0;
119        argStackPointer = 0;
120        yield return Evaluate();
121      }
122    }
123
124    private Instruction PostProcessInstruction(Instruction instr) {
125      if (instr.opCode == OpCodes.Variable) {
126        var variableTreeNode = instr.dynamicNode as VariableTreeNode;
127        instr.iArg0 = (ushort)dataset.GetVariableIndex(variableTreeNode.VariableName);
128      } else if (instr.opCode == OpCodes.LagVariable) {
129        var variableTreeNode = instr.dynamicNode as LaggedVariableTreeNode;
130        instr.iArg0 = (ushort)dataset.GetVariableIndex(variableTreeNode.VariableName);
131      }
132      return instr;
133    }
134
135    private byte MapSymbolToOpCode(SymbolicExpressionTreeNode treeNode) {
136      if (symbolToOpcode.ContainsKey(treeNode.Symbol.GetType()))
137        return symbolToOpcode[treeNode.Symbol.GetType()];
138      else
139        throw new NotSupportedException("Symbol: " + treeNode.Symbol);
140    }
141
142    private double[] argumentStack = new double[ARGUMENT_STACK_SIZE];
143    private int argStackPointer;
144
145    public double Evaluate() {
146      var currentInstr = code[pc++];
147      switch (currentInstr.opCode) {
148        case OpCodes.Add: {
149            double s = 0.0;
150            for (int i = 0; i < currentInstr.nArguments; i++) {
151              s += Evaluate();
152            }
153            return s;
154          }
155        case OpCodes.Sub: {
156            double s = Evaluate();
157            for (int i = 1; i < currentInstr.nArguments; i++) {
158              s -= Evaluate();
159            }
160            if (currentInstr.nArguments == 1) s = -s;
161            return s;
162          }
163        case OpCodes.Mul: {
164            double p = Evaluate();
165            for (int i = 1; i < currentInstr.nArguments; i++) {
166              p *= Evaluate();
167            }
168            return p;
169          }
170        case OpCodes.Div: {
171            double p = Evaluate();
172            for (int i = 1; i < currentInstr.nArguments; i++) {
173              p /= Evaluate();
174            }
175            if (currentInstr.nArguments == 1) p = 1.0 / p;
176            return p;
177          }
178        case OpCodes.Average: {
179            double sum = Evaluate();
180            for (int i = 1; i < currentInstr.nArguments; i++) {
181              sum += Evaluate();
182            }
183            return sum / currentInstr.nArguments;
184          }
185        case OpCodes.Cos: {
186            return Math.Cos(Evaluate());
187          }
188        case OpCodes.Sin: {
189            return Math.Sin(Evaluate());
190          }
191        case OpCodes.Tan: {
192            return Math.Tan(Evaluate());
193          }
194        case OpCodes.Exp: {
195            return Math.Exp(Evaluate());
196          }
197        case OpCodes.Log: {
198            return Math.Log(Evaluate());
199          }
200        case OpCodes.IfThenElse: {
201            double condition = Evaluate();
202            double result;
203            if (condition > 0.0) {
204              result = Evaluate(); SkipBakedCode();
205            } else {
206              SkipBakedCode(); result = Evaluate();
207            }
208            return result;
209          }
210        case OpCodes.AND: {
211            double result = Evaluate();
212            for (int i = 1; i < currentInstr.nArguments; i++) {
213              if (result <= 0.0) SkipBakedCode();
214              else {
215                result = Evaluate();
216              }
217            }
218            return result <= 0.0 ? -1.0 : 1.0;
219          }
220        case OpCodes.OR: {
221            double result = Evaluate();
222            for (int i = 1; i < currentInstr.nArguments; i++) {
223              if (result > 0.0) SkipBakedCode();
224              else {
225                result = Evaluate();
226              }
227            }
228            return result > 0.0 ? 1.0 : -1.0;
229          }
230        case OpCodes.NOT: {
231            return -Evaluate();
232          }
233        case OpCodes.GT: {
234            double x = Evaluate();
235            double y = Evaluate();
236            if (x > y) return 1.0;
237            else return -1.0;
238          }
239        case OpCodes.LT: {
240            double x = Evaluate();
241            double y = Evaluate();
242            if (x < y) return 1.0;
243            else return -1.0;
244          }
245        case OpCodes.Call: {
246            // evaluate sub-trees
247            // push on argStack in reverse order
248            for (int i = 0; i < currentInstr.nArguments; i++) {
249              argumentStack[argStackPointer + currentInstr.nArguments - i] = Evaluate();
250            }
251            argStackPointer += currentInstr.nArguments;
252
253            // save the pc
254            int nextPc = pc;
255            // set pc to start of function 
256            pc = currentInstr.iArg0;
257            // evaluate the function
258            double v = Evaluate();
259
260            // decrease the argument stack pointer by the number of arguments pushed
261            // to set the argStackPointer back to the original location
262            argStackPointer -= currentInstr.nArguments;
263
264            // restore the pc => evaluation will continue at point after my subtrees 
265            pc = nextPc;
266            return v;
267          }
268        case OpCodes.Arg: {
269            return argumentStack[argStackPointer - currentInstr.iArg0];
270          }
271        case OpCodes.Variable: {
272            var variableTreeNode = currentInstr.dynamicNode as VariableTreeNode;
273            return dataset[row, currentInstr.iArg0] * variableTreeNode.Weight;
274          }
275        case OpCodes.LagVariable: {
276            var lagVariableTreeNode = currentInstr.dynamicNode as LaggedVariableTreeNode;
277            int actualRow = row + lagVariableTreeNode.Lag;
278            if (actualRow < 0 || actualRow >= dataset.Rows) throw new ArgumentException("Out of range access to dataset row: " + row);
279            return dataset[actualRow, currentInstr.iArg0] * lagVariableTreeNode.Weight;
280          }
281        case OpCodes.Constant: {
282            var constTreeNode = currentInstr.dynamicNode as ConstantTreeNode;
283            return constTreeNode.Value;
284          }
285        default: throw new NotSupportedException();
286      }
287    }
288
289    // skips a whole branch
290    protected void SkipBakedCode() {
291      int i = 1;
292      while (i > 0) {
293        i += code[pc++].nArguments;
294        i--;
295      }
296    }
297  }
298}
Note: See TracBrowser for help on using the repository browser.