source: branches/HeuristicLab.LinqExpressionTreeInterpreter/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeInterpreter.cs @ 13141

Last change on this file since 13141 was 13141, checked in by bburlacu, 6 years ago

#2442: Merged files from trunk and updated project file. Implemented missing operations in the CompiledTreeInterpreter: Integral, Derivative, Lag, TimeLag. Adapted lambda signature to accept an array of List<double> in order to make it easier to work with compiled trees. Changed value parameters to fixed value parameters and adjusted interpreter constructors and after serialization hooks. Removed function symbol.

From the performance point of view, compiling the tree into a lambda accepting a double[][] parameter (an array of arrays for the values of each double variable), accessed with Expression.ArrayIndex is the fastest, but it can be cumbersome to provide the data as a double[][]. Therefore the variant with List<double>[] was chosen. Internally, for each variable node the List's underlying double array is used, result in an overall decent speed compromise.

File size: 20.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Data;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Parameters;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30
31namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
32  [StorableClass]
33  [Item("SymbolicDataAnalysisExpressionTreeInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.")]
34  public class SymbolicDataAnalysisExpressionTreeInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
35    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
36    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
37
38    public override bool CanChangeName { get { return false; } }
39    public override bool CanChangeDescription { get { return false; } }
40
41    #region parameter properties
42    public IFixedValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
43      get { return (IFixedValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
44    }
45
46    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
47      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
48    }
49    #endregion
50
51    #region properties
52    public bool CheckExpressionsWithIntervalArithmetic {
53      get { return CheckExpressionsWithIntervalArithmeticParameter.Value.Value; }
54      set { CheckExpressionsWithIntervalArithmeticParameter.Value.Value = value; }
55    }
56    public int EvaluatedSolutions {
57      get { return EvaluatedSolutionsParameter.Value.Value; }
58      set { EvaluatedSolutionsParameter.Value.Value = value; }
59    }
60    #endregion
61
62    [StorableConstructor]
63    protected SymbolicDataAnalysisExpressionTreeInterpreter(bool deserializing) : base(deserializing) { }
64    protected SymbolicDataAnalysisExpressionTreeInterpreter(SymbolicDataAnalysisExpressionTreeInterpreter original, Cloner cloner) : base(original, cloner) { }
65    public override IDeepCloneable Clone(Cloner cloner) {
66      return new SymbolicDataAnalysisExpressionTreeInterpreter(this, cloner);
67    }
68
69    public SymbolicDataAnalysisExpressionTreeInterpreter()
70      : base("SymbolicDataAnalysisExpressionTreeInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.") {
71      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName,
72        "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
73      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
74    }
75
76    protected SymbolicDataAnalysisExpressionTreeInterpreter(string name, string description)
77      : base(name, description) {
78      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName,
79        "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
80      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
81    }
82
83    [StorableHook(HookType.AfterDeserialization)]
84    private void AfterDeserialization() {
85      Parameters.Remove(EvaluatedSolutionsParameterName);
86      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
87
88      Parameters.Remove(CheckExpressionsWithIntervalArithmeticParameterName);
89      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName,
90        "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
91    }
92
93    #region IStatefulItem
94    public void InitializeState() {
95      EvaluatedSolutions = 0;
96    }
97
98    public void ClearState() {
99    }
100    #endregion
101
102    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
103      if (CheckExpressionsWithIntervalArithmetic)
104        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
105
106      lock (EvaluatedSolutionsParameter.Value) {
107        EvaluatedSolutions++; // increment the evaluated solutions counter
108      }
109      var state = PrepareInterpreterState(tree, dataset);
110
111      foreach (var rowEnum in rows) {
112        int row = rowEnum;
113        yield return Evaluate(dataset, ref row, state);
114        state.Reset();
115      }
116    }
117
118    private static InterpreterState PrepareInterpreterState(ISymbolicExpressionTree tree, IDataset dataset) {
119      Instruction[] code = SymbolicExpressionTreeCompiler.Compile(tree, OpCodes.MapSymbolToOpCode);
120      int necessaryArgStackSize = 0;
121      foreach (Instruction instr in code) {
122        if (instr.opCode == OpCodes.Variable) {
123          var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
124          instr.data = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
125        } else if (instr.opCode == OpCodes.LagVariable) {
126          var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
127          instr.data = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
128        } else if (instr.opCode == OpCodes.VariableCondition) {
129          var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
130          instr.data = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
131        } else if (instr.opCode == OpCodes.Call) {
132          necessaryArgStackSize += instr.nArguments + 1;
133        }
134      }
135      return new InterpreterState(code, necessaryArgStackSize);
136    }
137
138
139    public virtual double Evaluate(IDataset dataset, ref int row, InterpreterState state) {
140      Instruction currentInstr = state.NextInstruction();
141      switch (currentInstr.opCode) {
142        case OpCodes.Add:
143          {
144            double s = Evaluate(dataset, ref row, state);
145            for (int i = 1; i < currentInstr.nArguments; i++) {
146              s += Evaluate(dataset, ref row, state);
147            }
148            return s;
149          }
150        case OpCodes.Sub:
151          {
152            double s = Evaluate(dataset, ref row, state);
153            for (int i = 1; i < currentInstr.nArguments; i++) {
154              s -= Evaluate(dataset, ref row, state);
155            }
156            if (currentInstr.nArguments == 1) s = -s;
157            return s;
158          }
159        case OpCodes.Mul:
160          {
161            double p = Evaluate(dataset, ref row, state);
162            for (int i = 1; i < currentInstr.nArguments; i++) {
163              p *= Evaluate(dataset, ref row, state);
164            }
165            return p;
166          }
167        case OpCodes.Div:
168          {
169            double p = Evaluate(dataset, ref row, state);
170            for (int i = 1; i < currentInstr.nArguments; i++) {
171              p /= Evaluate(dataset, ref row, state);
172            }
173            if (currentInstr.nArguments == 1) p = 1.0 / p;
174            return p;
175          }
176        case OpCodes.Average:
177          {
178            double sum = Evaluate(dataset, ref row, state);
179            for (int i = 1; i < currentInstr.nArguments; i++) {
180              sum += Evaluate(dataset, ref row, state);
181            }
182            return sum / currentInstr.nArguments;
183          }
184        case OpCodes.Cos:
185          {
186            return Math.Cos(Evaluate(dataset, ref row, state));
187          }
188        case OpCodes.Sin:
189          {
190            return Math.Sin(Evaluate(dataset, ref row, state));
191          }
192        case OpCodes.Tan:
193          {
194            return Math.Tan(Evaluate(dataset, ref row, state));
195          }
196        case OpCodes.Square:
197          {
198            return Math.Pow(Evaluate(dataset, ref row, state), 2);
199          }
200        case OpCodes.Power:
201          {
202            double x = Evaluate(dataset, ref row, state);
203            double y = Math.Round(Evaluate(dataset, ref row, state));
204            return Math.Pow(x, y);
205          }
206        case OpCodes.SquareRoot:
207          {
208            return Math.Sqrt(Evaluate(dataset, ref row, state));
209          }
210        case OpCodes.Root:
211          {
212            double x = Evaluate(dataset, ref row, state);
213            double y = Math.Round(Evaluate(dataset, ref row, state));
214            return Math.Pow(x, 1 / y);
215          }
216        case OpCodes.Exp:
217          {
218            return Math.Exp(Evaluate(dataset, ref row, state));
219          }
220        case OpCodes.Log:
221          {
222            return Math.Log(Evaluate(dataset, ref row, state));
223          }
224        case OpCodes.Gamma:
225          {
226            var x = Evaluate(dataset, ref row, state);
227            if (double.IsNaN(x)) return double.NaN;
228            else return alglib.gammafunction(x);
229          }
230        case OpCodes.Psi:
231          {
232            var x = Evaluate(dataset, ref row, state);
233            if (double.IsNaN(x)) return double.NaN;
234            else if (x <= 0 && (Math.Floor(x) - x).IsAlmost(0)) return double.NaN;
235            return alglib.psi(x);
236          }
237        case OpCodes.Dawson:
238          {
239            var x = Evaluate(dataset, ref row, state);
240            if (double.IsNaN(x)) return double.NaN;
241            return alglib.dawsonintegral(x);
242          }
243        case OpCodes.ExponentialIntegralEi:
244          {
245            var x = Evaluate(dataset, ref row, state);
246            if (double.IsNaN(x)) return double.NaN;
247            return alglib.exponentialintegralei(x);
248          }
249        case OpCodes.SineIntegral:
250          {
251            double si, ci;
252            var x = Evaluate(dataset, ref row, state);
253            if (double.IsNaN(x)) return double.NaN;
254            else {
255              alglib.sinecosineintegrals(x, out si, out ci);
256              return si;
257            }
258          }
259        case OpCodes.CosineIntegral:
260          {
261            double si, ci;
262            var x = Evaluate(dataset, ref row, state);
263            if (double.IsNaN(x)) return double.NaN;
264            else {
265              alglib.sinecosineintegrals(x, out si, out ci);
266              return ci;
267            }
268          }
269        case OpCodes.HyperbolicSineIntegral:
270          {
271            double shi, chi;
272            var x = Evaluate(dataset, ref row, state);
273            if (double.IsNaN(x)) return double.NaN;
274            else {
275              alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
276              return shi;
277            }
278          }
279        case OpCodes.HyperbolicCosineIntegral:
280          {
281            double shi, chi;
282            var x = Evaluate(dataset, ref row, state);
283            if (double.IsNaN(x)) return double.NaN;
284            else {
285              alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
286              return chi;
287            }
288          }
289        case OpCodes.FresnelCosineIntegral:
290          {
291            double c = 0, s = 0;
292            var x = Evaluate(dataset, ref row, state);
293            if (double.IsNaN(x)) return double.NaN;
294            else {
295              alglib.fresnelintegral(x, ref c, ref s);
296              return c;
297            }
298          }
299        case OpCodes.FresnelSineIntegral:
300          {
301            double c = 0, s = 0;
302            var x = Evaluate(dataset, ref row, state);
303            if (double.IsNaN(x)) return double.NaN;
304            else {
305              alglib.fresnelintegral(x, ref c, ref s);
306              return s;
307            }
308          }
309        case OpCodes.AiryA:
310          {
311            double ai, aip, bi, bip;
312            var x = Evaluate(dataset, ref row, state);
313            if (double.IsNaN(x)) return double.NaN;
314            else {
315              alglib.airy(x, out ai, out aip, out bi, out bip);
316              return ai;
317            }
318          }
319        case OpCodes.AiryB:
320          {
321            double ai, aip, bi, bip;
322            var x = Evaluate(dataset, ref row, state);
323            if (double.IsNaN(x)) return double.NaN;
324            else {
325              alglib.airy(x, out ai, out aip, out bi, out bip);
326              return bi;
327            }
328          }
329        case OpCodes.Norm:
330          {
331            var x = Evaluate(dataset, ref row, state);
332            if (double.IsNaN(x)) return double.NaN;
333            else return alglib.normaldistribution(x);
334          }
335        case OpCodes.Erf:
336          {
337            var x = Evaluate(dataset, ref row, state);
338            if (double.IsNaN(x)) return double.NaN;
339            else return alglib.errorfunction(x);
340          }
341        case OpCodes.Bessel:
342          {
343            var x = Evaluate(dataset, ref row, state);
344            if (double.IsNaN(x)) return double.NaN;
345            else return alglib.besseli0(x);
346          }
347        case OpCodes.IfThenElse:
348          {
349            double condition = Evaluate(dataset, ref row, state);
350            double result;
351            if (condition > 0.0) {
352              result = Evaluate(dataset, ref row, state); state.SkipInstructions();
353            } else {
354              state.SkipInstructions(); result = Evaluate(dataset, ref row, state);
355            }
356            return result;
357          }
358        case OpCodes.AND:
359          {
360            double result = Evaluate(dataset, ref row, state);
361            for (int i = 1; i < currentInstr.nArguments; i++) {
362              if (result > 0.0) result = Evaluate(dataset, ref row, state);
363              else {
364                state.SkipInstructions();
365              }
366            }
367            return result > 0.0 ? 1.0 : -1.0;
368          }
369        case OpCodes.OR:
370          {
371            double result = Evaluate(dataset, ref row, state);
372            for (int i = 1; i < currentInstr.nArguments; i++) {
373              if (result <= 0.0) result = Evaluate(dataset, ref row, state);
374              else {
375                state.SkipInstructions();
376              }
377            }
378            return result > 0.0 ? 1.0 : -1.0;
379          }
380        case OpCodes.NOT:
381          {
382            return Evaluate(dataset, ref row, state) > 0.0 ? -1.0 : 1.0;
383          }
384        case OpCodes.XOR:
385          {
386            //mkommend: XOR on multiple inputs is defined as true if the number of positive signals is odd
387            // this is equal to a consecutive execution of binary XOR operations.
388            int positiveSignals = 0;
389            for (int i = 0; i < currentInstr.nArguments; i++) {
390              if (Evaluate(dataset, ref row, state) > 0.0) positiveSignals++;
391            }
392            return positiveSignals % 2 != 0 ? 1.0 : -1.0;
393          }
394        case OpCodes.GT:
395          {
396            double x = Evaluate(dataset, ref row, state);
397            double y = Evaluate(dataset, ref row, state);
398            if (x > y) return 1.0;
399            else return -1.0;
400          }
401        case OpCodes.LT:
402          {
403            double x = Evaluate(dataset, ref row, state);
404            double y = Evaluate(dataset, ref row, state);
405            if (x < y) return 1.0;
406            else return -1.0;
407          }
408        case OpCodes.TimeLag:
409          {
410            var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
411            row += timeLagTreeNode.Lag;
412            double result = Evaluate(dataset, ref row, state);
413            row -= timeLagTreeNode.Lag;
414            return result;
415          }
416        case OpCodes.Integral:
417          {
418            int savedPc = state.ProgramCounter;
419            var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
420            double sum = 0.0;
421            for (int i = 0; i < Math.Abs(timeLagTreeNode.Lag); i++) {
422              row += Math.Sign(timeLagTreeNode.Lag);
423              sum += Evaluate(dataset, ref row, state);
424              state.ProgramCounter = savedPc;
425            }
426            row -= timeLagTreeNode.Lag;
427            sum += Evaluate(dataset, ref row, state);
428            return sum;
429          }
430
431        //mkommend: derivate calculation taken from:
432        //http://www.holoborodko.com/pavel/numerical-methods/numerical-derivative/smooth-low-noise-differentiators/
433        //one sided smooth differentiatior, N = 4
434        // y' = 1/8h (f_i + 2f_i-1, -2 f_i-3 - f_i-4)
435        case OpCodes.Derivative:
436          {
437            int savedPc = state.ProgramCounter;
438            double f_0 = Evaluate(dataset, ref row, state); row--;
439            state.ProgramCounter = savedPc;
440            double f_1 = Evaluate(dataset, ref row, state); row -= 2;
441            state.ProgramCounter = savedPc;
442            double f_3 = Evaluate(dataset, ref row, state); row--;
443            state.ProgramCounter = savedPc;
444            double f_4 = Evaluate(dataset, ref row, state);
445            row += 4;
446
447            return (f_0 + 2 * f_1 - 2 * f_3 - f_4) / 8; // h = 1
448          }
449        case OpCodes.Call:
450          {
451            // evaluate sub-trees
452            double[] argValues = new double[currentInstr.nArguments];
453            for (int i = 0; i < currentInstr.nArguments; i++) {
454              argValues[i] = Evaluate(dataset, ref row, state);
455            }
456            // push on argument values on stack
457            state.CreateStackFrame(argValues);
458
459            // save the pc
460            int savedPc = state.ProgramCounter;
461            // set pc to start of function 
462            state.ProgramCounter = (ushort)currentInstr.data;
463            // evaluate the function
464            double v = Evaluate(dataset, ref row, state);
465
466            // delete the stack frame
467            state.RemoveStackFrame();
468
469            // restore the pc => evaluation will continue at point after my subtrees 
470            state.ProgramCounter = savedPc;
471            return v;
472          }
473        case OpCodes.Arg:
474          {
475            return state.GetStackFrameValue((ushort)currentInstr.data);
476          }
477        case OpCodes.Variable:
478          {
479            if (row < 0 || row >= dataset.Rows) return double.NaN;
480            var variableTreeNode = (VariableTreeNode)currentInstr.dynamicNode;
481            return ((IList<double>)currentInstr.data)[row] * variableTreeNode.Weight;
482          }
483        case OpCodes.LagVariable:
484          {
485            var laggedVariableTreeNode = (LaggedVariableTreeNode)currentInstr.dynamicNode;
486            int actualRow = row + laggedVariableTreeNode.Lag;
487            if (actualRow < 0 || actualRow >= dataset.Rows) return double.NaN;
488            return ((IList<double>)currentInstr.data)[actualRow] * laggedVariableTreeNode.Weight;
489          }
490        case OpCodes.Constant:
491          {
492            var constTreeNode = (ConstantTreeNode)currentInstr.dynamicNode;
493            return constTreeNode.Value;
494          }
495
496        //mkommend: this symbol uses the logistic function f(x) = 1 / (1 + e^(-alpha * x) )
497        //to determine the relative amounts of the true and false branch see http://en.wikipedia.org/wiki/Logistic_function
498        case OpCodes.VariableCondition:
499          {
500            if (row < 0 || row >= dataset.Rows) return double.NaN;
501            var variableConditionTreeNode = (VariableConditionTreeNode)currentInstr.dynamicNode;
502            double variableValue = ((IList<double>)currentInstr.data)[row];
503            double x = variableValue - variableConditionTreeNode.Threshold;
504            double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
505
506            double trueBranch = Evaluate(dataset, ref row, state);
507            double falseBranch = Evaluate(dataset, ref row, state);
508
509            return trueBranch * p + falseBranch * (1 - p);
510          }
511        default: throw new NotSupportedException();
512      }
513    }
514  }
515}
Note: See TracBrowser for help on using the repository browser.