source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeLinearInterpreter.cs @ 13248

Last change on this file since 13248 was 13248, checked in by mkommend, 7 years ago

#2442: Reintegrated branch for compiled symbolic expression tree interpreter.

File size: 19.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
33  [StorableClass]
34  [Item("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Fast linear (non-recursive) interpreter for symbolic expression trees. Does not support ADFs.")]
35  public sealed class SymbolicDataAnalysisExpressionTreeLinearInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
36    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
37    private const string CheckExpressionsWithIntervalArithmeticParameterDescription = "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.";
38    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
39
40    private SymbolicDataAnalysisExpressionTreeInterpreter interpreter;
41
42    public override bool CanChangeName {
43      get { return false; }
44    }
45
46    public override bool CanChangeDescription {
47      get { return false; }
48    }
49
50    #region parameter properties
51    public IFixedValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
52      get { return (IFixedValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
53    }
54
55    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
56      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
57    }
58    #endregion
59
60    #region properties
61    public bool CheckExpressionsWithIntervalArithmetic {
62      get { return CheckExpressionsWithIntervalArithmeticParameter.Value.Value; }
63      set { CheckExpressionsWithIntervalArithmeticParameter.Value.Value = value; }
64    }
65    public int EvaluatedSolutions {
66      get { return EvaluatedSolutionsParameter.Value.Value; }
67      set { EvaluatedSolutionsParameter.Value.Value = value; }
68    }
69    #endregion
70
71    [StorableConstructor]
72    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(bool deserializing)
73      : base(deserializing) {
74    }
75
76    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(SymbolicDataAnalysisExpressionTreeLinearInterpreter original, Cloner cloner)
77      : base(original, cloner) {
78      interpreter = cloner.Clone(original.interpreter);
79    }
80
81    public override IDeepCloneable Clone(Cloner cloner) {
82      return new SymbolicDataAnalysisExpressionTreeLinearInterpreter(this, cloner);
83    }
84
85    public SymbolicDataAnalysisExpressionTreeLinearInterpreter()
86      : base("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Linear (non-recursive) interpreter for symbolic expression trees (does not support ADFs).") {
87      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, CheckExpressionsWithIntervalArithmeticParameterDescription, new BoolValue(false)));
88      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
89      interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
90    }
91
92    public SymbolicDataAnalysisExpressionTreeLinearInterpreter(string name, string description)
93      : base(name, description) {
94      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, CheckExpressionsWithIntervalArithmeticParameterDescription, new BoolValue(false)));
95      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
96      interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
97    }
98
99    [StorableHook(HookType.AfterDeserialization)]
100    private void AfterDeserialization() {
101      var evaluatedSolutions = new IntValue(0);
102      var checkExpressionsWithIntervalArithmetic = new BoolValue(false);
103      if (Parameters.ContainsKey(EvaluatedSolutionsParameterName)) {
104        var evaluatedSolutionsParameter = (IValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName];
105        evaluatedSolutions = evaluatedSolutionsParameter.Value;
106        Parameters.Remove(EvaluatedSolutionsParameterName);
107      }
108      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", evaluatedSolutions));
109      if (Parameters.ContainsKey(CheckExpressionsWithIntervalArithmeticParameterName)) {
110        var checkExpressionsWithIntervalArithmeticParameter = (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName];
111        Parameters.Remove(CheckExpressionsWithIntervalArithmeticParameterName);
112        checkExpressionsWithIntervalArithmetic = checkExpressionsWithIntervalArithmeticParameter.Value;
113      }
114      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, CheckExpressionsWithIntervalArithmeticParameterDescription, checkExpressionsWithIntervalArithmetic));
115    }
116
117    #region IStatefulItem
118    public void InitializeState() {
119      EvaluatedSolutions = 0;
120    }
121
122    public void ClearState() { }
123    #endregion
124
125    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
126      if (CheckExpressionsWithIntervalArithmetic)
127        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
128
129      lock (EvaluatedSolutionsParameter.Value) {
130        EvaluatedSolutions++; // increment the evaluated solutions counter
131      }
132
133      var code = SymbolicExpressionTreeLinearCompiler.Compile(tree, OpCodes.MapSymbolToOpCode);
134      PrepareInstructions(code, dataset);
135      return rows.Select(row => Evaluate(dataset, row, code));
136    }
137
138    private double Evaluate(IDataset dataset, int row, LinearInstruction[] code) {
139      for (int i = code.Length - 1; i >= 0; --i) {
140        if (code[i].skip) continue;
141        #region opcode if
142        var instr = code[i];
143        if (instr.opCode == OpCodes.Variable) {
144          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
145          var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
146          instr.value = ((IList<double>)instr.data)[row] * variableTreeNode.Weight;
147        } else if (instr.opCode == OpCodes.LagVariable) {
148          var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
149          int actualRow = row + laggedVariableTreeNode.Lag;
150          if (actualRow < 0 || actualRow >= dataset.Rows)
151            instr.value = double.NaN;
152          else
153            instr.value = ((IList<double>)instr.data)[actualRow] * laggedVariableTreeNode.Weight;
154        } else if (instr.opCode == OpCodes.VariableCondition) {
155          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
156          var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
157          double variableValue = ((IList<double>)instr.data)[row];
158          double x = variableValue - variableConditionTreeNode.Threshold;
159          double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
160
161          double trueBranch = code[instr.childIndex].value;
162          double falseBranch = code[instr.childIndex + 1].value;
163
164          instr.value = trueBranch * p + falseBranch * (1 - p);
165        } else if (instr.opCode == OpCodes.Add) {
166          double s = code[instr.childIndex].value;
167          for (int j = 1; j != instr.nArguments; ++j) {
168            s += code[instr.childIndex + j].value;
169          }
170          instr.value = s;
171        } else if (instr.opCode == OpCodes.Sub) {
172          double s = code[instr.childIndex].value;
173          for (int j = 1; j != instr.nArguments; ++j) {
174            s -= code[instr.childIndex + j].value;
175          }
176          if (instr.nArguments == 1) s = -s;
177          instr.value = s;
178        } else if (instr.opCode == OpCodes.Mul) {
179          double p = code[instr.childIndex].value;
180          for (int j = 1; j != instr.nArguments; ++j) {
181            p *= code[instr.childIndex + j].value;
182          }
183          instr.value = p;
184        } else if (instr.opCode == OpCodes.Div) {
185          double p = code[instr.childIndex].value;
186          for (int j = 1; j != instr.nArguments; ++j) {
187            p /= code[instr.childIndex + j].value;
188          }
189          if (instr.nArguments == 1) p = 1.0 / p;
190          instr.value = p;
191        } else if (instr.opCode == OpCodes.Average) {
192          double s = code[instr.childIndex].value;
193          for (int j = 1; j != instr.nArguments; ++j) {
194            s += code[instr.childIndex + j].value;
195          }
196          instr.value = s / instr.nArguments;
197        } else if (instr.opCode == OpCodes.Cos) {
198          instr.value = Math.Cos(code[instr.childIndex].value);
199        } else if (instr.opCode == OpCodes.Sin) {
200          instr.value = Math.Sin(code[instr.childIndex].value);
201        } else if (instr.opCode == OpCodes.Tan) {
202          instr.value = Math.Tan(code[instr.childIndex].value);
203        } else if (instr.opCode == OpCodes.Square) {
204          instr.value = Math.Pow(code[instr.childIndex].value, 2);
205        } else if (instr.opCode == OpCodes.Power) {
206          double x = code[instr.childIndex].value;
207          double y = Math.Round(code[instr.childIndex + 1].value);
208          instr.value = Math.Pow(x, y);
209        } else if (instr.opCode == OpCodes.SquareRoot) {
210          instr.value = Math.Sqrt(code[instr.childIndex].value);
211        } else if (instr.opCode == OpCodes.Root) {
212          double x = code[instr.childIndex].value;
213          double y = code[instr.childIndex + 1].value;
214          instr.value = Math.Pow(x, 1 / y);
215        } else if (instr.opCode == OpCodes.Exp) {
216          instr.value = Math.Exp(code[instr.childIndex].value);
217        } else if (instr.opCode == OpCodes.Log) {
218          instr.value = Math.Log(code[instr.childIndex].value);
219        } else if (instr.opCode == OpCodes.Gamma) {
220          var x = code[instr.childIndex].value;
221          instr.value = double.IsNaN(x) ? double.NaN : alglib.gammafunction(x);
222        } else if (instr.opCode == OpCodes.Psi) {
223          var x = code[instr.childIndex].value;
224          if (double.IsNaN(x)) instr.value = double.NaN;
225          else if (x <= 0 && (Math.Floor(x) - x).IsAlmost(0)) instr.value = double.NaN;
226          else instr.value = alglib.psi(x);
227        } else if (instr.opCode == OpCodes.Dawson) {
228          var x = code[instr.childIndex].value;
229          instr.value = double.IsNaN(x) ? double.NaN : alglib.dawsonintegral(x);
230        } else if (instr.opCode == OpCodes.ExponentialIntegralEi) {
231          var x = code[instr.childIndex].value;
232          instr.value = double.IsNaN(x) ? double.NaN : alglib.exponentialintegralei(x);
233        } else if (instr.opCode == OpCodes.SineIntegral) {
234          double si, ci;
235          var x = code[instr.childIndex].value;
236          if (double.IsNaN(x)) instr.value = double.NaN;
237          else {
238            alglib.sinecosineintegrals(x, out si, out ci);
239            instr.value = si;
240          }
241        } else if (instr.opCode == OpCodes.CosineIntegral) {
242          double si, ci;
243          var x = code[instr.childIndex].value;
244          if (double.IsNaN(x)) instr.value = double.NaN;
245          else {
246            alglib.sinecosineintegrals(x, out si, out ci);
247            instr.value = ci;
248          }
249        } else if (instr.opCode == OpCodes.HyperbolicSineIntegral) {
250          double shi, chi;
251          var x = code[instr.childIndex].value;
252          if (double.IsNaN(x)) instr.value = double.NaN;
253          else {
254            alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
255            instr.value = shi;
256          }
257        } else if (instr.opCode == OpCodes.HyperbolicCosineIntegral) {
258          double shi, chi;
259          var x = code[instr.childIndex].value;
260          if (double.IsNaN(x)) instr.value = double.NaN;
261          else {
262            alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
263            instr.value = chi;
264          }
265        } else if (instr.opCode == OpCodes.FresnelCosineIntegral) {
266          double c = 0, s = 0;
267          var x = code[instr.childIndex].value;
268          if (double.IsNaN(x)) instr.value = double.NaN;
269          else {
270            alglib.fresnelintegral(x, ref c, ref s);
271            instr.value = c;
272          }
273        } else if (instr.opCode == OpCodes.FresnelSineIntegral) {
274          double c = 0, s = 0;
275          var x = code[instr.childIndex].value;
276          if (double.IsNaN(x)) instr.value = double.NaN;
277          else {
278            alglib.fresnelintegral(x, ref c, ref s);
279            instr.value = s;
280          }
281        } else if (instr.opCode == OpCodes.AiryA) {
282          double ai, aip, bi, bip;
283          var x = code[instr.childIndex].value;
284          if (double.IsNaN(x)) instr.value = double.NaN;
285          else {
286            alglib.airy(x, out ai, out aip, out bi, out bip);
287            instr.value = ai;
288          }
289        } else if (instr.opCode == OpCodes.AiryB) {
290          double ai, aip, bi, bip;
291          var x = code[instr.childIndex].value;
292          if (double.IsNaN(x)) instr.value = double.NaN;
293          else {
294            alglib.airy(x, out ai, out aip, out bi, out bip);
295            instr.value = bi;
296          }
297        } else if (instr.opCode == OpCodes.Norm) {
298          var x = code[instr.childIndex].value;
299          if (double.IsNaN(x)) instr.value = double.NaN;
300          else instr.value = alglib.normaldistribution(x);
301        } else if (instr.opCode == OpCodes.Erf) {
302          var x = code[instr.childIndex].value;
303          if (double.IsNaN(x)) instr.value = double.NaN;
304          else instr.value = alglib.errorfunction(x);
305        } else if (instr.opCode == OpCodes.Bessel) {
306          var x = code[instr.childIndex].value;
307          if (double.IsNaN(x)) instr.value = double.NaN;
308          else instr.value = alglib.besseli0(x);
309        } else if (instr.opCode == OpCodes.IfThenElse) {
310          double condition = code[instr.childIndex].value;
311          double result;
312          if (condition > 0.0) {
313            result = code[instr.childIndex + 1].value;
314          } else {
315            result = code[instr.childIndex + 2].value;
316          }
317          instr.value = result;
318        } else if (instr.opCode == OpCodes.AND) {
319          double result = code[instr.childIndex].value;
320          for (int j = 1; j < instr.nArguments; j++) {
321            if (result > 0.0) result = code[instr.childIndex + j].value;
322            else break;
323          }
324          instr.value = result > 0.0 ? 1.0 : -1.0;
325        } else if (instr.opCode == OpCodes.OR) {
326          double result = code[instr.childIndex].value;
327          for (int j = 1; j < instr.nArguments; j++) {
328            if (result <= 0.0) result = code[instr.childIndex + j].value;
329            else break;
330          }
331          instr.value = result > 0.0 ? 1.0 : -1.0;
332        } else if (instr.opCode == OpCodes.NOT) {
333          instr.value = code[instr.childIndex].value > 0.0 ? -1.0 : 1.0;
334        } else if (instr.opCode == OpCodes.XOR) {
335          int positiveSignals = 0;
336          for (int j = 0; j < instr.nArguments; j++) {
337            if (code[instr.childIndex + j].value > 0.0) positiveSignals++;
338          }
339          instr.value = positiveSignals % 2 != 0 ? 1.0 : -1.0;
340        } else if (instr.opCode == OpCodes.GT) {
341          double x = code[instr.childIndex].value;
342          double y = code[instr.childIndex + 1].value;
343          instr.value = x > y ? 1.0 : -1.0;
344        } else if (instr.opCode == OpCodes.LT) {
345          double x = code[instr.childIndex].value;
346          double y = code[instr.childIndex + 1].value;
347          instr.value = x < y ? 1.0 : -1.0;
348        } else if (instr.opCode == OpCodes.TimeLag || instr.opCode == OpCodes.Derivative || instr.opCode == OpCodes.Integral) {
349          var state = (InterpreterState)instr.data;
350          state.Reset();
351          instr.value = interpreter.Evaluate(dataset, ref row, state);
352        } else {
353          var errorText = string.Format("The {0} symbol is not supported by the linear interpreter. To support this symbol, please use the SymbolicDataAnalysisExpressionTreeInterpreter.", instr.dynamicNode.Symbol.Name);
354          throw new NotSupportedException(errorText);
355        }
356        #endregion
357      }
358      return code[0].value;
359    }
360
361    private static LinearInstruction[] GetPrefixSequence(LinearInstruction[] code, int startIndex) {
362      var s = new Stack<int>();
363      var list = new List<LinearInstruction>();
364      s.Push(startIndex);
365      while (s.Any()) {
366        int i = s.Pop();
367        var instr = code[i];
368        // push instructions in reverse execution order
369        for (int j = instr.nArguments - 1; j >= 0; j--) s.Push(instr.childIndex + j);
370        list.Add(instr);
371      }
372      return list.ToArray();
373    }
374
375    public static void PrepareInstructions(LinearInstruction[] code, IDataset dataset) {
376      for (int i = 0; i != code.Length; ++i) {
377        var instr = code[i];
378        #region opcode switch
379        switch (instr.opCode) {
380          case OpCodes.Constant: {
381              var constTreeNode = (ConstantTreeNode)instr.dynamicNode;
382              instr.value = constTreeNode.Value;
383              instr.skip = true; // the value is already set so this instruction should be skipped in the evaluation phase
384            }
385            break;
386          case OpCodes.Variable: {
387              var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
388              instr.data = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
389            }
390            break;
391          case OpCodes.LagVariable: {
392              var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
393              instr.data = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
394            }
395            break;
396          case OpCodes.VariableCondition: {
397              var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
398              instr.data = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
399            }
400            break;
401          case OpCodes.TimeLag:
402          case OpCodes.Integral:
403          case OpCodes.Derivative: {
404              var seq = GetPrefixSequence(code, i);
405              var interpreterState = new InterpreterState(seq, 0);
406              instr.data = interpreterState;
407              for (int j = 1; j != seq.Length; ++j)
408                seq[j].skip = true;
409            }
410            break;
411        }
412        #endregion
413      }
414    }
415  }
416}
Note: See TracBrowser for help on using the repository browser.