Free cookie consent management tool by TermsFeed Policy Generator

source: branches/1772_HeuristicLab.EvolutionTracking/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeLinearInterpreter.cs @ 17578

Last change on this file since 17578 was 17434, checked in by bburlacu, 5 years ago

#1772: Merge trunk changes and fix all errors and compilation warnings.

File size: 22.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Parameters;
30using HEAL.Attic;
31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
33  [StorableType("EF325166-E03A-44C4-83CE-7F07B836285E")]
34  [Item("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Fast linear (non-recursive) interpreter for symbolic expression trees. Does not support ADFs.")]
35  public sealed class SymbolicDataAnalysisExpressionTreeLinearInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
36    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
37    private const string CheckExpressionsWithIntervalArithmeticParameterDescription = "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.";
38    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
39
40    private readonly SymbolicDataAnalysisExpressionTreeInterpreter interpreter;
41
42    public override bool CanChangeName {
43      get { return false; }
44    }
45
46    public override bool CanChangeDescription {
47      get { return false; }
48    }
49
50    #region parameter properties
51    public IFixedValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
52      get { return (IFixedValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
53    }
54
55    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
56      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
57    }
58    #endregion
59
60    #region properties
61    public bool CheckExpressionsWithIntervalArithmetic {
62      get { return CheckExpressionsWithIntervalArithmeticParameter.Value.Value; }
63      set { CheckExpressionsWithIntervalArithmeticParameter.Value.Value = value; }
64    }
65    public int EvaluatedSolutions {
66      get { return EvaluatedSolutionsParameter.Value.Value; }
67      set { EvaluatedSolutionsParameter.Value.Value = value; }
68    }
69    #endregion
70
71    [StorableConstructor]
72    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(StorableConstructorFlag _) : base(_) {
73      interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
74    }
75
76    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(SymbolicDataAnalysisExpressionTreeLinearInterpreter original, Cloner cloner)
77      : base(original, cloner) {
78      interpreter = cloner.Clone(original.interpreter);
79    }
80
81    public override IDeepCloneable Clone(Cloner cloner) {
82      return new SymbolicDataAnalysisExpressionTreeLinearInterpreter(this, cloner);
83    }
84
85    public SymbolicDataAnalysisExpressionTreeLinearInterpreter()
86      : base("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Linear (non-recursive) interpreter for symbolic expression trees (does not support ADFs).") {
87      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, CheckExpressionsWithIntervalArithmeticParameterDescription, new BoolValue(false)));
88      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
89      interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
90    }
91
92    public SymbolicDataAnalysisExpressionTreeLinearInterpreter(string name, string description)
93      : base(name, description) {
94      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, CheckExpressionsWithIntervalArithmeticParameterDescription, new BoolValue(false)));
95      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
96      interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
97    }
98
99    [StorableHook(HookType.AfterDeserialization)]
100    private void AfterDeserialization() {
101      var evaluatedSolutions = new IntValue(0);
102      var checkExpressionsWithIntervalArithmetic = new BoolValue(false);
103      if (Parameters.ContainsKey(EvaluatedSolutionsParameterName)) {
104        var evaluatedSolutionsParameter = (IValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName];
105        evaluatedSolutions = evaluatedSolutionsParameter.Value;
106        Parameters.Remove(EvaluatedSolutionsParameterName);
107      }
108      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", evaluatedSolutions));
109      if (Parameters.ContainsKey(CheckExpressionsWithIntervalArithmeticParameterName)) {
110        var checkExpressionsWithIntervalArithmeticParameter = (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName];
111        Parameters.Remove(CheckExpressionsWithIntervalArithmeticParameterName);
112        checkExpressionsWithIntervalArithmetic = checkExpressionsWithIntervalArithmeticParameter.Value;
113      }
114      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, CheckExpressionsWithIntervalArithmeticParameterDescription, checkExpressionsWithIntervalArithmetic));
115    }
116
117    #region IStatefulItem
118    public void InitializeState() {
119      EvaluatedSolutions = 0;
120    }
121
122    public void ClearState() { }
123    #endregion
124
125    private readonly object syncRoot = new object();
126    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
127      if (!rows.Any()) return Enumerable.Empty<double>();
128      if (CheckExpressionsWithIntervalArithmetic)
129        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
130
131      lock (syncRoot) {
132        EvaluatedSolutions++; // increment the evaluated solutions counter
133      }
134
135      var code = SymbolicExpressionTreeLinearCompiler.Compile(tree, OpCodes.MapSymbolToOpCode);
136      PrepareInstructions(code, dataset);
137      return rows.Select(row => Evaluate(dataset, row, code));
138    }
139
140    // NOTE: do not use this method when evaluating trees. this method is provided as a shortcut for evaluating subtrees ad-hoc
141    public IEnumerable<double> GetValues(ISymbolicExpressionTreeNode node, IDataset dataset, IEnumerable<int> rows) {
142      var code = SymbolicExpressionTreeLinearCompiler.Compile(node, OpCodes.MapSymbolToOpCode);
143      PrepareInstructions(code, dataset);
144      return rows.Select(row => Evaluate(dataset, row, code));
145    }
146
147    private double Evaluate(IDataset dataset, int row, LinearInstruction[] code) {
148      for (int i = code.Length - 1; i >= 0; --i) {
149        if (code[i].skip) continue;
150        #region opcode if
151        var instr = code[i];
152        if (instr.opCode == OpCodes.Variable) {
153          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
154          else {
155            var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
156            instr.value = ((IList<double>)instr.data)[row] * variableTreeNode.Weight;
157          }
158        } else if (instr.opCode == OpCodes.BinaryFactorVariable) {
159          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
160          else {
161            var factorTreeNode = instr.dynamicNode as BinaryFactorVariableTreeNode;
162            instr.value = ((IList<string>)instr.data)[row] == factorTreeNode.VariableValue ? factorTreeNode.Weight : 0;
163          }
164        } else if (instr.opCode == OpCodes.FactorVariable) {
165          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
166          else {
167            var factorTreeNode = instr.dynamicNode as FactorVariableTreeNode;
168            instr.value = factorTreeNode.GetValue(((IList<string>)instr.data)[row]);
169          }
170        } else if (instr.opCode == OpCodes.LagVariable) {
171          var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
172          int actualRow = row + laggedVariableTreeNode.Lag;
173          if (actualRow < 0 || actualRow >= dataset.Rows)
174            instr.value = double.NaN;
175          else
176            instr.value = ((IList<double>)instr.data)[actualRow] * laggedVariableTreeNode.Weight;
177        } else if (instr.opCode == OpCodes.VariableCondition) {
178          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
179          var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
180          if (!variableConditionTreeNode.Symbol.IgnoreSlope) {
181            double variableValue = ((IList<double>)instr.data)[row];
182            double x = variableValue - variableConditionTreeNode.Threshold;
183            double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
184
185            double trueBranch = code[instr.childIndex].value;
186            double falseBranch = code[instr.childIndex + 1].value;
187
188            instr.value = trueBranch * p + falseBranch * (1 - p);
189          } else {
190            double variableValue = ((IList<double>)instr.data)[row];
191            if (variableValue <= variableConditionTreeNode.Threshold) {
192              instr.value = code[instr.childIndex].value;
193            } else {
194              instr.value = code[instr.childIndex + 1].value;
195            }
196          }
197        } else if (instr.opCode == OpCodes.Add) {
198          double s = code[instr.childIndex].value;
199          for (int j = 1; j != instr.nArguments; ++j) {
200            s += code[instr.childIndex + j].value;
201          }
202          instr.value = s;
203        } else if (instr.opCode == OpCodes.Sub) {
204          double s = code[instr.childIndex].value;
205          for (int j = 1; j != instr.nArguments; ++j) {
206            s -= code[instr.childIndex + j].value;
207          }
208          if (instr.nArguments == 1) s = -s;
209          instr.value = s;
210        } else if (instr.opCode == OpCodes.Mul) {
211          double p = code[instr.childIndex].value;
212          for (int j = 1; j != instr.nArguments; ++j) {
213            p *= code[instr.childIndex + j].value;
214          }
215          instr.value = p;
216        } else if (instr.opCode == OpCodes.Div) {
217          double p = code[instr.childIndex].value;
218          for (int j = 1; j != instr.nArguments; ++j) {
219            p /= code[instr.childIndex + j].value;
220          }
221          if (instr.nArguments == 1) p = 1.0 / p;
222          instr.value = p;
223        } else if (instr.opCode == OpCodes.AnalyticQuotient) {
224          var x1 = code[instr.childIndex].value;
225          var x2 = code[instr.childIndex + 1].value;
226          instr.value = x1 / Math.Sqrt(1 + x2 * x2);
227        } else if (instr.opCode == OpCodes.Average) {
228          double s = code[instr.childIndex].value;
229          for (int j = 1; j != instr.nArguments; ++j) {
230            s += code[instr.childIndex + j].value;
231          }
232          instr.value = s / instr.nArguments;
233        } else if (instr.opCode == OpCodes.Absolute) {
234          instr.value = Math.Abs(code[instr.childIndex].value);
235        } else if (instr.opCode == OpCodes.Tanh) {
236          instr.value = Math.Tanh(code[instr.childIndex].value);
237        } else if (instr.opCode == OpCodes.Cos) {
238          instr.value = Math.Cos(code[instr.childIndex].value);
239        } else if (instr.opCode == OpCodes.Sin) {
240          instr.value = Math.Sin(code[instr.childIndex].value);
241        } else if (instr.opCode == OpCodes.Tan) {
242          instr.value = Math.Tan(code[instr.childIndex].value);
243        } else if (instr.opCode == OpCodes.Square) {
244          instr.value = Math.Pow(code[instr.childIndex].value, 2);
245        } else if (instr.opCode == OpCodes.Cube) {
246          instr.value = Math.Pow(code[instr.childIndex].value, 3);
247        } else if (instr.opCode == OpCodes.Power) {
248          double x = code[instr.childIndex].value;
249          double y = Math.Round(code[instr.childIndex + 1].value);
250          instr.value = Math.Pow(x, y);
251        } else if (instr.opCode == OpCodes.SquareRoot) {
252          instr.value = Math.Sqrt(code[instr.childIndex].value);
253        } else if (instr.opCode == OpCodes.CubeRoot) {
254          var arg = code[instr.childIndex].value;
255          instr.value = arg < 0 ? -Math.Pow(-arg, 1.0 / 3.0) : Math.Pow(arg, 1.0 / 3.0);
256        } else if (instr.opCode == OpCodes.Root) {
257          double x = code[instr.childIndex].value;
258          double y = Math.Round(code[instr.childIndex + 1].value);
259          instr.value = Math.Pow(x, 1 / y);
260        } else if (instr.opCode == OpCodes.Exp) {
261          instr.value = Math.Exp(code[instr.childIndex].value);
262        } else if (instr.opCode == OpCodes.Log) {
263          instr.value = Math.Log(code[instr.childIndex].value);
264        } else if (instr.opCode == OpCodes.Gamma) {
265          var x = code[instr.childIndex].value;
266          instr.value = double.IsNaN(x) ? double.NaN : alglib.gammafunction(x);
267        } else if (instr.opCode == OpCodes.Psi) {
268          var x = code[instr.childIndex].value;
269          if (double.IsNaN(x)) instr.value = double.NaN;
270          else if (x <= 0 && (Math.Floor(x) - x).IsAlmost(0)) instr.value = double.NaN;
271          else instr.value = alglib.psi(x);
272        } else if (instr.opCode == OpCodes.Dawson) {
273          var x = code[instr.childIndex].value;
274          instr.value = double.IsNaN(x) ? double.NaN : alglib.dawsonintegral(x);
275        } else if (instr.opCode == OpCodes.ExponentialIntegralEi) {
276          var x = code[instr.childIndex].value;
277          instr.value = double.IsNaN(x) ? double.NaN : alglib.exponentialintegralei(x);
278        } else if (instr.opCode == OpCodes.SineIntegral) {
279          double si, ci;
280          var x = code[instr.childIndex].value;
281          if (double.IsNaN(x)) instr.value = double.NaN;
282          else {
283            alglib.sinecosineintegrals(x, out si, out ci);
284            instr.value = si;
285          }
286        } else if (instr.opCode == OpCodes.CosineIntegral) {
287          double si, ci;
288          var x = code[instr.childIndex].value;
289          if (double.IsNaN(x)) instr.value = double.NaN;
290          else {
291            alglib.sinecosineintegrals(x, out si, out ci);
292            instr.value = ci;
293          }
294        } else if (instr.opCode == OpCodes.HyperbolicSineIntegral) {
295          double shi, chi;
296          var x = code[instr.childIndex].value;
297          if (double.IsNaN(x)) instr.value = double.NaN;
298          else {
299            alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
300            instr.value = shi;
301          }
302        } else if (instr.opCode == OpCodes.HyperbolicCosineIntegral) {
303          double shi, chi;
304          var x = code[instr.childIndex].value;
305          if (double.IsNaN(x)) instr.value = double.NaN;
306          else {
307            alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
308            instr.value = chi;
309          }
310        } else if (instr.opCode == OpCodes.FresnelCosineIntegral) {
311          double c = 0, s = 0;
312          var x = code[instr.childIndex].value;
313          if (double.IsNaN(x)) instr.value = double.NaN;
314          else {
315            alglib.fresnelintegral(x, ref c, ref s);
316            instr.value = c;
317          }
318        } else if (instr.opCode == OpCodes.FresnelSineIntegral) {
319          double c = 0, s = 0;
320          var x = code[instr.childIndex].value;
321          if (double.IsNaN(x)) instr.value = double.NaN;
322          else {
323            alglib.fresnelintegral(x, ref c, ref s);
324            instr.value = s;
325          }
326        } else if (instr.opCode == OpCodes.AiryA) {
327          double ai, aip, bi, bip;
328          var x = code[instr.childIndex].value;
329          if (double.IsNaN(x)) instr.value = double.NaN;
330          else {
331            alglib.airy(x, out ai, out aip, out bi, out bip);
332            instr.value = ai;
333          }
334        } else if (instr.opCode == OpCodes.AiryB) {
335          double ai, aip, bi, bip;
336          var x = code[instr.childIndex].value;
337          if (double.IsNaN(x)) instr.value = double.NaN;
338          else {
339            alglib.airy(x, out ai, out aip, out bi, out bip);
340            instr.value = bi;
341          }
342        } else if (instr.opCode == OpCodes.Norm) {
343          var x = code[instr.childIndex].value;
344          if (double.IsNaN(x)) instr.value = double.NaN;
345          else instr.value = alglib.normaldistribution(x);
346        } else if (instr.opCode == OpCodes.Erf) {
347          var x = code[instr.childIndex].value;
348          if (double.IsNaN(x)) instr.value = double.NaN;
349          else instr.value = alglib.errorfunction(x);
350        } else if (instr.opCode == OpCodes.Bessel) {
351          var x = code[instr.childIndex].value;
352          if (double.IsNaN(x)) instr.value = double.NaN;
353          else instr.value = alglib.besseli0(x);
354        } else if (instr.opCode == OpCodes.IfThenElse) {
355          double condition = code[instr.childIndex].value;
356          double result;
357          if (condition > 0.0) {
358            result = code[instr.childIndex + 1].value;
359          } else {
360            result = code[instr.childIndex + 2].value;
361          }
362          instr.value = result;
363        } else if (instr.opCode == OpCodes.AND) {
364          double result = code[instr.childIndex].value;
365          for (int j = 1; j < instr.nArguments; j++) {
366            if (result > 0.0) result = code[instr.childIndex + j].value;
367            else break;
368          }
369          instr.value = result > 0.0 ? 1.0 : -1.0;
370        } else if (instr.opCode == OpCodes.OR) {
371          double result = code[instr.childIndex].value;
372          for (int j = 1; j < instr.nArguments; j++) {
373            if (result <= 0.0) result = code[instr.childIndex + j].value;
374            else break;
375          }
376          instr.value = result > 0.0 ? 1.0 : -1.0;
377        } else if (instr.opCode == OpCodes.NOT) {
378          instr.value = code[instr.childIndex].value > 0.0 ? -1.0 : 1.0;
379        } else if (instr.opCode == OpCodes.XOR) {
380          int positiveSignals = 0;
381          for (int j = 0; j < instr.nArguments; j++) {
382            if (code[instr.childIndex + j].value > 0.0) positiveSignals++;
383          }
384          instr.value = positiveSignals % 2 != 0 ? 1.0 : -1.0;
385        } else if (instr.opCode == OpCodes.GT) {
386          double x = code[instr.childIndex].value;
387          double y = code[instr.childIndex + 1].value;
388          instr.value = x > y ? 1.0 : -1.0;
389        } else if (instr.opCode == OpCodes.LT) {
390          double x = code[instr.childIndex].value;
391          double y = code[instr.childIndex + 1].value;
392          instr.value = x < y ? 1.0 : -1.0;
393        } else if (instr.opCode == OpCodes.TimeLag || instr.opCode == OpCodes.Derivative || instr.opCode == OpCodes.Integral) {
394          var state = (InterpreterState)instr.data;
395          state.Reset();
396          instr.value = interpreter.Evaluate(dataset, ref row, state);
397        } else {
398          var errorText = string.Format("The {0} symbol is not supported by the linear interpreter. To support this symbol, please use the SymbolicDataAnalysisExpressionTreeInterpreter.", instr.dynamicNode.Symbol.Name);
399          throw new NotSupportedException(errorText);
400        }
401        #endregion
402      }
403      return code[0].value;
404    }
405
406    private static LinearInstruction[] GetPrefixSequence(LinearInstruction[] code, int startIndex) {
407      var s = new Stack<int>();
408      var list = new List<LinearInstruction>();
409      s.Push(startIndex);
410      while (s.Any()) {
411        int i = s.Pop();
412        var instr = code[i];
413        // push instructions in reverse execution order
414        for (int j = instr.nArguments - 1; j >= 0; j--) s.Push(instr.childIndex + j);
415        list.Add(instr);
416      }
417      return list.ToArray();
418    }
419
420    public static void PrepareInstructions(LinearInstruction[] code, IDataset dataset) {
421      for (int i = 0; i != code.Length; ++i) {
422        var instr = code[i];
423        #region opcode switch
424        switch (instr.opCode) {
425          case OpCodes.Constant: {
426              var constTreeNode = (ConstantTreeNode)instr.dynamicNode;
427              instr.value = constTreeNode.Value;
428              instr.skip = true; // the value is already set so this instruction should be skipped in the evaluation phase
429            }
430            break;
431          case OpCodes.Variable: {
432              var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
433              instr.data = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
434            }
435            break;
436          case OpCodes.BinaryFactorVariable: {
437              var factorVariableTreeNode = instr.dynamicNode as BinaryFactorVariableTreeNode;
438              instr.data = dataset.GetReadOnlyStringValues(factorVariableTreeNode.VariableName);
439            }
440            break;
441          case OpCodes.FactorVariable: {
442              var factorVariableTreeNode = instr.dynamicNode as FactorVariableTreeNode;
443              instr.data = dataset.GetReadOnlyStringValues(factorVariableTreeNode.VariableName);
444            }
445            break;
446          case OpCodes.LagVariable: {
447              var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
448              instr.data = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
449            }
450            break;
451          case OpCodes.VariableCondition: {
452              var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
453              instr.data = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
454            }
455            break;
456          case OpCodes.TimeLag:
457          case OpCodes.Integral:
458          case OpCodes.Derivative: {
459              var seq = GetPrefixSequence(code, i);
460              var interpreterState = new InterpreterState(seq, 0);
461              instr.data = interpreterState;
462              for (int j = 1; j != seq.Length; ++j)
463                seq[j].skip = true;
464              break;
465            }
466        }
467        #endregion
468      }
469    }
470  }
471}
Note: See TracBrowser for help on using the repository browser.