1  #region License Information


2  /* HeuristicLab


3  * Copyright (C) 20022010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)


4  *


5  * This file is part of HeuristicLab.


6  *


7  * HeuristicLab is free software: you can redistribute it and/or modify


8  * it under the terms of the GNU General Public License as published by


9  * the Free Software Foundation, either version 3 of the License, or


10  * (at your option) any later version.


11  *


12  * HeuristicLab is distributed in the hope that it will be useful,


13  * but WITHOUT ANY WARRANTY; without even the implied warranty of


14  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the


15  * GNU General Public License for more details.


16  *


17  * You should have received a copy of the GNU General Public License


18  * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.


19  */


20  #endregion


21 


22  using System;


23  using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;


24  using HeuristicLab.Common;


25  using HeuristicLab.Core;


26  using System.Collections.Generic;


27  using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;


28  using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Symbols;


29  using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;


30  using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Compiler;


31 


32  namespace HeuristicLab.Problems.DataAnalysis.Symbolic {


33  [StorableClass]


34  [Item("SimpleArithmeticExpressionInterpreter", "Interpreter for arithmetic symbolic expression trees including function calls.")]


35  // not thread safe!


36  public class SimpleArithmeticExpressionInterpreter : NamedItem, ISymbolicExpressionTreeInterpreter {


37  private class OpCodes {


38  public const byte Add = 1;


39  public const byte Sub = 2;


40  public const byte Mul = 3;


41  public const byte Div = 4;


42 


43  public const byte Sin = 5;


44  public const byte Cos = 6;


45  public const byte Tan = 7;


46 


47  public const byte Log = 8;


48  public const byte Exp = 9;


49 


50  public const byte IfThenElse = 10;


51 


52  public const byte GT = 11;


53  public const byte LT = 12;


54 


55  public const byte AND = 13;


56  public const byte OR = 14;


57  public const byte NOT = 15;


58 


59 


60  public const byte Average = 16;


61 


62  public const byte Call = 17;


63 


64  public const byte Variable = 18;


65  public const byte LagVariable = 19;


66  public const byte Constant = 20;


67  public const byte Arg = 21;


68  }


69 


70  private Dictionary<Type, byte> symbolToOpcode = new Dictionary<Type, byte>() {


71  { typeof(Addition), OpCodes.Add },


72  { typeof(Subtraction), OpCodes.Sub },


73  { typeof(Multiplication), OpCodes.Mul },


74  { typeof(Division), OpCodes.Div },


75  { typeof(Sine), OpCodes.Sin },


76  { typeof(Cosine), OpCodes.Cos },


77  { typeof(Tangent), OpCodes.Tan },


78  { typeof(Logarithm), OpCodes.Log },


79  { typeof(Exponential), OpCodes.Exp },


80  { typeof(IfThenElse), OpCodes.IfThenElse },


81  { typeof(GreaterThan), OpCodes.GT },


82  { typeof(LessThan), OpCodes.LT },


83  { typeof(And), OpCodes.AND },


84  { typeof(Or), OpCodes.OR },


85  { typeof(Not), OpCodes.NOT},


86  { typeof(Average), OpCodes.Average},


87  { typeof(InvokeFunction), OpCodes.Call },


88  { typeof(HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols.Variable), OpCodes.Variable },


89  { typeof(LaggedVariable), OpCodes.LagVariable },


90  { typeof(Constant), OpCodes.Constant },


91  { typeof(Argument), OpCodes.Arg },


92  };


93  private const int ARGUMENT_STACK_SIZE = 1024;


94 


95  private Dataset dataset;


96  private int row;


97  private Instruction[] code;


98  private int pc;


99 


100  public override bool CanChangeName {


101  get { return false; }


102  }


103  public override bool CanChangeDescription {


104  get { return false; }


105  }


106 


107  public SimpleArithmeticExpressionInterpreter()


108  : base() {


109  }


110 


111  public IEnumerable<double> GetSymbolicExpressionTreeValues(SymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows) {


112  this.dataset = dataset;


113  var compiler = new SymbolicExpressionTreeCompiler();


114  compiler.AddInstructionPostProcessingHook(PostProcessInstruction);


115  code = compiler.Compile(tree, MapSymbolToOpCode);


116  foreach (var row in rows) {


117  this.row = row;


118  pc = 0;


119  argStackPointer = 0;


120  yield return Evaluate();


121  }


122  }


123 


124  private Instruction PostProcessInstruction(Instruction instr) {


125  if (instr.opCode == OpCodes.Variable) {


126  var variableTreeNode = instr.dynamicNode as VariableTreeNode;


127  instr.iArg0 = (ushort)dataset.GetVariableIndex(variableTreeNode.VariableName);


128  } else if (instr.opCode == OpCodes.LagVariable) {


129  var variableTreeNode = instr.dynamicNode as LaggedVariableTreeNode;


130  instr.iArg0 = (ushort)dataset.GetVariableIndex(variableTreeNode.VariableName);


131  }


132  return instr;


133  }


134 


135  private byte MapSymbolToOpCode(SymbolicExpressionTreeNode treeNode) {


136  if (symbolToOpcode.ContainsKey(treeNode.Symbol.GetType()))


137  return symbolToOpcode[treeNode.Symbol.GetType()];


138  else


139  throw new NotSupportedException("Symbol: " + treeNode.Symbol);


140  }


141 


142  private double[] argumentStack = new double[ARGUMENT_STACK_SIZE];


143  private int argStackPointer;


144 


145  public double Evaluate() {


146  var currentInstr = code[pc++];


147  switch (currentInstr.opCode) {


148  case OpCodes.Add: {


149  double s = Evaluate();


150  for (int i = 1; i < currentInstr.nArguments; i++) {


151  s += Evaluate();


152  }


153  return s;


154  }


155  case OpCodes.Sub: {


156  double s = Evaluate();


157  for (int i = 1; i < currentInstr.nArguments; i++) {


158  s = Evaluate();


159  }


160  if (currentInstr.nArguments == 1) s = s;


161  return s;


162  }


163  case OpCodes.Mul: {


164  double p = Evaluate();


165  for (int i = 1; i < currentInstr.nArguments; i++) {


166  p *= Evaluate();


167  }


168  return p;


169  }


170  case OpCodes.Div: {


171  double p = Evaluate();


172  for (int i = 1; i < currentInstr.nArguments; i++) {


173  p /= Evaluate();


174  }


175  if (currentInstr.nArguments == 1) p = 1.0 / p;


176  return p;


177  }


178  case OpCodes.Average: {


179  double sum = Evaluate();


180  for (int i = 1; i < currentInstr.nArguments; i++) {


181  sum += Evaluate();


182  }


183  return sum / currentInstr.nArguments;


184  }


185  case OpCodes.Cos: {


186  return Math.Cos(Evaluate());


187  }


188  case OpCodes.Sin: {


189  return Math.Sin(Evaluate());


190  }


191  case OpCodes.Tan: {


192  return Math.Tan(Evaluate());


193  }


194  case OpCodes.Exp: {


195  return Math.Exp(Evaluate());


196  }


197  case OpCodes.Log: {


198  return Math.Log(Evaluate());


199  }


200  case OpCodes.IfThenElse: {


201  double condition = Evaluate();


202  double result;


203  if (condition > 0.0) {


204  result = Evaluate(); SkipBakedCode();


205  } else {


206  SkipBakedCode(); result = Evaluate();


207  }


208  return result;


209  }


210  case OpCodes.AND: {


211  double result = Evaluate();


212  for (int i = 1; i < currentInstr.nArguments; i++) {


213  if (result <= 0.0) SkipBakedCode();


214  else {


215  result = Evaluate();


216  }


217  }


218  return result <= 0.0 ? 1.0 : 1.0;


219  }


220  case OpCodes.OR: {


221  double result = Evaluate();


222  for (int i = 1; i < currentInstr.nArguments; i++) {


223  if (result > 0.0) SkipBakedCode();


224  else {


225  result = Evaluate();


226  }


227  }


228  return result > 0.0 ? 1.0 : 1.0;


229  }


230  case OpCodes.NOT: {


231  return Evaluate();


232  }


233  case OpCodes.GT: {


234  double x = Evaluate();


235  double y = Evaluate();


236  if (x > y) return 1.0;


237  else return 1.0;


238  }


239  case OpCodes.LT: {


240  double x = Evaluate();


241  double y = Evaluate();


242  if (x < y) return 1.0;


243  else return 1.0;


244  }


245  case OpCodes.Call: {


246  // evaluate subtrees


247  // push on argStack in reverse order


248  for (int i = 0; i < currentInstr.nArguments; i++) {


249  argumentStack[argStackPointer + currentInstr.nArguments  i] = Evaluate();


250  }


251  argStackPointer += currentInstr.nArguments;


252 


253  // save the pc


254  int nextPc = pc;


255  // set pc to start of function


256  pc = currentInstr.iArg0;


257  // evaluate the function


258  double v = Evaluate();


259 


260  // decrease the argument stack pointer by the number of arguments pushed


261  // to set the argStackPointer back to the original location


262  argStackPointer = currentInstr.nArguments;


263 


264  // restore the pc => evaluation will continue at point after my subtrees


265  pc = nextPc;


266  return v;


267  }


268  case OpCodes.Arg: {


269  return argumentStack[argStackPointer  currentInstr.iArg0];


270  }


271  case OpCodes.Variable: {


272  var variableTreeNode = currentInstr.dynamicNode as VariableTreeNode;


273  return dataset[row, currentInstr.iArg0] * variableTreeNode.Weight;


274  }


275  case OpCodes.LagVariable: {


276  var lagVariableTreeNode = currentInstr.dynamicNode as LaggedVariableTreeNode;


277  int actualRow = row + lagVariableTreeNode.Lag;


278  if (actualRow < 0  actualRow >= dataset.Rows) throw new ArgumentException("Out of range access to dataset row: " + row);


279  return dataset[actualRow, currentInstr.iArg0] * lagVariableTreeNode.Weight;


280  }


281  case OpCodes.Constant: {


282  var constTreeNode = currentInstr.dynamicNode as ConstantTreeNode;


283  return constTreeNode.Value;


284  }


285  default: throw new NotSupportedException();


286  }


287  }


288 


289  // skips a whole branch


290  protected void SkipBakedCode() {


291  int i = 1;


292  while (i > 0) {


293  i += code[pc++].nArguments;


294  i;


295  }


296  }


297  }


298  }

