source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeLinearInterpreter.cs @ 9828

Last change on this file since 9828 was 9828, checked in by mkommend, 9 years ago

#2021: Integrated the linear interpreter in the trunk and restructed interpreter unit tests.

File size: 18.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
33  [StorableClass]
34  [Item("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Fast linear (non-recursive) interpreter for symbolic expression trees. Does not support ADFs.")]
35  public sealed class SymbolicDataAnalysisExpressionTreeLinearInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
36    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
37    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
38
39    private SymbolicDataAnalysisExpressionTreeInterpreter interpreter;
40
41    public override bool CanChangeName {
42      get { return false; }
43    }
44
45    public override bool CanChangeDescription {
46      get { return false; }
47    }
48
49    #region parameter properties
50    public IValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
51      get { return (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
52    }
53
54    public IValueParameter<IntValue> EvaluatedSolutionsParameter {
55      get { return (IValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
56    }
57    #endregion
58
59    #region properties
60    public BoolValue CheckExpressionsWithIntervalArithmetic {
61      get { return CheckExpressionsWithIntervalArithmeticParameter.Value; }
62      set { CheckExpressionsWithIntervalArithmeticParameter.Value = value; }
63    }
64    public IntValue EvaluatedSolutions {
65      get { return EvaluatedSolutionsParameter.Value; }
66      set { EvaluatedSolutionsParameter.Value = value; }
67    }
68    #endregion
69
70    [StorableConstructor]
71    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(bool deserializing)
72      : base(deserializing) {
73    }
74
75    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(SymbolicDataAnalysisExpressionTreeLinearInterpreter original, Cloner cloner)
76      : base(original, cloner) {
77      interpreter = cloner.Clone(original.interpreter);
78    }
79
80    public override IDeepCloneable Clone(Cloner cloner) {
81      return new SymbolicDataAnalysisExpressionTreeLinearInterpreter(this, cloner);
82    }
83
84    public SymbolicDataAnalysisExpressionTreeLinearInterpreter()
85      : base("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Linear (non-recursive) interpreter for symbolic expression trees (does not support ADFs).") {
86      Parameters.Add(new ValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
87      Parameters.Add(new ValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
88      interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
89    }
90
91    [StorableHook(HookType.AfterDeserialization)]
92    private void AfterDeserialization() {
93      if (interpreter == null) interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
94    }
95
96    #region IStatefulItem
97    public void InitializeState() {
98      EvaluatedSolutions.Value = 0;
99    }
100
101    public void ClearState() { }
102    #endregion
103
104    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows) {
105      if (CheckExpressionsWithIntervalArithmetic.Value)
106        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
107
108      lock (EvaluatedSolutions) {
109        EvaluatedSolutions.Value++; // increment the evaluated solutions counter
110      }
111
112      var code = SymbolicExpressionTreeLinearCompiler.Compile(tree, OpCodes.MapSymbolToOpCode);
113      PrepareInstructions(code, dataset);
114      return rows.Select(row => Evaluate(dataset, row, code));
115    }
116
117    private double Evaluate(Dataset dataset, int row, LinearInstruction[] code) {
118      for (int i = code.Length - 1; i >= 0; --i) {
119        if (code[i].skip) continue;
120        #region opcode switch
121        var instr = code[i];
122        switch (instr.opCode) {
123          case OpCodes.Variable: {
124              if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
125              var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
126              instr.value = ((IList<double>)instr.data)[row] * variableTreeNode.Weight;
127            }
128            break;
129          case OpCodes.LagVariable: {
130              var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
131              int actualRow = row + laggedVariableTreeNode.Lag;
132              if (actualRow < 0 || actualRow >= dataset.Rows)
133                instr.value = double.NaN;
134              else
135                instr.value = ((IList<double>)instr.data)[actualRow] * laggedVariableTreeNode.Weight;
136            }
137            break;
138          case OpCodes.VariableCondition: {
139              if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
140              var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
141              double variableValue = ((IList<double>)instr.data)[row];
142              double x = variableValue - variableConditionTreeNode.Threshold;
143              double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
144
145              double trueBranch = code[instr.childIndex].value;
146              double falseBranch = code[instr.childIndex + 1].value;
147
148              instr.value = trueBranch * p + falseBranch * (1 - p);
149            }
150            break;
151          case OpCodes.Add: {
152              double s = code[instr.childIndex].value;
153              for (int j = 1; j != instr.nArguments; ++j) {
154                s += code[instr.childIndex + j].value;
155              }
156              instr.value = s;
157            }
158            break;
159          case OpCodes.Sub: {
160              double s = code[instr.childIndex].value;
161              for (int j = 1; j != instr.nArguments; ++j) {
162                s -= code[instr.childIndex + j].value;
163              }
164              if (instr.nArguments == 1) s = -s;
165              instr.value = s;
166            }
167            break;
168          case OpCodes.Mul: {
169              double p = code[instr.childIndex].value;
170              for (int j = 1; j != instr.nArguments; ++j) {
171                p *= code[instr.childIndex + j].value;
172              }
173              instr.value = p;
174            }
175            break;
176          case OpCodes.Div: {
177              double p = code[instr.childIndex].value;
178              for (int j = 1; j != instr.nArguments; ++j) {
179                p /= code[instr.childIndex + j].value;
180              }
181              if (instr.nArguments == 1) p = 1.0 / p;
182              instr.value = p;
183            }
184            break;
185          case OpCodes.Average: {
186              double s = code[instr.childIndex].value;
187              for (int j = 1; j != instr.nArguments; ++j) {
188                s += code[instr.childIndex + j].value;
189              }
190              instr.value = s / instr.nArguments;
191            }
192            break;
193          case OpCodes.Cos: {
194              instr.value = Math.Cos(code[instr.childIndex].value);
195            }
196            break;
197          case OpCodes.Sin: {
198              instr.value = Math.Sin(code[instr.childIndex].value);
199            }
200            break;
201          case OpCodes.Tan: {
202              instr.value = Math.Tan(code[instr.childIndex].value);
203            }
204            break;
205          case OpCodes.Square: {
206              instr.value = Math.Pow(code[instr.childIndex].value, 2);
207            }
208            break;
209          case OpCodes.Power: {
210              double x = code[instr.childIndex].value;
211              double y = Math.Round(code[instr.childIndex + 1].value);
212              instr.value = Math.Pow(x, y);
213            }
214            break;
215          case OpCodes.SquareRoot: {
216              instr.value = Math.Sqrt(code[instr.childIndex].value);
217            }
218            break;
219          case OpCodes.Root: {
220              double x = code[instr.childIndex].value;
221              double y = code[instr.childIndex + 1].value;
222              instr.value = Math.Pow(x, 1 / y);
223            }
224            break;
225          case OpCodes.Exp: {
226              instr.value = Math.Exp(code[instr.childIndex].value);
227            }
228            break;
229          case OpCodes.Log: {
230              instr.value = Math.Log(code[instr.childIndex].value);
231            }
232            break;
233          case OpCodes.Gamma: {
234              var x = code[instr.childIndex].value;
235              instr.value = double.IsNaN(x) ? double.NaN : alglib.gammafunction(x);
236            }
237            break;
238          case OpCodes.Psi: {
239              var x = code[instr.childIndex].value;
240              if (double.IsNaN(x)) instr.value = double.NaN;
241              else if (x <= 0 && (Math.Floor(x) - x).IsAlmost(0)) instr.value = double.NaN;
242              else instr.value = alglib.psi(x);
243            }
244            break;
245          case OpCodes.Dawson: {
246              var x = code[instr.childIndex].value;
247              instr.value = double.IsNaN(x) ? double.NaN : alglib.dawsonintegral(x);
248            }
249            break;
250          case OpCodes.ExponentialIntegralEi: {
251              var x = code[instr.childIndex].value;
252              instr.value = double.IsNaN(x) ? double.NaN : alglib.exponentialintegralei(x);
253            }
254            break;
255          case OpCodes.SineIntegral: {
256              double si, ci;
257              var x = code[instr.childIndex].value;
258              if (double.IsNaN(x)) instr.value = double.NaN;
259              else {
260                alglib.sinecosineintegrals(x, out si, out ci);
261                instr.value = si;
262              }
263            }
264            break;
265          case OpCodes.CosineIntegral: {
266              double si, ci;
267              var x = code[instr.childIndex].value;
268              if (double.IsNaN(x)) instr.value = double.NaN;
269              else {
270                alglib.sinecosineintegrals(x, out si, out ci);
271                instr.value = ci;
272              }
273            }
274            break;
275          case OpCodes.HyperbolicSineIntegral: {
276              double shi, chi;
277              var x = code[instr.childIndex].value;
278              if (double.IsNaN(x)) instr.value = double.NaN;
279              else {
280                alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
281                instr.value = shi;
282              }
283            }
284            break;
285          case OpCodes.HyperbolicCosineIntegral: {
286              double shi, chi;
287              var x = code[instr.childIndex].value;
288              if (double.IsNaN(x)) instr.value = double.NaN;
289              else {
290                alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
291                instr.value = chi;
292              }
293            }
294            break;
295          case OpCodes.FresnelCosineIntegral: {
296              double c = 0, s = 0;
297              var x = code[instr.childIndex].value;
298              if (double.IsNaN(x)) instr.value = double.NaN;
299              else {
300                alglib.fresnelintegral(x, ref c, ref s);
301                instr.value = c;
302              }
303            }
304            break;
305          case OpCodes.FresnelSineIntegral: {
306              double c = 0, s = 0;
307              var x = code[instr.childIndex].value;
308              if (double.IsNaN(x)) instr.value = double.NaN;
309              else {
310                alglib.fresnelintegral(x, ref c, ref s);
311                instr.value = s;
312              }
313            }
314            break;
315          case OpCodes.AiryA: {
316              double ai, aip, bi, bip;
317              var x = code[instr.childIndex].value;
318              if (double.IsNaN(x)) instr.value = double.NaN;
319              else {
320                alglib.airy(x, out ai, out aip, out bi, out bip);
321                instr.value = ai;
322              }
323            }
324            break;
325          case OpCodes.AiryB: {
326              double ai, aip, bi, bip;
327              var x = code[instr.childIndex].value;
328              if (double.IsNaN(x)) instr.value = double.NaN;
329              else {
330                alglib.airy(x, out ai, out aip, out bi, out bip);
331                instr.value = bi;
332              }
333            }
334            break;
335          case OpCodes.Norm: {
336              var x = code[instr.childIndex].value;
337              if (double.IsNaN(x)) instr.value = double.NaN;
338              else instr.value = alglib.normaldistribution(x);
339            }
340            break;
341          case OpCodes.Erf: {
342              var x = code[instr.childIndex].value;
343              if (double.IsNaN(x)) instr.value = double.NaN;
344              else instr.value = alglib.errorfunction(x);
345            }
346            break;
347          case OpCodes.Bessel: {
348              var x = code[instr.childIndex].value;
349              if (double.IsNaN(x)) instr.value = double.NaN;
350              else instr.value = alglib.besseli0(x);
351            }
352            break;
353          case OpCodes.IfThenElse: {
354              double condition = code[instr.childIndex].value;
355              double result;
356              if (condition > 0.0) {
357                result = code[instr.childIndex + 1].value;
358              } else {
359                result = code[instr.childIndex + 2].value;
360              }
361              instr.value = result;
362            }
363            break;
364          case OpCodes.AND: {
365              double result = code[instr.childIndex].value;
366              for (int j = 1; j < instr.nArguments; j++) {
367                if (result > 0.0) result = code[instr.childIndex + j].value;
368                else break;
369              }
370              instr.value = result > 0.0 ? 1.0 : -1.0;
371            }
372            break;
373          case OpCodes.OR: {
374              double result = code[instr.childIndex].value;
375              for (int j = 1; j < instr.nArguments; j++) {
376                if (result <= 0.0) result = code[instr.childIndex + j].value;
377                else break;
378              }
379              instr.value = result > 0.0 ? 1.0 : -1.0;
380            }
381            break;
382          case OpCodes.NOT: {
383              instr.value = code[instr.childIndex].value > 0.0 ? -1.0 : 1.0;
384            }
385            break;
386          case OpCodes.GT: {
387              double x = code[instr.childIndex].value;
388              double y = code[instr.childIndex + 1].value;
389              instr.value = x > y ? 1.0 : -1.0;
390            }
391            break;
392          case OpCodes.LT: {
393              double x = code[instr.childIndex].value;
394              double y = code[instr.childIndex + 1].value;
395              instr.value = x < y ? 1.0 : -1.0;
396            }
397            break;
398          case OpCodes.TimeLag:
399          case OpCodes.Integral:
400          case OpCodes.Derivative: {
401              var state = (InterpreterState)instr.data;
402              state.Reset();
403              instr.value = interpreter.Evaluate(dataset, ref row, state);
404            }
405            break;
406          default:
407            var errorText = string.Format("The {0} symbol is not supported by the linear interpreter. To support this symbol, please use the SymbolicDataAnalysisExpressionTreeInterpreter.", instr.dynamicNode.Symbol.Name);
408            throw new NotSupportedException(errorText);
409        }
410        #endregion
411      }
412      return code[0].value;
413    }
414
415    private static LinearInstruction[] GetPrefixSequence(LinearInstruction[] code, int startIndex) {
416      var list = new List<LinearInstruction>();
417      int i = startIndex;
418      while (i != code.Length) {
419        var instr = code[i];
420        list.Add(instr);
421        i = instr.nArguments > 0 ? instr.childIndex : i + 1;
422      }
423      return list.ToArray();
424    }
425
426    private static void PrepareInstructions(LinearInstruction[] code, Dataset dataset) {
427      for (int i = 0; i != code.Length; ++i) {
428        var instr = code[i];
429        #region opcode switch
430        switch (instr.opCode) {
431          case OpCodes.Constant: {
432              var constTreeNode = (ConstantTreeNode)instr.dynamicNode;
433              instr.value = constTreeNode.Value;
434              instr.skip = true; // the value is already set so this instruction should be skipped in the evaluation phase
435            }
436            break;
437          case OpCodes.Variable: {
438              var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
439              instr.data = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
440            }
441            break;
442          case OpCodes.LagVariable: {
443              var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
444              instr.data = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
445            }
446            break;
447          case OpCodes.VariableCondition: {
448              var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
449              instr.data = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
450            }
451            break;
452          case OpCodes.TimeLag:
453          case OpCodes.Integral:
454          case OpCodes.Derivative: {
455              var seq = GetPrefixSequence(code, i);
456              var interpreterState = new InterpreterState(seq, 0);
457              instr.data = interpreterState;
458              for (int j = 1; j != seq.Length; ++j)
459                seq[j].skip = true;
460            }
461            break;
462        }
463        #endregion
464      }
465    }
466  }
467}
Note: See TracBrowser for help on using the repository browser.