Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.LinqExpressionTreeInterpreter/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeLinearInterpreter.cs @ 13141

Last change on this file since 13141 was 13141, checked in by bburlacu, 7 years ago

#2442: Merged files from trunk and updated project file. Implemented missing operations in the CompiledTreeInterpreter: Integral, Derivative, Lag, TimeLag. Adapted lambda signature to accept an array of List<double> in order to make it easier to work with compiled trees. Changed value parameters to fixed value parameters and adjusted interpreter constructors and after serialization hooks. Removed function symbol.

From the performance point of view, compiling the tree into a lambda accepting a double[][] parameter (an array of arrays for the values of each double variable), accessed with Expression.ArrayIndex is the fastest, but it can be cumbersome to provide the data as a double[][]. Therefore the variant with List<double>[] was chosen. Internally, for each variable node the List's underlying double array is used, result in an overall decent speed compromise.

File size: 19.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
33  [StorableClass]
34  [Item("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Fast linear (non-recursive) interpreter for symbolic expression trees. Does not support ADFs.")]
35  public sealed class SymbolicDataAnalysisExpressionTreeLinearInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
36    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
37    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
38
39    private SymbolicDataAnalysisExpressionTreeInterpreter interpreter;
40
41    public override bool CanChangeName {
42      get { return false; }
43    }
44
45    public override bool CanChangeDescription {
46      get { return false; }
47    }
48
49    #region parameter properties
50    public IFixedValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
51      get { return (IFixedValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
52    }
53
54    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
55      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
56    }
57    #endregion
58
59    #region properties
60    public bool CheckExpressionsWithIntervalArithmetic {
61      get { return CheckExpressionsWithIntervalArithmeticParameter.Value.Value; }
62      set { CheckExpressionsWithIntervalArithmeticParameter.Value.Value = value; }
63    }
64    public int EvaluatedSolutions {
65      get { return EvaluatedSolutionsParameter.Value.Value; }
66      set { EvaluatedSolutionsParameter.Value.Value = value; }
67    }
68    #endregion
69
70    [StorableConstructor]
71    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(bool deserializing)
72      : base(deserializing) {
73    }
74
75    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(SymbolicDataAnalysisExpressionTreeLinearInterpreter original, Cloner cloner)
76      : base(original, cloner) {
77      interpreter = cloner.Clone(original.interpreter);
78    }
79
80    public override IDeepCloneable Clone(Cloner cloner) {
81      return new SymbolicDataAnalysisExpressionTreeLinearInterpreter(this, cloner);
82    }
83
84    public SymbolicDataAnalysisExpressionTreeLinearInterpreter()
85      : base("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Linear (non-recursive) interpreter for symbolic expression trees (does not support ADFs).") {
86      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName,
87        "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
88      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
89      interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
90    }
91
92    public SymbolicDataAnalysisExpressionTreeLinearInterpreter(string name, string description)
93      : base(name, description) {
94      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName,
95        "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
96      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
97      interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
98    }
99
100    [StorableHook(HookType.AfterDeserialization)]
101    private void AfterDeserialization() {
102      if (interpreter == null) interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
103      Parameters.Remove(EvaluatedSolutionsParameterName);
104      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
105
106      Parameters.Remove(CheckExpressionsWithIntervalArithmeticParameterName);
107      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName,
108        "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
109    }
110
111    #region IStatefulItem
112    public void InitializeState() {
113      EvaluatedSolutions = 0;
114    }
115
116    public void ClearState() { }
117    #endregion
118
119    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
120      if (CheckExpressionsWithIntervalArithmetic)
121        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
122
123      lock (EvaluatedSolutionsParameter.Value) {
124        EvaluatedSolutions++; // increment the evaluated solutions counter
125      }
126
127      var code = SymbolicExpressionTreeLinearCompiler.Compile(tree, OpCodes.MapSymbolToOpCode);
128      PrepareInstructions(code, dataset);
129      return rows.Select(row => Evaluate(dataset, row, code));
130    }
131
132    private double Evaluate(IDataset dataset, int row, LinearInstruction[] code) {
133      for (int i = code.Length - 1; i >= 0; --i) {
134        if (code[i].skip) continue;
135        #region opcode if
136        var instr = code[i];
137        if (instr.opCode == OpCodes.Variable) {
138          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
139          var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
140          instr.value = ((IList<double>)instr.data)[row] * variableTreeNode.Weight;
141        } else if (instr.opCode == OpCodes.LagVariable) {
142          var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
143          int actualRow = row + laggedVariableTreeNode.Lag;
144          if (actualRow < 0 || actualRow >= dataset.Rows)
145            instr.value = double.NaN;
146          else
147            instr.value = ((IList<double>)instr.data)[actualRow] * laggedVariableTreeNode.Weight;
148        } else if (instr.opCode == OpCodes.VariableCondition) {
149          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
150          var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
151          double variableValue = ((IList<double>)instr.data)[row];
152          double x = variableValue - variableConditionTreeNode.Threshold;
153          double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
154
155          double trueBranch = code[instr.childIndex].value;
156          double falseBranch = code[instr.childIndex + 1].value;
157
158          instr.value = trueBranch * p + falseBranch * (1 - p);
159        } else if (instr.opCode == OpCodes.Add) {
160          double s = code[instr.childIndex].value;
161          for (int j = 1; j != instr.nArguments; ++j) {
162            s += code[instr.childIndex + j].value;
163          }
164          instr.value = s;
165        } else if (instr.opCode == OpCodes.Sub) {
166          double s = code[instr.childIndex].value;
167          for (int j = 1; j != instr.nArguments; ++j) {
168            s -= code[instr.childIndex + j].value;
169          }
170          if (instr.nArguments == 1) s = -s;
171          instr.value = s;
172        } else if (instr.opCode == OpCodes.Mul) {
173          double p = code[instr.childIndex].value;
174          for (int j = 1; j != instr.nArguments; ++j) {
175            p *= code[instr.childIndex + j].value;
176          }
177          instr.value = p;
178        } else if (instr.opCode == OpCodes.Div) {
179          double p = code[instr.childIndex].value;
180          for (int j = 1; j != instr.nArguments; ++j) {
181            p /= code[instr.childIndex + j].value;
182          }
183          if (instr.nArguments == 1) p = 1.0 / p;
184          instr.value = p;
185        } else if (instr.opCode == OpCodes.Average) {
186          double s = code[instr.childIndex].value;
187          for (int j = 1; j != instr.nArguments; ++j) {
188            s += code[instr.childIndex + j].value;
189          }
190          instr.value = s / instr.nArguments;
191        } else if (instr.opCode == OpCodes.Cos) {
192          instr.value = Math.Cos(code[instr.childIndex].value);
193        } else if (instr.opCode == OpCodes.Sin) {
194          instr.value = Math.Sin(code[instr.childIndex].value);
195        } else if (instr.opCode == OpCodes.Tan) {
196          instr.value = Math.Tan(code[instr.childIndex].value);
197        } else if (instr.opCode == OpCodes.Square) {
198          instr.value = Math.Pow(code[instr.childIndex].value, 2);
199        } else if (instr.opCode == OpCodes.Power) {
200          double x = code[instr.childIndex].value;
201          double y = Math.Round(code[instr.childIndex + 1].value);
202          instr.value = Math.Pow(x, y);
203        } else if (instr.opCode == OpCodes.SquareRoot) {
204          instr.value = Math.Sqrt(code[instr.childIndex].value);
205        } else if (instr.opCode == OpCodes.Root) {
206          double x = code[instr.childIndex].value;
207          double y = code[instr.childIndex + 1].value;
208          instr.value = Math.Pow(x, 1 / y);
209        } else if (instr.opCode == OpCodes.Exp) {
210          instr.value = Math.Exp(code[instr.childIndex].value);
211        } else if (instr.opCode == OpCodes.Log) {
212          instr.value = Math.Log(code[instr.childIndex].value);
213        } else if (instr.opCode == OpCodes.Gamma) {
214          var x = code[instr.childIndex].value;
215          instr.value = double.IsNaN(x) ? double.NaN : alglib.gammafunction(x);
216        } else if (instr.opCode == OpCodes.Psi) {
217          var x = code[instr.childIndex].value;
218          if (double.IsNaN(x)) instr.value = double.NaN;
219          else if (x <= 0 && (Math.Floor(x) - x).IsAlmost(0)) instr.value = double.NaN;
220          else instr.value = alglib.psi(x);
221        } else if (instr.opCode == OpCodes.Dawson) {
222          var x = code[instr.childIndex].value;
223          instr.value = double.IsNaN(x) ? double.NaN : alglib.dawsonintegral(x);
224        } else if (instr.opCode == OpCodes.ExponentialIntegralEi) {
225          var x = code[instr.childIndex].value;
226          instr.value = double.IsNaN(x) ? double.NaN : alglib.exponentialintegralei(x);
227        } else if (instr.opCode == OpCodes.SineIntegral) {
228          double si, ci;
229          var x = code[instr.childIndex].value;
230          if (double.IsNaN(x)) instr.value = double.NaN;
231          else {
232            alglib.sinecosineintegrals(x, out si, out ci);
233            instr.value = si;
234          }
235        } else if (instr.opCode == OpCodes.CosineIntegral) {
236          double si, ci;
237          var x = code[instr.childIndex].value;
238          if (double.IsNaN(x)) instr.value = double.NaN;
239          else {
240            alglib.sinecosineintegrals(x, out si, out ci);
241            instr.value = ci;
242          }
243        } else if (instr.opCode == OpCodes.HyperbolicSineIntegral) {
244          double shi, chi;
245          var x = code[instr.childIndex].value;
246          if (double.IsNaN(x)) instr.value = double.NaN;
247          else {
248            alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
249            instr.value = shi;
250          }
251        } else if (instr.opCode == OpCodes.HyperbolicCosineIntegral) {
252          double shi, chi;
253          var x = code[instr.childIndex].value;
254          if (double.IsNaN(x)) instr.value = double.NaN;
255          else {
256            alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
257            instr.value = chi;
258          }
259        } else if (instr.opCode == OpCodes.FresnelCosineIntegral) {
260          double c = 0, s = 0;
261          var x = code[instr.childIndex].value;
262          if (double.IsNaN(x)) instr.value = double.NaN;
263          else {
264            alglib.fresnelintegral(x, ref c, ref s);
265            instr.value = c;
266          }
267        } else if (instr.opCode == OpCodes.FresnelSineIntegral) {
268          double c = 0, s = 0;
269          var x = code[instr.childIndex].value;
270          if (double.IsNaN(x)) instr.value = double.NaN;
271          else {
272            alglib.fresnelintegral(x, ref c, ref s);
273            instr.value = s;
274          }
275        } else if (instr.opCode == OpCodes.AiryA) {
276          double ai, aip, bi, bip;
277          var x = code[instr.childIndex].value;
278          if (double.IsNaN(x)) instr.value = double.NaN;
279          else {
280            alglib.airy(x, out ai, out aip, out bi, out bip);
281            instr.value = ai;
282          }
283        } else if (instr.opCode == OpCodes.AiryB) {
284          double ai, aip, bi, bip;
285          var x = code[instr.childIndex].value;
286          if (double.IsNaN(x)) instr.value = double.NaN;
287          else {
288            alglib.airy(x, out ai, out aip, out bi, out bip);
289            instr.value = bi;
290          }
291        } else if (instr.opCode == OpCodes.Norm) {
292          var x = code[instr.childIndex].value;
293          if (double.IsNaN(x)) instr.value = double.NaN;
294          else instr.value = alglib.normaldistribution(x);
295        } else if (instr.opCode == OpCodes.Erf) {
296          var x = code[instr.childIndex].value;
297          if (double.IsNaN(x)) instr.value = double.NaN;
298          else instr.value = alglib.errorfunction(x);
299        } else if (instr.opCode == OpCodes.Bessel) {
300          var x = code[instr.childIndex].value;
301          if (double.IsNaN(x)) instr.value = double.NaN;
302          else instr.value = alglib.besseli0(x);
303        } else if (instr.opCode == OpCodes.IfThenElse) {
304          double condition = code[instr.childIndex].value;
305          double result;
306          if (condition > 0.0) {
307            result = code[instr.childIndex + 1].value;
308          } else {
309            result = code[instr.childIndex + 2].value;
310          }
311          instr.value = result;
312        } else if (instr.opCode == OpCodes.AND) {
313          double result = code[instr.childIndex].value;
314          for (int j = 1; j < instr.nArguments; j++) {
315            if (result > 0.0) result = code[instr.childIndex + j].value;
316            else break;
317          }
318          instr.value = result > 0.0 ? 1.0 : -1.0;
319        } else if (instr.opCode == OpCodes.OR) {
320          double result = code[instr.childIndex].value;
321          for (int j = 1; j < instr.nArguments; j++) {
322            if (result <= 0.0) result = code[instr.childIndex + j].value;
323            else break;
324          }
325          instr.value = result > 0.0 ? 1.0 : -1.0;
326        } else if (instr.opCode == OpCodes.NOT) {
327          instr.value = code[instr.childIndex].value > 0.0 ? -1.0 : 1.0;
328        } else if (instr.opCode == OpCodes.XOR) {
329          int positiveSignals = 0;
330          for (int j = 0; j < instr.nArguments; j++) {
331            if (code[instr.childIndex + j].value > 0.0) positiveSignals++;
332          }
333          instr.value = positiveSignals % 2 != 0 ? 1.0 : -1.0;
334        } else if (instr.opCode == OpCodes.GT) {
335          double x = code[instr.childIndex].value;
336          double y = code[instr.childIndex + 1].value;
337          instr.value = x > y ? 1.0 : -1.0;
338        } else if (instr.opCode == OpCodes.LT) {
339          double x = code[instr.childIndex].value;
340          double y = code[instr.childIndex + 1].value;
341          instr.value = x < y ? 1.0 : -1.0;
342        } else if (instr.opCode == OpCodes.TimeLag || instr.opCode == OpCodes.Derivative || instr.opCode == OpCodes.Integral) {
343          var state = (InterpreterState)instr.data;
344          state.Reset();
345          instr.value = interpreter.Evaluate(dataset, ref row, state);
346        } else {
347          var errorText = string.Format("The {0} symbol is not supported by the linear interpreter. To support this symbol, please use the SymbolicDataAnalysisExpressionTreeInterpreter.", instr.dynamicNode.Symbol.Name);
348          throw new NotSupportedException(errorText);
349        }
350        #endregion
351      }
352      return code[0].value;
353    }
354
355    private static LinearInstruction[] GetPrefixSequence(LinearInstruction[] code, int startIndex) {
356      var s = new Stack<int>();
357      var list = new List<LinearInstruction>();
358      s.Push(startIndex);
359      while (s.Any()) {
360        int i = s.Pop();
361        var instr = code[i];
362        // push instructions in reverse execution order
363        for (int j = instr.nArguments - 1; j >= 0; j--) s.Push(instr.childIndex + j);
364        list.Add(instr);
365      }
366      return list.ToArray();
367    }
368
369    public static void PrepareInstructions(LinearInstruction[] code, IDataset dataset) {
370      for (int i = 0; i != code.Length; ++i) {
371        var instr = code[i];
372        #region opcode switch
373        switch (instr.opCode) {
374          case OpCodes.Constant:
375            {
376              var constTreeNode = (ConstantTreeNode)instr.dynamicNode;
377              instr.value = constTreeNode.Value;
378              instr.skip = true; // the value is already set so this instruction should be skipped in the evaluation phase
379            }
380            break;
381          case OpCodes.Variable:
382            {
383              var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
384              instr.data = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
385            }
386            break;
387          case OpCodes.LagVariable:
388            {
389              var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
390              instr.data = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
391            }
392            break;
393          case OpCodes.VariableCondition:
394            {
395              var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
396              instr.data = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
397            }
398            break;
399          case OpCodes.TimeLag:
400          case OpCodes.Integral:
401          case OpCodes.Derivative:
402            {
403              var seq = GetPrefixSequence(code, i);
404              var interpreterState = new InterpreterState(seq, 0);
405              instr.data = interpreterState;
406              for (int j = 1; j != seq.Length; ++j)
407                seq[j].skip = true;
408            }
409            break;
410        }
411        #endregion
412      }
413    }
414  }
415}
Note: See TracBrowser for help on using the repository browser.