Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.DataAnalysis.Symbolic.LinearInterpreter/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeLinearInterpreter.cs @ 9739

Last change on this file since 9739 was 9739, checked in by bburlacu, 11 years ago

#2021: Added separate SymbolicExpressionTreeLinearCompiler. Updated the SymbolicDataAnalysisExpressionTreeLinearInterpreter:

  • moved constants evaluation outside of loop
  • simplified code
  • renamed EvaluateFast method to Evaluate and made it static.
File size: 17.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
33  [StorableClass]
34  [Item("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Linear (non-recursive) interpreter for symbolic expression trees (does not support ADFs).")]
35  public class SymbolicDataAnalysisExpressionTreeLinearInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
36    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
37    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
38
39    public override bool CanChangeName {
40      get { return false; }
41    }
42
43    public override bool CanChangeDescription {
44      get { return false; }
45    }
46
47    #region parameter properties
48
49    public IValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
50      get { return (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
51    }
52
53    public IValueParameter<IntValue> EvaluatedSolutionsParameter {
54      get { return (IValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
55    }
56
57    #endregion
58
59    #region properties
60
61    public BoolValue CheckExpressionsWithIntervalArithmetic {
62      get { return CheckExpressionsWithIntervalArithmeticParameter.Value; }
63      set { CheckExpressionsWithIntervalArithmeticParameter.Value = value; }
64    }
65
66    public IntValue EvaluatedSolutions {
67      get { return EvaluatedSolutionsParameter.Value; }
68      set { EvaluatedSolutionsParameter.Value = value; }
69    }
70
71    #endregion
72
73    [StorableConstructor]
74    protected SymbolicDataAnalysisExpressionTreeLinearInterpreter(bool deserializing)
75      : base(deserializing) {
76    }
77
78    protected SymbolicDataAnalysisExpressionTreeLinearInterpreter(
79      SymbolicDataAnalysisExpressionTreeLinearInterpreter original, Cloner cloner)
80      : base(original, cloner) {
81    }
82
83    public override IDeepCloneable Clone(Cloner cloner) {
84      return new SymbolicDataAnalysisExpressionTreeLinearInterpreter(this, cloner);
85    }
86
87    public SymbolicDataAnalysisExpressionTreeLinearInterpreter()
88      : base("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.") {
89      Parameters.Add(new ValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
90      Parameters.Add(new ValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
91    }
92
93    protected SymbolicDataAnalysisExpressionTreeLinearInterpreter(string name, string description)
94      : base(name, description) {
95      Parameters.Add(new ValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
96      Parameters.Add(new ValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
97    }
98
99    [StorableHook(HookType.AfterDeserialization)]
100    private void AfterDeserialization() {
101      if (!Parameters.ContainsKey(EvaluatedSolutionsParameterName))
102        Parameters.Add(new ValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
103    }
104
105    #region IStatefulItem
106
107    public void InitializeState() {
108      EvaluatedSolutions.Value = 0;
109    }
110
111    public void ClearState() {
112    }
113
114    #endregion
115
116    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows) {
117      if (CheckExpressionsWithIntervalArithmetic.Value)
118        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
119
120      lock (EvaluatedSolutions) {
121        EvaluatedSolutions.Value++; // increment the evaluated solutions counter
122      }
123
124      var code = SymbolicExpressionTreeLinearCompiler.Compile(tree, OpCodes.MapSymbolToOpCode);
125      PrepareInterpreterState(code, dataset);
126      return rows.Select(row => Evaluate(dataset, ref row, code));
127    }
128
129    private static void PrepareInterpreterState(LinearInstruction[] code, Dataset dataset) {
130      for (int i = code.Length - 1; i >= 0; --i) {
131        var instr = code[i];
132        #region opcode switch
133        switch (instr.opCode) {
134          case OpCodes.Constant: {
135              var constTreeNode = (ConstantTreeNode)instr.dynamicNode;
136              instr.value = constTreeNode.Value;
137            }
138            break;
139          case OpCodes.Variable: {
140              var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
141              instr.iArg0 = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
142            }
143            break;
144          case OpCodes.LagVariable: {
145              var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
146              instr.iArg0 = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
147            }
148            break;
149          case OpCodes.VariableCondition: {
150              var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
151              instr.iArg0 = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
152            }
153            break;
154        }
155        #endregion
156      }
157    }
158
159    private static double Evaluate(Dataset dataset, ref int row, LinearInstruction[] code) {
160      for (int i = code.Length - 1; i >= 0; --i) {
161        var instr = code[i];
162        if (instr.opCode == OpCodes.Constant) continue;
163        #region opcode switch
164        switch (instr.opCode) {
165          case OpCodes.Variable: {
166              if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
167              var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
168              instr.value = ((IList<double>)instr.iArg0)[row] * variableTreeNode.Weight;
169            }
170            break;
171          case OpCodes.LagVariable: {
172              var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
173              int actualRow = row + laggedVariableTreeNode.Lag;
174              if (actualRow < 0 || actualRow >= dataset.Rows) instr.value = double.NaN;
175              instr.value = ((IList<double>)instr.iArg0)[actualRow] * laggedVariableTreeNode.Weight;
176            }
177            break;
178          case OpCodes.VariableCondition: {
179              if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
180              var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
181              double variableValue = ((IList<double>)instr.iArg0)[row];
182              double x = variableValue - variableConditionTreeNode.Threshold;
183              double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
184
185              double trueBranch = code[instr.childIndex].value;
186              double falseBranch = code[instr.childIndex + 1].value;
187
188              instr.value = trueBranch * p + falseBranch * (1 - p);
189            }
190            break;
191          case OpCodes.Add: {
192              double s = code[instr.childIndex].value;
193              for (int j = 1; j != instr.nArguments; ++j) {
194                s += code[instr.childIndex + j].value;
195              }
196              instr.value = s;
197            }
198            break;
199          case OpCodes.Sub: {
200              double s = code[instr.childIndex].value;
201              for (int j = 1; j != instr.nArguments; ++j) {
202                s -= code[instr.childIndex + j].value;
203              }
204              if (instr.nArguments == 1) s = -s;
205              instr.value = s;
206            }
207            break;
208          case OpCodes.Mul: {
209              double p = code[instr.childIndex].value;
210              for (int j = 1; j != instr.nArguments; ++j) {
211                p *= code[instr.childIndex + j].value;
212              }
213              instr.value = p;
214            }
215            break;
216          case OpCodes.Div: {
217              double p = code[instr.childIndex].value;
218              for (int j = 1; j != instr.nArguments; ++j) {
219                p /= code[instr.childIndex + j].value;
220              }
221              if (instr.nArguments == 1) p = 1.0 / p;
222              instr.value = p;
223            }
224            break;
225          case OpCodes.Average: {
226              double s = code[instr.childIndex].value;
227              for (int j = 1; j != instr.nArguments; ++j) {
228                s += code[instr.childIndex + j].value;
229              }
230              instr.value = s / instr.nArguments;
231            }
232            break;
233          case OpCodes.Cos: {
234              instr.value = Math.Cos(code[instr.childIndex].value);
235            }
236            break;
237          case OpCodes.Sin: {
238              instr.value = Math.Sin(code[instr.childIndex].value);
239            }
240            break;
241          case OpCodes.Tan: {
242              instr.value = Math.Tan(code[instr.childIndex].value);
243            }
244            break;
245          case OpCodes.Root: {
246              double x = code[instr.childIndex].value;
247              double y = code[instr.childIndex + 1].value;
248              instr.value = Math.Pow(x, 1 / y);
249            }
250            break;
251          case OpCodes.Exp: {
252              instr.value = Math.Exp(code[instr.childIndex].value);
253            }
254            break;
255          case OpCodes.Log: {
256              instr.value = Math.Log(code[instr.childIndex].value);
257            }
258            break;
259          case OpCodes.Gamma: {
260              var x = code[instr.childIndex].value;
261              instr.value = double.IsNaN(x) ? double.NaN : alglib.gammafunction(x);
262            }
263            break;
264          case OpCodes.Psi: {
265              var x = code[instr.childIndex].value;
266              if (double.IsNaN(x)) instr.value = double.NaN;
267              else if (x <= 0 && (Math.Floor(x) - x).IsAlmost(0)) instr.value = double.NaN;
268              else instr.value = alglib.psi(x);
269            }
270            break;
271          case OpCodes.Dawson: {
272              var x = code[instr.childIndex].value;
273              instr.value = double.IsNaN(x) ? double.NaN : alglib.dawsonintegral(x);
274            }
275            break;
276          case OpCodes.ExponentialIntegralEi: {
277              var x = code[instr.childIndex].value;
278              instr.value = double.IsNaN(x) ? double.NaN : alglib.exponentialintegralei(x);
279            }
280            break;
281          case OpCodes.SineIntegral: {
282              double si, ci;
283              var x = code[instr.childIndex].value;
284              if (double.IsNaN(x)) instr.value = double.NaN;
285              else {
286                alglib.sinecosineintegrals(x, out si, out ci);
287                instr.value = si;
288              }
289            }
290            break;
291          case OpCodes.CosineIntegral: {
292              double si, ci;
293              var x = code[instr.childIndex].value;
294              if (double.IsNaN(x)) instr.value = double.NaN;
295              else {
296                alglib.sinecosineintegrals(x, out si, out ci);
297                instr.value = si;
298              }
299            }
300            break;
301          case OpCodes.HyperbolicSineIntegral: {
302              double shi, chi;
303              var x = code[instr.childIndex].value;
304              if (double.IsNaN(x)) instr.value = double.NaN;
305              else {
306                alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
307                instr.value = shi;
308              }
309            }
310            break;
311          case OpCodes.HyperbolicCosineIntegral: {
312              double shi, chi;
313              var x = code[instr.childIndex].value;
314              if (double.IsNaN(x)) instr.value = double.NaN;
315              else {
316                alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
317                instr.value = chi;
318              }
319            }
320            break;
321          case OpCodes.FresnelCosineIntegral: {
322              double c = 0, s = 0;
323              var x = code[instr.childIndex].value;
324              if (double.IsNaN(x)) instr.value = double.NaN;
325              else {
326                alglib.fresnelintegral(x, ref c, ref s);
327                instr.value = c;
328              }
329            }
330            break;
331          case OpCodes.FresnelSineIntegral: {
332              double c = 0, s = 0;
333              var x = code[instr.childIndex].value;
334              if (double.IsNaN(x)) instr.value = double.NaN;
335              else {
336                alglib.fresnelintegral(x, ref c, ref s);
337                instr.value = s;
338              }
339            }
340            break;
341          case OpCodes.AiryA: {
342              double ai, aip, bi, bip;
343              var x = code[instr.childIndex].value;
344              if (double.IsNaN(x)) instr.value = double.NaN;
345              else {
346                alglib.airy(x, out ai, out aip, out bi, out bip);
347                instr.value = ai;
348              }
349            }
350            break;
351          case OpCodes.AiryB: {
352              double ai, aip, bi, bip;
353              var x = code[instr.childIndex].value;
354              if (double.IsNaN(x)) instr.value = double.NaN;
355              else {
356                alglib.airy(x, out ai, out aip, out bi, out bip);
357                instr.value = bi;
358              }
359            }
360            break;
361          case OpCodes.Norm: {
362              var x = code[instr.childIndex].value;
363              if (double.IsNaN(x)) instr.value = double.NaN;
364              else instr.value = alglib.normaldistribution(x);
365            }
366            break;
367          case OpCodes.Erf: {
368              var x = code[instr.childIndex].value;
369              if (double.IsNaN(x)) instr.value = double.NaN;
370              else instr.value = alglib.errorfunction(x);
371            }
372            break;
373          case OpCodes.Bessel: {
374              var x = code[instr.childIndex].value;
375              if (double.IsNaN(x)) instr.value = double.NaN;
376              else instr.value = alglib.besseli0(x);
377            }
378            break;
379          case OpCodes.IfThenElse: {
380              double condition = code[instr.childIndex].value;
381              double result;
382              if (condition > 0.0) {
383                result = code[instr.childIndex + 1].value;
384              } else {
385                result = code[instr.childIndex + 2].value;
386              }
387              instr.value = result;
388            }
389            break;
390          case OpCodes.AND: {
391              double result = code[instr.childIndex].value;
392              for (int j = 1; j < instr.nArguments; j++) {
393                if (result > 0.0) result = code[instr.childIndex + j].value;
394                else break;
395              }
396              instr.value = result > 0.0 ? 1.0 : -1.0;
397            }
398            break;
399          case OpCodes.OR: {
400              double result = code[instr.childIndex].value;
401              for (int j = 1; j < instr.nArguments; j++) {
402                if (result <= 0.0) result = code[instr.childIndex + j].value;
403                else break;
404              }
405              instr.value = result > 0.0 ? 1.0 : -1.0;
406            }
407            break;
408          case OpCodes.NOT: {
409              instr.value = code[instr.childIndex].value > 0.0 ? -1.0 : 1.0;
410            }
411            break;
412          case OpCodes.GT: {
413              double x = code[instr.childIndex].value;
414              double y = code[instr.childIndex + 1].value;
415              instr.value = x > y ? 1.0 : -1.0;
416            }
417            break;
418          case OpCodes.LT: {
419              double x = code[instr.childIndex].value;
420              double y = code[instr.childIndex + 1].value;
421              instr.value = x < y ? 1.0 : -1.0;
422            }
423            break;
424          case OpCodes.TimeLag: {
425              throw new NotSupportedException();
426            }
427          case OpCodes.Integral: {
428              throw new NotSupportedException();
429            }
430          case OpCodes.Derivative: {
431              throw new NotSupportedException();
432            }
433          case OpCodes.Arg: {
434              throw new NotSupportedException();
435            }
436          default:
437            throw new NotSupportedException();
438        }
439        #endregion
440      }
441      return code[0].value;
442    }
443  }
444}
Note: See TracBrowser for help on using the repository browser.