Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.EvolutionTracking/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeLinearInterpreter.cs @ 12394

Last change on this file since 12394 was 12155, checked in by bburlacu, 10 years ago

#1772: Merged trunk changes.

File size: 18.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
33  [StorableClass]
34  [Item("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Fast linear (non-recursive) interpreter for symbolic expression trees. Does not support ADFs.")]
35  public sealed class SymbolicDataAnalysisExpressionTreeLinearInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
36    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
37    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
38
39    private SymbolicDataAnalysisExpressionTreeInterpreter interpreter;
40
41    public override bool CanChangeName {
42      get { return false; }
43    }
44
45    public override bool CanChangeDescription {
46      get { return false; }
47    }
48
49    #region parameter properties
50    public IValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
51      get { return (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
52    }
53
54    public IValueParameter<IntValue> EvaluatedSolutionsParameter {
55      get { return (IValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
56    }
57    #endregion
58
59    #region properties
60    public BoolValue CheckExpressionsWithIntervalArithmetic {
61      get { return CheckExpressionsWithIntervalArithmeticParameter.Value; }
62      set { CheckExpressionsWithIntervalArithmeticParameter.Value = value; }
63    }
64    public IntValue EvaluatedSolutions {
65      get { return EvaluatedSolutionsParameter.Value; }
66      set { EvaluatedSolutionsParameter.Value = value; }
67    }
68    #endregion
69
70    [StorableConstructor]
71    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(bool deserializing)
72      : base(deserializing) {
73    }
74
75    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(SymbolicDataAnalysisExpressionTreeLinearInterpreter original, Cloner cloner)
76      : base(original, cloner) {
77      interpreter = cloner.Clone(original.interpreter);
78    }
79
80    public override IDeepCloneable Clone(Cloner cloner) {
81      return new SymbolicDataAnalysisExpressionTreeLinearInterpreter(this, cloner);
82    }
83
84    public SymbolicDataAnalysisExpressionTreeLinearInterpreter()
85      : base("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Linear (non-recursive) interpreter for symbolic expression trees (does not support ADFs).") {
86      Parameters.Add(new ValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
87      Parameters.Add(new ValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
88      interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
89    }
90
91    [StorableHook(HookType.AfterDeserialization)]
92    private void AfterDeserialization() {
93      if (interpreter == null) interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
94    }
95
96    #region IStatefulItem
97    public void InitializeState() {
98      EvaluatedSolutions.Value = 0;
99    }
100
101    public void ClearState() { }
102    #endregion
103
104    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows) {
105      if (CheckExpressionsWithIntervalArithmetic.Value)
106        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
107
108      lock (EvaluatedSolutions) {
109        EvaluatedSolutions.Value++; // increment the evaluated solutions counter
110      }
111
112      var code = SymbolicExpressionTreeLinearCompiler.Compile(tree, OpCodes.MapSymbolToOpCode);
113      PrepareInstructions(code, dataset);
114      return rows.Select(row => Evaluate(dataset, row, code));
115    }
116
117    // NOTE: do not use this method when evaluating trees. this method is provided as a shortcut for evaluating subtrees ad-hoc
118    public IEnumerable<double> GetValues(ISymbolicExpressionTreeNode node, Dataset dataset, IEnumerable<int> rows) {
119      var code = SymbolicExpressionTreeLinearCompiler.Compile(node, OpCodes.MapSymbolToOpCode);
120      PrepareInstructions(code, dataset);
121      return rows.Select(row => Evaluate(dataset, row, code));
122    }
123
124    private double Evaluate(Dataset dataset, int row, LinearInstruction[] code) {
125      for (int i = code.Length - 1; i >= 0; --i) {
126        if (code[i].skip) continue;
127        #region opcode if
128        var instr = code[i];
129        if (instr.opCode == OpCodes.Variable) {
130          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
131          var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
132          instr.value = ((IList<double>)instr.data)[row] * variableTreeNode.Weight;
133        } else if (instr.opCode == OpCodes.LagVariable) {
134          var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
135          int actualRow = row + laggedVariableTreeNode.Lag;
136          if (actualRow < 0 || actualRow >= dataset.Rows)
137            instr.value = double.NaN;
138          else
139            instr.value = ((IList<double>)instr.data)[actualRow] * laggedVariableTreeNode.Weight;
140        } else if (instr.opCode == OpCodes.VariableCondition) {
141          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
142          var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
143          double variableValue = ((IList<double>)instr.data)[row];
144          double x = variableValue - variableConditionTreeNode.Threshold;
145          double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
146
147          double trueBranch = code[instr.childIndex].value;
148          double falseBranch = code[instr.childIndex + 1].value;
149
150          instr.value = trueBranch * p + falseBranch * (1 - p);
151        } else if (instr.opCode == OpCodes.Add) {
152          double s = code[instr.childIndex].value;
153          for (int j = 1; j != instr.nArguments; ++j) {
154            s += code[instr.childIndex + j].value;
155          }
156          instr.value = s;
157        } else if (instr.opCode == OpCodes.Sub) {
158          double s = code[instr.childIndex].value;
159          for (int j = 1; j != instr.nArguments; ++j) {
160            s -= code[instr.childIndex + j].value;
161          }
162          if (instr.nArguments == 1) s = -s;
163          instr.value = s;
164        } else if (instr.opCode == OpCodes.Mul) {
165          double p = code[instr.childIndex].value;
166          for (int j = 1; j != instr.nArguments; ++j) {
167            p *= code[instr.childIndex + j].value;
168          }
169          instr.value = p;
170        } else if (instr.opCode == OpCodes.Div) {
171          double p = code[instr.childIndex].value;
172          for (int j = 1; j != instr.nArguments; ++j) {
173            p /= code[instr.childIndex + j].value;
174          }
175          if (instr.nArguments == 1) p = 1.0 / p;
176          instr.value = p;
177        } else if (instr.opCode == OpCodes.Average) {
178          double s = code[instr.childIndex].value;
179          for (int j = 1; j != instr.nArguments; ++j) {
180            s += code[instr.childIndex + j].value;
181          }
182          instr.value = s / instr.nArguments;
183        } else if (instr.opCode == OpCodes.Cos) {
184          instr.value = Math.Cos(code[instr.childIndex].value);
185        } else if (instr.opCode == OpCodes.Sin) {
186          instr.value = Math.Sin(code[instr.childIndex].value);
187        } else if (instr.opCode == OpCodes.Tan) {
188          instr.value = Math.Tan(code[instr.childIndex].value);
189        } else if (instr.opCode == OpCodes.Square) {
190          instr.value = Math.Pow(code[instr.childIndex].value, 2);
191        } else if (instr.opCode == OpCodes.Power) {
192          double x = code[instr.childIndex].value;
193          double y = Math.Round(code[instr.childIndex + 1].value);
194          instr.value = Math.Pow(x, y);
195        } else if (instr.opCode == OpCodes.SquareRoot) {
196          instr.value = Math.Sqrt(code[instr.childIndex].value);
197        } else if (instr.opCode == OpCodes.Root) {
198          double x = code[instr.childIndex].value;
199          double y = code[instr.childIndex + 1].value;
200          instr.value = Math.Pow(x, 1 / y);
201        } else if (instr.opCode == OpCodes.Exp) {
202          instr.value = Math.Exp(code[instr.childIndex].value);
203        } else if (instr.opCode == OpCodes.Log) {
204          instr.value = Math.Log(code[instr.childIndex].value);
205        } else if (instr.opCode == OpCodes.Gamma) {
206          var x = code[instr.childIndex].value;
207          instr.value = double.IsNaN(x) ? double.NaN : alglib.gammafunction(x);
208        } else if (instr.opCode == OpCodes.Psi) {
209          var x = code[instr.childIndex].value;
210          if (double.IsNaN(x)) instr.value = double.NaN;
211          else if (x <= 0 && (Math.Floor(x) - x).IsAlmost(0)) instr.value = double.NaN;
212          else instr.value = alglib.psi(x);
213        } else if (instr.opCode == OpCodes.Dawson) {
214          var x = code[instr.childIndex].value;
215          instr.value = double.IsNaN(x) ? double.NaN : alglib.dawsonintegral(x);
216        } else if (instr.opCode == OpCodes.ExponentialIntegralEi) {
217          var x = code[instr.childIndex].value;
218          instr.value = double.IsNaN(x) ? double.NaN : alglib.exponentialintegralei(x);
219        } else if (instr.opCode == OpCodes.SineIntegral) {
220          double si, ci;
221          var x = code[instr.childIndex].value;
222          if (double.IsNaN(x)) instr.value = double.NaN;
223          else {
224            alglib.sinecosineintegrals(x, out si, out ci);
225            instr.value = si;
226          }
227        } else if (instr.opCode == OpCodes.CosineIntegral) {
228          double si, ci;
229          var x = code[instr.childIndex].value;
230          if (double.IsNaN(x)) instr.value = double.NaN;
231          else {
232            alglib.sinecosineintegrals(x, out si, out ci);
233            instr.value = ci;
234          }
235        } else if (instr.opCode == OpCodes.HyperbolicSineIntegral) {
236          double shi, chi;
237          var x = code[instr.childIndex].value;
238          if (double.IsNaN(x)) instr.value = double.NaN;
239          else {
240            alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
241            instr.value = shi;
242          }
243        } else if (instr.opCode == OpCodes.HyperbolicCosineIntegral) {
244          double shi, chi;
245          var x = code[instr.childIndex].value;
246          if (double.IsNaN(x)) instr.value = double.NaN;
247          else {
248            alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
249            instr.value = chi;
250          }
251        } else if (instr.opCode == OpCodes.FresnelCosineIntegral) {
252          double c = 0, s = 0;
253          var x = code[instr.childIndex].value;
254          if (double.IsNaN(x)) instr.value = double.NaN;
255          else {
256            alglib.fresnelintegral(x, ref c, ref s);
257            instr.value = c;
258          }
259        } else if (instr.opCode == OpCodes.FresnelSineIntegral) {
260          double c = 0, s = 0;
261          var x = code[instr.childIndex].value;
262          if (double.IsNaN(x)) instr.value = double.NaN;
263          else {
264            alglib.fresnelintegral(x, ref c, ref s);
265            instr.value = s;
266          }
267        } else if (instr.opCode == OpCodes.AiryA) {
268          double ai, aip, bi, bip;
269          var x = code[instr.childIndex].value;
270          if (double.IsNaN(x)) instr.value = double.NaN;
271          else {
272            alglib.airy(x, out ai, out aip, out bi, out bip);
273            instr.value = ai;
274          }
275        } else if (instr.opCode == OpCodes.AiryB) {
276          double ai, aip, bi, bip;
277          var x = code[instr.childIndex].value;
278          if (double.IsNaN(x)) instr.value = double.NaN;
279          else {
280            alglib.airy(x, out ai, out aip, out bi, out bip);
281            instr.value = bi;
282          }
283        } else if (instr.opCode == OpCodes.Norm) {
284          var x = code[instr.childIndex].value;
285          if (double.IsNaN(x)) instr.value = double.NaN;
286          else instr.value = alglib.normaldistribution(x);
287        } else if (instr.opCode == OpCodes.Erf) {
288          var x = code[instr.childIndex].value;
289          if (double.IsNaN(x)) instr.value = double.NaN;
290          else instr.value = alglib.errorfunction(x);
291        } else if (instr.opCode == OpCodes.Bessel) {
292          var x = code[instr.childIndex].value;
293          if (double.IsNaN(x)) instr.value = double.NaN;
294          else instr.value = alglib.besseli0(x);
295        } else if (instr.opCode == OpCodes.IfThenElse) {
296          double condition = code[instr.childIndex].value;
297          double result;
298          if (condition > 0.0) {
299            result = code[instr.childIndex + 1].value;
300          } else {
301            result = code[instr.childIndex + 2].value;
302          }
303          instr.value = result;
304        } else if (instr.opCode == OpCodes.AND) {
305          double result = code[instr.childIndex].value;
306          for (int j = 1; j < instr.nArguments; j++) {
307            if (result > 0.0) result = code[instr.childIndex + j].value;
308            else break;
309          }
310          instr.value = result > 0.0 ? 1.0 : -1.0;
311        } else if (instr.opCode == OpCodes.OR) {
312          double result = code[instr.childIndex].value;
313          for (int j = 1; j < instr.nArguments; j++) {
314            if (result <= 0.0) result = code[instr.childIndex + j].value;
315            else break;
316          }
317          instr.value = result > 0.0 ? 1.0 : -1.0;
318        } else if (instr.opCode == OpCodes.NOT) {
319          instr.value = code[instr.childIndex].value > 0.0 ? -1.0 : 1.0;
320        } else if (instr.opCode == OpCodes.XOR) {
321          int positiveSignals = 0;
322          for (int j = 0; j < instr.nArguments; j++) {
323            if (code[instr.childIndex + j].value > 0.0) positiveSignals++;
324          }
325          instr.value = positiveSignals % 2 != 0 ? 1.0 : -1.0;
326        } else if (instr.opCode == OpCodes.GT) {
327          double x = code[instr.childIndex].value;
328          double y = code[instr.childIndex + 1].value;
329          instr.value = x > y ? 1.0 : -1.0;
330        } else if (instr.opCode == OpCodes.LT) {
331          double x = code[instr.childIndex].value;
332          double y = code[instr.childIndex + 1].value;
333          instr.value = x < y ? 1.0 : -1.0;
334        } else if (instr.opCode == OpCodes.TimeLag || instr.opCode == OpCodes.Derivative || instr.opCode == OpCodes.Integral) {
335          var state = (InterpreterState)instr.data;
336          state.Reset();
337          instr.value = interpreter.Evaluate(dataset, ref row, state);
338        } else if (instr.opCode == OpCodes.Passthrough) {
339          instr.value = code[instr.childIndex].value;
340        } else {
341          var errorText = string.Format("The {0} symbol is not supported by the linear interpreter. To support this symbol, please use the SymbolicDataAnalysisExpressionTreeInterpreter.", instr.dynamicNode.Symbol.Name);
342          throw new NotSupportedException(errorText);
343        }
344        #endregion
345      }
346      return code[0].value;
347    }
348
349    private static LinearInstruction[] GetPrefixSequence(LinearInstruction[] code, int startIndex) {
350      var s = new Stack<int>();
351      var list = new List<LinearInstruction>();
352      s.Push(startIndex);
353      while (s.Any()) {
354        int i = s.Pop();
355        var instr = code[i];
356        // push instructions in reverse execution order
357        for (int j = instr.nArguments - 1; j >= 0; j--) s.Push(instr.childIndex + j);
358        list.Add(instr);
359      }
360      return list.ToArray();
361    }
362
363    public static void PrepareInstructions(LinearInstruction[] code, Dataset dataset) {
364      for (int i = 0; i != code.Length; ++i) {
365        var instr = code[i];
366        #region opcode switch
367        switch (instr.opCode) {
368          case OpCodes.Constant: {
369              var constTreeNode = (ConstantTreeNode)instr.dynamicNode;
370              instr.value = constTreeNode.Value;
371              instr.skip = true; // the value is already set so this instruction should be skipped in the evaluation phase
372            }
373            break;
374          case OpCodes.Variable: {
375              var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
376              instr.data = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
377            }
378            break;
379          case OpCodes.LagVariable: {
380              var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
381              instr.data = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
382            }
383            break;
384          case OpCodes.VariableCondition: {
385              var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
386              instr.data = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
387            }
388            break;
389          case OpCodes.TimeLag:
390          case OpCodes.Integral:
391          case OpCodes.Derivative: {
392              var seq = GetPrefixSequence(code, i);
393              var interpreterState = new InterpreterState(seq, 0);
394              instr.data = interpreterState;
395              for (int j = 1; j != seq.Length; ++j)
396                seq[j].skip = true;
397            }
398            break;
399        }
400        #endregion
401      }
402    }
403  }
404}
Note: See TracBrowser for help on using the repository browser.