Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2915-AbsoluteSymbol/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeLinearInterpreter.cs @ 15944

Last change on this file since 15944 was 15944, checked in by gkronber, 6 years ago

#2915 added support for Abs() symbol to tree interpreter and linear interpreter as well as to the infix parser

File size: 21.7 KB
RevLine 
[5571]1#region License Information
2/* HeuristicLab
[15583]3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[5571]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
[9739]24using System.Linq;
[5571]25using HeuristicLab.Common;
26using HeuristicLab.Core;
[6740]27using HeuristicLab.Data;
[5571]28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
[6740]29using HeuristicLab.Parameters;
[5571]30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
33  [StorableClass]
[9815]34  [Item("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Fast linear (non-recursive) interpreter for symbolic expression trees. Does not support ADFs.")]
[9758]35  public sealed class SymbolicDataAnalysisExpressionTreeLinearInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
[5749]36    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
[13248]37    private const string CheckExpressionsWithIntervalArithmeticParameterDescription = "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.";
[7615]38    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
[5571]39
[14282]40    private readonly SymbolicDataAnalysisExpressionTreeInterpreter interpreter;
[9776]41
[9732]42    public override bool CanChangeName {
43      get { return false; }
44    }
[5571]45
[9732]46    public override bool CanChangeDescription {
47      get { return false; }
48    }
49
[5749]50    #region parameter properties
[13248]51    public IFixedValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
52      get { return (IFixedValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
[5749]53    }
[7615]54
[13248]55    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
56      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
[7615]57    }
[5749]58    #endregion
59
60    #region properties
[13248]61    public bool CheckExpressionsWithIntervalArithmetic {
62      get { return CheckExpressionsWithIntervalArithmeticParameter.Value.Value; }
63      set { CheckExpressionsWithIntervalArithmeticParameter.Value.Value = value; }
[5749]64    }
[13248]65    public int EvaluatedSolutions {
66      get { return EvaluatedSolutionsParameter.Value.Value; }
67      set { EvaluatedSolutionsParameter.Value.Value = value; }
[7615]68    }
[5749]69    #endregion
70
[5571]71    [StorableConstructor]
[9758]72    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(bool deserializing)
[9732]73      : base(deserializing) {
[14282]74      interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
[9732]75    }
76
[9828]77    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(SymbolicDataAnalysisExpressionTreeLinearInterpreter original, Cloner cloner)
[9732]78      : base(original, cloner) {
[9828]79      interpreter = cloner.Clone(original.interpreter);
[9732]80    }
81
[5571]82    public override IDeepCloneable Clone(Cloner cloner) {
[9734]83      return new SymbolicDataAnalysisExpressionTreeLinearInterpreter(this, cloner);
[5571]84    }
85
[9734]86    public SymbolicDataAnalysisExpressionTreeLinearInterpreter()
[9758]87      : base("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Linear (non-recursive) interpreter for symbolic expression trees (does not support ADFs).") {
[13248]88      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, CheckExpressionsWithIntervalArithmeticParameterDescription, new BoolValue(false)));
89      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
[9776]90      interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
[5571]91    }
92
[13248]93    public SymbolicDataAnalysisExpressionTreeLinearInterpreter(string name, string description)
94      : base(name, description) {
95      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, CheckExpressionsWithIntervalArithmeticParameterDescription, new BoolValue(false)));
96      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
97      interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
98    }
99
[7615]100    [StorableHook(HookType.AfterDeserialization)]
101    private void AfterDeserialization() {
[13248]102      var evaluatedSolutions = new IntValue(0);
103      var checkExpressionsWithIntervalArithmetic = new BoolValue(false);
104      if (Parameters.ContainsKey(EvaluatedSolutionsParameterName)) {
105        var evaluatedSolutionsParameter = (IValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName];
106        evaluatedSolutions = evaluatedSolutionsParameter.Value;
107        Parameters.Remove(EvaluatedSolutionsParameterName);
108      }
109      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", evaluatedSolutions));
110      if (Parameters.ContainsKey(CheckExpressionsWithIntervalArithmeticParameterName)) {
111        var checkExpressionsWithIntervalArithmeticParameter = (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName];
112        Parameters.Remove(CheckExpressionsWithIntervalArithmeticParameterName);
113        checkExpressionsWithIntervalArithmetic = checkExpressionsWithIntervalArithmeticParameter.Value;
114      }
115      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, CheckExpressionsWithIntervalArithmeticParameterDescription, checkExpressionsWithIntervalArithmetic));
[7615]116    }
117
118    #region IStatefulItem
119    public void InitializeState() {
[13248]120      EvaluatedSolutions = 0;
[7615]121    }
122
[9828]123    public void ClearState() { }
[7615]124    #endregion
125
[13251]126    private readonly object syncRoot = new object();
[12509]127    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
[14345]128      if (!rows.Any()) return Enumerable.Empty<double>();
[13248]129      if (CheckExpressionsWithIntervalArithmetic)
[9734]130        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
[7120]131
[13251]132      lock (syncRoot) {
[13248]133        EvaluatedSolutions++; // increment the evaluated solutions counter
[9004]134      }
[8436]135
[9739]136      var code = SymbolicExpressionTreeLinearCompiler.Compile(tree, OpCodes.MapSymbolToOpCode);
[9758]137      PrepareInstructions(code, dataset);
[9818]138      return rows.Select(row => Evaluate(dataset, row, code));
[9739]139    }
[9732]140
[12509]141    private double Evaluate(IDataset dataset, int row, LinearInstruction[] code) {
[9732]142      for (int i = code.Length - 1; i >= 0; --i) {
[9776]143        if (code[i].skip) continue;
[9871]144        #region opcode if
[9732]145        var instr = code[i];
[9871]146        if (instr.opCode == OpCodes.Variable) {
147          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
[13268]148          else {
149            var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
150            instr.value = ((IList<double>)instr.data)[row] * variableTreeNode.Weight;
151          }
[14826]152        } else if (instr.opCode == OpCodes.BinaryFactorVariable) {
153          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
154          else {
155            var factorTreeNode = instr.dynamicNode as BinaryFactorVariableTreeNode;
156            instr.value = ((IList<string>)instr.data)[row] == factorTreeNode.VariableValue ? factorTreeNode.Weight : 0;
157          }
158        } else if (instr.opCode == OpCodes.FactorVariable) {
159          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
160          else {
161            var factorTreeNode = instr.dynamicNode as FactorVariableTreeNode;
162            instr.value = factorTreeNode.GetValue(((IList<string>)instr.data)[row]);
163          }
[9871]164        } else if (instr.opCode == OpCodes.LagVariable) {
165          var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
166          int actualRow = row + laggedVariableTreeNode.Lag;
167          if (actualRow < 0 || actualRow >= dataset.Rows)
168            instr.value = double.NaN;
169          else
170            instr.value = ((IList<double>)instr.data)[actualRow] * laggedVariableTreeNode.Weight;
171        } else if (instr.opCode == OpCodes.VariableCondition) {
172          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
173          var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
[14345]174          if (!variableConditionTreeNode.Symbol.IgnoreSlope) {
175            double variableValue = ((IList<double>)instr.data)[row];
176            double x = variableValue - variableConditionTreeNode.Threshold;
177            double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
[9738]178
[14345]179            double trueBranch = code[instr.childIndex].value;
180            double falseBranch = code[instr.childIndex + 1].value;
[9738]181
[14345]182            instr.value = trueBranch * p + falseBranch * (1 - p);
183          } else {
184            double variableValue = ((IList<double>)instr.data)[row];
185            if (variableValue <= variableConditionTreeNode.Threshold) {
186              instr.value = code[instr.childIndex].value;
187            } else {
188              instr.value = code[instr.childIndex + 1].value;
189            }
190          }
[9871]191        } else if (instr.opCode == OpCodes.Add) {
192          double s = code[instr.childIndex].value;
193          for (int j = 1; j != instr.nArguments; ++j) {
194            s += code[instr.childIndex + j].value;
195          }
196          instr.value = s;
197        } else if (instr.opCode == OpCodes.Sub) {
198          double s = code[instr.childIndex].value;
199          for (int j = 1; j != instr.nArguments; ++j) {
200            s -= code[instr.childIndex + j].value;
201          }
202          if (instr.nArguments == 1) s = -s;
203          instr.value = s;
204        } else if (instr.opCode == OpCodes.Mul) {
205          double p = code[instr.childIndex].value;
206          for (int j = 1; j != instr.nArguments; ++j) {
207            p *= code[instr.childIndex + j].value;
208          }
209          instr.value = p;
210        } else if (instr.opCode == OpCodes.Div) {
211          double p = code[instr.childIndex].value;
212          for (int j = 1; j != instr.nArguments; ++j) {
213            p /= code[instr.childIndex + j].value;
214          }
215          if (instr.nArguments == 1) p = 1.0 / p;
216          instr.value = p;
217        } else if (instr.opCode == OpCodes.Average) {
218          double s = code[instr.childIndex].value;
219          for (int j = 1; j != instr.nArguments; ++j) {
220            s += code[instr.childIndex + j].value;
221          }
222          instr.value = s / instr.nArguments;
[15944]223        } else if (instr.opCode == OpCodes.Absolute) {
224          instr.value = Math.Abs(code[instr.childIndex].value);
[9871]225        } else if (instr.opCode == OpCodes.Cos) {
226          instr.value = Math.Cos(code[instr.childIndex].value);
227        } else if (instr.opCode == OpCodes.Sin) {
228          instr.value = Math.Sin(code[instr.childIndex].value);
229        } else if (instr.opCode == OpCodes.Tan) {
230          instr.value = Math.Tan(code[instr.childIndex].value);
231        } else if (instr.opCode == OpCodes.Square) {
232          instr.value = Math.Pow(code[instr.childIndex].value, 2);
233        } else if (instr.opCode == OpCodes.Power) {
234          double x = code[instr.childIndex].value;
235          double y = Math.Round(code[instr.childIndex + 1].value);
236          instr.value = Math.Pow(x, y);
237        } else if (instr.opCode == OpCodes.SquareRoot) {
238          instr.value = Math.Sqrt(code[instr.childIndex].value);
239        } else if (instr.opCode == OpCodes.Root) {
240          double x = code[instr.childIndex].value;
[13254]241          double y = Math.Round(code[instr.childIndex + 1].value);
[9871]242          instr.value = Math.Pow(x, 1 / y);
243        } else if (instr.opCode == OpCodes.Exp) {
244          instr.value = Math.Exp(code[instr.childIndex].value);
245        } else if (instr.opCode == OpCodes.Log) {
246          instr.value = Math.Log(code[instr.childIndex].value);
247        } else if (instr.opCode == OpCodes.Gamma) {
248          var x = code[instr.childIndex].value;
249          instr.value = double.IsNaN(x) ? double.NaN : alglib.gammafunction(x);
250        } else if (instr.opCode == OpCodes.Psi) {
251          var x = code[instr.childIndex].value;
252          if (double.IsNaN(x)) instr.value = double.NaN;
253          else if (x <= 0 && (Math.Floor(x) - x).IsAlmost(0)) instr.value = double.NaN;
254          else instr.value = alglib.psi(x);
255        } else if (instr.opCode == OpCodes.Dawson) {
256          var x = code[instr.childIndex].value;
257          instr.value = double.IsNaN(x) ? double.NaN : alglib.dawsonintegral(x);
258        } else if (instr.opCode == OpCodes.ExponentialIntegralEi) {
259          var x = code[instr.childIndex].value;
260          instr.value = double.IsNaN(x) ? double.NaN : alglib.exponentialintegralei(x);
261        } else if (instr.opCode == OpCodes.SineIntegral) {
262          double si, ci;
263          var x = code[instr.childIndex].value;
264          if (double.IsNaN(x)) instr.value = double.NaN;
265          else {
266            alglib.sinecosineintegrals(x, out si, out ci);
267            instr.value = si;
268          }
269        } else if (instr.opCode == OpCodes.CosineIntegral) {
270          double si, ci;
271          var x = code[instr.childIndex].value;
272          if (double.IsNaN(x)) instr.value = double.NaN;
273          else {
274            alglib.sinecosineintegrals(x, out si, out ci);
275            instr.value = ci;
276          }
277        } else if (instr.opCode == OpCodes.HyperbolicSineIntegral) {
278          double shi, chi;
279          var x = code[instr.childIndex].value;
280          if (double.IsNaN(x)) instr.value = double.NaN;
281          else {
282            alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
283            instr.value = shi;
284          }
285        } else if (instr.opCode == OpCodes.HyperbolicCosineIntegral) {
286          double shi, chi;
287          var x = code[instr.childIndex].value;
288          if (double.IsNaN(x)) instr.value = double.NaN;
289          else {
290            alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
291            instr.value = chi;
292          }
293        } else if (instr.opCode == OpCodes.FresnelCosineIntegral) {
294          double c = 0, s = 0;
295          var x = code[instr.childIndex].value;
296          if (double.IsNaN(x)) instr.value = double.NaN;
297          else {
298            alglib.fresnelintegral(x, ref c, ref s);
299            instr.value = c;
300          }
301        } else if (instr.opCode == OpCodes.FresnelSineIntegral) {
302          double c = 0, s = 0;
303          var x = code[instr.childIndex].value;
304          if (double.IsNaN(x)) instr.value = double.NaN;
305          else {
306            alglib.fresnelintegral(x, ref c, ref s);
307            instr.value = s;
308          }
309        } else if (instr.opCode == OpCodes.AiryA) {
310          double ai, aip, bi, bip;
311          var x = code[instr.childIndex].value;
312          if (double.IsNaN(x)) instr.value = double.NaN;
313          else {
314            alglib.airy(x, out ai, out aip, out bi, out bip);
315            instr.value = ai;
316          }
317        } else if (instr.opCode == OpCodes.AiryB) {
318          double ai, aip, bi, bip;
319          var x = code[instr.childIndex].value;
320          if (double.IsNaN(x)) instr.value = double.NaN;
321          else {
322            alglib.airy(x, out ai, out aip, out bi, out bip);
323            instr.value = bi;
324          }
325        } else if (instr.opCode == OpCodes.Norm) {
326          var x = code[instr.childIndex].value;
327          if (double.IsNaN(x)) instr.value = double.NaN;
328          else instr.value = alglib.normaldistribution(x);
329        } else if (instr.opCode == OpCodes.Erf) {
330          var x = code[instr.childIndex].value;
331          if (double.IsNaN(x)) instr.value = double.NaN;
332          else instr.value = alglib.errorfunction(x);
333        } else if (instr.opCode == OpCodes.Bessel) {
334          var x = code[instr.childIndex].value;
335          if (double.IsNaN(x)) instr.value = double.NaN;
336          else instr.value = alglib.besseli0(x);
337        } else if (instr.opCode == OpCodes.IfThenElse) {
338          double condition = code[instr.childIndex].value;
339          double result;
340          if (condition > 0.0) {
341            result = code[instr.childIndex + 1].value;
342          } else {
343            result = code[instr.childIndex + 2].value;
344          }
345          instr.value = result;
346        } else if (instr.opCode == OpCodes.AND) {
347          double result = code[instr.childIndex].value;
348          for (int j = 1; j < instr.nArguments; j++) {
349            if (result > 0.0) result = code[instr.childIndex + j].value;
350            else break;
351          }
352          instr.value = result > 0.0 ? 1.0 : -1.0;
353        } else if (instr.opCode == OpCodes.OR) {
354          double result = code[instr.childIndex].value;
355          for (int j = 1; j < instr.nArguments; j++) {
356            if (result <= 0.0) result = code[instr.childIndex + j].value;
357            else break;
358          }
359          instr.value = result > 0.0 ? 1.0 : -1.0;
360        } else if (instr.opCode == OpCodes.NOT) {
361          instr.value = code[instr.childIndex].value > 0.0 ? -1.0 : 1.0;
[10774]362        } else if (instr.opCode == OpCodes.XOR) {
[10788]363          int positiveSignals = 0;
364          for (int j = 0; j < instr.nArguments; j++) {
365            if (code[instr.childIndex + j].value > 0.0) positiveSignals++;
[10774]366          }
[10788]367          instr.value = positiveSignals % 2 != 0 ? 1.0 : -1.0;
[9871]368        } else if (instr.opCode == OpCodes.GT) {
369          double x = code[instr.childIndex].value;
370          double y = code[instr.childIndex + 1].value;
371          instr.value = x > y ? 1.0 : -1.0;
372        } else if (instr.opCode == OpCodes.LT) {
373          double x = code[instr.childIndex].value;
374          double y = code[instr.childIndex + 1].value;
375          instr.value = x < y ? 1.0 : -1.0;
376        } else if (instr.opCode == OpCodes.TimeLag || instr.opCode == OpCodes.Derivative || instr.opCode == OpCodes.Integral) {
377          var state = (InterpreterState)instr.data;
378          state.Reset();
379          instr.value = interpreter.Evaluate(dataset, ref row, state);
380        } else {
381          var errorText = string.Format("The {0} symbol is not supported by the linear interpreter. To support this symbol, please use the SymbolicDataAnalysisExpressionTreeInterpreter.", instr.dynamicNode.Symbol.Name);
382          throw new NotSupportedException(errorText);
[9271]383        }
[9739]384        #endregion
[5571]385      }
[9739]386      return code[0].value;
[5571]387    }
[9815]388
389    private static LinearInstruction[] GetPrefixSequence(LinearInstruction[] code, int startIndex) {
[9944]390      var s = new Stack<int>();
[9815]391      var list = new List<LinearInstruction>();
[9944]392      s.Push(startIndex);
393      while (s.Any()) {
394        int i = s.Pop();
[9815]395        var instr = code[i];
[9944]396        // push instructions in reverse execution order
397        for (int j = instr.nArguments - 1; j >= 0; j--) s.Push(instr.childIndex + j);
[9815]398        list.Add(instr);
399      }
400      return list.ToArray();
401    }
402
[12509]403    public static void PrepareInstructions(LinearInstruction[] code, IDataset dataset) {
[9815]404      for (int i = 0; i != code.Length; ++i) {
405        var instr = code[i];
406        #region opcode switch
407        switch (instr.opCode) {
408          case OpCodes.Constant: {
409              var constTreeNode = (ConstantTreeNode)instr.dynamicNode;
410              instr.value = constTreeNode.Value;
411              instr.skip = true; // the value is already set so this instruction should be skipped in the evaluation phase
412            }
413            break;
414          case OpCodes.Variable: {
415              var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
[9826]416              instr.data = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
[9815]417            }
418            break;
[14826]419          case OpCodes.BinaryFactorVariable: {
420              var factorVariableTreeNode = instr.dynamicNode as BinaryFactorVariableTreeNode;
421              instr.data = dataset.GetReadOnlyStringValues(factorVariableTreeNode.VariableName);
422            }
423            break;
424          case OpCodes.FactorVariable: {
425              var factorVariableTreeNode = instr.dynamicNode as FactorVariableTreeNode;
426              instr.data = dataset.GetReadOnlyStringValues(factorVariableTreeNode.VariableName);
427            }
428            break;
[9815]429          case OpCodes.LagVariable: {
430              var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
[9826]431              instr.data = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
[9815]432            }
433            break;
434          case OpCodes.VariableCondition: {
435              var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
[9826]436              instr.data = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
[9815]437            }
438            break;
439          case OpCodes.TimeLag:
440          case OpCodes.Integral:
441          case OpCodes.Derivative: {
442              var seq = GetPrefixSequence(code, i);
443              var interpreterState = new InterpreterState(seq, 0);
[9826]444              instr.data = interpreterState;
[9815]445              for (int j = 1; j != seq.Length; ++j)
446                seq[j].skip = true;
[14345]447              break;
[9815]448            }
449        }
450        #endregion
451      }
452    }
[5571]453  }
454}
Note: See TracBrowser for help on using the repository browser.