#region License Information /* HeuristicLab * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; using HeuristicLab.Parameters; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; namespace HeuristicLab.Problems.DataAnalysis.Symbolic { [StorableClass] [Item("SymbolicDataAnalysisExpressionTreeFastInterpreter", "Fast interpreter for symbolic expression trees including automatically defined functions.")] public class SymbolicDataAnalysisExpressionTreeFastInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter { private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic"; private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions"; public override bool CanChangeName { get { return false; } } public override bool CanChangeDescription { get { return false; } } #region parameter properties public IValueParameter CheckExpressionsWithIntervalArithmeticParameter { get { return (IValueParameter)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; } } public IValueParameter EvaluatedSolutionsParameter { get { return (IValueParameter)Parameters[EvaluatedSolutionsParameterName]; } } #endregion #region properties public BoolValue CheckExpressionsWithIntervalArithmetic { get { return CheckExpressionsWithIntervalArithmeticParameter.Value; } set { CheckExpressionsWithIntervalArithmeticParameter.Value = value; } } public IntValue EvaluatedSolutions { get { return EvaluatedSolutionsParameter.Value; } set { EvaluatedSolutionsParameter.Value = value; } } #endregion [StorableConstructor] protected SymbolicDataAnalysisExpressionTreeFastInterpreter(bool deserializing) : base(deserializing) { } protected SymbolicDataAnalysisExpressionTreeFastInterpreter( SymbolicDataAnalysisExpressionTreeFastInterpreter original, Cloner cloner) : base(original, cloner) { } public override IDeepCloneable Clone(Cloner cloner) { return new SymbolicDataAnalysisExpressionTreeFastInterpreter(this, cloner); } public SymbolicDataAnalysisExpressionTreeFastInterpreter() : base( "SymbolicDataAnalysisExpressionTreeFastInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.") { Parameters.Add(new ValueParameter(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false))); Parameters.Add(new ValueParameter(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0))); } protected SymbolicDataAnalysisExpressionTreeFastInterpreter(string name, string description) : base(name, description) { Parameters.Add(new ValueParameter(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false))); Parameters.Add(new ValueParameter(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0))); } [StorableHook(HookType.AfterDeserialization)] private void AfterDeserialization() { if (!Parameters.ContainsKey(EvaluatedSolutionsParameterName)) Parameters.Add(new ValueParameter(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0))); } #region IStatefulItem public void InitializeState() { EvaluatedSolutions.Value = 0; } public void ClearState() { } #endregion public IEnumerable GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable rows) { if (CheckExpressionsWithIntervalArithmetic.Value) throw new NotSupportedException( "Interval arithmetic is not yet supported in the symbolic data analysis interpreter."); lock (EvaluatedSolutions) { EvaluatedSolutions.Value++; // increment the evaluated solutions counter } var root = tree.Root.GetSubtree(0).GetSubtree(0); var nodes = new List { root }; var code = new List{ new Instruction { dynamicNode = root, nArguments = (byte) root.SubtreeCount, opCode = OpCodes.MapSymbolToOpCode(root) } }; // iterate breadth-wise over tree nodes and produce an array of instructions int i = 0; while (i != nodes.Count) { if (nodes[i].SubtreeCount > 0) { // save index of the first child in the instructions array code[i].childIndex = code.Count; for (int j = 0; j != nodes[i].SubtreeCount; ++j) { var s = nodes[i].GetSubtree(j); nodes.Add(s); code.Add(new Instruction { dynamicNode = s, nArguments = (byte)s.SubtreeCount, opCode = OpCodes.MapSymbolToOpCode(s) }); } } ++i; } // fill in iArg0 value for terminal nodes foreach (var instr in code) { switch (instr.opCode) { case OpCodes.Variable: { var variableTreeNode = (VariableTreeNode)instr.dynamicNode; instr.iArg0 = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName); } break; case OpCodes.LagVariable: { var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode; instr.iArg0 = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName); } break; case OpCodes.VariableCondition: { var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode; instr.iArg0 = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName); } break; } } var array = code.ToArray(); foreach (var rowEnum in rows) { int row = rowEnum; EvaluateFast(dataset, ref row, array); yield return code[0].value; } } private void EvaluateFast(Dataset dataset, ref int row, Instruction[] code) { for (int i = code.Length - 1; i >= 0; --i) { var instr = code[i]; switch (instr.opCode) { case OpCodes.Add: { double s = code[instr.childIndex].value; for (int j = 1; j != instr.nArguments; ++j) { s += code[instr.childIndex + j].value; } instr.value = s; } break; case OpCodes.Sub: { double s = code[instr.childIndex].value; for (int j = 1; j != instr.nArguments; ++j) { s -= code[instr.childIndex + j].value; } if (instr.nArguments == 1) s = -s; instr.value = s; } break; case OpCodes.Mul: { double p = code[instr.childIndex].value; for (int j = 1; j != instr.nArguments; ++j) { p *= code[instr.childIndex + j].value; } instr.value = p; } break; case OpCodes.Div: { double p = code[instr.childIndex].value; for (int j = 1; j != instr.nArguments; ++j) { p /= code[instr.childIndex + j].value; } if (instr.nArguments == 1) p = 1.0 / p; instr.value = p; } break; case OpCodes.Average: { double s = code[instr.childIndex].value; for (int j = 1; j != instr.nArguments; ++j) { s += code[instr.childIndex + j].value; } instr.value = s / instr.nArguments; } break; case OpCodes.Cos: { instr.value = Math.Cos(code[instr.childIndex].value); } break; case OpCodes.Sin: { instr.value = Math.Sin(code[instr.childIndex].value); } break; case OpCodes.Tan: { instr.value = Math.Tan(code[instr.childIndex].value); } break; case OpCodes.Root: { double x = code[instr.childIndex].value; double y = code[instr.childIndex + 1].value; instr.value = Math.Pow(x, 1 / y); } break; case OpCodes.Exp: { instr.value = Math.Exp(code[instr.childIndex].value); } break; case OpCodes.Log: { instr.value = Math.Log(code[instr.childIndex].value); } break; case OpCodes.Gamma: { var x = code[instr.childIndex].value; instr.value = double.IsNaN(x) ? double.NaN : alglib.gammafunction(x); } break; case OpCodes.Psi: { var x = code[instr.childIndex].value; if (double.IsNaN(x)) instr.value = double.NaN; else if (x <= 0 && (Math.Floor(x) - x).IsAlmost(0)) instr.value = double.NaN; else instr.value = alglib.psi(x); } break; case OpCodes.Dawson: { var x = code[instr.childIndex].value; instr.value = double.IsNaN(x) ? double.NaN : alglib.dawsonintegral(x); } break; case OpCodes.ExponentialIntegralEi: { var x = code[instr.childIndex].value; instr.value = double.IsNaN(x) ? double.NaN : alglib.exponentialintegralei(x); } break; case OpCodes.SineIntegral: { double si, ci; var x = code[instr.childIndex].value; if (double.IsNaN(x)) instr.value = double.NaN; else { alglib.sinecosineintegrals(x, out si, out ci); instr.value = si; } } break; case OpCodes.CosineIntegral: { double si, ci; var x = code[instr.childIndex].value; if (double.IsNaN(x)) instr.value = double.NaN; else { alglib.sinecosineintegrals(x, out si, out ci); instr.value = si; } } break; case OpCodes.HyperbolicSineIntegral: { double shi, chi; var x = code[instr.childIndex].value; if (double.IsNaN(x)) instr.value = double.NaN; else { alglib.hyperbolicsinecosineintegrals(x, out shi, out chi); instr.value = shi; } } break; case OpCodes.HyperbolicCosineIntegral: { double shi, chi; var x = code[instr.childIndex].value; if (double.IsNaN(x)) instr.value = double.NaN; else { alglib.hyperbolicsinecosineintegrals(x, out shi, out chi); instr.value = chi; } } break; case OpCodes.FresnelCosineIntegral: { double c = 0, s = 0; var x = code[instr.childIndex].value; if (double.IsNaN(x)) instr.value = double.NaN; else { alglib.fresnelintegral(x, ref c, ref s); instr.value = c; } } break; case OpCodes.FresnelSineIntegral: { double c = 0, s = 0; var x = code[instr.childIndex].value; if (double.IsNaN(x)) instr.value = double.NaN; else { alglib.fresnelintegral(x, ref c, ref s); instr.value = s; } } break; case OpCodes.AiryA: { double ai, aip, bi, bip; var x = code[instr.childIndex].value; if (double.IsNaN(x)) instr.value = double.NaN; else { alglib.airy(x, out ai, out aip, out bi, out bip); instr.value = ai; } } break; case OpCodes.AiryB: { double ai, aip, bi, bip; var x = code[instr.childIndex].value; if (double.IsNaN(x)) instr.value = double.NaN; else { alglib.airy(x, out ai, out aip, out bi, out bip); instr.value = bi; } } break; case OpCodes.Norm: { var x = code[instr.childIndex].value; if (double.IsNaN(x)) instr.value = double.NaN; else instr.value = alglib.normaldistribution(x); } break; case OpCodes.Erf: { var x = code[instr.childIndex].value; if (double.IsNaN(x)) instr.value = double.NaN; else instr.value = alglib.errorfunction(x); } break; case OpCodes.Bessel: { var x = code[instr.childIndex].value; if (double.IsNaN(x)) instr.value = double.NaN; else instr.value = alglib.besseli0(x); } break; case OpCodes.IfThenElse: { double condition = code[instr.childIndex].value; double result; if (condition > 0.0) { result = code[instr.childIndex + 1].value; } else { result = code[instr.childIndex + 2].value; } instr.value = result; } break; case OpCodes.AND: { double result = code[instr.childIndex].value; for (int j = 1; j < instr.nArguments; j++) { if (result > 0.0) result = code[instr.childIndex + j].value; else break; } instr.value = result > 0.0 ? 1.0 : -1.0; } break; case OpCodes.OR: { double result = code[instr.childIndex].value; for (int j = 1; j < instr.nArguments; j++) { if (result <= 0.0) result = code[instr.childIndex + j].value; else break; } instr.value = result > 0.0 ? 1.0 : -1.0; } break; case OpCodes.NOT: { instr.value = code[instr.childIndex].value > 0.0 ? -1.0 : 1.0; } break; case OpCodes.GT: { double x = code[instr.childIndex].value; double y = code[instr.childIndex + 1].value; instr.value = x > y ? 1.0 : -1.0; } break; case OpCodes.LT: { double x = code[instr.childIndex].value; double y = code[instr.childIndex + 1].value; instr.value = x < y ? 1.0 : -1.0; } break; case OpCodes.TimeLag: { throw new NotSupportedException(); } case OpCodes.Integral: { throw new NotSupportedException(); } case OpCodes.Derivative: { throw new NotSupportedException(); } case OpCodes.Arg: { throw new NotSupportedException(); } case OpCodes.Variable: { if (row < 0 || row >= dataset.Rows) instr.value = double.NaN; var variableTreeNode = (VariableTreeNode)instr.dynamicNode; instr.value = ((IList)instr.iArg0)[row] * variableTreeNode.Weight; } break; case OpCodes.LagVariable: { var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode; int actualRow = row + laggedVariableTreeNode.Lag; if (actualRow < 0 || actualRow >= dataset.Rows) instr.value = double.NaN; instr.value = ((IList)instr.iArg0)[actualRow] * laggedVariableTreeNode.Weight; } break; case OpCodes.Constant: { var constTreeNode = (ConstantTreeNode)instr.dynamicNode; instr.value = constTreeNode.Value; } break; case OpCodes.VariableCondition: { if (row < 0 || row >= dataset.Rows) instr.value = double.NaN; var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode; double variableValue = ((IList)instr.iArg0)[row]; double x = variableValue - variableConditionTreeNode.Threshold; double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x)); double trueBranch = code[instr.childIndex].value; double falseBranch = code[instr.childIndex + 1].value; instr.value = trueBranch * p + falseBranch * (1 - p); } break; default: throw new NotSupportedException(); } } } } }