Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2915-AbsoluteSymbol/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeInterpreter.cs @ 16003

Last change on this file since 16003 was 15944, checked in by gkronber, 7 years ago

#2915 added support for Abs() symbol to tree interpreter and linear interpreter as well as to the infix parser

File size: 23.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
33  [StorableClass]
34  [Item("SymbolicDataAnalysisExpressionTreeInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.")]
35  public class SymbolicDataAnalysisExpressionTreeInterpreter : ParameterizedNamedItem,
36    ISymbolicDataAnalysisExpressionTreeInterpreter {
37    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
38    private const string CheckExpressionsWithIntervalArithmeticParameterDescription = "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.";
39    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
40
41    public override bool CanChangeName {
42      get { return false; }
43    }
44
45    public override bool CanChangeDescription {
46      get { return false; }
47    }
48
49    #region parameter properties
50    public IFixedValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
51      get { return (IFixedValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
52    }
53
54    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
55      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
56    }
57    #endregion
58
59    #region properties
60    public bool CheckExpressionsWithIntervalArithmetic {
61      get { return CheckExpressionsWithIntervalArithmeticParameter.Value.Value; }
62      set { CheckExpressionsWithIntervalArithmeticParameter.Value.Value = value; }
63    }
64
65    public int EvaluatedSolutions {
66      get { return EvaluatedSolutionsParameter.Value.Value; }
67      set { EvaluatedSolutionsParameter.Value.Value = value; }
68    }
69    #endregion
70
71    [StorableConstructor]
72    protected SymbolicDataAnalysisExpressionTreeInterpreter(bool deserializing) : base(deserializing) { }
73
74    protected SymbolicDataAnalysisExpressionTreeInterpreter(SymbolicDataAnalysisExpressionTreeInterpreter original,
75      Cloner cloner)
76      : base(original, cloner) { }
77
78    public override IDeepCloneable Clone(Cloner cloner) {
79      return new SymbolicDataAnalysisExpressionTreeInterpreter(this, cloner);
80    }
81
82    public SymbolicDataAnalysisExpressionTreeInterpreter()
83      : base("SymbolicDataAnalysisExpressionTreeInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.") {
84      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
85      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
86    }
87
88    protected SymbolicDataAnalysisExpressionTreeInterpreter(string name, string description)
89      : base(name, description) {
90      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
91      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
92    }
93
94    [StorableHook(HookType.AfterDeserialization)]
95    private void AfterDeserialization() {
96      var evaluatedSolutions = new IntValue(0);
97      var checkExpressionsWithIntervalArithmetic = new BoolValue(false);
98      if (Parameters.ContainsKey(EvaluatedSolutionsParameterName)) {
99        var evaluatedSolutionsParameter = (IValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName];
100        evaluatedSolutions = evaluatedSolutionsParameter.Value;
101        Parameters.Remove(EvaluatedSolutionsParameterName);
102      }
103      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", evaluatedSolutions));
104      if (Parameters.ContainsKey(CheckExpressionsWithIntervalArithmeticParameterName)) {
105        var checkExpressionsWithIntervalArithmeticParameter = (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName];
106        Parameters.Remove(CheckExpressionsWithIntervalArithmeticParameterName);
107        checkExpressionsWithIntervalArithmetic = checkExpressionsWithIntervalArithmeticParameter.Value;
108      }
109      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, CheckExpressionsWithIntervalArithmeticParameterDescription, checkExpressionsWithIntervalArithmetic));
110    }
111
112    #region IStatefulItem
113    public void InitializeState() {
114      EvaluatedSolutions = 0;
115    }
116
117    public void ClearState() { }
118    #endregion
119
120    private readonly object syncRoot = new object();
121    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset,
122      IEnumerable<int> rows) {
123      if (CheckExpressionsWithIntervalArithmetic) {
124        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
125      }
126
127      lock (syncRoot) {
128        EvaluatedSolutions++; // increment the evaluated solutions counter
129      }
130      var state = PrepareInterpreterState(tree, dataset);
131
132      foreach (var rowEnum in rows) {
133        int row = rowEnum;
134        yield return Evaluate(dataset, ref row, state);
135        state.Reset();
136      }
137    }
138
139    private static InterpreterState PrepareInterpreterState(ISymbolicExpressionTree tree, IDataset dataset) {
140      Instruction[] code = SymbolicExpressionTreeCompiler.Compile(tree, OpCodes.MapSymbolToOpCode);
141      int necessaryArgStackSize = 0;
142      foreach (Instruction instr in code) {
143        if (instr.opCode == OpCodes.Variable) {
144          var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
145          instr.data = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
146        } else if (instr.opCode == OpCodes.FactorVariable) {
147          var factorTreeNode = instr.dynamicNode as FactorVariableTreeNode;
148          instr.data = dataset.GetReadOnlyStringValues(factorTreeNode.VariableName);
149        } else if (instr.opCode == OpCodes.BinaryFactorVariable) {
150          var factorTreeNode = instr.dynamicNode as BinaryFactorVariableTreeNode;
151          instr.data = dataset.GetReadOnlyStringValues(factorTreeNode.VariableName);
152        } else if (instr.opCode == OpCodes.LagVariable) {
153          var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
154          instr.data = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
155        } else if (instr.opCode == OpCodes.VariableCondition) {
156          var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
157          instr.data = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
158        } else if (instr.opCode == OpCodes.Call) {
159          necessaryArgStackSize += instr.nArguments + 1;
160        }
161      }
162      return new InterpreterState(code, necessaryArgStackSize);
163    }
164
165    public virtual double Evaluate(IDataset dataset, ref int row, InterpreterState state) {
166      Instruction currentInstr = state.NextInstruction();
167      switch (currentInstr.opCode) {
168        case OpCodes.Add: {
169            double s = Evaluate(dataset, ref row, state);
170            for (int i = 1; i < currentInstr.nArguments; i++) {
171              s += Evaluate(dataset, ref row, state);
172            }
173            return s;
174          }
175        case OpCodes.Sub: {
176            double s = Evaluate(dataset, ref row, state);
177            for (int i = 1; i < currentInstr.nArguments; i++) {
178              s -= Evaluate(dataset, ref row, state);
179            }
180            if (currentInstr.nArguments == 1) { s = -s; }
181            return s;
182          }
183        case OpCodes.Mul: {
184            double p = Evaluate(dataset, ref row, state);
185            for (int i = 1; i < currentInstr.nArguments; i++) {
186              p *= Evaluate(dataset, ref row, state);
187            }
188            return p;
189          }
190        case OpCodes.Div: {
191            double p = Evaluate(dataset, ref row, state);
192            for (int i = 1; i < currentInstr.nArguments; i++) {
193              p /= Evaluate(dataset, ref row, state);
194            }
195            if (currentInstr.nArguments == 1) { p = 1.0 / p; }
196            return p;
197          }
198        case OpCodes.Average: {
199            double sum = Evaluate(dataset, ref row, state);
200            for (int i = 1; i < currentInstr.nArguments; i++) {
201              sum += Evaluate(dataset, ref row, state);
202            }
203            return sum / currentInstr.nArguments;
204          }
205        case OpCodes.Absolute: {
206            return Math.Abs(Evaluate(dataset, ref row, state));
207          }
208        case OpCodes.Cos: {
209            return Math.Cos(Evaluate(dataset, ref row, state));
210          }
211        case OpCodes.Sin: {
212            return Math.Sin(Evaluate(dataset, ref row, state));
213          }
214        case OpCodes.Tan: {
215            return Math.Tan(Evaluate(dataset, ref row, state));
216          }
217        case OpCodes.Square: {
218            return Math.Pow(Evaluate(dataset, ref row, state), 2);
219          }
220        case OpCodes.Power: {
221            double x = Evaluate(dataset, ref row, state);
222            double y = Math.Round(Evaluate(dataset, ref row, state));
223            return Math.Pow(x, y);
224          }
225        case OpCodes.SquareRoot: {
226            return Math.Sqrt(Evaluate(dataset, ref row, state));
227          }
228        case OpCodes.Root: {
229            double x = Evaluate(dataset, ref row, state);
230            double y = Math.Round(Evaluate(dataset, ref row, state));
231            return Math.Pow(x, 1 / y);
232          }
233        case OpCodes.Exp: {
234            return Math.Exp(Evaluate(dataset, ref row, state));
235          }
236        case OpCodes.Log: {
237            return Math.Log(Evaluate(dataset, ref row, state));
238          }
239        case OpCodes.Gamma: {
240            var x = Evaluate(dataset, ref row, state);
241            if (double.IsNaN(x)) { return double.NaN; } else { return alglib.gammafunction(x); }
242          }
243        case OpCodes.Psi: {
244            var x = Evaluate(dataset, ref row, state);
245            if (double.IsNaN(x)) return double.NaN;
246            else if (x <= 0 && (Math.Floor(x) - x).IsAlmost(0)) return double.NaN;
247            return alglib.psi(x);
248          }
249        case OpCodes.Dawson: {
250            var x = Evaluate(dataset, ref row, state);
251            if (double.IsNaN(x)) { return double.NaN; }
252            return alglib.dawsonintegral(x);
253          }
254        case OpCodes.ExponentialIntegralEi: {
255            var x = Evaluate(dataset, ref row, state);
256            if (double.IsNaN(x)) { return double.NaN; }
257            return alglib.exponentialintegralei(x);
258          }
259        case OpCodes.SineIntegral: {
260            double si, ci;
261            var x = Evaluate(dataset, ref row, state);
262            if (double.IsNaN(x)) return double.NaN;
263            else {
264              alglib.sinecosineintegrals(x, out si, out ci);
265              return si;
266            }
267          }
268        case OpCodes.CosineIntegral: {
269            double si, ci;
270            var x = Evaluate(dataset, ref row, state);
271            if (double.IsNaN(x)) return double.NaN;
272            else {
273              alglib.sinecosineintegrals(x, out si, out ci);
274              return ci;
275            }
276          }
277        case OpCodes.HyperbolicSineIntegral: {
278            double shi, chi;
279            var x = Evaluate(dataset, ref row, state);
280            if (double.IsNaN(x)) return double.NaN;
281            else {
282              alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
283              return shi;
284            }
285          }
286        case OpCodes.HyperbolicCosineIntegral: {
287            double shi, chi;
288            var x = Evaluate(dataset, ref row, state);
289            if (double.IsNaN(x)) return double.NaN;
290            else {
291              alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
292              return chi;
293            }
294          }
295        case OpCodes.FresnelCosineIntegral: {
296            double c = 0, s = 0;
297            var x = Evaluate(dataset, ref row, state);
298            if (double.IsNaN(x)) return double.NaN;
299            else {
300              alglib.fresnelintegral(x, ref c, ref s);
301              return c;
302            }
303          }
304        case OpCodes.FresnelSineIntegral: {
305            double c = 0, s = 0;
306            var x = Evaluate(dataset, ref row, state);
307            if (double.IsNaN(x)) return double.NaN;
308            else {
309              alglib.fresnelintegral(x, ref c, ref s);
310              return s;
311            }
312          }
313        case OpCodes.AiryA: {
314            double ai, aip, bi, bip;
315            var x = Evaluate(dataset, ref row, state);
316            if (double.IsNaN(x)) return double.NaN;
317            else {
318              alglib.airy(x, out ai, out aip, out bi, out bip);
319              return ai;
320            }
321          }
322        case OpCodes.AiryB: {
323            double ai, aip, bi, bip;
324            var x = Evaluate(dataset, ref row, state);
325            if (double.IsNaN(x)) return double.NaN;
326            else {
327              alglib.airy(x, out ai, out aip, out bi, out bip);
328              return bi;
329            }
330          }
331        case OpCodes.Norm: {
332            var x = Evaluate(dataset, ref row, state);
333            if (double.IsNaN(x)) return double.NaN;
334            else return alglib.normaldistribution(x);
335          }
336        case OpCodes.Erf: {
337            var x = Evaluate(dataset, ref row, state);
338            if (double.IsNaN(x)) return double.NaN;
339            else return alglib.errorfunction(x);
340          }
341        case OpCodes.Bessel: {
342            var x = Evaluate(dataset, ref row, state);
343            if (double.IsNaN(x)) return double.NaN;
344            else return alglib.besseli0(x);
345          }
346        case OpCodes.IfThenElse: {
347            double condition = Evaluate(dataset, ref row, state);
348            double result;
349            if (condition > 0.0) {
350              result = Evaluate(dataset, ref row, state); state.SkipInstructions();
351            } else {
352              state.SkipInstructions(); result = Evaluate(dataset, ref row, state);
353            }
354            return result;
355          }
356        case OpCodes.AND: {
357            double result = Evaluate(dataset, ref row, state);
358            for (int i = 1; i < currentInstr.nArguments; i++) {
359              if (result > 0.0) result = Evaluate(dataset, ref row, state);
360              else {
361                state.SkipInstructions();
362              }
363            }
364            return result > 0.0 ? 1.0 : -1.0;
365          }
366        case OpCodes.OR: {
367            double result = Evaluate(dataset, ref row, state);
368            for (int i = 1; i < currentInstr.nArguments; i++) {
369              if (result <= 0.0) result = Evaluate(dataset, ref row, state);
370              else {
371                state.SkipInstructions();
372              }
373            }
374            return result > 0.0 ? 1.0 : -1.0;
375          }
376        case OpCodes.NOT: {
377            return Evaluate(dataset, ref row, state) > 0.0 ? -1.0 : 1.0;
378          }
379        case OpCodes.XOR: {
380            //mkommend: XOR on multiple inputs is defined as true if the number of positive signals is odd
381            // this is equal to a consecutive execution of binary XOR operations.
382            int positiveSignals = 0;
383            for (int i = 0; i < currentInstr.nArguments; i++) {
384              if (Evaluate(dataset, ref row, state) > 0.0) { positiveSignals++; }
385            }
386            return positiveSignals % 2 != 0 ? 1.0 : -1.0;
387          }
388        case OpCodes.GT: {
389            double x = Evaluate(dataset, ref row, state);
390            double y = Evaluate(dataset, ref row, state);
391            if (x > y) { return 1.0; } else { return -1.0; }
392          }
393        case OpCodes.LT: {
394            double x = Evaluate(dataset, ref row, state);
395            double y = Evaluate(dataset, ref row, state);
396            if (x < y) { return 1.0; } else { return -1.0; }
397          }
398        case OpCodes.TimeLag: {
399            var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
400            row += timeLagTreeNode.Lag;
401            double result = Evaluate(dataset, ref row, state);
402            row -= timeLagTreeNode.Lag;
403            return result;
404          }
405        case OpCodes.Integral: {
406            int savedPc = state.ProgramCounter;
407            var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
408            double sum = 0.0;
409            for (int i = 0; i < Math.Abs(timeLagTreeNode.Lag); i++) {
410              row += Math.Sign(timeLagTreeNode.Lag);
411              sum += Evaluate(dataset, ref row, state);
412              state.ProgramCounter = savedPc;
413            }
414            row -= timeLagTreeNode.Lag;
415            sum += Evaluate(dataset, ref row, state);
416            return sum;
417          }
418
419        //mkommend: derivate calculation taken from:
420        //http://www.holoborodko.com/pavel/numerical-methods/numerical-derivative/smooth-low-noise-differentiators/
421        //one sided smooth differentiatior, N = 4
422        // y' = 1/8h (f_i + 2f_i-1, -2 f_i-3 - f_i-4)
423        case OpCodes.Derivative: {
424            int savedPc = state.ProgramCounter;
425            double f_0 = Evaluate(dataset, ref row, state); row--;
426            state.ProgramCounter = savedPc;
427            double f_1 = Evaluate(dataset, ref row, state); row -= 2;
428            state.ProgramCounter = savedPc;
429            double f_3 = Evaluate(dataset, ref row, state); row--;
430            state.ProgramCounter = savedPc;
431            double f_4 = Evaluate(dataset, ref row, state);
432            row += 4;
433
434            return (f_0 + 2 * f_1 - 2 * f_3 - f_4) / 8; // h = 1
435          }
436        case OpCodes.Call: {
437            // evaluate sub-trees
438            double[] argValues = new double[currentInstr.nArguments];
439            for (int i = 0; i < currentInstr.nArguments; i++) {
440              argValues[i] = Evaluate(dataset, ref row, state);
441            }
442            // push on argument values on stack
443            state.CreateStackFrame(argValues);
444
445            // save the pc
446            int savedPc = state.ProgramCounter;
447            // set pc to start of function 
448            state.ProgramCounter = (ushort)currentInstr.data;
449            // evaluate the function
450            double v = Evaluate(dataset, ref row, state);
451
452            // delete the stack frame
453            state.RemoveStackFrame();
454
455            // restore the pc => evaluation will continue at point after my subtrees 
456            state.ProgramCounter = savedPc;
457            return v;
458          }
459        case OpCodes.Arg: {
460            return state.GetStackFrameValue((ushort)currentInstr.data);
461          }
462        case OpCodes.Variable: {
463            if (row < 0 || row >= dataset.Rows) return double.NaN;
464            var variableTreeNode = (VariableTreeNode)currentInstr.dynamicNode;
465            return ((IList<double>)currentInstr.data)[row] * variableTreeNode.Weight;
466          }
467        case OpCodes.BinaryFactorVariable: {
468            if (row < 0 || row >= dataset.Rows) return double.NaN;
469            var factorVarTreeNode = currentInstr.dynamicNode as BinaryFactorVariableTreeNode;
470            return ((IList<string>)currentInstr.data)[row] == factorVarTreeNode.VariableValue ? factorVarTreeNode.Weight : 0;
471          }
472        case OpCodes.FactorVariable: {
473            if (row < 0 || row >= dataset.Rows) return double.NaN;
474            var factorVarTreeNode = currentInstr.dynamicNode as FactorVariableTreeNode;
475            return factorVarTreeNode.GetValue(((IList<string>)currentInstr.data)[row]);
476          }
477        case OpCodes.LagVariable: {
478            var laggedVariableTreeNode = (LaggedVariableTreeNode)currentInstr.dynamicNode;
479            int actualRow = row + laggedVariableTreeNode.Lag;
480            if (actualRow < 0 || actualRow >= dataset.Rows) { return double.NaN; }
481            return ((IList<double>)currentInstr.data)[actualRow] * laggedVariableTreeNode.Weight;
482          }
483        case OpCodes.Constant: {
484            var constTreeNode = (ConstantTreeNode)currentInstr.dynamicNode;
485            return constTreeNode.Value;
486          }
487
488        //mkommend: this symbol uses the logistic function f(x) = 1 / (1 + e^(-alpha * x) )
489        //to determine the relative amounts of the true and false branch see http://en.wikipedia.org/wiki/Logistic_function
490        case OpCodes.VariableCondition: {
491            if (row < 0 || row >= dataset.Rows) return double.NaN;
492            var variableConditionTreeNode = (VariableConditionTreeNode)currentInstr.dynamicNode;
493            if (!variableConditionTreeNode.Symbol.IgnoreSlope) {
494              double variableValue = ((IList<double>)currentInstr.data)[row];
495              double x = variableValue - variableConditionTreeNode.Threshold;
496              double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
497
498              double trueBranch = Evaluate(dataset, ref row, state);
499              double falseBranch = Evaluate(dataset, ref row, state);
500
501              return trueBranch * p + falseBranch * (1 - p);
502            } else {
503              // strict threshold
504              double variableValue = ((IList<double>)currentInstr.data)[row];
505              if (variableValue <= variableConditionTreeNode.Threshold) {
506                var left = Evaluate(dataset, ref row, state);
507                state.SkipInstructions();
508                return left;
509              } else {
510                state.SkipInstructions();
511                return Evaluate(dataset, ref row, state);
512              }
513            }
514          }
515        default:
516          throw new NotSupportedException();
517      }
518    }
519  }
520}
Note: See TracBrowser for help on using the repository browser.