Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeLinearInterpreter.cs @ 9871

Last change on this file since 9871 was 9871, checked in by bburlacu, 11 years ago

#2021: Replaced the opcode switch block in the Evaluate() method with an if-else block, as it apparently improves performance.

File size: 17.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
33  [StorableClass]
34  [Item("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Fast linear (non-recursive) interpreter for symbolic expression trees. Does not support ADFs.")]
35  public sealed class SymbolicDataAnalysisExpressionTreeLinearInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
36    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
37    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
38
39    private SymbolicDataAnalysisExpressionTreeInterpreter interpreter;
40
41    public override bool CanChangeName {
42      get { return false; }
43    }
44
45    public override bool CanChangeDescription {
46      get { return false; }
47    }
48
49    #region parameter properties
50    public IValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
51      get { return (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
52    }
53
54    public IValueParameter<IntValue> EvaluatedSolutionsParameter {
55      get { return (IValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
56    }
57    #endregion
58
59    #region properties
60    public BoolValue CheckExpressionsWithIntervalArithmetic {
61      get { return CheckExpressionsWithIntervalArithmeticParameter.Value; }
62      set { CheckExpressionsWithIntervalArithmeticParameter.Value = value; }
63    }
64    public IntValue EvaluatedSolutions {
65      get { return EvaluatedSolutionsParameter.Value; }
66      set { EvaluatedSolutionsParameter.Value = value; }
67    }
68    #endregion
69
70    [StorableConstructor]
71    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(bool deserializing)
72      : base(deserializing) {
73    }
74
75    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(SymbolicDataAnalysisExpressionTreeLinearInterpreter original, Cloner cloner)
76      : base(original, cloner) {
77      interpreter = cloner.Clone(original.interpreter);
78    }
79
80    public override IDeepCloneable Clone(Cloner cloner) {
81      return new SymbolicDataAnalysisExpressionTreeLinearInterpreter(this, cloner);
82    }
83
84    public SymbolicDataAnalysisExpressionTreeLinearInterpreter()
85      : base("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Linear (non-recursive) interpreter for symbolic expression trees (does not support ADFs).") {
86      Parameters.Add(new ValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
87      Parameters.Add(new ValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
88      interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
89    }
90
91    [StorableHook(HookType.AfterDeserialization)]
92    private void AfterDeserialization() {
93      if (interpreter == null) interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
94    }
95
96    #region IStatefulItem
97    public void InitializeState() {
98      EvaluatedSolutions.Value = 0;
99    }
100
101    public void ClearState() { }
102    #endregion
103
104    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows) {
105      if (CheckExpressionsWithIntervalArithmetic.Value)
106        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
107
108      lock (EvaluatedSolutions) {
109        EvaluatedSolutions.Value++; // increment the evaluated solutions counter
110      }
111
112      var code = SymbolicExpressionTreeLinearCompiler.Compile(tree, OpCodes.MapSymbolToOpCode);
113      PrepareInstructions(code, dataset);
114      return rows.Select(row => Evaluate(dataset, row, code));
115    }
116
117    private double Evaluate(Dataset dataset, int row, LinearInstruction[] code) {
118      for (int i = code.Length - 1; i >= 0; --i) {
119        if (code[i].skip) continue;
120        #region opcode if
121        var instr = code[i];
122        if (instr.opCode == OpCodes.Variable) {
123          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
124          var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
125          instr.value = ((IList<double>)instr.data)[row] * variableTreeNode.Weight;
126        } else if (instr.opCode == OpCodes.LagVariable) {
127          var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
128          int actualRow = row + laggedVariableTreeNode.Lag;
129          if (actualRow < 0 || actualRow >= dataset.Rows)
130            instr.value = double.NaN;
131          else
132            instr.value = ((IList<double>)instr.data)[actualRow] * laggedVariableTreeNode.Weight;
133        } else if (instr.opCode == OpCodes.VariableCondition) {
134          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
135          var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
136          double variableValue = ((IList<double>)instr.data)[row];
137          double x = variableValue - variableConditionTreeNode.Threshold;
138          double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
139
140          double trueBranch = code[instr.childIndex].value;
141          double falseBranch = code[instr.childIndex + 1].value;
142
143          instr.value = trueBranch * p + falseBranch * (1 - p);
144        } else if (instr.opCode == OpCodes.Add) {
145          double s = code[instr.childIndex].value;
146          for (int j = 1; j != instr.nArguments; ++j) {
147            s += code[instr.childIndex + j].value;
148          }
149          instr.value = s;
150        } else if (instr.opCode == OpCodes.Sub) {
151          double s = code[instr.childIndex].value;
152          for (int j = 1; j != instr.nArguments; ++j) {
153            s -= code[instr.childIndex + j].value;
154          }
155          if (instr.nArguments == 1) s = -s;
156          instr.value = s;
157        } else if (instr.opCode == OpCodes.Mul) {
158          double p = code[instr.childIndex].value;
159          for (int j = 1; j != instr.nArguments; ++j) {
160            p *= code[instr.childIndex + j].value;
161          }
162          instr.value = p;
163        } else if (instr.opCode == OpCodes.Div) {
164          double p = code[instr.childIndex].value;
165          for (int j = 1; j != instr.nArguments; ++j) {
166            p /= code[instr.childIndex + j].value;
167          }
168          if (instr.nArguments == 1) p = 1.0 / p;
169          instr.value = p;
170        } else if (instr.opCode == OpCodes.Average) {
171          double s = code[instr.childIndex].value;
172          for (int j = 1; j != instr.nArguments; ++j) {
173            s += code[instr.childIndex + j].value;
174          }
175          instr.value = s / instr.nArguments;
176        } else if (instr.opCode == OpCodes.Cos) {
177          instr.value = Math.Cos(code[instr.childIndex].value);
178        } else if (instr.opCode == OpCodes.Sin) {
179          instr.value = Math.Sin(code[instr.childIndex].value);
180        } else if (instr.opCode == OpCodes.Tan) {
181          instr.value = Math.Tan(code[instr.childIndex].value);
182        } else if (instr.opCode == OpCodes.Square) {
183          instr.value = Math.Pow(code[instr.childIndex].value, 2);
184        } else if (instr.opCode == OpCodes.Power) {
185          double x = code[instr.childIndex].value;
186          double y = Math.Round(code[instr.childIndex + 1].value);
187          instr.value = Math.Pow(x, y);
188        } else if (instr.opCode == OpCodes.SquareRoot) {
189          instr.value = Math.Sqrt(code[instr.childIndex].value);
190        } else if (instr.opCode == OpCodes.Root) {
191          double x = code[instr.childIndex].value;
192          double y = code[instr.childIndex + 1].value;
193          instr.value = Math.Pow(x, 1 / y);
194        } else if (instr.opCode == OpCodes.Exp) {
195          instr.value = Math.Exp(code[instr.childIndex].value);
196        } else if (instr.opCode == OpCodes.Log) {
197          instr.value = Math.Log(code[instr.childIndex].value);
198        } else if (instr.opCode == OpCodes.Gamma) {
199          var x = code[instr.childIndex].value;
200          instr.value = double.IsNaN(x) ? double.NaN : alglib.gammafunction(x);
201        } else if (instr.opCode == OpCodes.Psi) {
202          var x = code[instr.childIndex].value;
203          if (double.IsNaN(x)) instr.value = double.NaN;
204          else if (x <= 0 && (Math.Floor(x) - x).IsAlmost(0)) instr.value = double.NaN;
205          else instr.value = alglib.psi(x);
206        } else if (instr.opCode == OpCodes.Dawson) {
207          var x = code[instr.childIndex].value;
208          instr.value = double.IsNaN(x) ? double.NaN : alglib.dawsonintegral(x);
209        } else if (instr.opCode == OpCodes.ExponentialIntegralEi) {
210          var x = code[instr.childIndex].value;
211          instr.value = double.IsNaN(x) ? double.NaN : alglib.exponentialintegralei(x);
212        } else if (instr.opCode == OpCodes.SineIntegral) {
213          double si, ci;
214          var x = code[instr.childIndex].value;
215          if (double.IsNaN(x)) instr.value = double.NaN;
216          else {
217            alglib.sinecosineintegrals(x, out si, out ci);
218            instr.value = si;
219          }
220        } else if (instr.opCode == OpCodes.CosineIntegral) {
221          double si, ci;
222          var x = code[instr.childIndex].value;
223          if (double.IsNaN(x)) instr.value = double.NaN;
224          else {
225            alglib.sinecosineintegrals(x, out si, out ci);
226            instr.value = ci;
227          }
228        } else if (instr.opCode == OpCodes.HyperbolicSineIntegral) {
229          double shi, chi;
230          var x = code[instr.childIndex].value;
231          if (double.IsNaN(x)) instr.value = double.NaN;
232          else {
233            alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
234            instr.value = shi;
235          }
236        } else if (instr.opCode == OpCodes.HyperbolicCosineIntegral) {
237          double shi, chi;
238          var x = code[instr.childIndex].value;
239          if (double.IsNaN(x)) instr.value = double.NaN;
240          else {
241            alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
242            instr.value = chi;
243          }
244        } else if (instr.opCode == OpCodes.FresnelCosineIntegral) {
245          double c = 0, s = 0;
246          var x = code[instr.childIndex].value;
247          if (double.IsNaN(x)) instr.value = double.NaN;
248          else {
249            alglib.fresnelintegral(x, ref c, ref s);
250            instr.value = c;
251          }
252        } else if (instr.opCode == OpCodes.FresnelSineIntegral) {
253          double c = 0, s = 0;
254          var x = code[instr.childIndex].value;
255          if (double.IsNaN(x)) instr.value = double.NaN;
256          else {
257            alglib.fresnelintegral(x, ref c, ref s);
258            instr.value = s;
259          }
260        } else if (instr.opCode == OpCodes.AiryA) {
261          double ai, aip, bi, bip;
262          var x = code[instr.childIndex].value;
263          if (double.IsNaN(x)) instr.value = double.NaN;
264          else {
265            alglib.airy(x, out ai, out aip, out bi, out bip);
266            instr.value = ai;
267          }
268        } else if (instr.opCode == OpCodes.AiryB) {
269          double ai, aip, bi, bip;
270          var x = code[instr.childIndex].value;
271          if (double.IsNaN(x)) instr.value = double.NaN;
272          else {
273            alglib.airy(x, out ai, out aip, out bi, out bip);
274            instr.value = bi;
275          }
276        } else if (instr.opCode == OpCodes.Norm) {
277          var x = code[instr.childIndex].value;
278          if (double.IsNaN(x)) instr.value = double.NaN;
279          else instr.value = alglib.normaldistribution(x);
280        } else if (instr.opCode == OpCodes.Erf) {
281          var x = code[instr.childIndex].value;
282          if (double.IsNaN(x)) instr.value = double.NaN;
283          else instr.value = alglib.errorfunction(x);
284        } else if (instr.opCode == OpCodes.Bessel) {
285          var x = code[instr.childIndex].value;
286          if (double.IsNaN(x)) instr.value = double.NaN;
287          else instr.value = alglib.besseli0(x);
288        } else if (instr.opCode == OpCodes.IfThenElse) {
289          double condition = code[instr.childIndex].value;
290          double result;
291          if (condition > 0.0) {
292            result = code[instr.childIndex + 1].value;
293          } else {
294            result = code[instr.childIndex + 2].value;
295          }
296          instr.value = result;
297        } else if (instr.opCode == OpCodes.AND) {
298          double result = code[instr.childIndex].value;
299          for (int j = 1; j < instr.nArguments; j++) {
300            if (result > 0.0) result = code[instr.childIndex + j].value;
301            else break;
302          }
303          instr.value = result > 0.0 ? 1.0 : -1.0;
304        } else if (instr.opCode == OpCodes.OR) {
305          double result = code[instr.childIndex].value;
306          for (int j = 1; j < instr.nArguments; j++) {
307            if (result <= 0.0) result = code[instr.childIndex + j].value;
308            else break;
309          }
310          instr.value = result > 0.0 ? 1.0 : -1.0;
311        } else if (instr.opCode == OpCodes.NOT) {
312          instr.value = code[instr.childIndex].value > 0.0 ? -1.0 : 1.0;
313        } else if (instr.opCode == OpCodes.GT) {
314          double x = code[instr.childIndex].value;
315          double y = code[instr.childIndex + 1].value;
316          instr.value = x > y ? 1.0 : -1.0;
317        } else if (instr.opCode == OpCodes.LT) {
318          double x = code[instr.childIndex].value;
319          double y = code[instr.childIndex + 1].value;
320          instr.value = x < y ? 1.0 : -1.0;
321        } else if (instr.opCode == OpCodes.TimeLag || instr.opCode == OpCodes.Derivative || instr.opCode == OpCodes.Integral) {
322          var state = (InterpreterState)instr.data;
323          state.Reset();
324          instr.value = interpreter.Evaluate(dataset, ref row, state);
325        } else {
326          var errorText = string.Format("The {0} symbol is not supported by the linear interpreter. To support this symbol, please use the SymbolicDataAnalysisExpressionTreeInterpreter.", instr.dynamicNode.Symbol.Name);
327          throw new NotSupportedException(errorText);
328        }
329        #endregion
330      }
331      return code[0].value;
332    }
333
334    private static LinearInstruction[] GetPrefixSequence(LinearInstruction[] code, int startIndex) {
335      var list = new List<LinearInstruction>();
336      int i = startIndex;
337      while (i != code.Length) {
338        var instr = code[i];
339        list.Add(instr);
340        i = instr.nArguments > 0 ? instr.childIndex : i + 1;
341      }
342      return list.ToArray();
343    }
344
345    private static void PrepareInstructions(LinearInstruction[] code, Dataset dataset) {
346      for (int i = 0; i != code.Length; ++i) {
347        var instr = code[i];
348        #region opcode switch
349        switch (instr.opCode) {
350          case OpCodes.Constant: {
351              var constTreeNode = (ConstantTreeNode)instr.dynamicNode;
352              instr.value = constTreeNode.Value;
353              instr.skip = true; // the value is already set so this instruction should be skipped in the evaluation phase
354            }
355            break;
356          case OpCodes.Variable: {
357              var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
358              instr.data = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
359            }
360            break;
361          case OpCodes.LagVariable: {
362              var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
363              instr.data = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
364            }
365            break;
366          case OpCodes.VariableCondition: {
367              var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
368              instr.data = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
369            }
370            break;
371          case OpCodes.TimeLag:
372          case OpCodes.Integral:
373          case OpCodes.Derivative: {
374              var seq = GetPrefixSequence(code, i);
375              var interpreterState = new InterpreterState(seq, 0);
376              instr.data = interpreterState;
377              for (int j = 1; j != seq.Length; ++j)
378                seq[j].skip = true;
379            }
380            break;
381        }
382        #endregion
383      }
384    }
385  }
386}
Note: See TracBrowser for help on using the repository browser.