Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.DataAnalysis.Symbolic.LinearInterpreter/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeLinearInterpreter.cs @ 9758

Last change on this file since 9758 was 9758, checked in by bburlacu, 11 years ago

#2021:

  • Derived the LinearInstruction class from Instruction.
  • Added missing symbols to the linear interpreter
  • Changed description for the linear interpreter
  • Added more helpful exception message when a symbol is not supported.
  • Added evaluation test for the linear interpreter
File size: 18.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
33  [StorableClass]
34  [Item("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Linear (non-recursive) interpreter for symbolic expression trees. This interpreter is faster but does not support Integral, Derivative, TimeLag or ADF function nodes.")]
35  public sealed class SymbolicDataAnalysisExpressionTreeLinearInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
36    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
37    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
38
39    public override bool CanChangeName {
40      get { return false; }
41    }
42
43    public override bool CanChangeDescription {
44      get { return false; }
45    }
46
47    #region parameter properties
48    public IValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
49      get { return (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
50    }
51
52    public IValueParameter<IntValue> EvaluatedSolutionsParameter {
53      get { return (IValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
54    }
55    #endregion
56
57    #region properties
58    public BoolValue CheckExpressionsWithIntervalArithmetic {
59      get { return CheckExpressionsWithIntervalArithmeticParameter.Value; }
60      set { CheckExpressionsWithIntervalArithmeticParameter.Value = value; }
61    }
62    public IntValue EvaluatedSolutions {
63      get { return EvaluatedSolutionsParameter.Value; }
64      set { EvaluatedSolutionsParameter.Value = value; }
65    }
66    #endregion
67
68    [StorableConstructor]
69    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(bool deserializing)
70      : base(deserializing) {
71    }
72
73    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(
74      SymbolicDataAnalysisExpressionTreeLinearInterpreter original, Cloner cloner)
75      : base(original, cloner) {
76    }
77
78    public override IDeepCloneable Clone(Cloner cloner) {
79      return new SymbolicDataAnalysisExpressionTreeLinearInterpreter(this, cloner);
80    }
81
82    public SymbolicDataAnalysisExpressionTreeLinearInterpreter()
83      : base("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Linear (non-recursive) interpreter for symbolic expression trees (does not support ADFs).") {
84      Parameters.Add(new ValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
85      Parameters.Add(new ValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
86    }
87
88    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(string name, string description)
89      : base(name, description) {
90      Parameters.Add(new ValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
91      Parameters.Add(new ValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
92    }
93
94    [StorableHook(HookType.AfterDeserialization)]
95    private void AfterDeserialization() {
96      if (!Parameters.ContainsKey(EvaluatedSolutionsParameterName))
97        Parameters.Add(new ValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
98    }
99
100    #region IStatefulItem
101
102    public void InitializeState() {
103      EvaluatedSolutions.Value = 0;
104    }
105
106    public void ClearState() {
107    }
108
109    #endregion
110
111    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows) {
112      if (CheckExpressionsWithIntervalArithmetic.Value)
113        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
114
115      lock (EvaluatedSolutions) {
116        EvaluatedSolutions.Value++; // increment the evaluated solutions counter
117      }
118
119      var code = SymbolicExpressionTreeLinearCompiler.Compile(tree, OpCodes.MapSymbolToOpCode);
120      PrepareInstructions(code, dataset);
121      return rows.Select(row => Evaluate(dataset, ref row, code));
122    }
123
124    private static void PrepareInstructions(LinearInstruction[] code, Dataset dataset) {
125      for (int i = code.Length - 1; i >= 0; --i) {
126        var instr = code[i];
127        #region opcode switch
128        switch (instr.opCode) {
129          case OpCodes.Constant: {
130              var constTreeNode = (ConstantTreeNode)instr.dynamicNode;
131              instr.value = constTreeNode.Value;
132            }
133            break;
134          case OpCodes.Variable: {
135              var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
136              instr.iArg0 = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
137            }
138            break;
139          case OpCodes.LagVariable: {
140              var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
141              instr.iArg0 = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
142            }
143            break;
144          case OpCodes.VariableCondition: {
145              var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
146              instr.iArg0 = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
147            }
148            break;
149        }
150        #endregion
151      }
152    }
153
154    private static double Evaluate(Dataset dataset, ref int row, LinearInstruction[] code) {
155      for (int i = code.Length - 1; i >= 0; --i) {
156        if (code[i].opCode == OpCodes.Constant) continue;
157        #region opcode switch
158        var instr = code[i];
159        switch (instr.opCode) {
160          case OpCodes.Variable: {
161              if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
162              var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
163              instr.value = ((IList<double>)instr.iArg0)[row] * variableTreeNode.Weight;
164            }
165            break;
166          case OpCodes.LagVariable: {
167              var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
168              int actualRow = row + laggedVariableTreeNode.Lag;
169              if (actualRow < 0 || actualRow >= dataset.Rows) instr.value = double.NaN;
170              instr.value = ((IList<double>)instr.iArg0)[actualRow] * laggedVariableTreeNode.Weight;
171            }
172            break;
173          case OpCodes.VariableCondition: {
174              if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
175              var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
176              double variableValue = ((IList<double>)instr.iArg0)[row];
177              double x = variableValue - variableConditionTreeNode.Threshold;
178              double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
179
180              double trueBranch = code[instr.childIndex].value;
181              double falseBranch = code[instr.childIndex + 1].value;
182
183              instr.value = trueBranch * p + falseBranch * (1 - p);
184            }
185            break;
186          case OpCodes.Add: {
187              double s = code[instr.childIndex].value;
188              for (int j = 1; j != instr.nArguments; ++j) {
189                s += code[instr.childIndex + j].value;
190              }
191              instr.value = s;
192            }
193            break;
194          case OpCodes.Sub: {
195              double s = code[instr.childIndex].value;
196              for (int j = 1; j != instr.nArguments; ++j) {
197                s -= code[instr.childIndex + j].value;
198              }
199              if (instr.nArguments == 1) s = -s;
200              instr.value = s;
201            }
202            break;
203          case OpCodes.Mul: {
204              double p = code[instr.childIndex].value;
205              for (int j = 1; j != instr.nArguments; ++j) {
206                p *= code[instr.childIndex + j].value;
207              }
208              instr.value = p;
209            }
210            break;
211          case OpCodes.Div: {
212              double p = code[instr.childIndex].value;
213              for (int j = 1; j != instr.nArguments; ++j) {
214                p /= code[instr.childIndex + j].value;
215              }
216              if (instr.nArguments == 1) p = 1.0 / p;
217              instr.value = p;
218            }
219            break;
220          case OpCodes.Average: {
221              double s = code[instr.childIndex].value;
222              for (int j = 1; j != instr.nArguments; ++j) {
223                s += code[instr.childIndex + j].value;
224              }
225              instr.value = s / instr.nArguments;
226            }
227            break;
228          case OpCodes.Cos: {
229              instr.value = Math.Cos(code[instr.childIndex].value);
230            }
231            break;
232          case OpCodes.Sin: {
233              instr.value = Math.Sin(code[instr.childIndex].value);
234            }
235            break;
236          case OpCodes.Tan: {
237              instr.value = Math.Tan(code[instr.childIndex].value);
238            }
239            break;
240          case OpCodes.Square: {
241              instr.value = Math.Pow(code[instr.childIndex].value, 2);
242            }
243            break;
244          case OpCodes.Power: {
245              double x = code[instr.childIndex].value;
246              double y = Math.Round(code[instr.childIndex + 1].value);
247              instr.value = Math.Pow(x, y);
248            }
249            break;
250          case OpCodes.SquareRoot: {
251              instr.value = Math.Sqrt(code[instr.childIndex].value);
252            }
253            break;
254          case OpCodes.Root: {
255              double x = code[instr.childIndex].value;
256              double y = code[instr.childIndex + 1].value;
257              instr.value = Math.Pow(x, 1 / y);
258            }
259            break;
260          case OpCodes.Exp: {
261              instr.value = Math.Exp(code[instr.childIndex].value);
262            }
263            break;
264          case OpCodes.Log: {
265              instr.value = Math.Log(code[instr.childIndex].value);
266            }
267            break;
268          case OpCodes.Gamma: {
269              var x = code[instr.childIndex].value;
270              instr.value = double.IsNaN(x) ? double.NaN : alglib.gammafunction(x);
271            }
272            break;
273          case OpCodes.Psi: {
274              var x = code[instr.childIndex].value;
275              if (double.IsNaN(x)) instr.value = double.NaN;
276              else if (x <= 0 && (Math.Floor(x) - x).IsAlmost(0)) instr.value = double.NaN;
277              else instr.value = alglib.psi(x);
278            }
279            break;
280          case OpCodes.Dawson: {
281              var x = code[instr.childIndex].value;
282              instr.value = double.IsNaN(x) ? double.NaN : alglib.dawsonintegral(x);
283            }
284            break;
285          case OpCodes.ExponentialIntegralEi: {
286              var x = code[instr.childIndex].value;
287              instr.value = double.IsNaN(x) ? double.NaN : alglib.exponentialintegralei(x);
288            }
289            break;
290          case OpCodes.SineIntegral: {
291              double si, ci;
292              var x = code[instr.childIndex].value;
293              if (double.IsNaN(x)) instr.value = double.NaN;
294              else {
295                alglib.sinecosineintegrals(x, out si, out ci);
296                instr.value = si;
297              }
298            }
299            break;
300          case OpCodes.CosineIntegral: {
301              double si, ci;
302              var x = code[instr.childIndex].value;
303              if (double.IsNaN(x)) instr.value = double.NaN;
304              else {
305                alglib.sinecosineintegrals(x, out si, out ci);
306                instr.value = si;
307              }
308            }
309            break;
310          case OpCodes.HyperbolicSineIntegral: {
311              double shi, chi;
312              var x = code[instr.childIndex].value;
313              if (double.IsNaN(x)) instr.value = double.NaN;
314              else {
315                alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
316                instr.value = shi;
317              }
318            }
319            break;
320          case OpCodes.HyperbolicCosineIntegral: {
321              double shi, chi;
322              var x = code[instr.childIndex].value;
323              if (double.IsNaN(x)) instr.value = double.NaN;
324              else {
325                alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
326                instr.value = chi;
327              }
328            }
329            break;
330          case OpCodes.FresnelCosineIntegral: {
331              double c = 0, s = 0;
332              var x = code[instr.childIndex].value;
333              if (double.IsNaN(x)) instr.value = double.NaN;
334              else {
335                alglib.fresnelintegral(x, ref c, ref s);
336                instr.value = c;
337              }
338            }
339            break;
340          case OpCodes.FresnelSineIntegral: {
341              double c = 0, s = 0;
342              var x = code[instr.childIndex].value;
343              if (double.IsNaN(x)) instr.value = double.NaN;
344              else {
345                alglib.fresnelintegral(x, ref c, ref s);
346                instr.value = s;
347              }
348            }
349            break;
350          case OpCodes.AiryA: {
351              double ai, aip, bi, bip;
352              var x = code[instr.childIndex].value;
353              if (double.IsNaN(x)) instr.value = double.NaN;
354              else {
355                alglib.airy(x, out ai, out aip, out bi, out bip);
356                instr.value = ai;
357              }
358            }
359            break;
360          case OpCodes.AiryB: {
361              double ai, aip, bi, bip;
362              var x = code[instr.childIndex].value;
363              if (double.IsNaN(x)) instr.value = double.NaN;
364              else {
365                alglib.airy(x, out ai, out aip, out bi, out bip);
366                instr.value = bi;
367              }
368            }
369            break;
370          case OpCodes.Norm: {
371              var x = code[instr.childIndex].value;
372              if (double.IsNaN(x)) instr.value = double.NaN;
373              else instr.value = alglib.normaldistribution(x);
374            }
375            break;
376          case OpCodes.Erf: {
377              var x = code[instr.childIndex].value;
378              if (double.IsNaN(x)) instr.value = double.NaN;
379              else instr.value = alglib.errorfunction(x);
380            }
381            break;
382          case OpCodes.Bessel: {
383              var x = code[instr.childIndex].value;
384              if (double.IsNaN(x)) instr.value = double.NaN;
385              else instr.value = alglib.besseli0(x);
386            }
387            break;
388          case OpCodes.IfThenElse: {
389              double condition = code[instr.childIndex].value;
390              double result;
391              if (condition > 0.0) {
392                result = code[instr.childIndex + 1].value;
393              } else {
394                result = code[instr.childIndex + 2].value;
395              }
396              instr.value = result;
397            }
398            break;
399          case OpCodes.AND: {
400              double result = code[instr.childIndex].value;
401              for (int j = 1; j < instr.nArguments; j++) {
402                if (result > 0.0) result = code[instr.childIndex + j].value;
403                else break;
404              }
405              instr.value = result > 0.0 ? 1.0 : -1.0;
406            }
407            break;
408          case OpCodes.OR: {
409              double result = code[instr.childIndex].value;
410              for (int j = 1; j < instr.nArguments; j++) {
411                if (result <= 0.0) result = code[instr.childIndex + j].value;
412                else break;
413              }
414              instr.value = result > 0.0 ? 1.0 : -1.0;
415            }
416            break;
417          case OpCodes.NOT: {
418              instr.value = code[instr.childIndex].value > 0.0 ? -1.0 : 1.0;
419            }
420            break;
421          case OpCodes.GT: {
422              double x = code[instr.childIndex].value;
423              double y = code[instr.childIndex + 1].value;
424              instr.value = x > y ? 1.0 : -1.0;
425            }
426            break;
427          case OpCodes.LT: {
428              double x = code[instr.childIndex].value;
429              double y = code[instr.childIndex + 1].value;
430              instr.value = x < y ? 1.0 : -1.0;
431            }
432            break;
433          default:
434            var errorText = string.Format("The {0} symbol is not supported by the linear interpreter. To support this symbol, please use another interpreter.", instr.dynamicNode.Symbol.Name);
435            throw new NotSupportedException(errorText);
436        }
437        #endregion
438      }
439      return code[0].value;
440    }
441  }
442}
Note: See TracBrowser for help on using the repository browser.