Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeLinearInterpreter.cs @ 10538

Last change on this file since 10538 was 9944, checked in by gkronber, 11 years ago

#2021 made fix in the way how the breath-first linear representation of the tree is translated to the prefix representation.

File size: 17.6 KB
RevLine 
[5571]1#region License Information
2/* HeuristicLab
[9734]3 * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[5571]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
[9739]24using System.Linq;
[5571]25using HeuristicLab.Common;
26using HeuristicLab.Core;
[6740]27using HeuristicLab.Data;
[5571]28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
[6740]29using HeuristicLab.Parameters;
[5571]30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
33  [StorableClass]
[9815]34  [Item("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Fast linear (non-recursive) interpreter for symbolic expression trees. Does not support ADFs.")]
[9758]35  public sealed class SymbolicDataAnalysisExpressionTreeLinearInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
[5749]36    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
[7615]37    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
[5571]38
[9776]39    private SymbolicDataAnalysisExpressionTreeInterpreter interpreter;
40
[9732]41    public override bool CanChangeName {
42      get { return false; }
43    }
[5571]44
[9732]45    public override bool CanChangeDescription {
46      get { return false; }
47    }
48
[5749]49    #region parameter properties
50    public IValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
51      get { return (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
52    }
[7615]53
54    public IValueParameter<IntValue> EvaluatedSolutionsParameter {
55      get { return (IValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
56    }
[5749]57    #endregion
58
59    #region properties
60    public BoolValue CheckExpressionsWithIntervalArithmetic {
61      get { return CheckExpressionsWithIntervalArithmeticParameter.Value; }
62      set { CheckExpressionsWithIntervalArithmeticParameter.Value = value; }
63    }
[7615]64    public IntValue EvaluatedSolutions {
65      get { return EvaluatedSolutionsParameter.Value; }
66      set { EvaluatedSolutionsParameter.Value = value; }
67    }
[5749]68    #endregion
69
[5571]70    [StorableConstructor]
[9758]71    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(bool deserializing)
[9732]72      : base(deserializing) {
73    }
74
[9828]75    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(SymbolicDataAnalysisExpressionTreeLinearInterpreter original, Cloner cloner)
[9732]76      : base(original, cloner) {
[9828]77      interpreter = cloner.Clone(original.interpreter);
[9732]78    }
79
[5571]80    public override IDeepCloneable Clone(Cloner cloner) {
[9734]81      return new SymbolicDataAnalysisExpressionTreeLinearInterpreter(this, cloner);
[5571]82    }
83
[9734]84    public SymbolicDataAnalysisExpressionTreeLinearInterpreter()
[9758]85      : base("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Linear (non-recursive) interpreter for symbolic expression trees (does not support ADFs).") {
[9734]86      Parameters.Add(new ValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
87      Parameters.Add(new ValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
[9776]88      interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
[5571]89    }
90
[7615]91    [StorableHook(HookType.AfterDeserialization)]
92    private void AfterDeserialization() {
[9776]93      if (interpreter == null) interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
[7615]94    }
95
96    #region IStatefulItem
97    public void InitializeState() {
98      EvaluatedSolutions.Value = 0;
99    }
100
[9828]101    public void ClearState() { }
[7615]102    #endregion
103
[9734]104    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows) {
[8436]105      if (CheckExpressionsWithIntervalArithmetic.Value)
[9734]106        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
[7120]107
[9004]108      lock (EvaluatedSolutions) {
109        EvaluatedSolutions.Value++; // increment the evaluated solutions counter
110      }
[8436]111
[9739]112      var code = SymbolicExpressionTreeLinearCompiler.Compile(tree, OpCodes.MapSymbolToOpCode);
[9758]113      PrepareInstructions(code, dataset);
[9818]114      return rows.Select(row => Evaluate(dataset, row, code));
[9739]115    }
[9732]116
[9818]117    private double Evaluate(Dataset dataset, int row, LinearInstruction[] code) {
[9732]118      for (int i = code.Length - 1; i >= 0; --i) {
[9776]119        if (code[i].skip) continue;
[9871]120        #region opcode if
[9732]121        var instr = code[i];
[9871]122        if (instr.opCode == OpCodes.Variable) {
123          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
124          var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
125          instr.value = ((IList<double>)instr.data)[row] * variableTreeNode.Weight;
126        } else if (instr.opCode == OpCodes.LagVariable) {
127          var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
128          int actualRow = row + laggedVariableTreeNode.Lag;
129          if (actualRow < 0 || actualRow >= dataset.Rows)
130            instr.value = double.NaN;
131          else
132            instr.value = ((IList<double>)instr.data)[actualRow] * laggedVariableTreeNode.Weight;
133        } else if (instr.opCode == OpCodes.VariableCondition) {
134          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
135          var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
136          double variableValue = ((IList<double>)instr.data)[row];
137          double x = variableValue - variableConditionTreeNode.Threshold;
138          double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
[9738]139
[9871]140          double trueBranch = code[instr.childIndex].value;
141          double falseBranch = code[instr.childIndex + 1].value;
[9738]142
[9871]143          instr.value = trueBranch * p + falseBranch * (1 - p);
144        } else if (instr.opCode == OpCodes.Add) {
145          double s = code[instr.childIndex].value;
146          for (int j = 1; j != instr.nArguments; ++j) {
147            s += code[instr.childIndex + j].value;
148          }
149          instr.value = s;
150        } else if (instr.opCode == OpCodes.Sub) {
151          double s = code[instr.childIndex].value;
152          for (int j = 1; j != instr.nArguments; ++j) {
153            s -= code[instr.childIndex + j].value;
154          }
155          if (instr.nArguments == 1) s = -s;
156          instr.value = s;
157        } else if (instr.opCode == OpCodes.Mul) {
158          double p = code[instr.childIndex].value;
159          for (int j = 1; j != instr.nArguments; ++j) {
160            p *= code[instr.childIndex + j].value;
161          }
162          instr.value = p;
163        } else if (instr.opCode == OpCodes.Div) {
164          double p = code[instr.childIndex].value;
165          for (int j = 1; j != instr.nArguments; ++j) {
166            p /= code[instr.childIndex + j].value;
167          }
168          if (instr.nArguments == 1) p = 1.0 / p;
169          instr.value = p;
170        } else if (instr.opCode == OpCodes.Average) {
171          double s = code[instr.childIndex].value;
172          for (int j = 1; j != instr.nArguments; ++j) {
173            s += code[instr.childIndex + j].value;
174          }
175          instr.value = s / instr.nArguments;
176        } else if (instr.opCode == OpCodes.Cos) {
177          instr.value = Math.Cos(code[instr.childIndex].value);
178        } else if (instr.opCode == OpCodes.Sin) {
179          instr.value = Math.Sin(code[instr.childIndex].value);
180        } else if (instr.opCode == OpCodes.Tan) {
181          instr.value = Math.Tan(code[instr.childIndex].value);
182        } else if (instr.opCode == OpCodes.Square) {
183          instr.value = Math.Pow(code[instr.childIndex].value, 2);
184        } else if (instr.opCode == OpCodes.Power) {
185          double x = code[instr.childIndex].value;
186          double y = Math.Round(code[instr.childIndex + 1].value);
187          instr.value = Math.Pow(x, y);
188        } else if (instr.opCode == OpCodes.SquareRoot) {
189          instr.value = Math.Sqrt(code[instr.childIndex].value);
190        } else if (instr.opCode == OpCodes.Root) {
191          double x = code[instr.childIndex].value;
192          double y = code[instr.childIndex + 1].value;
193          instr.value = Math.Pow(x, 1 / y);
194        } else if (instr.opCode == OpCodes.Exp) {
195          instr.value = Math.Exp(code[instr.childIndex].value);
196        } else if (instr.opCode == OpCodes.Log) {
197          instr.value = Math.Log(code[instr.childIndex].value);
198        } else if (instr.opCode == OpCodes.Gamma) {
199          var x = code[instr.childIndex].value;
200          instr.value = double.IsNaN(x) ? double.NaN : alglib.gammafunction(x);
201        } else if (instr.opCode == OpCodes.Psi) {
202          var x = code[instr.childIndex].value;
203          if (double.IsNaN(x)) instr.value = double.NaN;
204          else if (x <= 0 && (Math.Floor(x) - x).IsAlmost(0)) instr.value = double.NaN;
205          else instr.value = alglib.psi(x);
206        } else if (instr.opCode == OpCodes.Dawson) {
207          var x = code[instr.childIndex].value;
208          instr.value = double.IsNaN(x) ? double.NaN : alglib.dawsonintegral(x);
209        } else if (instr.opCode == OpCodes.ExponentialIntegralEi) {
210          var x = code[instr.childIndex].value;
211          instr.value = double.IsNaN(x) ? double.NaN : alglib.exponentialintegralei(x);
212        } else if (instr.opCode == OpCodes.SineIntegral) {
213          double si, ci;
214          var x = code[instr.childIndex].value;
215          if (double.IsNaN(x)) instr.value = double.NaN;
216          else {
217            alglib.sinecosineintegrals(x, out si, out ci);
218            instr.value = si;
219          }
220        } else if (instr.opCode == OpCodes.CosineIntegral) {
221          double si, ci;
222          var x = code[instr.childIndex].value;
223          if (double.IsNaN(x)) instr.value = double.NaN;
224          else {
225            alglib.sinecosineintegrals(x, out si, out ci);
226            instr.value = ci;
227          }
228        } else if (instr.opCode == OpCodes.HyperbolicSineIntegral) {
229          double shi, chi;
230          var x = code[instr.childIndex].value;
231          if (double.IsNaN(x)) instr.value = double.NaN;
232          else {
233            alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
234            instr.value = shi;
235          }
236        } else if (instr.opCode == OpCodes.HyperbolicCosineIntegral) {
237          double shi, chi;
238          var x = code[instr.childIndex].value;
239          if (double.IsNaN(x)) instr.value = double.NaN;
240          else {
241            alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
242            instr.value = chi;
243          }
244        } else if (instr.opCode == OpCodes.FresnelCosineIntegral) {
245          double c = 0, s = 0;
246          var x = code[instr.childIndex].value;
247          if (double.IsNaN(x)) instr.value = double.NaN;
248          else {
249            alglib.fresnelintegral(x, ref c, ref s);
250            instr.value = c;
251          }
252        } else if (instr.opCode == OpCodes.FresnelSineIntegral) {
253          double c = 0, s = 0;
254          var x = code[instr.childIndex].value;
255          if (double.IsNaN(x)) instr.value = double.NaN;
256          else {
257            alglib.fresnelintegral(x, ref c, ref s);
258            instr.value = s;
259          }
260        } else if (instr.opCode == OpCodes.AiryA) {
261          double ai, aip, bi, bip;
262          var x = code[instr.childIndex].value;
263          if (double.IsNaN(x)) instr.value = double.NaN;
264          else {
265            alglib.airy(x, out ai, out aip, out bi, out bip);
266            instr.value = ai;
267          }
268        } else if (instr.opCode == OpCodes.AiryB) {
269          double ai, aip, bi, bip;
270          var x = code[instr.childIndex].value;
271          if (double.IsNaN(x)) instr.value = double.NaN;
272          else {
273            alglib.airy(x, out ai, out aip, out bi, out bip);
274            instr.value = bi;
275          }
276        } else if (instr.opCode == OpCodes.Norm) {
277          var x = code[instr.childIndex].value;
278          if (double.IsNaN(x)) instr.value = double.NaN;
279          else instr.value = alglib.normaldistribution(x);
280        } else if (instr.opCode == OpCodes.Erf) {
281          var x = code[instr.childIndex].value;
282          if (double.IsNaN(x)) instr.value = double.NaN;
283          else instr.value = alglib.errorfunction(x);
284        } else if (instr.opCode == OpCodes.Bessel) {
285          var x = code[instr.childIndex].value;
286          if (double.IsNaN(x)) instr.value = double.NaN;
287          else instr.value = alglib.besseli0(x);
288        } else if (instr.opCode == OpCodes.IfThenElse) {
289          double condition = code[instr.childIndex].value;
290          double result;
291          if (condition > 0.0) {
292            result = code[instr.childIndex + 1].value;
293          } else {
294            result = code[instr.childIndex + 2].value;
295          }
296          instr.value = result;
297        } else if (instr.opCode == OpCodes.AND) {
298          double result = code[instr.childIndex].value;
299          for (int j = 1; j < instr.nArguments; j++) {
300            if (result > 0.0) result = code[instr.childIndex + j].value;
301            else break;
302          }
303          instr.value = result > 0.0 ? 1.0 : -1.0;
304        } else if (instr.opCode == OpCodes.OR) {
305          double result = code[instr.childIndex].value;
306          for (int j = 1; j < instr.nArguments; j++) {
307            if (result <= 0.0) result = code[instr.childIndex + j].value;
308            else break;
309          }
310          instr.value = result > 0.0 ? 1.0 : -1.0;
311        } else if (instr.opCode == OpCodes.NOT) {
312          instr.value = code[instr.childIndex].value > 0.0 ? -1.0 : 1.0;
313        } else if (instr.opCode == OpCodes.GT) {
314          double x = code[instr.childIndex].value;
315          double y = code[instr.childIndex + 1].value;
316          instr.value = x > y ? 1.0 : -1.0;
317        } else if (instr.opCode == OpCodes.LT) {
318          double x = code[instr.childIndex].value;
319          double y = code[instr.childIndex + 1].value;
320          instr.value = x < y ? 1.0 : -1.0;
321        } else if (instr.opCode == OpCodes.TimeLag || instr.opCode == OpCodes.Derivative || instr.opCode == OpCodes.Integral) {
322          var state = (InterpreterState)instr.data;
323          state.Reset();
324          instr.value = interpreter.Evaluate(dataset, ref row, state);
325        } else {
326          var errorText = string.Format("The {0} symbol is not supported by the linear interpreter. To support this symbol, please use the SymbolicDataAnalysisExpressionTreeInterpreter.", instr.dynamicNode.Symbol.Name);
327          throw new NotSupportedException(errorText);
[9271]328        }
[9739]329        #endregion
[5571]330      }
[9739]331      return code[0].value;
[5571]332    }
[9815]333
334    private static LinearInstruction[] GetPrefixSequence(LinearInstruction[] code, int startIndex) {
[9944]335      var s = new Stack<int>();
[9815]336      var list = new List<LinearInstruction>();
[9944]337      s.Push(startIndex);
338      while (s.Any()) {
339        int i = s.Pop();
[9815]340        var instr = code[i];
[9944]341        // push instructions in reverse execution order
342        for (int j = instr.nArguments - 1; j >= 0; j--) s.Push(instr.childIndex + j);
[9815]343        list.Add(instr);
344      }
345      return list.ToArray();
346    }
347
348    private static void PrepareInstructions(LinearInstruction[] code, Dataset dataset) {
349      for (int i = 0; i != code.Length; ++i) {
350        var instr = code[i];
351        #region opcode switch
352        switch (instr.opCode) {
353          case OpCodes.Constant: {
354              var constTreeNode = (ConstantTreeNode)instr.dynamicNode;
355              instr.value = constTreeNode.Value;
356              instr.skip = true; // the value is already set so this instruction should be skipped in the evaluation phase
357            }
358            break;
359          case OpCodes.Variable: {
360              var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
[9826]361              instr.data = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
[9815]362            }
363            break;
364          case OpCodes.LagVariable: {
365              var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
[9826]366              instr.data = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
[9815]367            }
368            break;
369          case OpCodes.VariableCondition: {
370              var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
[9826]371              instr.data = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
[9815]372            }
373            break;
374          case OpCodes.TimeLag:
375          case OpCodes.Integral:
376          case OpCodes.Derivative: {
377              var seq = GetPrefixSequence(code, i);
378              var interpreterState = new InterpreterState(seq, 0);
[9826]379              instr.data = interpreterState;
[9815]380              for (int j = 1; j != seq.Length; ++j)
381                seq[j].skip = true;
382            }
383            break;
384        }
385        #endregion
386      }
387    }
[5571]388  }
389}
Note: See TracBrowser for help on using the repository browser.