source: branches/HeuristicLab.DataAnalysis.Symbolic.LinearInterpreter/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeLinearInterpreter.cs @ 9732

Last change on this file since 9732 was 9732, checked in by bburlacu, 6 years ago

#2021: Merged trunk changes for HeuristicLab.Encodings.SymbolicExpressionTreeEncoding and HeuristicLab.Problems.DataAnalysis.Symbolic. Replaced prefix iteration of nodes in the linear interpretation with breadth iteration for simplified logic and extra performance. Reversed unnecessary changes to other projects.

File size: 19.0 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2012 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Data;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Parameters;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30
31namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
32  [StorableClass]
33  [Item("SymbolicDataAnalysisExpressionTreeFastInterpreter",
34    "Fast interpreter for symbolic expression trees including automatically defined functions.")]
35  public class SymbolicDataAnalysisExpressionTreeFastInterpreter : ParameterizedNamedItem,
36                                                                   ISymbolicDataAnalysisExpressionTreeInterpreter {
37    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
38    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
39
40    public override bool CanChangeName {
41      get { return false; }
42    }
43
44    public override bool CanChangeDescription {
45      get { return false; }
46    }
47
48    #region parameter properties
49
50    public IValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
51      get { return (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
52    }
53
54    public IValueParameter<IntValue> EvaluatedSolutionsParameter {
55      get { return (IValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
56    }
57
58    #endregion
59
60    #region properties
61
62    public BoolValue CheckExpressionsWithIntervalArithmetic {
63      get { return CheckExpressionsWithIntervalArithmeticParameter.Value; }
64      set { CheckExpressionsWithIntervalArithmeticParameter.Value = value; }
65    }
66
67    public IntValue EvaluatedSolutions {
68      get { return EvaluatedSolutionsParameter.Value; }
69      set { EvaluatedSolutionsParameter.Value = value; }
70    }
71
72    #endregion
73
74    [StorableConstructor]
75    protected SymbolicDataAnalysisExpressionTreeFastInterpreter(bool deserializing)
76      : base(deserializing) {
77    }
78
79    protected SymbolicDataAnalysisExpressionTreeFastInterpreter(
80      SymbolicDataAnalysisExpressionTreeFastInterpreter original, Cloner cloner)
81      : base(original, cloner) {
82    }
83
84    public override IDeepCloneable Clone(Cloner cloner) {
85      return new SymbolicDataAnalysisExpressionTreeFastInterpreter(this, cloner);
86    }
87
88    public SymbolicDataAnalysisExpressionTreeFastInterpreter()
89      : base(
90        "SymbolicDataAnalysisExpressionTreeFastInterpreter",
91        "Interpreter for symbolic expression trees including automatically defined functions.") {
92      Parameters.Add(new ValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName,
93                                                   "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.",
94                                                   new BoolValue(false)));
95      Parameters.Add(new ValueParameter<IntValue>(EvaluatedSolutionsParameterName,
96                                                  "A counter for the total number of solutions the interpreter has evaluated",
97                                                  new IntValue(0)));
98    }
99
100    protected SymbolicDataAnalysisExpressionTreeFastInterpreter(string name, string description)
101      : base(name, description) {
102      Parameters.Add(new ValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName,
103                                                   "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.",
104                                                   new BoolValue(false)));
105      Parameters.Add(new ValueParameter<IntValue>(EvaluatedSolutionsParameterName,
106                                                  "A counter for the total number of solutions the interpreter has evaluated",
107                                                  new IntValue(0)));
108    }
109
110    [StorableHook(HookType.AfterDeserialization)]
111    private void AfterDeserialization() {
112      if (!Parameters.ContainsKey(EvaluatedSolutionsParameterName))
113        Parameters.Add(new ValueParameter<IntValue>(EvaluatedSolutionsParameterName,
114                                                    "A counter for the total number of solutions the interpreter has evaluated",
115                                                    new IntValue(0)));
116    }
117
118    #region IStatefulItem
119
120    public void InitializeState() {
121      EvaluatedSolutions.Value = 0;
122    }
123
124    public void ClearState() {
125    }
126
127    #endregion
128
129    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset,
130                                                               IEnumerable<int> rows) {
131      if (CheckExpressionsWithIntervalArithmetic.Value)
132        throw new NotSupportedException(
133          "Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
134
135      lock (EvaluatedSolutions) {
136        EvaluatedSolutions.Value++; // increment the evaluated solutions counter
137      }
138
139      var root = tree.Root.GetSubtree(0).GetSubtree(0);
140      var nodes = new List<ISymbolicExpressionTreeNode> { root };
141      var code = new List<Instruction>{ 
142        new Instruction { dynamicNode = root,
143        nArguments = (byte) root.SubtreeCount, 
144        opCode = OpCodes.MapSymbolToOpCode(root)
145        }
146      };
147
148      // iterate breadth-wise over tree nodes and produce an array of instructions
149      int i = 0;
150      while (i != nodes.Count) {
151        if (nodes[i].SubtreeCount > 0) {
152          // save index of the first child in the instructions array
153          code[i].childIndex = code.Count;
154          for (int j = 0; j != nodes[i].SubtreeCount; ++j) {
155            var s = nodes[i].GetSubtree(j);
156            nodes.Add(s);
157            code.Add(new Instruction {
158              dynamicNode = s,
159              nArguments = (byte)s.SubtreeCount,
160              opCode = OpCodes.MapSymbolToOpCode(s)
161            });
162          }
163        }
164        ++i;
165      }
166      // fill in iArg0 value for terminal nodes
167      foreach (var instr in code) {
168        switch (instr.opCode) {
169          case OpCodes.Variable: {
170              var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
171              instr.iArg0 = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
172            }
173            break;
174          case OpCodes.LagVariable: {
175              var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
176              instr.iArg0 = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
177            }
178            break;
179          case OpCodes.VariableCondition: {
180              var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
181              instr.iArg0 = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
182            }
183            break;
184        }
185      }
186
187      var array = code.ToArray();
188
189      foreach (var rowEnum in rows) {
190        int row = rowEnum;
191        EvaluateFast(dataset, ref row, array);
192        yield return code[0].value;
193      }
194    }
195
196    private void EvaluateFast(Dataset dataset, ref int row, Instruction[] code) {
197      for (int i = code.Length - 1; i >= 0; --i) {
198        var instr = code[i];
199
200        switch (instr.opCode) {
201          case OpCodes.Add: {
202              double s = code[instr.childIndex].value;
203              for (int j = 1; j != instr.nArguments; ++j) {
204                s += code[instr.childIndex + j].value;
205              }
206              instr.value = s;
207            }
208            break;
209          case OpCodes.Sub: {
210              double s = code[instr.childIndex].value;
211              for (int j = 1; j != instr.nArguments; ++j) {
212                s -= code[instr.childIndex + j].value;
213              }
214              if (instr.nArguments == 1) s = -s;
215              instr.value = s;
216            }
217            break;
218          case OpCodes.Mul: {
219              double p = code[instr.childIndex].value;
220              for (int j = 1; j != instr.nArguments; ++j) {
221                p *= code[instr.childIndex + j].value;
222              }
223              instr.value = p;
224            }
225            break;
226          case OpCodes.Div: {
227              double p = code[instr.childIndex].value;
228              for (int j = 1; j != instr.nArguments; ++j) {
229                p /= code[instr.childIndex + j].value;
230              }
231              if (instr.nArguments == 1) p = 1.0 / p;
232              instr.value = p;
233            }
234            break;
235          case OpCodes.Average: {
236              double s = code[instr.childIndex].value;
237              for (int j = 1; j != instr.nArguments; ++j) {
238                s += code[instr.childIndex + j].value;
239              }
240              instr.value = s / instr.nArguments;
241            }
242            break;
243          case OpCodes.Cos: {
244              instr.value = Math.Cos(code[instr.childIndex].value);
245            }
246            break;
247          case OpCodes.Sin: {
248              instr.value = Math.Sin(code[instr.childIndex].value);
249            }
250            break;
251          case OpCodes.Tan: {
252              instr.value = Math.Tan(code[instr.childIndex].value);
253            }
254            break;
255          case OpCodes.Root: {
256              double x = code[instr.childIndex].value;
257              double y = code[instr.childIndex + 1].value;
258              instr.value = Math.Pow(x, 1 / y);
259            }
260            break;
261          case OpCodes.Exp: {
262              instr.value = Math.Exp(code[instr.childIndex].value);
263            }
264            break;
265          case OpCodes.Log: {
266              instr.value = Math.Log(code[instr.childIndex].value);
267            }
268            break;
269          case OpCodes.Gamma: {
270              var x = code[instr.childIndex].value;
271              instr.value = double.IsNaN(x) ? double.NaN : alglib.gammafunction(x);
272            }
273            break;
274          case OpCodes.Psi: {
275              var x = code[instr.childIndex].value;
276              if (double.IsNaN(x)) instr.value = double.NaN;
277              else if (x <= 0 && (Math.Floor(x) - x).IsAlmost(0)) instr.value = double.NaN;
278              else instr.value = alglib.psi(x);
279            }
280            break;
281          case OpCodes.Dawson: {
282              var x = code[instr.childIndex].value;
283              instr.value = double.IsNaN(x) ? double.NaN : alglib.dawsonintegral(x);
284            }
285            break;
286          case OpCodes.ExponentialIntegralEi: {
287              var x = code[instr.childIndex].value;
288              instr.value = double.IsNaN(x) ? double.NaN : alglib.exponentialintegralei(x);
289            }
290            break;
291          case OpCodes.SineIntegral: {
292              double si, ci;
293              var x = code[instr.childIndex].value;
294              if (double.IsNaN(x)) instr.value = double.NaN;
295              else {
296                alglib.sinecosineintegrals(x, out si, out ci);
297                instr.value = si;
298              }
299            }
300            break;
301          case OpCodes.CosineIntegral: {
302              double si, ci;
303              var x = code[instr.childIndex].value;
304              if (double.IsNaN(x)) instr.value = double.NaN;
305              else {
306                alglib.sinecosineintegrals(x, out si, out ci);
307                instr.value = si;
308              }
309            }
310            break;
311          case OpCodes.HyperbolicSineIntegral: {
312              double shi, chi;
313              var x = code[instr.childIndex].value;
314              if (double.IsNaN(x)) instr.value = double.NaN;
315              else {
316                alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
317                instr.value = shi;
318              }
319            }
320            break;
321          case OpCodes.HyperbolicCosineIntegral: {
322              double shi, chi;
323              var x = code[instr.childIndex].value;
324              if (double.IsNaN(x)) instr.value = double.NaN;
325              else {
326                alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
327                instr.value = chi;
328              }
329            }
330            break;
331          case OpCodes.FresnelCosineIntegral: {
332              double c = 0, s = 0;
333              var x = code[instr.childIndex].value;
334              if (double.IsNaN(x)) instr.value = double.NaN;
335              else {
336                alglib.fresnelintegral(x, ref c, ref s);
337                instr.value = c;
338              }
339            }
340            break;
341          case OpCodes.FresnelSineIntegral: {
342              double c = 0, s = 0;
343              var x = code[instr.childIndex].value;
344              if (double.IsNaN(x)) instr.value = double.NaN;
345              else {
346                alglib.fresnelintegral(x, ref c, ref s);
347                instr.value = s;
348              }
349            }
350            break;
351          case OpCodes.AiryA: {
352              double ai, aip, bi, bip;
353              var x = code[instr.childIndex].value;
354              if (double.IsNaN(x)) instr.value = double.NaN;
355              else {
356                alglib.airy(x, out ai, out aip, out bi, out bip);
357                instr.value = ai;
358              }
359            }
360            break;
361          case OpCodes.AiryB: {
362              double ai, aip, bi, bip;
363              var x = code[instr.childIndex].value;
364              if (double.IsNaN(x)) instr.value = double.NaN;
365              else {
366                alglib.airy(x, out ai, out aip, out bi, out bip);
367                instr.value = bi;
368              }
369            }
370            break;
371          case OpCodes.Norm: {
372              var x = code[instr.childIndex].value;
373              if (double.IsNaN(x)) instr.value = double.NaN;
374              else instr.value = alglib.normaldistribution(x);
375            }
376            break;
377          case OpCodes.Erf: {
378              var x = code[instr.childIndex].value;
379              if (double.IsNaN(x)) instr.value = double.NaN;
380              else instr.value = alglib.errorfunction(x);
381            }
382            break;
383          case OpCodes.Bessel: {
384              var x = code[instr.childIndex].value;
385              if (double.IsNaN(x)) instr.value = double.NaN;
386              else instr.value = alglib.besseli0(x);
387            }
388            break;
389          case OpCodes.IfThenElse: {
390              double condition = code[instr.childIndex].value;
391              double result;
392              if (condition > 0.0) {
393                result = code[instr.childIndex + 1].value;
394              } else {
395                result = code[instr.childIndex + 2].value;
396              }
397              instr.value = result;
398            }
399            break;
400          case OpCodes.AND: {
401              double result = code[instr.childIndex].value;
402              for (int j = 1; j < instr.nArguments; j++) {
403                if (result > 0.0) result = code[instr.childIndex + j].value;
404                else break;
405              }
406              instr.value = result > 0.0 ? 1.0 : -1.0;
407            }
408            break;
409          case OpCodes.OR: {
410              double result = code[instr.childIndex].value;
411              for (int j = 1; j < instr.nArguments; j++) {
412                if (result <= 0.0) result = code[instr.childIndex + j].value;
413                else break;
414              }
415              instr.value = result > 0.0 ? 1.0 : -1.0;
416            }
417            break;
418          case OpCodes.NOT: {
419              instr.value = code[instr.childIndex].value > 0.0 ? -1.0 : 1.0;
420            }
421            break;
422          case OpCodes.GT: {
423              double x = code[instr.childIndex].value;
424              double y = code[instr.childIndex + 1].value;
425              instr.value = x > y ? 1.0 : -1.0;
426            }
427            break;
428          case OpCodes.LT: {
429              double x = code[instr.childIndex].value;
430              double y = code[instr.childIndex + 1].value;
431              instr.value = x < y ? 1.0 : -1.0;
432            }
433            break;
434          case OpCodes.TimeLag: {
435              throw new NotSupportedException();
436            }
437          case OpCodes.Integral: {
438              throw new NotSupportedException();
439            }
440          case OpCodes.Derivative: {
441              throw new NotSupportedException();
442            }
443          case OpCodes.Arg: {
444              throw new NotSupportedException();
445            }
446          case OpCodes.Variable: {
447              if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
448              var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
449              instr.value = ((IList<double>)instr.iArg0)[row] * variableTreeNode.Weight;
450            }
451            break;
452          case OpCodes.LagVariable: {
453              var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
454              int actualRow = row + laggedVariableTreeNode.Lag;
455              if (actualRow < 0 || actualRow >= dataset.Rows) instr.value = double.NaN;
456              instr.value = ((IList<double>)instr.iArg0)[actualRow] * laggedVariableTreeNode.Weight;
457            }
458            break;
459          case OpCodes.Constant: {
460              var constTreeNode = (ConstantTreeNode)instr.dynamicNode;
461              instr.value = constTreeNode.Value;
462            }
463            break;
464          case OpCodes.VariableCondition: {
465              if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
466              var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
467              double variableValue = ((IList<double>)instr.iArg0)[row];
468              double x = variableValue - variableConditionTreeNode.Threshold;
469              double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
470
471              double trueBranch = code[instr.childIndex].value;
472              double falseBranch = code[instr.childIndex + 1].value;
473
474              instr.value = trueBranch * p + falseBranch * (1 - p);
475            }
476            break;
477          default:
478            throw new NotSupportedException();
479        }
480      }
481    }
482  }
483}
Note: See TracBrowser for help on using the repository browser.