Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeLinearInterpreter.cs @ 13818

Last change on this file since 13818 was 13268, checked in by bburlacu, 9 years ago

#2442: Fix potential out of bounds exception when getting variable values in the SymbolicDataAnalysisExpressionTreeLinearInterpreter

File size: 19.9 KB
RevLine 
[5571]1#region License Information
2/* HeuristicLab
[12012]3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[5571]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
[9739]24using System.Linq;
[5571]25using HeuristicLab.Common;
26using HeuristicLab.Core;
[6740]27using HeuristicLab.Data;
[5571]28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
[6740]29using HeuristicLab.Parameters;
[5571]30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
33  [StorableClass]
[9815]34  [Item("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Fast linear (non-recursive) interpreter for symbolic expression trees. Does not support ADFs.")]
[9758]35  public sealed class SymbolicDataAnalysisExpressionTreeLinearInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
[5749]36    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
[13248]37    private const string CheckExpressionsWithIntervalArithmeticParameterDescription = "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.";
[7615]38    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
[5571]39
[9776]40    private SymbolicDataAnalysisExpressionTreeInterpreter interpreter;
41
[9732]42    public override bool CanChangeName {
43      get { return false; }
44    }
[5571]45
[9732]46    public override bool CanChangeDescription {
47      get { return false; }
48    }
49
[5749]50    #region parameter properties
[13248]51    public IFixedValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
52      get { return (IFixedValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
[5749]53    }
[7615]54
[13248]55    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
56      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
[7615]57    }
[5749]58    #endregion
59
60    #region properties
[13248]61    public bool CheckExpressionsWithIntervalArithmetic {
62      get { return CheckExpressionsWithIntervalArithmeticParameter.Value.Value; }
63      set { CheckExpressionsWithIntervalArithmeticParameter.Value.Value = value; }
[5749]64    }
[13248]65    public int EvaluatedSolutions {
66      get { return EvaluatedSolutionsParameter.Value.Value; }
67      set { EvaluatedSolutionsParameter.Value.Value = value; }
[7615]68    }
[5749]69    #endregion
70
[5571]71    [StorableConstructor]
[9758]72    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(bool deserializing)
[9732]73      : base(deserializing) {
74    }
75
[9828]76    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(SymbolicDataAnalysisExpressionTreeLinearInterpreter original, Cloner cloner)
[9732]77      : base(original, cloner) {
[9828]78      interpreter = cloner.Clone(original.interpreter);
[9732]79    }
80
[5571]81    public override IDeepCloneable Clone(Cloner cloner) {
[9734]82      return new SymbolicDataAnalysisExpressionTreeLinearInterpreter(this, cloner);
[5571]83    }
84
[9734]85    public SymbolicDataAnalysisExpressionTreeLinearInterpreter()
[9758]86      : base("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Linear (non-recursive) interpreter for symbolic expression trees (does not support ADFs).") {
[13248]87      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, CheckExpressionsWithIntervalArithmeticParameterDescription, new BoolValue(false)));
88      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
[9776]89      interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
[5571]90    }
91
[13248]92    public SymbolicDataAnalysisExpressionTreeLinearInterpreter(string name, string description)
93      : base(name, description) {
94      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, CheckExpressionsWithIntervalArithmeticParameterDescription, new BoolValue(false)));
95      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
96      interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
97    }
98
[7615]99    [StorableHook(HookType.AfterDeserialization)]
100    private void AfterDeserialization() {
[13248]101      var evaluatedSolutions = new IntValue(0);
102      var checkExpressionsWithIntervalArithmetic = new BoolValue(false);
103      if (Parameters.ContainsKey(EvaluatedSolutionsParameterName)) {
104        var evaluatedSolutionsParameter = (IValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName];
105        evaluatedSolutions = evaluatedSolutionsParameter.Value;
106        Parameters.Remove(EvaluatedSolutionsParameterName);
107      }
108      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", evaluatedSolutions));
109      if (Parameters.ContainsKey(CheckExpressionsWithIntervalArithmeticParameterName)) {
110        var checkExpressionsWithIntervalArithmeticParameter = (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName];
111        Parameters.Remove(CheckExpressionsWithIntervalArithmeticParameterName);
112        checkExpressionsWithIntervalArithmetic = checkExpressionsWithIntervalArithmeticParameter.Value;
113      }
114      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, CheckExpressionsWithIntervalArithmeticParameterDescription, checkExpressionsWithIntervalArithmetic));
[7615]115    }
116
117    #region IStatefulItem
118    public void InitializeState() {
[13248]119      EvaluatedSolutions = 0;
[7615]120    }
121
[9828]122    public void ClearState() { }
[7615]123    #endregion
124
[13251]125    private readonly object syncRoot = new object();
[12509]126    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
[13248]127      if (CheckExpressionsWithIntervalArithmetic)
[9734]128        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
[7120]129
[13251]130      lock (syncRoot) {
[13248]131        EvaluatedSolutions++; // increment the evaluated solutions counter
[9004]132      }
[8436]133
[9739]134      var code = SymbolicExpressionTreeLinearCompiler.Compile(tree, OpCodes.MapSymbolToOpCode);
[9758]135      PrepareInstructions(code, dataset);
[9818]136      return rows.Select(row => Evaluate(dataset, row, code));
[9739]137    }
[9732]138
[12509]139    private double Evaluate(IDataset dataset, int row, LinearInstruction[] code) {
[9732]140      for (int i = code.Length - 1; i >= 0; --i) {
[9776]141        if (code[i].skip) continue;
[9871]142        #region opcode if
[9732]143        var instr = code[i];
[9871]144        if (instr.opCode == OpCodes.Variable) {
145          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
[13268]146          else {
147            var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
148            instr.value = ((IList<double>)instr.data)[row] * variableTreeNode.Weight;
149          }
[9871]150        } else if (instr.opCode == OpCodes.LagVariable) {
151          var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
152          int actualRow = row + laggedVariableTreeNode.Lag;
153          if (actualRow < 0 || actualRow >= dataset.Rows)
154            instr.value = double.NaN;
155          else
156            instr.value = ((IList<double>)instr.data)[actualRow] * laggedVariableTreeNode.Weight;
157        } else if (instr.opCode == OpCodes.VariableCondition) {
158          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
159          var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
160          double variableValue = ((IList<double>)instr.data)[row];
161          double x = variableValue - variableConditionTreeNode.Threshold;
162          double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
[9738]163
[9871]164          double trueBranch = code[instr.childIndex].value;
165          double falseBranch = code[instr.childIndex + 1].value;
[9738]166
[9871]167          instr.value = trueBranch * p + falseBranch * (1 - p);
168        } else if (instr.opCode == OpCodes.Add) {
169          double s = code[instr.childIndex].value;
170          for (int j = 1; j != instr.nArguments; ++j) {
171            s += code[instr.childIndex + j].value;
172          }
173          instr.value = s;
174        } else if (instr.opCode == OpCodes.Sub) {
175          double s = code[instr.childIndex].value;
176          for (int j = 1; j != instr.nArguments; ++j) {
177            s -= code[instr.childIndex + j].value;
178          }
179          if (instr.nArguments == 1) s = -s;
180          instr.value = s;
181        } else if (instr.opCode == OpCodes.Mul) {
182          double p = code[instr.childIndex].value;
183          for (int j = 1; j != instr.nArguments; ++j) {
184            p *= code[instr.childIndex + j].value;
185          }
186          instr.value = p;
187        } else if (instr.opCode == OpCodes.Div) {
188          double p = code[instr.childIndex].value;
189          for (int j = 1; j != instr.nArguments; ++j) {
190            p /= code[instr.childIndex + j].value;
191          }
192          if (instr.nArguments == 1) p = 1.0 / p;
193          instr.value = p;
194        } else if (instr.opCode == OpCodes.Average) {
195          double s = code[instr.childIndex].value;
196          for (int j = 1; j != instr.nArguments; ++j) {
197            s += code[instr.childIndex + j].value;
198          }
199          instr.value = s / instr.nArguments;
200        } else if (instr.opCode == OpCodes.Cos) {
201          instr.value = Math.Cos(code[instr.childIndex].value);
202        } else if (instr.opCode == OpCodes.Sin) {
203          instr.value = Math.Sin(code[instr.childIndex].value);
204        } else if (instr.opCode == OpCodes.Tan) {
205          instr.value = Math.Tan(code[instr.childIndex].value);
206        } else if (instr.opCode == OpCodes.Square) {
207          instr.value = Math.Pow(code[instr.childIndex].value, 2);
208        } else if (instr.opCode == OpCodes.Power) {
209          double x = code[instr.childIndex].value;
210          double y = Math.Round(code[instr.childIndex + 1].value);
211          instr.value = Math.Pow(x, y);
212        } else if (instr.opCode == OpCodes.SquareRoot) {
213          instr.value = Math.Sqrt(code[instr.childIndex].value);
214        } else if (instr.opCode == OpCodes.Root) {
215          double x = code[instr.childIndex].value;
[13254]216          double y = Math.Round(code[instr.childIndex + 1].value);
[9871]217          instr.value = Math.Pow(x, 1 / y);
218        } else if (instr.opCode == OpCodes.Exp) {
219          instr.value = Math.Exp(code[instr.childIndex].value);
220        } else if (instr.opCode == OpCodes.Log) {
221          instr.value = Math.Log(code[instr.childIndex].value);
222        } else if (instr.opCode == OpCodes.Gamma) {
223          var x = code[instr.childIndex].value;
224          instr.value = double.IsNaN(x) ? double.NaN : alglib.gammafunction(x);
225        } else if (instr.opCode == OpCodes.Psi) {
226          var x = code[instr.childIndex].value;
227          if (double.IsNaN(x)) instr.value = double.NaN;
228          else if (x <= 0 && (Math.Floor(x) - x).IsAlmost(0)) instr.value = double.NaN;
229          else instr.value = alglib.psi(x);
230        } else if (instr.opCode == OpCodes.Dawson) {
231          var x = code[instr.childIndex].value;
232          instr.value = double.IsNaN(x) ? double.NaN : alglib.dawsonintegral(x);
233        } else if (instr.opCode == OpCodes.ExponentialIntegralEi) {
234          var x = code[instr.childIndex].value;
235          instr.value = double.IsNaN(x) ? double.NaN : alglib.exponentialintegralei(x);
236        } else if (instr.opCode == OpCodes.SineIntegral) {
237          double si, ci;
238          var x = code[instr.childIndex].value;
239          if (double.IsNaN(x)) instr.value = double.NaN;
240          else {
241            alglib.sinecosineintegrals(x, out si, out ci);
242            instr.value = si;
243          }
244        } else if (instr.opCode == OpCodes.CosineIntegral) {
245          double si, ci;
246          var x = code[instr.childIndex].value;
247          if (double.IsNaN(x)) instr.value = double.NaN;
248          else {
249            alglib.sinecosineintegrals(x, out si, out ci);
250            instr.value = ci;
251          }
252        } else if (instr.opCode == OpCodes.HyperbolicSineIntegral) {
253          double shi, chi;
254          var x = code[instr.childIndex].value;
255          if (double.IsNaN(x)) instr.value = double.NaN;
256          else {
257            alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
258            instr.value = shi;
259          }
260        } else if (instr.opCode == OpCodes.HyperbolicCosineIntegral) {
261          double shi, chi;
262          var x = code[instr.childIndex].value;
263          if (double.IsNaN(x)) instr.value = double.NaN;
264          else {
265            alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
266            instr.value = chi;
267          }
268        } else if (instr.opCode == OpCodes.FresnelCosineIntegral) {
269          double c = 0, s = 0;
270          var x = code[instr.childIndex].value;
271          if (double.IsNaN(x)) instr.value = double.NaN;
272          else {
273            alglib.fresnelintegral(x, ref c, ref s);
274            instr.value = c;
275          }
276        } else if (instr.opCode == OpCodes.FresnelSineIntegral) {
277          double c = 0, s = 0;
278          var x = code[instr.childIndex].value;
279          if (double.IsNaN(x)) instr.value = double.NaN;
280          else {
281            alglib.fresnelintegral(x, ref c, ref s);
282            instr.value = s;
283          }
284        } else if (instr.opCode == OpCodes.AiryA) {
285          double ai, aip, bi, bip;
286          var x = code[instr.childIndex].value;
287          if (double.IsNaN(x)) instr.value = double.NaN;
288          else {
289            alglib.airy(x, out ai, out aip, out bi, out bip);
290            instr.value = ai;
291          }
292        } else if (instr.opCode == OpCodes.AiryB) {
293          double ai, aip, bi, bip;
294          var x = code[instr.childIndex].value;
295          if (double.IsNaN(x)) instr.value = double.NaN;
296          else {
297            alglib.airy(x, out ai, out aip, out bi, out bip);
298            instr.value = bi;
299          }
300        } else if (instr.opCode == OpCodes.Norm) {
301          var x = code[instr.childIndex].value;
302          if (double.IsNaN(x)) instr.value = double.NaN;
303          else instr.value = alglib.normaldistribution(x);
304        } else if (instr.opCode == OpCodes.Erf) {
305          var x = code[instr.childIndex].value;
306          if (double.IsNaN(x)) instr.value = double.NaN;
307          else instr.value = alglib.errorfunction(x);
308        } else if (instr.opCode == OpCodes.Bessel) {
309          var x = code[instr.childIndex].value;
310          if (double.IsNaN(x)) instr.value = double.NaN;
311          else instr.value = alglib.besseli0(x);
312        } else if (instr.opCode == OpCodes.IfThenElse) {
313          double condition = code[instr.childIndex].value;
314          double result;
315          if (condition > 0.0) {
316            result = code[instr.childIndex + 1].value;
317          } else {
318            result = code[instr.childIndex + 2].value;
319          }
320          instr.value = result;
321        } else if (instr.opCode == OpCodes.AND) {
322          double result = code[instr.childIndex].value;
323          for (int j = 1; j < instr.nArguments; j++) {
324            if (result > 0.0) result = code[instr.childIndex + j].value;
325            else break;
326          }
327          instr.value = result > 0.0 ? 1.0 : -1.0;
328        } else if (instr.opCode == OpCodes.OR) {
329          double result = code[instr.childIndex].value;
330          for (int j = 1; j < instr.nArguments; j++) {
331            if (result <= 0.0) result = code[instr.childIndex + j].value;
332            else break;
333          }
334          instr.value = result > 0.0 ? 1.0 : -1.0;
335        } else if (instr.opCode == OpCodes.NOT) {
336          instr.value = code[instr.childIndex].value > 0.0 ? -1.0 : 1.0;
[10774]337        } else if (instr.opCode == OpCodes.XOR) {
[10788]338          int positiveSignals = 0;
339          for (int j = 0; j < instr.nArguments; j++) {
340            if (code[instr.childIndex + j].value > 0.0) positiveSignals++;
[10774]341          }
[10788]342          instr.value = positiveSignals % 2 != 0 ? 1.0 : -1.0;
[9871]343        } else if (instr.opCode == OpCodes.GT) {
344          double x = code[instr.childIndex].value;
345          double y = code[instr.childIndex + 1].value;
346          instr.value = x > y ? 1.0 : -1.0;
347        } else if (instr.opCode == OpCodes.LT) {
348          double x = code[instr.childIndex].value;
349          double y = code[instr.childIndex + 1].value;
350          instr.value = x < y ? 1.0 : -1.0;
351        } else if (instr.opCode == OpCodes.TimeLag || instr.opCode == OpCodes.Derivative || instr.opCode == OpCodes.Integral) {
352          var state = (InterpreterState)instr.data;
353          state.Reset();
354          instr.value = interpreter.Evaluate(dataset, ref row, state);
355        } else {
356          var errorText = string.Format("The {0} symbol is not supported by the linear interpreter. To support this symbol, please use the SymbolicDataAnalysisExpressionTreeInterpreter.", instr.dynamicNode.Symbol.Name);
357          throw new NotSupportedException(errorText);
[9271]358        }
[9739]359        #endregion
[5571]360      }
[9739]361      return code[0].value;
[5571]362    }
[9815]363
364    private static LinearInstruction[] GetPrefixSequence(LinearInstruction[] code, int startIndex) {
[9944]365      var s = new Stack<int>();
[9815]366      var list = new List<LinearInstruction>();
[9944]367      s.Push(startIndex);
368      while (s.Any()) {
369        int i = s.Pop();
[9815]370        var instr = code[i];
[9944]371        // push instructions in reverse execution order
372        for (int j = instr.nArguments - 1; j >= 0; j--) s.Push(instr.childIndex + j);
[9815]373        list.Add(instr);
374      }
375      return list.ToArray();
376    }
377
[12509]378    public static void PrepareInstructions(LinearInstruction[] code, IDataset dataset) {
[9815]379      for (int i = 0; i != code.Length; ++i) {
380        var instr = code[i];
381        #region opcode switch
382        switch (instr.opCode) {
383          case OpCodes.Constant: {
384              var constTreeNode = (ConstantTreeNode)instr.dynamicNode;
385              instr.value = constTreeNode.Value;
386              instr.skip = true; // the value is already set so this instruction should be skipped in the evaluation phase
387            }
388            break;
389          case OpCodes.Variable: {
390              var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
[9826]391              instr.data = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
[9815]392            }
393            break;
394          case OpCodes.LagVariable: {
395              var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
[9826]396              instr.data = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
[9815]397            }
398            break;
399          case OpCodes.VariableCondition: {
400              var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
[9826]401              instr.data = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
[9815]402            }
403            break;
404          case OpCodes.TimeLag:
405          case OpCodes.Integral:
406          case OpCodes.Derivative: {
407              var seq = GetPrefixSequence(code, i);
408              var interpreterState = new InterpreterState(seq, 0);
[9826]409              instr.data = interpreterState;
[9815]410              for (int j = 1; j != seq.Length; ++j)
411                seq[j].skip = true;
412            }
413            break;
414        }
415        #endregion
416      }
417    }
[5571]418  }
419}
Note: See TracBrowser for help on using the repository browser.