source: branches/HeuristicLab.DataAnalysis.Symbolic.LinearInterpreter/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeLinearInterpreter.cs @ 9826

Last change on this file since 9826 was 9826, checked in by mkommend, 9 years ago

#2021: Minor code changes in the linear interpreter branch.

File size: 19.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
33  [StorableClass]
34  [Item("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Fast linear (non-recursive) interpreter for symbolic expression trees. Does not support ADFs.")]
35  public sealed class SymbolicDataAnalysisExpressionTreeLinearInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
36    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
37    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
38
39    private SymbolicDataAnalysisExpressionTreeInterpreter interpreter;
40
41    public override bool CanChangeName {
42      get { return false; }
43    }
44
45    public override bool CanChangeDescription {
46      get { return false; }
47    }
48
49    #region parameter properties
50    public IValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
51      get { return (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
52    }
53
54    public IValueParameter<IntValue> EvaluatedSolutionsParameter {
55      get { return (IValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
56    }
57    #endregion
58
59    #region properties
60    public BoolValue CheckExpressionsWithIntervalArithmetic {
61      get { return CheckExpressionsWithIntervalArithmeticParameter.Value; }
62      set { CheckExpressionsWithIntervalArithmeticParameter.Value = value; }
63    }
64    public IntValue EvaluatedSolutions {
65      get { return EvaluatedSolutionsParameter.Value; }
66      set { EvaluatedSolutionsParameter.Value = value; }
67    }
68    #endregion
69
70    [StorableConstructor]
71    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(bool deserializing)
72      : base(deserializing) {
73    }
74
75    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(
76      SymbolicDataAnalysisExpressionTreeLinearInterpreter original, Cloner cloner)
77      : base(original, cloner) {
78      interpreter = original.interpreter;
79    }
80
81    public override IDeepCloneable Clone(Cloner cloner) {
82      return new SymbolicDataAnalysisExpressionTreeLinearInterpreter(this, cloner);
83    }
84
85    public SymbolicDataAnalysisExpressionTreeLinearInterpreter()
86      : base("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Linear (non-recursive) interpreter for symbolic expression trees (does not support ADFs).") {
87      Parameters.Add(new ValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
88      Parameters.Add(new ValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
89      interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
90    }
91
92    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(string name, string description)
93      : base(name, description) {
94      Parameters.Add(new ValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
95      Parameters.Add(new ValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
96      interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
97    }
98
99    [StorableHook(HookType.AfterDeserialization)]
100    private void AfterDeserialization() {
101      if (!Parameters.ContainsKey(EvaluatedSolutionsParameterName))
102        Parameters.Add(new ValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
103      if (interpreter == null) interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
104    }
105
106    #region IStatefulItem
107
108    public void InitializeState() {
109      EvaluatedSolutions.Value = 0;
110    }
111
112    public void ClearState() {
113    }
114
115    #endregion
116
117    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows) {
118      if (CheckExpressionsWithIntervalArithmetic.Value)
119        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
120
121      lock (EvaluatedSolutions) {
122        EvaluatedSolutions.Value++; // increment the evaluated solutions counter
123      }
124
125      var code = SymbolicExpressionTreeLinearCompiler.Compile(tree, OpCodes.MapSymbolToOpCode);
126      PrepareInstructions(code, dataset);
127      return rows.Select(row => Evaluate(dataset, row, code));
128    }
129
130    private double Evaluate(Dataset dataset, int row, LinearInstruction[] code) {
131      for (int i = code.Length - 1; i >= 0; --i) {
132        if (code[i].skip) continue;
133        #region opcode switch
134        var instr = code[i];
135        switch (instr.opCode) {
136          case OpCodes.Variable: {
137              if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
138              var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
139              instr.value = ((IList<double>)instr.data)[row] * variableTreeNode.Weight;
140            }
141            break;
142          case OpCodes.LagVariable: {
143              var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
144              int actualRow = row + laggedVariableTreeNode.Lag;
145              if (actualRow < 0 || actualRow >= dataset.Rows)
146                instr.value = double.NaN;
147              else
148                instr.value = ((IList<double>)instr.data)[actualRow] * laggedVariableTreeNode.Weight;
149            }
150            break;
151          case OpCodes.VariableCondition: {
152              if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
153              var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
154              double variableValue = ((IList<double>)instr.data)[row];
155              double x = variableValue - variableConditionTreeNode.Threshold;
156              double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
157
158              double trueBranch = code[instr.childIndex].value;
159              double falseBranch = code[instr.childIndex + 1].value;
160
161              instr.value = trueBranch * p + falseBranch * (1 - p);
162            }
163            break;
164          case OpCodes.Add: {
165              double s = code[instr.childIndex].value;
166              for (int j = 1; j != instr.nArguments; ++j) {
167                s += code[instr.childIndex + j].value;
168              }
169              instr.value = s;
170            }
171            break;
172          case OpCodes.Sub: {
173              double s = code[instr.childIndex].value;
174              for (int j = 1; j != instr.nArguments; ++j) {
175                s -= code[instr.childIndex + j].value;
176              }
177              if (instr.nArguments == 1) s = -s;
178              instr.value = s;
179            }
180            break;
181          case OpCodes.Mul: {
182              double p = code[instr.childIndex].value;
183              for (int j = 1; j != instr.nArguments; ++j) {
184                p *= code[instr.childIndex + j].value;
185              }
186              instr.value = p;
187            }
188            break;
189          case OpCodes.Div: {
190              double p = code[instr.childIndex].value;
191              for (int j = 1; j != instr.nArguments; ++j) {
192                p /= code[instr.childIndex + j].value;
193              }
194              if (instr.nArguments == 1) p = 1.0 / p;
195              instr.value = p;
196            }
197            break;
198          case OpCodes.Average: {
199              double s = code[instr.childIndex].value;
200              for (int j = 1; j != instr.nArguments; ++j) {
201                s += code[instr.childIndex + j].value;
202              }
203              instr.value = s / instr.nArguments;
204            }
205            break;
206          case OpCodes.Cos: {
207              instr.value = Math.Cos(code[instr.childIndex].value);
208            }
209            break;
210          case OpCodes.Sin: {
211              instr.value = Math.Sin(code[instr.childIndex].value);
212            }
213            break;
214          case OpCodes.Tan: {
215              instr.value = Math.Tan(code[instr.childIndex].value);
216            }
217            break;
218          case OpCodes.Square: {
219              instr.value = Math.Pow(code[instr.childIndex].value, 2);
220            }
221            break;
222          case OpCodes.Power: {
223              double x = code[instr.childIndex].value;
224              double y = Math.Round(code[instr.childIndex + 1].value);
225              instr.value = Math.Pow(x, y);
226            }
227            break;
228          case OpCodes.SquareRoot: {
229              instr.value = Math.Sqrt(code[instr.childIndex].value);
230            }
231            break;
232          case OpCodes.Root: {
233              double x = code[instr.childIndex].value;
234              double y = code[instr.childIndex + 1].value;
235              instr.value = Math.Pow(x, 1 / y);
236            }
237            break;
238          case OpCodes.Exp: {
239              instr.value = Math.Exp(code[instr.childIndex].value);
240            }
241            break;
242          case OpCodes.Log: {
243              instr.value = Math.Log(code[instr.childIndex].value);
244            }
245            break;
246          case OpCodes.Gamma: {
247              var x = code[instr.childIndex].value;
248              instr.value = double.IsNaN(x) ? double.NaN : alglib.gammafunction(x);
249            }
250            break;
251          case OpCodes.Psi: {
252              var x = code[instr.childIndex].value;
253              if (double.IsNaN(x)) instr.value = double.NaN;
254              else if (x <= 0 && (Math.Floor(x) - x).IsAlmost(0)) instr.value = double.NaN;
255              else instr.value = alglib.psi(x);
256            }
257            break;
258          case OpCodes.Dawson: {
259              var x = code[instr.childIndex].value;
260              instr.value = double.IsNaN(x) ? double.NaN : alglib.dawsonintegral(x);
261            }
262            break;
263          case OpCodes.ExponentialIntegralEi: {
264              var x = code[instr.childIndex].value;
265              instr.value = double.IsNaN(x) ? double.NaN : alglib.exponentialintegralei(x);
266            }
267            break;
268          case OpCodes.SineIntegral: {
269              double si, ci;
270              var x = code[instr.childIndex].value;
271              if (double.IsNaN(x)) instr.value = double.NaN;
272              else {
273                alglib.sinecosineintegrals(x, out si, out ci);
274                instr.value = si;
275              }
276            }
277            break;
278          case OpCodes.CosineIntegral: {
279              double si, ci;
280              var x = code[instr.childIndex].value;
281              if (double.IsNaN(x)) instr.value = double.NaN;
282              else {
283                alglib.sinecosineintegrals(x, out si, out ci);
284                instr.value = ci;
285              }
286            }
287            break;
288          case OpCodes.HyperbolicSineIntegral: {
289              double shi, chi;
290              var x = code[instr.childIndex].value;
291              if (double.IsNaN(x)) instr.value = double.NaN;
292              else {
293                alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
294                instr.value = shi;
295              }
296            }
297            break;
298          case OpCodes.HyperbolicCosineIntegral: {
299              double shi, chi;
300              var x = code[instr.childIndex].value;
301              if (double.IsNaN(x)) instr.value = double.NaN;
302              else {
303                alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
304                instr.value = chi;
305              }
306            }
307            break;
308          case OpCodes.FresnelCosineIntegral: {
309              double c = 0, s = 0;
310              var x = code[instr.childIndex].value;
311              if (double.IsNaN(x)) instr.value = double.NaN;
312              else {
313                alglib.fresnelintegral(x, ref c, ref s);
314                instr.value = c;
315              }
316            }
317            break;
318          case OpCodes.FresnelSineIntegral: {
319              double c = 0, s = 0;
320              var x = code[instr.childIndex].value;
321              if (double.IsNaN(x)) instr.value = double.NaN;
322              else {
323                alglib.fresnelintegral(x, ref c, ref s);
324                instr.value = s;
325              }
326            }
327            break;
328          case OpCodes.AiryA: {
329              double ai, aip, bi, bip;
330              var x = code[instr.childIndex].value;
331              if (double.IsNaN(x)) instr.value = double.NaN;
332              else {
333                alglib.airy(x, out ai, out aip, out bi, out bip);
334                instr.value = ai;
335              }
336            }
337            break;
338          case OpCodes.AiryB: {
339              double ai, aip, bi, bip;
340              var x = code[instr.childIndex].value;
341              if (double.IsNaN(x)) instr.value = double.NaN;
342              else {
343                alglib.airy(x, out ai, out aip, out bi, out bip);
344                instr.value = bi;
345              }
346            }
347            break;
348          case OpCodes.Norm: {
349              var x = code[instr.childIndex].value;
350              if (double.IsNaN(x)) instr.value = double.NaN;
351              else instr.value = alglib.normaldistribution(x);
352            }
353            break;
354          case OpCodes.Erf: {
355              var x = code[instr.childIndex].value;
356              if (double.IsNaN(x)) instr.value = double.NaN;
357              else instr.value = alglib.errorfunction(x);
358            }
359            break;
360          case OpCodes.Bessel: {
361              var x = code[instr.childIndex].value;
362              if (double.IsNaN(x)) instr.value = double.NaN;
363              else instr.value = alglib.besseli0(x);
364            }
365            break;
366          case OpCodes.IfThenElse: {
367              double condition = code[instr.childIndex].value;
368              double result;
369              if (condition > 0.0) {
370                result = code[instr.childIndex + 1].value;
371              } else {
372                result = code[instr.childIndex + 2].value;
373              }
374              instr.value = result;
375            }
376            break;
377          case OpCodes.AND: {
378              double result = code[instr.childIndex].value;
379              for (int j = 1; j < instr.nArguments; j++) {
380                if (result > 0.0) result = code[instr.childIndex + j].value;
381                else break;
382              }
383              instr.value = result > 0.0 ? 1.0 : -1.0;
384            }
385            break;
386          case OpCodes.OR: {
387              double result = code[instr.childIndex].value;
388              for (int j = 1; j < instr.nArguments; j++) {
389                if (result <= 0.0) result = code[instr.childIndex + j].value;
390                else break;
391              }
392              instr.value = result > 0.0 ? 1.0 : -1.0;
393            }
394            break;
395          case OpCodes.NOT: {
396              instr.value = code[instr.childIndex].value > 0.0 ? -1.0 : 1.0;
397            }
398            break;
399          case OpCodes.GT: {
400              double x = code[instr.childIndex].value;
401              double y = code[instr.childIndex + 1].value;
402              instr.value = x > y ? 1.0 : -1.0;
403            }
404            break;
405          case OpCodes.LT: {
406              double x = code[instr.childIndex].value;
407              double y = code[instr.childIndex + 1].value;
408              instr.value = x < y ? 1.0 : -1.0;
409            }
410            break;
411          case OpCodes.TimeLag:
412          case OpCodes.Integral:
413          case OpCodes.Derivative: {
414              var state = (InterpreterState)instr.data;
415              state.Reset();
416              instr.value = interpreter.Evaluate(dataset, ref row, state);
417            }
418            break;
419          default:
420            var errorText = string.Format("The {0} symbol is not supported by the linear interpreter. To support this symbol, please use the SymbolicDataAnalysisExpressionTreeInterpreter.", instr.dynamicNode.Symbol.Name);
421            throw new NotSupportedException(errorText);
422        }
423        #endregion
424      }
425      return code[0].value;
426    }
427
428    private static LinearInstruction[] GetPrefixSequence(LinearInstruction[] code, int startIndex) {
429      var list = new List<LinearInstruction>();
430      int i = startIndex;
431      while (i != code.Length) {
432        var instr = code[i];
433        list.Add(instr);
434        i = instr.nArguments > 0 ? instr.childIndex : i + 1;
435      }
436      return list.ToArray();
437    }
438
439    private static void PrepareInstructions(LinearInstruction[] code, Dataset dataset) {
440      for (int i = 0; i != code.Length; ++i) {
441        var instr = code[i];
442        #region opcode switch
443        switch (instr.opCode) {
444          case OpCodes.Constant: {
445              var constTreeNode = (ConstantTreeNode)instr.dynamicNode;
446              instr.value = constTreeNode.Value;
447              instr.skip = true; // the value is already set so this instruction should be skipped in the evaluation phase
448            }
449            break;
450          case OpCodes.Variable: {
451              var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
452              instr.data = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
453            }
454            break;
455          case OpCodes.LagVariable: {
456              var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
457              instr.data = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
458            }
459            break;
460          case OpCodes.VariableCondition: {
461              var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
462              instr.data = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
463            }
464            break;
465          case OpCodes.TimeLag:
466          case OpCodes.Integral:
467          case OpCodes.Derivative: {
468              var seq = GetPrefixSequence(code, i);
469              var interpreterState = new InterpreterState(seq, 0);
470              instr.data = interpreterState;
471              for (int j = 1; j != seq.Length; ++j)
472                seq[j].skip = true;
473            }
474            break;
475        }
476        #endregion
477      }
478    }
479  }
480}
Note: See TracBrowser for help on using the repository browser.