Free cookie consent management tool by TermsFeed Policy Generator

source: stable/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeLinearInterpreter.cs @ 14304

Last change on this file since 14304 was 14304, checked in by gkronber, 6 years ago

#2668: merged r14282 from trunk to stable

File size: 17.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
33  [StorableClass]
34  [Item("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Fast linear (non-recursive) interpreter for symbolic expression trees. Does not support ADFs.")]
35  public sealed class SymbolicDataAnalysisExpressionTreeLinearInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
36    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
37    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
38
39    private readonly SymbolicDataAnalysisExpressionTreeInterpreter interpreter;
40
41    public override bool CanChangeName {
42      get { return false; }
43    }
44
45    public override bool CanChangeDescription {
46      get { return false; }
47    }
48
49    #region parameter properties
50    public IValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
51      get { return (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
52    }
53
54    public IValueParameter<IntValue> EvaluatedSolutionsParameter {
55      get { return (IValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
56    }
57    #endregion
58
59    #region properties
60    public BoolValue CheckExpressionsWithIntervalArithmetic {
61      get { return CheckExpressionsWithIntervalArithmeticParameter.Value; }
62      set { CheckExpressionsWithIntervalArithmeticParameter.Value = value; }
63    }
64    public IntValue EvaluatedSolutions {
65      get { return EvaluatedSolutionsParameter.Value; }
66      set { EvaluatedSolutionsParameter.Value = value; }
67    }
68    #endregion
69
70    [StorableConstructor]
71    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(bool deserializing)
72      : base(deserializing) {
73      interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
74    }
75
76    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(SymbolicDataAnalysisExpressionTreeLinearInterpreter original, Cloner cloner)
77      : base(original, cloner) {
78      interpreter = cloner.Clone(original.interpreter);
79    }
80
81    public override IDeepCloneable Clone(Cloner cloner) {
82      return new SymbolicDataAnalysisExpressionTreeLinearInterpreter(this, cloner);
83    }
84
85    public SymbolicDataAnalysisExpressionTreeLinearInterpreter()
86      : base("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Linear (non-recursive) interpreter for symbolic expression trees (does not support ADFs).") {
87      Parameters.Add(new ValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
88      Parameters.Add(new ValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
89      interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
90    }
91
92    [StorableHook(HookType.AfterDeserialization)]
93    private void AfterDeserialization() {
94      if (interpreter == null) interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
95    }
96
97    #region IStatefulItem
98    public void InitializeState() {
99      EvaluatedSolutions.Value = 0;
100    }
101
102    public void ClearState() { }
103    #endregion
104
105    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
106      if (CheckExpressionsWithIntervalArithmetic.Value)
107        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
108
109      lock (EvaluatedSolutions) {
110        EvaluatedSolutions.Value++; // increment the evaluated solutions counter
111      }
112
113      var code = SymbolicExpressionTreeLinearCompiler.Compile(tree, OpCodes.MapSymbolToOpCode);
114      PrepareInstructions(code, dataset);
115      return rows.Select(row => Evaluate(dataset, row, code));
116    }
117
118    private double Evaluate(IDataset dataset, int row, LinearInstruction[] code) {
119      for (int i = code.Length - 1; i >= 0; --i) {
120        if (code[i].skip) continue;
121        #region opcode if
122        var instr = code[i];
123        if (instr.opCode == OpCodes.Variable) {
124          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
125          var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
126          instr.value = ((IList<double>)instr.data)[row] * variableTreeNode.Weight;
127        } else if (instr.opCode == OpCodes.LagVariable) {
128          var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
129          int actualRow = row + laggedVariableTreeNode.Lag;
130          if (actualRow < 0 || actualRow >= dataset.Rows)
131            instr.value = double.NaN;
132          else
133            instr.value = ((IList<double>)instr.data)[actualRow] * laggedVariableTreeNode.Weight;
134        } else if (instr.opCode == OpCodes.VariableCondition) {
135          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
136          var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
137          double variableValue = ((IList<double>)instr.data)[row];
138          double x = variableValue - variableConditionTreeNode.Threshold;
139          double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
140
141          double trueBranch = code[instr.childIndex].value;
142          double falseBranch = code[instr.childIndex + 1].value;
143
144          instr.value = trueBranch * p + falseBranch * (1 - p);
145        } else if (instr.opCode == OpCodes.Add) {
146          double s = code[instr.childIndex].value;
147          for (int j = 1; j != instr.nArguments; ++j) {
148            s += code[instr.childIndex + j].value;
149          }
150          instr.value = s;
151        } else if (instr.opCode == OpCodes.Sub) {
152          double s = code[instr.childIndex].value;
153          for (int j = 1; j != instr.nArguments; ++j) {
154            s -= code[instr.childIndex + j].value;
155          }
156          if (instr.nArguments == 1) s = -s;
157          instr.value = s;
158        } else if (instr.opCode == OpCodes.Mul) {
159          double p = code[instr.childIndex].value;
160          for (int j = 1; j != instr.nArguments; ++j) {
161            p *= code[instr.childIndex + j].value;
162          }
163          instr.value = p;
164        } else if (instr.opCode == OpCodes.Div) {
165          double p = code[instr.childIndex].value;
166          for (int j = 1; j != instr.nArguments; ++j) {
167            p /= code[instr.childIndex + j].value;
168          }
169          if (instr.nArguments == 1) p = 1.0 / p;
170          instr.value = p;
171        } else if (instr.opCode == OpCodes.Average) {
172          double s = code[instr.childIndex].value;
173          for (int j = 1; j != instr.nArguments; ++j) {
174            s += code[instr.childIndex + j].value;
175          }
176          instr.value = s / instr.nArguments;
177        } else if (instr.opCode == OpCodes.Cos) {
178          instr.value = Math.Cos(code[instr.childIndex].value);
179        } else if (instr.opCode == OpCodes.Sin) {
180          instr.value = Math.Sin(code[instr.childIndex].value);
181        } else if (instr.opCode == OpCodes.Tan) {
182          instr.value = Math.Tan(code[instr.childIndex].value);
183        } else if (instr.opCode == OpCodes.Square) {
184          instr.value = Math.Pow(code[instr.childIndex].value, 2);
185        } else if (instr.opCode == OpCodes.Power) {
186          double x = code[instr.childIndex].value;
187          double y = Math.Round(code[instr.childIndex + 1].value);
188          instr.value = Math.Pow(x, y);
189        } else if (instr.opCode == OpCodes.SquareRoot) {
190          instr.value = Math.Sqrt(code[instr.childIndex].value);
191        } else if (instr.opCode == OpCodes.Root) {
192          double x = code[instr.childIndex].value;
193          double y = code[instr.childIndex + 1].value;
194          instr.value = Math.Pow(x, 1 / y);
195        } else if (instr.opCode == OpCodes.Exp) {
196          instr.value = Math.Exp(code[instr.childIndex].value);
197        } else if (instr.opCode == OpCodes.Log) {
198          instr.value = Math.Log(code[instr.childIndex].value);
199        } else if (instr.opCode == OpCodes.Gamma) {
200          var x = code[instr.childIndex].value;
201          instr.value = double.IsNaN(x) ? double.NaN : alglib.gammafunction(x);
202        } else if (instr.opCode == OpCodes.Psi) {
203          var x = code[instr.childIndex].value;
204          if (double.IsNaN(x)) instr.value = double.NaN;
205          else if (x <= 0 && (Math.Floor(x) - x).IsAlmost(0)) instr.value = double.NaN;
206          else instr.value = alglib.psi(x);
207        } else if (instr.opCode == OpCodes.Dawson) {
208          var x = code[instr.childIndex].value;
209          instr.value = double.IsNaN(x) ? double.NaN : alglib.dawsonintegral(x);
210        } else if (instr.opCode == OpCodes.ExponentialIntegralEi) {
211          var x = code[instr.childIndex].value;
212          instr.value = double.IsNaN(x) ? double.NaN : alglib.exponentialintegralei(x);
213        } else if (instr.opCode == OpCodes.SineIntegral) {
214          double si, ci;
215          var x = code[instr.childIndex].value;
216          if (double.IsNaN(x)) instr.value = double.NaN;
217          else {
218            alglib.sinecosineintegrals(x, out si, out ci);
219            instr.value = si;
220          }
221        } else if (instr.opCode == OpCodes.CosineIntegral) {
222          double si, ci;
223          var x = code[instr.childIndex].value;
224          if (double.IsNaN(x)) instr.value = double.NaN;
225          else {
226            alglib.sinecosineintegrals(x, out si, out ci);
227            instr.value = ci;
228          }
229        } else if (instr.opCode == OpCodes.HyperbolicSineIntegral) {
230          double shi, chi;
231          var x = code[instr.childIndex].value;
232          if (double.IsNaN(x)) instr.value = double.NaN;
233          else {
234            alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
235            instr.value = shi;
236          }
237        } else if (instr.opCode == OpCodes.HyperbolicCosineIntegral) {
238          double shi, chi;
239          var x = code[instr.childIndex].value;
240          if (double.IsNaN(x)) instr.value = double.NaN;
241          else {
242            alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
243            instr.value = chi;
244          }
245        } else if (instr.opCode == OpCodes.FresnelCosineIntegral) {
246          double c = 0, s = 0;
247          var x = code[instr.childIndex].value;
248          if (double.IsNaN(x)) instr.value = double.NaN;
249          else {
250            alglib.fresnelintegral(x, ref c, ref s);
251            instr.value = c;
252          }
253        } else if (instr.opCode == OpCodes.FresnelSineIntegral) {
254          double c = 0, s = 0;
255          var x = code[instr.childIndex].value;
256          if (double.IsNaN(x)) instr.value = double.NaN;
257          else {
258            alglib.fresnelintegral(x, ref c, ref s);
259            instr.value = s;
260          }
261        } else if (instr.opCode == OpCodes.AiryA) {
262          double ai, aip, bi, bip;
263          var x = code[instr.childIndex].value;
264          if (double.IsNaN(x)) instr.value = double.NaN;
265          else {
266            alglib.airy(x, out ai, out aip, out bi, out bip);
267            instr.value = ai;
268          }
269        } else if (instr.opCode == OpCodes.AiryB) {
270          double ai, aip, bi, bip;
271          var x = code[instr.childIndex].value;
272          if (double.IsNaN(x)) instr.value = double.NaN;
273          else {
274            alglib.airy(x, out ai, out aip, out bi, out bip);
275            instr.value = bi;
276          }
277        } else if (instr.opCode == OpCodes.Norm) {
278          var x = code[instr.childIndex].value;
279          if (double.IsNaN(x)) instr.value = double.NaN;
280          else instr.value = alglib.normaldistribution(x);
281        } else if (instr.opCode == OpCodes.Erf) {
282          var x = code[instr.childIndex].value;
283          if (double.IsNaN(x)) instr.value = double.NaN;
284          else instr.value = alglib.errorfunction(x);
285        } else if (instr.opCode == OpCodes.Bessel) {
286          var x = code[instr.childIndex].value;
287          if (double.IsNaN(x)) instr.value = double.NaN;
288          else instr.value = alglib.besseli0(x);
289        } else if (instr.opCode == OpCodes.IfThenElse) {
290          double condition = code[instr.childIndex].value;
291          double result;
292          if (condition > 0.0) {
293            result = code[instr.childIndex + 1].value;
294          } else {
295            result = code[instr.childIndex + 2].value;
296          }
297          instr.value = result;
298        } else if (instr.opCode == OpCodes.AND) {
299          double result = code[instr.childIndex].value;
300          for (int j = 1; j < instr.nArguments; j++) {
301            if (result > 0.0) result = code[instr.childIndex + j].value;
302            else break;
303          }
304          instr.value = result > 0.0 ? 1.0 : -1.0;
305        } else if (instr.opCode == OpCodes.OR) {
306          double result = code[instr.childIndex].value;
307          for (int j = 1; j < instr.nArguments; j++) {
308            if (result <= 0.0) result = code[instr.childIndex + j].value;
309            else break;
310          }
311          instr.value = result > 0.0 ? 1.0 : -1.0;
312        } else if (instr.opCode == OpCodes.NOT) {
313          instr.value = code[instr.childIndex].value > 0.0 ? -1.0 : 1.0;
314        } else if (instr.opCode == OpCodes.XOR) {
315          int positiveSignals = 0;
316          for (int j = 0; j < instr.nArguments; j++) {
317            if (code[instr.childIndex + j].value > 0.0) positiveSignals++;
318          }
319          instr.value = positiveSignals % 2 != 0 ? 1.0 : -1.0;
320        } else if (instr.opCode == OpCodes.GT) {
321          double x = code[instr.childIndex].value;
322          double y = code[instr.childIndex + 1].value;
323          instr.value = x > y ? 1.0 : -1.0;
324        } else if (instr.opCode == OpCodes.LT) {
325          double x = code[instr.childIndex].value;
326          double y = code[instr.childIndex + 1].value;
327          instr.value = x < y ? 1.0 : -1.0;
328        } else if (instr.opCode == OpCodes.TimeLag || instr.opCode == OpCodes.Derivative || instr.opCode == OpCodes.Integral) {
329          var state = (InterpreterState)instr.data;
330          state.Reset();
331          instr.value = interpreter.Evaluate(dataset, ref row, state);
332        } else {
333          var errorText = string.Format("The {0} symbol is not supported by the linear interpreter. To support this symbol, please use the SymbolicDataAnalysisExpressionTreeInterpreter.", instr.dynamicNode.Symbol.Name);
334          throw new NotSupportedException(errorText);
335        }
336        #endregion
337      }
338      return code[0].value;
339    }
340
341    private static LinearInstruction[] GetPrefixSequence(LinearInstruction[] code, int startIndex) {
342      var s = new Stack<int>();
343      var list = new List<LinearInstruction>();
344      s.Push(startIndex);
345      while (s.Any()) {
346        int i = s.Pop();
347        var instr = code[i];
348        // push instructions in reverse execution order
349        for (int j = instr.nArguments - 1; j >= 0; j--) s.Push(instr.childIndex + j);
350        list.Add(instr);
351      }
352      return list.ToArray();
353    }
354
355    public static void PrepareInstructions(LinearInstruction[] code, IDataset dataset) {
356      for (int i = 0; i != code.Length; ++i) {
357        var instr = code[i];
358        #region opcode switch
359        switch (instr.opCode) {
360          case OpCodes.Constant: {
361              var constTreeNode = (ConstantTreeNode)instr.dynamicNode;
362              instr.value = constTreeNode.Value;
363              instr.skip = true; // the value is already set so this instruction should be skipped in the evaluation phase
364            }
365            break;
366          case OpCodes.Variable: {
367              var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
368              instr.data = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
369            }
370            break;
371          case OpCodes.LagVariable: {
372              var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
373              instr.data = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
374            }
375            break;
376          case OpCodes.VariableCondition: {
377              var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
378              instr.data = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
379            }
380            break;
381          case OpCodes.TimeLag:
382          case OpCodes.Integral:
383          case OpCodes.Derivative: {
384              var seq = GetPrefixSequence(code, i);
385              var interpreterState = new InterpreterState(seq, 0);
386              instr.data = interpreterState;
387              for (int j = 1; j != seq.Length; ++j)
388                seq[j].skip = true;
389            }
390            break;
391        }
392        #endregion
393      }
394    }
395  }
396}
Note: See TracBrowser for help on using the repository browser.