Free cookie consent management tool by TermsFeed Policy Generator

source: branches/crossvalidation-2434/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeInterpreter.cs @ 14728

Last change on this file since 14728 was 14029, checked in by gkronber, 9 years ago

#2434: merged trunk changes r12934:14026 from trunk to branch

File size: 21.4 KB
RevLine 
[5571]1#region License Information
2/* HeuristicLab
[12012]3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[5571]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
[6740]26using HeuristicLab.Data;
[5571]27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
[6740]28using HeuristicLab.Parameters;
[5571]29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30
31namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
32  [StorableClass]
33  [Item("SymbolicDataAnalysisExpressionTreeInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.")]
[14029]34  public class SymbolicDataAnalysisExpressionTreeInterpreter : ParameterizedNamedItem,
35    ISymbolicDataAnalysisExpressionTreeInterpreter {
[5749]36    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
[14029]37    private const string CheckExpressionsWithIntervalArithmeticParameterDescription = "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.";
[7615]38    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
[5571]39
[14029]40    public override bool CanChangeName {
41      get { return false; }
42    }
[5571]43
[14029]44    public override bool CanChangeDescription {
45      get { return false; }
46    }
47
[5749]48    #region parameter properties
[14029]49    public IFixedValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
50      get { return (IFixedValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
[5749]51    }
[7615]52
[14029]53    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
54      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
[7615]55    }
[5749]56    #endregion
57
58    #region properties
[14029]59    public bool CheckExpressionsWithIntervalArithmetic {
60      get { return CheckExpressionsWithIntervalArithmeticParameter.Value.Value; }
61      set { CheckExpressionsWithIntervalArithmeticParameter.Value.Value = value; }
[5749]62    }
[7615]63
[14029]64    public int EvaluatedSolutions {
65      get { return EvaluatedSolutionsParameter.Value.Value; }
66      set { EvaluatedSolutionsParameter.Value.Value = value; }
[7615]67    }
[5749]68    #endregion
69
[5571]70    [StorableConstructor]
[8436]71    protected SymbolicDataAnalysisExpressionTreeInterpreter(bool deserializing) : base(deserializing) { }
[14029]72
73    protected SymbolicDataAnalysisExpressionTreeInterpreter(SymbolicDataAnalysisExpressionTreeInterpreter original,
74      Cloner cloner)
75      : base(original, cloner) { }
76
[5571]77    public override IDeepCloneable Clone(Cloner cloner) {
78      return new SymbolicDataAnalysisExpressionTreeInterpreter(this, cloner);
79    }
80
81    public SymbolicDataAnalysisExpressionTreeInterpreter()
[5749]82      : base("SymbolicDataAnalysisExpressionTreeInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.") {
[14029]83      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
84      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
[5571]85    }
86
[8436]87    protected SymbolicDataAnalysisExpressionTreeInterpreter(string name, string description)
88      : base(name, description) {
[14029]89      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
90      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
[8436]91    }
92
[7615]93    [StorableHook(HookType.AfterDeserialization)]
94    private void AfterDeserialization() {
[14029]95      var evaluatedSolutions = new IntValue(0);
96      var checkExpressionsWithIntervalArithmetic = new BoolValue(false);
97      if (Parameters.ContainsKey(EvaluatedSolutionsParameterName)) {
98        var evaluatedSolutionsParameter = (IValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName];
99        evaluatedSolutions = evaluatedSolutionsParameter.Value;
100        Parameters.Remove(EvaluatedSolutionsParameterName);
101      }
102      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", evaluatedSolutions));
103      if (Parameters.ContainsKey(CheckExpressionsWithIntervalArithmeticParameterName)) {
104        var checkExpressionsWithIntervalArithmeticParameter = (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName];
105        Parameters.Remove(CheckExpressionsWithIntervalArithmeticParameterName);
106        checkExpressionsWithIntervalArithmetic = checkExpressionsWithIntervalArithmeticParameter.Value;
107      }
108      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, CheckExpressionsWithIntervalArithmeticParameterDescription, checkExpressionsWithIntervalArithmetic));
[7615]109    }
110
111    #region IStatefulItem
112    public void InitializeState() {
[14029]113      EvaluatedSolutions = 0;
[7615]114    }
115
[14029]116    public void ClearState() { }
[7615]117    #endregion
118
[14029]119    private readonly object syncRoot = new object();
120    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset,
121      IEnumerable<int> rows) {
122      if (CheckExpressionsWithIntervalArithmetic) {
[8436]123        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
[14029]124      }
[7120]125
[14029]126      lock (syncRoot) {
127        EvaluatedSolutions++; // increment the evaluated solutions counter
[9004]128      }
[8436]129      var state = PrepareInterpreterState(tree, dataset);
130
131      foreach (var rowEnum in rows) {
132        int row = rowEnum;
133        yield return Evaluate(dataset, ref row, state);
134        state.Reset();
135      }
[7154]136    }
137
[12509]138    private static InterpreterState PrepareInterpreterState(ISymbolicExpressionTree tree, IDataset dataset) {
[8436]139      Instruction[] code = SymbolicExpressionTreeCompiler.Compile(tree, OpCodes.MapSymbolToOpCode);
[5987]140      int necessaryArgStackSize = 0;
[8436]141      foreach (Instruction instr in code) {
[6860]142        if (instr.opCode == OpCodes.Variable) {
[8436]143          var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
[9828]144          instr.data = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
[5571]145        } else if (instr.opCode == OpCodes.LagVariable) {
[8436]146          var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
[9828]147          instr.data = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
[6860]148        } else if (instr.opCode == OpCodes.VariableCondition) {
[8436]149          var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
[9828]150          instr.data = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
[5987]151        } else if (instr.opCode == OpCodes.Call) {
152          necessaryArgStackSize += instr.nArguments + 1;
[5571]153        }
154      }
[8436]155      return new InterpreterState(code, necessaryArgStackSize);
156    }
[5571]157
[12509]158    public virtual double Evaluate(IDataset dataset, ref int row, InterpreterState state) {
[5571]159      Instruction currentInstr = state.NextInstruction();
160      switch (currentInstr.opCode) {
161        case OpCodes.Add: {
[8436]162            double s = Evaluate(dataset, ref row, state);
[5571]163            for (int i = 1; i < currentInstr.nArguments; i++) {
[8436]164              s += Evaluate(dataset, ref row, state);
[5571]165            }
166            return s;
167          }
168        case OpCodes.Sub: {
[8436]169            double s = Evaluate(dataset, ref row, state);
[5571]170            for (int i = 1; i < currentInstr.nArguments; i++) {
[8436]171              s -= Evaluate(dataset, ref row, state);
[5571]172            }
[14029]173            if (currentInstr.nArguments == 1) { s = -s; }
[5571]174            return s;
175          }
176        case OpCodes.Mul: {
[8436]177            double p = Evaluate(dataset, ref row, state);
[5571]178            for (int i = 1; i < currentInstr.nArguments; i++) {
[8436]179              p *= Evaluate(dataset, ref row, state);
[5571]180            }
181            return p;
182          }
183        case OpCodes.Div: {
[8436]184            double p = Evaluate(dataset, ref row, state);
[5571]185            for (int i = 1; i < currentInstr.nArguments; i++) {
[8436]186              p /= Evaluate(dataset, ref row, state);
[5571]187            }
[14029]188            if (currentInstr.nArguments == 1) { p = 1.0 / p; }
[5571]189            return p;
190          }
191        case OpCodes.Average: {
[8436]192            double sum = Evaluate(dataset, ref row, state);
[5571]193            for (int i = 1; i < currentInstr.nArguments; i++) {
[8436]194              sum += Evaluate(dataset, ref row, state);
[5571]195            }
196            return sum / currentInstr.nArguments;
197          }
198        case OpCodes.Cos: {
[8436]199            return Math.Cos(Evaluate(dataset, ref row, state));
[5571]200          }
201        case OpCodes.Sin: {
[8436]202            return Math.Sin(Evaluate(dataset, ref row, state));
[5571]203          }
204        case OpCodes.Tan: {
[8436]205            return Math.Tan(Evaluate(dataset, ref row, state));
[5571]206          }
[7842]207        case OpCodes.Square: {
[8436]208            return Math.Pow(Evaluate(dataset, ref row, state), 2);
[7842]209          }
[5571]210        case OpCodes.Power: {
[8436]211            double x = Evaluate(dataset, ref row, state);
212            double y = Math.Round(Evaluate(dataset, ref row, state));
[5571]213            return Math.Pow(x, y);
214          }
[7842]215        case OpCodes.SquareRoot: {
[8436]216            return Math.Sqrt(Evaluate(dataset, ref row, state));
[7842]217          }
[5571]218        case OpCodes.Root: {
[8436]219            double x = Evaluate(dataset, ref row, state);
220            double y = Math.Round(Evaluate(dataset, ref row, state));
[5571]221            return Math.Pow(x, 1 / y);
222          }
223        case OpCodes.Exp: {
[8436]224            return Math.Exp(Evaluate(dataset, ref row, state));
[5571]225          }
226        case OpCodes.Log: {
[8436]227            return Math.Log(Evaluate(dataset, ref row, state));
[5571]228          }
[7842]229        case OpCodes.Gamma: {
[8436]230            var x = Evaluate(dataset, ref row, state);
[14029]231            if (double.IsNaN(x)) { return double.NaN; } else { return alglib.gammafunction(x); }
[7842]232          }
233        case OpCodes.Psi: {
[8436]234            var x = Evaluate(dataset, ref row, state);
[7842]235            if (double.IsNaN(x)) return double.NaN;
[8430]236            else if (x <= 0 && (Math.Floor(x) - x).IsAlmost(0)) return double.NaN;
[7842]237            return alglib.psi(x);
238          }
239        case OpCodes.Dawson: {
[8436]240            var x = Evaluate(dataset, ref row, state);
[14029]241            if (double.IsNaN(x)) { return double.NaN; }
[7842]242            return alglib.dawsonintegral(x);
243          }
244        case OpCodes.ExponentialIntegralEi: {
[8436]245            var x = Evaluate(dataset, ref row, state);
[14029]246            if (double.IsNaN(x)) { return double.NaN; }
[7842]247            return alglib.exponentialintegralei(x);
248          }
249        case OpCodes.SineIntegral: {
250            double si, ci;
[8436]251            var x = Evaluate(dataset, ref row, state);
[7842]252            if (double.IsNaN(x)) return double.NaN;
253            else {
254              alglib.sinecosineintegrals(x, out si, out ci);
255              return si;
256            }
257          }
258        case OpCodes.CosineIntegral: {
259            double si, ci;
[8436]260            var x = Evaluate(dataset, ref row, state);
[7842]261            if (double.IsNaN(x)) return double.NaN;
262            else {
263              alglib.sinecosineintegrals(x, out si, out ci);
264              return ci;
265            }
266          }
267        case OpCodes.HyperbolicSineIntegral: {
268            double shi, chi;
[8436]269            var x = Evaluate(dataset, ref row, state);
[7842]270            if (double.IsNaN(x)) return double.NaN;
271            else {
272              alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
273              return shi;
274            }
275          }
276        case OpCodes.HyperbolicCosineIntegral: {
277            double shi, chi;
[8436]278            var x = Evaluate(dataset, ref row, state);
[7842]279            if (double.IsNaN(x)) return double.NaN;
280            else {
281              alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
282              return chi;
283            }
284          }
285        case OpCodes.FresnelCosineIntegral: {
286            double c = 0, s = 0;
[8436]287            var x = Evaluate(dataset, ref row, state);
[7842]288            if (double.IsNaN(x)) return double.NaN;
289            else {
290              alglib.fresnelintegral(x, ref c, ref s);
291              return c;
292            }
293          }
294        case OpCodes.FresnelSineIntegral: {
295            double c = 0, s = 0;
[8436]296            var x = Evaluate(dataset, ref row, state);
[7842]297            if (double.IsNaN(x)) return double.NaN;
298            else {
299              alglib.fresnelintegral(x, ref c, ref s);
300              return s;
301            }
302          }
303        case OpCodes.AiryA: {
304            double ai, aip, bi, bip;
[8436]305            var x = Evaluate(dataset, ref row, state);
[7842]306            if (double.IsNaN(x)) return double.NaN;
307            else {
308              alglib.airy(x, out ai, out aip, out bi, out bip);
309              return ai;
310            }
311          }
312        case OpCodes.AiryB: {
313            double ai, aip, bi, bip;
[8436]314            var x = Evaluate(dataset, ref row, state);
[7842]315            if (double.IsNaN(x)) return double.NaN;
316            else {
317              alglib.airy(x, out ai, out aip, out bi, out bip);
318              return bi;
319            }
320          }
321        case OpCodes.Norm: {
[8436]322            var x = Evaluate(dataset, ref row, state);
[7842]323            if (double.IsNaN(x)) return double.NaN;
324            else return alglib.normaldistribution(x);
325          }
326        case OpCodes.Erf: {
[8436]327            var x = Evaluate(dataset, ref row, state);
[7842]328            if (double.IsNaN(x)) return double.NaN;
329            else return alglib.errorfunction(x);
330          }
331        case OpCodes.Bessel: {
[8436]332            var x = Evaluate(dataset, ref row, state);
[7842]333            if (double.IsNaN(x)) return double.NaN;
334            else return alglib.besseli0(x);
335          }
[5571]336        case OpCodes.IfThenElse: {
[8436]337            double condition = Evaluate(dataset, ref row, state);
[5571]338            double result;
339            if (condition > 0.0) {
[8436]340              result = Evaluate(dataset, ref row, state); state.SkipInstructions();
[5571]341            } else {
[8436]342              state.SkipInstructions(); result = Evaluate(dataset, ref row, state);
[5571]343            }
344            return result;
345          }
346        case OpCodes.AND: {
[8436]347            double result = Evaluate(dataset, ref row, state);
[5571]348            for (int i = 1; i < currentInstr.nArguments; i++) {
[8436]349              if (result > 0.0) result = Evaluate(dataset, ref row, state);
[5571]350              else {
[8436]351                state.SkipInstructions();
[5571]352              }
353            }
[6732]354            return result > 0.0 ? 1.0 : -1.0;
[5571]355          }
356        case OpCodes.OR: {
[8436]357            double result = Evaluate(dataset, ref row, state);
[5571]358            for (int i = 1; i < currentInstr.nArguments; i++) {
[8436]359              if (result <= 0.0) result = Evaluate(dataset, ref row, state);
[5571]360              else {
[8436]361                state.SkipInstructions();
[5571]362              }
363            }
364            return result > 0.0 ? 1.0 : -1.0;
365          }
366        case OpCodes.NOT: {
[8436]367            return Evaluate(dataset, ref row, state) > 0.0 ? -1.0 : 1.0;
[5571]368          }
[10774]369        case OpCodes.XOR: {
[10788]370            //mkommend: XOR on multiple inputs is defined as true if the number of positive signals is odd
371            // this is equal to a consecutive execution of binary XOR operations.
372            int positiveSignals = 0;
373            for (int i = 0; i < currentInstr.nArguments; i++) {
[14029]374              if (Evaluate(dataset, ref row, state) > 0.0) { positiveSignals++; }
[10774]375            }
[10788]376            return positiveSignals % 2 != 0 ? 1.0 : -1.0;
[10774]377          }
[5571]378        case OpCodes.GT: {
[8436]379            double x = Evaluate(dataset, ref row, state);
380            double y = Evaluate(dataset, ref row, state);
[14029]381            if (x > y) { return 1.0; } else { return -1.0; }
[5571]382          }
383        case OpCodes.LT: {
[8436]384            double x = Evaluate(dataset, ref row, state);
385            double y = Evaluate(dataset, ref row, state);
[14029]386            if (x < y) { return 1.0; } else { return -1.0; }
[5571]387          }
388        case OpCodes.TimeLag: {
389            var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
390            row += timeLagTreeNode.Lag;
[8436]391            double result = Evaluate(dataset, ref row, state);
[5571]392            row -= timeLagTreeNode.Lag;
393            return result;
394          }
395        case OpCodes.Integral: {
396            int savedPc = state.ProgramCounter;
397            var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
398            double sum = 0.0;
399            for (int i = 0; i < Math.Abs(timeLagTreeNode.Lag); i++) {
400              row += Math.Sign(timeLagTreeNode.Lag);
[8436]401              sum += Evaluate(dataset, ref row, state);
[5571]402              state.ProgramCounter = savedPc;
403            }
404            row -= timeLagTreeNode.Lag;
[8436]405            sum += Evaluate(dataset, ref row, state);
[5571]406            return sum;
407          }
408
409        //mkommend: derivate calculation taken from:
410        //http://www.holoborodko.com/pavel/numerical-methods/numerical-derivative/smooth-low-noise-differentiators/
411        //one sided smooth differentiatior, N = 4
412        // y' = 1/8h (f_i + 2f_i-1, -2 f_i-3 - f_i-4)
413        case OpCodes.Derivative: {
414            int savedPc = state.ProgramCounter;
[8436]415            double f_0 = Evaluate(dataset, ref row, state); row--;
[5571]416            state.ProgramCounter = savedPc;
[8436]417            double f_1 = Evaluate(dataset, ref row, state); row -= 2;
[5571]418            state.ProgramCounter = savedPc;
[8436]419            double f_3 = Evaluate(dataset, ref row, state); row--;
[5571]420            state.ProgramCounter = savedPc;
[8436]421            double f_4 = Evaluate(dataset, ref row, state);
[5571]422            row += 4;
423
424            return (f_0 + 2 * f_1 - 2 * f_3 - f_4) / 8; // h = 1
425          }
426        case OpCodes.Call: {
427            // evaluate sub-trees
428            double[] argValues = new double[currentInstr.nArguments];
429            for (int i = 0; i < currentInstr.nArguments; i++) {
[8436]430              argValues[i] = Evaluate(dataset, ref row, state);
[5571]431            }
432            // push on argument values on stack
433            state.CreateStackFrame(argValues);
434
435            // save the pc
436            int savedPc = state.ProgramCounter;
437            // set pc to start of function 
[9828]438            state.ProgramCounter = (ushort)currentInstr.data;
[5571]439            // evaluate the function
[8436]440            double v = Evaluate(dataset, ref row, state);
[5571]441
442            // delete the stack frame
443            state.RemoveStackFrame();
444
445            // restore the pc => evaluation will continue at point after my subtrees 
446            state.ProgramCounter = savedPc;
447            return v;
448          }
449        case OpCodes.Arg: {
[9828]450            return state.GetStackFrameValue((ushort)currentInstr.data);
[5571]451          }
452        case OpCodes.Variable: {
[8486]453            if (row < 0 || row >= dataset.Rows) return double.NaN;
[6740]454            var variableTreeNode = (VariableTreeNode)currentInstr.dynamicNode;
[9828]455            return ((IList<double>)currentInstr.data)[row] * variableTreeNode.Weight;
[5571]456          }
457        case OpCodes.LagVariable: {
[6740]458            var laggedVariableTreeNode = (LaggedVariableTreeNode)currentInstr.dynamicNode;
[5571]459            int actualRow = row + laggedVariableTreeNode.Lag;
[14029]460            if (actualRow < 0 || actualRow >= dataset.Rows) { return double.NaN; }
[9828]461            return ((IList<double>)currentInstr.data)[actualRow] * laggedVariableTreeNode.Weight;
[5571]462          }
463        case OpCodes.Constant: {
[8436]464            var constTreeNode = (ConstantTreeNode)currentInstr.dynamicNode;
[5897]465            return constTreeNode.Value;
[5571]466          }
467
468        //mkommend: this symbol uses the logistic function f(x) = 1 / (1 + e^(-alpha * x) )
469        //to determine the relative amounts of the true and false branch see http://en.wikipedia.org/wiki/Logistic_function
470        case OpCodes.VariableCondition: {
[8486]471            if (row < 0 || row >= dataset.Rows) return double.NaN;
[5571]472            var variableConditionTreeNode = (VariableConditionTreeNode)currentInstr.dynamicNode;
[9828]473            double variableValue = ((IList<double>)currentInstr.data)[row];
[5897]474            double x = variableValue - variableConditionTreeNode.Threshold;
[5571]475            double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
476
[8436]477            double trueBranch = Evaluate(dataset, ref row, state);
478            double falseBranch = Evaluate(dataset, ref row, state);
[5571]479
480            return trueBranch * p + falseBranch * (1 - p);
481          }
[14029]482        default:
483          throw new NotSupportedException();
[5571]484      }
485    }
486  }
[14029]487}
Note: See TracBrowser for help on using the repository browser.