source: branches/HeuristicLab.DataAnalysis.Symbolic.LinearInterpreter/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeLinearInterpreter.cs @ 9734

Last change on this file since 9734 was 9734, checked in by bburlacu, 6 years ago

#2021: Updated license year, fixed interpreter name (SymbolicDataAnalysisExpressionTreeLinearInterpreter) and updated description. Replaced tabs with spaces in Instruction.cs.

File size: 18.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Data;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Parameters;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30
31namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
32  [StorableClass]
33  [Item("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Linear (non-recursive) interpreter for symbolic expression trees (does not support ADFs).")]
34  public class SymbolicDataAnalysisExpressionTreeLinearInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
35    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
36    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
37
38    public override bool CanChangeName {
39      get { return false; }
40    }
41
42    public override bool CanChangeDescription {
43      get { return false; }
44    }
45
46    #region parameter properties
47
48    public IValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
49      get { return (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
50    }
51
52    public IValueParameter<IntValue> EvaluatedSolutionsParameter {
53      get { return (IValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
54    }
55
56    #endregion
57
58    #region properties
59
60    public BoolValue CheckExpressionsWithIntervalArithmetic {
61      get { return CheckExpressionsWithIntervalArithmeticParameter.Value; }
62      set { CheckExpressionsWithIntervalArithmeticParameter.Value = value; }
63    }
64
65    public IntValue EvaluatedSolutions {
66      get { return EvaluatedSolutionsParameter.Value; }
67      set { EvaluatedSolutionsParameter.Value = value; }
68    }
69
70    #endregion
71
72    [StorableConstructor]
73    protected SymbolicDataAnalysisExpressionTreeLinearInterpreter(bool deserializing)
74      : base(deserializing) {
75    }
76
77    protected SymbolicDataAnalysisExpressionTreeLinearInterpreter(
78      SymbolicDataAnalysisExpressionTreeLinearInterpreter original, Cloner cloner)
79      : base(original, cloner) {
80    }
81
82    public override IDeepCloneable Clone(Cloner cloner) {
83      return new SymbolicDataAnalysisExpressionTreeLinearInterpreter(this, cloner);
84    }
85
86    public SymbolicDataAnalysisExpressionTreeLinearInterpreter()
87      : base("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.") {
88      Parameters.Add(new ValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
89      Parameters.Add(new ValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
90    }
91
92    protected SymbolicDataAnalysisExpressionTreeLinearInterpreter(string name, string description)
93      : base(name, description) {
94      Parameters.Add(new ValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
95      Parameters.Add(new ValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
96    }
97
98    [StorableHook(HookType.AfterDeserialization)]
99    private void AfterDeserialization() {
100      if (!Parameters.ContainsKey(EvaluatedSolutionsParameterName))
101        Parameters.Add(new ValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
102    }
103
104    #region IStatefulItem
105
106    public void InitializeState() {
107      EvaluatedSolutions.Value = 0;
108    }
109
110    public void ClearState() {
111    }
112
113    #endregion
114
115    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows) {
116      if (CheckExpressionsWithIntervalArithmetic.Value)
117        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
118
119      lock (EvaluatedSolutions) {
120        EvaluatedSolutions.Value++; // increment the evaluated solutions counter
121      }
122
123      var root = tree.Root.GetSubtree(0).GetSubtree(0);
124      var nodes = new List<ISymbolicExpressionTreeNode> { root };
125      var code = new List<Instruction>{ 
126        new Instruction { dynamicNode = root,
127        nArguments = (byte) root.SubtreeCount, 
128        opCode = OpCodes.MapSymbolToOpCode(root)
129        }
130      };
131
132      // iterate breadth-wise over tree nodes and produce an array of instructions
133      int i = 0;
134      while (i != nodes.Count) {
135        if (nodes[i].SubtreeCount > 0) {
136          // save index of the first child in the instructions array
137          code[i].childIndex = code.Count;
138          for (int j = 0; j != nodes[i].SubtreeCount; ++j) {
139            var s = nodes[i].GetSubtree(j);
140            nodes.Add(s);
141            code.Add(new Instruction {
142              dynamicNode = s,
143              nArguments = (byte)s.SubtreeCount,
144              opCode = OpCodes.MapSymbolToOpCode(s)
145            });
146          }
147        }
148        ++i;
149      }
150      // fill in iArg0 value for terminal nodes
151      foreach (var instr in code) {
152        switch (instr.opCode) {
153          case OpCodes.Variable: {
154              var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
155              instr.iArg0 = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
156            }
157            break;
158          case OpCodes.LagVariable: {
159              var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
160              instr.iArg0 = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
161            }
162            break;
163          case OpCodes.VariableCondition: {
164              var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
165              instr.iArg0 = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
166            }
167            break;
168        }
169      }
170
171      var array = code.ToArray();
172
173      foreach (var rowEnum in rows) {
174        int row = rowEnum;
175        EvaluateFast(dataset, ref row, array);
176        yield return code[0].value;
177      }
178    }
179
180    private void EvaluateFast(Dataset dataset, ref int row, Instruction[] code) {
181      for (int i = code.Length - 1; i >= 0; --i) {
182        var instr = code[i];
183
184        switch (instr.opCode) {
185          case OpCodes.Add: {
186              double s = code[instr.childIndex].value;
187              for (int j = 1; j != instr.nArguments; ++j) {
188                s += code[instr.childIndex + j].value;
189              }
190              instr.value = s;
191            }
192            break;
193          case OpCodes.Sub: {
194              double s = code[instr.childIndex].value;
195              for (int j = 1; j != instr.nArguments; ++j) {
196                s -= code[instr.childIndex + j].value;
197              }
198              if (instr.nArguments == 1) s = -s;
199              instr.value = s;
200            }
201            break;
202          case OpCodes.Mul: {
203              double p = code[instr.childIndex].value;
204              for (int j = 1; j != instr.nArguments; ++j) {
205                p *= code[instr.childIndex + j].value;
206              }
207              instr.value = p;
208            }
209            break;
210          case OpCodes.Div: {
211              double p = code[instr.childIndex].value;
212              for (int j = 1; j != instr.nArguments; ++j) {
213                p /= code[instr.childIndex + j].value;
214              }
215              if (instr.nArguments == 1) p = 1.0 / p;
216              instr.value = p;
217            }
218            break;
219          case OpCodes.Average: {
220              double s = code[instr.childIndex].value;
221              for (int j = 1; j != instr.nArguments; ++j) {
222                s += code[instr.childIndex + j].value;
223              }
224              instr.value = s / instr.nArguments;
225            }
226            break;
227          case OpCodes.Cos: {
228              instr.value = Math.Cos(code[instr.childIndex].value);
229            }
230            break;
231          case OpCodes.Sin: {
232              instr.value = Math.Sin(code[instr.childIndex].value);
233            }
234            break;
235          case OpCodes.Tan: {
236              instr.value = Math.Tan(code[instr.childIndex].value);
237            }
238            break;
239          case OpCodes.Root: {
240              double x = code[instr.childIndex].value;
241              double y = code[instr.childIndex + 1].value;
242              instr.value = Math.Pow(x, 1 / y);
243            }
244            break;
245          case OpCodes.Exp: {
246              instr.value = Math.Exp(code[instr.childIndex].value);
247            }
248            break;
249          case OpCodes.Log: {
250              instr.value = Math.Log(code[instr.childIndex].value);
251            }
252            break;
253          case OpCodes.Gamma: {
254              var x = code[instr.childIndex].value;
255              instr.value = double.IsNaN(x) ? double.NaN : alglib.gammafunction(x);
256            }
257            break;
258          case OpCodes.Psi: {
259              var x = code[instr.childIndex].value;
260              if (double.IsNaN(x)) instr.value = double.NaN;
261              else if (x <= 0 && (Math.Floor(x) - x).IsAlmost(0)) instr.value = double.NaN;
262              else instr.value = alglib.psi(x);
263            }
264            break;
265          case OpCodes.Dawson: {
266              var x = code[instr.childIndex].value;
267              instr.value = double.IsNaN(x) ? double.NaN : alglib.dawsonintegral(x);
268            }
269            break;
270          case OpCodes.ExponentialIntegralEi: {
271              var x = code[instr.childIndex].value;
272              instr.value = double.IsNaN(x) ? double.NaN : alglib.exponentialintegralei(x);
273            }
274            break;
275          case OpCodes.SineIntegral: {
276              double si, ci;
277              var x = code[instr.childIndex].value;
278              if (double.IsNaN(x)) instr.value = double.NaN;
279              else {
280                alglib.sinecosineintegrals(x, out si, out ci);
281                instr.value = si;
282              }
283            }
284            break;
285          case OpCodes.CosineIntegral: {
286              double si, ci;
287              var x = code[instr.childIndex].value;
288              if (double.IsNaN(x)) instr.value = double.NaN;
289              else {
290                alglib.sinecosineintegrals(x, out si, out ci);
291                instr.value = si;
292              }
293            }
294            break;
295          case OpCodes.HyperbolicSineIntegral: {
296              double shi, chi;
297              var x = code[instr.childIndex].value;
298              if (double.IsNaN(x)) instr.value = double.NaN;
299              else {
300                alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
301                instr.value = shi;
302              }
303            }
304            break;
305          case OpCodes.HyperbolicCosineIntegral: {
306              double shi, chi;
307              var x = code[instr.childIndex].value;
308              if (double.IsNaN(x)) instr.value = double.NaN;
309              else {
310                alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
311                instr.value = chi;
312              }
313            }
314            break;
315          case OpCodes.FresnelCosineIntegral: {
316              double c = 0, s = 0;
317              var x = code[instr.childIndex].value;
318              if (double.IsNaN(x)) instr.value = double.NaN;
319              else {
320                alglib.fresnelintegral(x, ref c, ref s);
321                instr.value = c;
322              }
323            }
324            break;
325          case OpCodes.FresnelSineIntegral: {
326              double c = 0, s = 0;
327              var x = code[instr.childIndex].value;
328              if (double.IsNaN(x)) instr.value = double.NaN;
329              else {
330                alglib.fresnelintegral(x, ref c, ref s);
331                instr.value = s;
332              }
333            }
334            break;
335          case OpCodes.AiryA: {
336              double ai, aip, bi, bip;
337              var x = code[instr.childIndex].value;
338              if (double.IsNaN(x)) instr.value = double.NaN;
339              else {
340                alglib.airy(x, out ai, out aip, out bi, out bip);
341                instr.value = ai;
342              }
343            }
344            break;
345          case OpCodes.AiryB: {
346              double ai, aip, bi, bip;
347              var x = code[instr.childIndex].value;
348              if (double.IsNaN(x)) instr.value = double.NaN;
349              else {
350                alglib.airy(x, out ai, out aip, out bi, out bip);
351                instr.value = bi;
352              }
353            }
354            break;
355          case OpCodes.Norm: {
356              var x = code[instr.childIndex].value;
357              if (double.IsNaN(x)) instr.value = double.NaN;
358              else instr.value = alglib.normaldistribution(x);
359            }
360            break;
361          case OpCodes.Erf: {
362              var x = code[instr.childIndex].value;
363              if (double.IsNaN(x)) instr.value = double.NaN;
364              else instr.value = alglib.errorfunction(x);
365            }
366            break;
367          case OpCodes.Bessel: {
368              var x = code[instr.childIndex].value;
369              if (double.IsNaN(x)) instr.value = double.NaN;
370              else instr.value = alglib.besseli0(x);
371            }
372            break;
373          case OpCodes.IfThenElse: {
374              double condition = code[instr.childIndex].value;
375              double result;
376              if (condition > 0.0) {
377                result = code[instr.childIndex + 1].value;
378              } else {
379                result = code[instr.childIndex + 2].value;
380              }
381              instr.value = result;
382            }
383            break;
384          case OpCodes.AND: {
385              double result = code[instr.childIndex].value;
386              for (int j = 1; j < instr.nArguments; j++) {
387                if (result > 0.0) result = code[instr.childIndex + j].value;
388                else break;
389              }
390              instr.value = result > 0.0 ? 1.0 : -1.0;
391            }
392            break;
393          case OpCodes.OR: {
394              double result = code[instr.childIndex].value;
395              for (int j = 1; j < instr.nArguments; j++) {
396                if (result <= 0.0) result = code[instr.childIndex + j].value;
397                else break;
398              }
399              instr.value = result > 0.0 ? 1.0 : -1.0;
400            }
401            break;
402          case OpCodes.NOT: {
403              instr.value = code[instr.childIndex].value > 0.0 ? -1.0 : 1.0;
404            }
405            break;
406          case OpCodes.GT: {
407              double x = code[instr.childIndex].value;
408              double y = code[instr.childIndex + 1].value;
409              instr.value = x > y ? 1.0 : -1.0;
410            }
411            break;
412          case OpCodes.LT: {
413              double x = code[instr.childIndex].value;
414              double y = code[instr.childIndex + 1].value;
415              instr.value = x < y ? 1.0 : -1.0;
416            }
417            break;
418          case OpCodes.TimeLag: {
419              throw new NotSupportedException();
420            }
421          case OpCodes.Integral: {
422              throw new NotSupportedException();
423            }
424          case OpCodes.Derivative: {
425              throw new NotSupportedException();
426            }
427          case OpCodes.Arg: {
428              throw new NotSupportedException();
429            }
430          case OpCodes.Variable: {
431              if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
432              var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
433              instr.value = ((IList<double>)instr.iArg0)[row] * variableTreeNode.Weight;
434            }
435            break;
436          case OpCodes.LagVariable: {
437              var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
438              int actualRow = row + laggedVariableTreeNode.Lag;
439              if (actualRow < 0 || actualRow >= dataset.Rows) instr.value = double.NaN;
440              instr.value = ((IList<double>)instr.iArg0)[actualRow] * laggedVariableTreeNode.Weight;
441            }
442            break;
443          case OpCodes.Constant: {
444              var constTreeNode = (ConstantTreeNode)instr.dynamicNode;
445              instr.value = constTreeNode.Value;
446            }
447            break;
448          case OpCodes.VariableCondition: {
449              if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
450              var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
451              double variableValue = ((IList<double>)instr.iArg0)[row];
452              double x = variableValue - variableConditionTreeNode.Threshold;
453              double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
454
455              double trueBranch = code[instr.childIndex].value;
456              double falseBranch = code[instr.childIndex + 1].value;
457
458              instr.value = trueBranch * p + falseBranch * (1 - p);
459            }
460            break;
461          default:
462            throw new NotSupportedException();
463        }
464      }
465    }
466  }
467}
Note: See TracBrowser for help on using the repository browser.