Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.EvolutionTracking/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeLinearInterpreter.cs @ 16130

Last change on this file since 16130 was 16130, checked in by bburlacu, 6 years ago

#1772: Merge trunk changes

File size: 22.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
33  [StorableClass]
34  [Item("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Fast linear (non-recursive) interpreter for symbolic expression trees. Does not support ADFs.")]
35  public sealed class SymbolicDataAnalysisExpressionTreeLinearInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
36    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
37    private const string CheckExpressionsWithIntervalArithmeticParameterDescription = "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.";
38    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
39
40    private readonly SymbolicDataAnalysisExpressionTreeInterpreter interpreter;
41
42    public override bool CanChangeName {
43      get { return false; }
44    }
45
46    public override bool CanChangeDescription {
47      get { return false; }
48    }
49
50    #region parameter properties
51    public IFixedValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
52      get { return (IFixedValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
53    }
54
55    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
56      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
57    }
58    #endregion
59
60    #region properties
61    public bool CheckExpressionsWithIntervalArithmetic {
62      get { return CheckExpressionsWithIntervalArithmeticParameter.Value.Value; }
63      set { CheckExpressionsWithIntervalArithmeticParameter.Value.Value = value; }
64    }
65    public int EvaluatedSolutions {
66      get { return EvaluatedSolutionsParameter.Value.Value; }
67      set { EvaluatedSolutionsParameter.Value.Value = value; }
68    }
69    #endregion
70
71    [StorableConstructor]
72    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(bool deserializing)
73      : base(deserializing) {
74      interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
75    }
76
77    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(SymbolicDataAnalysisExpressionTreeLinearInterpreter original, Cloner cloner)
78      : base(original, cloner) {
79      interpreter = cloner.Clone(original.interpreter);
80    }
81
82    public override IDeepCloneable Clone(Cloner cloner) {
83      return new SymbolicDataAnalysisExpressionTreeLinearInterpreter(this, cloner);
84    }
85
86    public SymbolicDataAnalysisExpressionTreeLinearInterpreter()
87      : base("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Linear (non-recursive) interpreter for symbolic expression trees (does not support ADFs).") {
88      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, CheckExpressionsWithIntervalArithmeticParameterDescription, new BoolValue(false)));
89      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
90      interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
91    }
92
93    public SymbolicDataAnalysisExpressionTreeLinearInterpreter(string name, string description)
94      : base(name, description) {
95      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, CheckExpressionsWithIntervalArithmeticParameterDescription, new BoolValue(false)));
96      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
97      interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
98    }
99
100    [StorableHook(HookType.AfterDeserialization)]
101    private void AfterDeserialization() {
102      var evaluatedSolutions = new IntValue(0);
103      var checkExpressionsWithIntervalArithmetic = new BoolValue(false);
104      if (Parameters.ContainsKey(EvaluatedSolutionsParameterName)) {
105        var evaluatedSolutionsParameter = (IValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName];
106        evaluatedSolutions = evaluatedSolutionsParameter.Value;
107        Parameters.Remove(EvaluatedSolutionsParameterName);
108      }
109      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", evaluatedSolutions));
110      if (Parameters.ContainsKey(CheckExpressionsWithIntervalArithmeticParameterName)) {
111        var checkExpressionsWithIntervalArithmeticParameter = (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName];
112        Parameters.Remove(CheckExpressionsWithIntervalArithmeticParameterName);
113        checkExpressionsWithIntervalArithmetic = checkExpressionsWithIntervalArithmeticParameter.Value;
114      }
115      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, CheckExpressionsWithIntervalArithmeticParameterDescription, checkExpressionsWithIntervalArithmetic));
116    }
117
118    #region IStatefulItem
119    public void InitializeState() {
120      EvaluatedSolutions = 0;
121    }
122
123    public void ClearState() { }
124    #endregion
125
126    private readonly object syncRoot = new object();
127    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
128      if (!rows.Any()) return Enumerable.Empty<double>();
129      if (CheckExpressionsWithIntervalArithmetic)
130        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
131
132      lock (syncRoot) {
133        EvaluatedSolutions++; // increment the evaluated solutions counter
134      }
135
136      var code = SymbolicExpressionTreeLinearCompiler.Compile(tree, OpCodes.MapSymbolToOpCode);
137      PrepareInstructions(code, dataset);
138      return rows.Select(row => Evaluate(dataset, row, code));
139    }
140
141    // NOTE: do not use this method when evaluating trees. this method is provided as a shortcut for evaluating subtrees ad-hoc
142    public IEnumerable<double> GetValues(ISymbolicExpressionTreeNode node, IDataset dataset, IEnumerable<int> rows) {
143      var code = SymbolicExpressionTreeLinearCompiler.Compile(node, OpCodes.MapSymbolToOpCode);
144      PrepareInstructions(code, dataset);
145      return rows.Select(row => Evaluate(dataset, row, code));
146    }
147
148    private double Evaluate(IDataset dataset, int row, LinearInstruction[] code) {
149      for (int i = code.Length - 1; i >= 0; --i) {
150        if (code[i].skip) continue;
151        #region opcode if
152        var instr = code[i];
153        if (instr.opCode == OpCodes.Variable) {
154          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
155          else {
156            var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
157            instr.value = ((IList<double>)instr.data)[row] * variableTreeNode.Weight;
158          }
159        } else if (instr.opCode == OpCodes.BinaryFactorVariable) {
160          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
161          else {
162            var factorTreeNode = instr.dynamicNode as BinaryFactorVariableTreeNode;
163            instr.value = ((IList<string>)instr.data)[row] == factorTreeNode.VariableValue ? factorTreeNode.Weight : 0;
164          }
165        } else if (instr.opCode == OpCodes.FactorVariable) {
166          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
167          else {
168            var factorTreeNode = instr.dynamicNode as FactorVariableTreeNode;
169            instr.value = factorTreeNode.GetValue(((IList<string>)instr.data)[row]);
170          }
171        } else if (instr.opCode == OpCodes.LagVariable) {
172          var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
173          int actualRow = row + laggedVariableTreeNode.Lag;
174          if (actualRow < 0 || actualRow >= dataset.Rows)
175            instr.value = double.NaN;
176          else
177            instr.value = ((IList<double>)instr.data)[actualRow] * laggedVariableTreeNode.Weight;
178        } else if (instr.opCode == OpCodes.VariableCondition) {
179          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
180          var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
181          if (!variableConditionTreeNode.Symbol.IgnoreSlope) {
182            double variableValue = ((IList<double>)instr.data)[row];
183            double x = variableValue - variableConditionTreeNode.Threshold;
184            double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
185
186            double trueBranch = code[instr.childIndex].value;
187            double falseBranch = code[instr.childIndex + 1].value;
188
189            instr.value = trueBranch * p + falseBranch * (1 - p);
190          } else {
191            double variableValue = ((IList<double>)instr.data)[row];
192            if (variableValue <= variableConditionTreeNode.Threshold) {
193              instr.value = code[instr.childIndex].value;
194            } else {
195              instr.value = code[instr.childIndex + 1].value;
196            }
197          }
198        } else if (instr.opCode == OpCodes.Add) {
199          double s = code[instr.childIndex].value;
200          for (int j = 1; j != instr.nArguments; ++j) {
201            s += code[instr.childIndex + j].value;
202          }
203          instr.value = s;
204        } else if (instr.opCode == OpCodes.Sub) {
205          double s = code[instr.childIndex].value;
206          for (int j = 1; j != instr.nArguments; ++j) {
207            s -= code[instr.childIndex + j].value;
208          }
209          if (instr.nArguments == 1) s = -s;
210          instr.value = s;
211        } else if (instr.opCode == OpCodes.Mul) {
212          double p = code[instr.childIndex].value;
213          for (int j = 1; j != instr.nArguments; ++j) {
214            p *= code[instr.childIndex + j].value;
215          }
216          instr.value = p;
217        } else if (instr.opCode == OpCodes.Div) {
218          double p = code[instr.childIndex].value;
219          for (int j = 1; j != instr.nArguments; ++j) {
220            p /= code[instr.childIndex + j].value;
221          }
222          if (instr.nArguments == 1) p = 1.0 / p;
223          instr.value = p;
224        } else if (instr.opCode == OpCodes.Average) {
225          double s = code[instr.childIndex].value;
226          for (int j = 1; j != instr.nArguments; ++j) {
227            s += code[instr.childIndex + j].value;
228          }
229          instr.value = s / instr.nArguments;
230        } else if (instr.opCode == OpCodes.Cos) {
231          instr.value = Math.Cos(code[instr.childIndex].value);
232        } else if (instr.opCode == OpCodes.Sin) {
233          instr.value = Math.Sin(code[instr.childIndex].value);
234        } else if (instr.opCode == OpCodes.Tan) {
235          instr.value = Math.Tan(code[instr.childIndex].value);
236        } else if (instr.opCode == OpCodes.Square) {
237          instr.value = Math.Pow(code[instr.childIndex].value, 2);
238        } else if (instr.opCode == OpCodes.Power) {
239          double x = code[instr.childIndex].value;
240          double y = Math.Round(code[instr.childIndex + 1].value);
241          instr.value = Math.Pow(x, y);
242        } else if (instr.opCode == OpCodes.SquareRoot) {
243          instr.value = Math.Sqrt(code[instr.childIndex].value);
244        } else if (instr.opCode == OpCodes.Root) {
245          double x = code[instr.childIndex].value;
246          double y = Math.Round(code[instr.childIndex + 1].value);
247          instr.value = Math.Pow(x, 1 / y);
248        } else if (instr.opCode == OpCodes.Exp) {
249          instr.value = Math.Exp(code[instr.childIndex].value);
250        } else if (instr.opCode == OpCodes.Log) {
251          instr.value = Math.Log(code[instr.childIndex].value);
252        } else if (instr.opCode == OpCodes.Gamma) {
253          var x = code[instr.childIndex].value;
254          instr.value = double.IsNaN(x) ? double.NaN : alglib.gammafunction(x);
255        } else if (instr.opCode == OpCodes.Psi) {
256          var x = code[instr.childIndex].value;
257          if (double.IsNaN(x)) instr.value = double.NaN;
258          else if (x <= 0 && (Math.Floor(x) - x).IsAlmost(0)) instr.value = double.NaN;
259          else instr.value = alglib.psi(x);
260        } else if (instr.opCode == OpCodes.Dawson) {
261          var x = code[instr.childIndex].value;
262          instr.value = double.IsNaN(x) ? double.NaN : alglib.dawsonintegral(x);
263        } else if (instr.opCode == OpCodes.ExponentialIntegralEi) {
264          var x = code[instr.childIndex].value;
265          instr.value = double.IsNaN(x) ? double.NaN : alglib.exponentialintegralei(x);
266        } else if (instr.opCode == OpCodes.SineIntegral) {
267          double si, ci;
268          var x = code[instr.childIndex].value;
269          if (double.IsNaN(x)) instr.value = double.NaN;
270          else {
271            alglib.sinecosineintegrals(x, out si, out ci);
272            instr.value = si;
273          }
274        } else if (instr.opCode == OpCodes.CosineIntegral) {
275          double si, ci;
276          var x = code[instr.childIndex].value;
277          if (double.IsNaN(x)) instr.value = double.NaN;
278          else {
279            alglib.sinecosineintegrals(x, out si, out ci);
280            instr.value = ci;
281          }
282        } else if (instr.opCode == OpCodes.HyperbolicSineIntegral) {
283          double shi, chi;
284          var x = code[instr.childIndex].value;
285          if (double.IsNaN(x)) instr.value = double.NaN;
286          else {
287            alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
288            instr.value = shi;
289          }
290        } else if (instr.opCode == OpCodes.HyperbolicCosineIntegral) {
291          double shi, chi;
292          var x = code[instr.childIndex].value;
293          if (double.IsNaN(x)) instr.value = double.NaN;
294          else {
295            alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
296            instr.value = chi;
297          }
298        } else if (instr.opCode == OpCodes.FresnelCosineIntegral) {
299          double c = 0, s = 0;
300          var x = code[instr.childIndex].value;
301          if (double.IsNaN(x)) instr.value = double.NaN;
302          else {
303            alglib.fresnelintegral(x, ref c, ref s);
304            instr.value = c;
305          }
306        } else if (instr.opCode == OpCodes.FresnelSineIntegral) {
307          double c = 0, s = 0;
308          var x = code[instr.childIndex].value;
309          if (double.IsNaN(x)) instr.value = double.NaN;
310          else {
311            alglib.fresnelintegral(x, ref c, ref s);
312            instr.value = s;
313          }
314        } else if (instr.opCode == OpCodes.AiryA) {
315          double ai, aip, bi, bip;
316          var x = code[instr.childIndex].value;
317          if (double.IsNaN(x)) instr.value = double.NaN;
318          else {
319            alglib.airy(x, out ai, out aip, out bi, out bip);
320            instr.value = ai;
321          }
322        } else if (instr.opCode == OpCodes.AiryB) {
323          double ai, aip, bi, bip;
324          var x = code[instr.childIndex].value;
325          if (double.IsNaN(x)) instr.value = double.NaN;
326          else {
327            alglib.airy(x, out ai, out aip, out bi, out bip);
328            instr.value = bi;
329          }
330        } else if (instr.opCode == OpCodes.Norm) {
331          var x = code[instr.childIndex].value;
332          if (double.IsNaN(x)) instr.value = double.NaN;
333          else instr.value = alglib.normaldistribution(x);
334        } else if (instr.opCode == OpCodes.Erf) {
335          var x = code[instr.childIndex].value;
336          if (double.IsNaN(x)) instr.value = double.NaN;
337          else instr.value = alglib.errorfunction(x);
338        } else if (instr.opCode == OpCodes.Bessel) {
339          var x = code[instr.childIndex].value;
340          if (double.IsNaN(x)) instr.value = double.NaN;
341          else instr.value = alglib.besseli0(x);
342        } else if (instr.opCode == OpCodes.IfThenElse) {
343          double condition = code[instr.childIndex].value;
344          double result;
345          if (condition > 0.0) {
346            result = code[instr.childIndex + 1].value;
347          } else {
348            result = code[instr.childIndex + 2].value;
349          }
350          instr.value = result;
351        } else if (instr.opCode == OpCodes.AND) {
352          double result = code[instr.childIndex].value;
353          for (int j = 1; j < instr.nArguments; j++) {
354            if (result > 0.0) result = code[instr.childIndex + j].value;
355            else break;
356          }
357          instr.value = result > 0.0 ? 1.0 : -1.0;
358        } else if (instr.opCode == OpCodes.OR) {
359          double result = code[instr.childIndex].value;
360          for (int j = 1; j < instr.nArguments; j++) {
361            if (result <= 0.0) result = code[instr.childIndex + j].value;
362            else break;
363          }
364          instr.value = result > 0.0 ? 1.0 : -1.0;
365        } else if (instr.opCode == OpCodes.NOT) {
366          instr.value = code[instr.childIndex].value > 0.0 ? -1.0 : 1.0;
367        } else if (instr.opCode == OpCodes.XOR) {
368          int positiveSignals = 0;
369          for (int j = 0; j < instr.nArguments; j++) {
370            if (code[instr.childIndex + j].value > 0.0) positiveSignals++;
371          }
372          instr.value = positiveSignals % 2 != 0 ? 1.0 : -1.0;
373        } else if (instr.opCode == OpCodes.GT) {
374          double x = code[instr.childIndex].value;
375          double y = code[instr.childIndex + 1].value;
376          instr.value = x > y ? 1.0 : -1.0;
377        } else if (instr.opCode == OpCodes.LT) {
378          double x = code[instr.childIndex].value;
379          double y = code[instr.childIndex + 1].value;
380          instr.value = x < y ? 1.0 : -1.0;
381        } else if (instr.opCode == OpCodes.TimeLag || instr.opCode == OpCodes.Derivative || instr.opCode == OpCodes.Integral) {
382          var state = (InterpreterState)instr.data;
383          state.Reset();
384          instr.value = interpreter.Evaluate(dataset, ref row, state);
385        } else {
386          var errorText = string.Format("The {0} symbol is not supported by the linear interpreter. To support this symbol, please use the SymbolicDataAnalysisExpressionTreeInterpreter.", instr.dynamicNode.Symbol.Name);
387          throw new NotSupportedException(errorText);
388        }
389        #endregion
390      }
391      return code[0].value;
392    }
393
394    private static LinearInstruction[] GetPrefixSequence(LinearInstruction[] code, int startIndex) {
395      var s = new Stack<int>();
396      var list = new List<LinearInstruction>();
397      s.Push(startIndex);
398      while (s.Any()) {
399        int i = s.Pop();
400        var instr = code[i];
401        // push instructions in reverse execution order
402        for (int j = instr.nArguments - 1; j >= 0; j--) s.Push(instr.childIndex + j);
403        list.Add(instr);
404      }
405      return list.ToArray();
406    }
407
408    public static void PrepareInstructions(LinearInstruction[] code, IDataset dataset) {
409      for (int i = 0; i != code.Length; ++i) {
410        var instr = code[i];
411        #region opcode switch
412        switch (instr.opCode) {
413          case OpCodes.Constant: {
414              var constTreeNode = (ConstantTreeNode)instr.dynamicNode;
415              instr.value = constTreeNode.Value;
416              instr.skip = true; // the value is already set so this instruction should be skipped in the evaluation phase
417            }
418            break;
419          case OpCodes.Variable: {
420              var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
421              instr.data = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
422            }
423            break;
424          case OpCodes.BinaryFactorVariable: {
425              var factorVariableTreeNode = instr.dynamicNode as BinaryFactorVariableTreeNode;
426              instr.data = dataset.GetReadOnlyStringValues(factorVariableTreeNode.VariableName);
427            }
428            break;
429          case OpCodes.FactorVariable: {
430              var factorVariableTreeNode = instr.dynamicNode as FactorVariableTreeNode;
431              instr.data = dataset.GetReadOnlyStringValues(factorVariableTreeNode.VariableName);
432            }
433            break;
434          case OpCodes.LagVariable: {
435              var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
436              instr.data = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
437            }
438            break;
439          case OpCodes.VariableCondition: {
440              var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
441              instr.data = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
442            }
443            break;
444          case OpCodes.TimeLag:
445          case OpCodes.Integral:
446          case OpCodes.Derivative: {
447              var seq = GetPrefixSequence(code, i);
448              var interpreterState = new InterpreterState(seq, 0);
449              instr.data = interpreterState;
450              for (int j = 1; j != seq.Length; ++j)
451                seq[j].skip = true;
452              break;
453            }
454        }
455        #endregion
456      }
457    }
458  }
459}
Note: See TracBrowser for help on using the repository browser.