Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.DataAnalysis.Symbolic.LinearInterpreter/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeLinearInterpreter.cs @ 9734

Last change on this file since 9734 was 9734, checked in by bburlacu, 11 years ago

#2021: Updated license year, fixed interpreter name (SymbolicDataAnalysisExpressionTreeLinearInterpreter) and updated description. Replaced tabs with spaces in Instruction.cs.

File size: 18.4 KB
RevLine 
[5571]1#region License Information
2/* HeuristicLab
[9734]3 * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[5571]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
[6740]26using HeuristicLab.Data;
[5571]27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
[6740]28using HeuristicLab.Parameters;
[5571]29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30
31namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
32  [StorableClass]
[9734]33  [Item("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Linear (non-recursive) interpreter for symbolic expression trees (does not support ADFs).")]
34  public class SymbolicDataAnalysisExpressionTreeLinearInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
[5749]35    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
[7615]36    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
[5571]37
[9732]38    public override bool CanChangeName {
39      get { return false; }
40    }
[5571]41
[9732]42    public override bool CanChangeDescription {
43      get { return false; }
44    }
45
[5749]46    #region parameter properties
[9732]47
[5749]48    public IValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
49      get { return (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
50    }
[7615]51
52    public IValueParameter<IntValue> EvaluatedSolutionsParameter {
53      get { return (IValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
54    }
[9732]55
[5749]56    #endregion
57
58    #region properties
[9732]59
[5749]60    public BoolValue CheckExpressionsWithIntervalArithmetic {
61      get { return CheckExpressionsWithIntervalArithmeticParameter.Value; }
62      set { CheckExpressionsWithIntervalArithmeticParameter.Value = value; }
63    }
[7615]64
65    public IntValue EvaluatedSolutions {
66      get { return EvaluatedSolutionsParameter.Value; }
67      set { EvaluatedSolutionsParameter.Value = value; }
68    }
[9732]69
[5749]70    #endregion
71
[5571]72    [StorableConstructor]
[9734]73    protected SymbolicDataAnalysisExpressionTreeLinearInterpreter(bool deserializing)
[9732]74      : base(deserializing) {
75    }
76
[9734]77    protected SymbolicDataAnalysisExpressionTreeLinearInterpreter(
78      SymbolicDataAnalysisExpressionTreeLinearInterpreter original, Cloner cloner)
[9732]79      : base(original, cloner) {
80    }
81
[5571]82    public override IDeepCloneable Clone(Cloner cloner) {
[9734]83      return new SymbolicDataAnalysisExpressionTreeLinearInterpreter(this, cloner);
[5571]84    }
85
[9734]86    public SymbolicDataAnalysisExpressionTreeLinearInterpreter()
87      : base("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.") {
88      Parameters.Add(new ValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
89      Parameters.Add(new ValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
[5571]90    }
91
[9734]92    protected SymbolicDataAnalysisExpressionTreeLinearInterpreter(string name, string description)
[8436]93      : base(name, description) {
[9734]94      Parameters.Add(new ValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
95      Parameters.Add(new ValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
[8436]96    }
97
[7615]98    [StorableHook(HookType.AfterDeserialization)]
99    private void AfterDeserialization() {
100      if (!Parameters.ContainsKey(EvaluatedSolutionsParameterName))
[9734]101        Parameters.Add(new ValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
[7615]102    }
103
104    #region IStatefulItem
[9732]105
[7615]106    public void InitializeState() {
107      EvaluatedSolutions.Value = 0;
108    }
109
110    public void ClearState() {
111    }
[9732]112
[7615]113    #endregion
114
[9734]115    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows) {
[8436]116      if (CheckExpressionsWithIntervalArithmetic.Value)
[9734]117        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
[7120]118
[9004]119      lock (EvaluatedSolutions) {
120        EvaluatedSolutions.Value++; // increment the evaluated solutions counter
121      }
[8436]122
[9732]123      var root = tree.Root.GetSubtree(0).GetSubtree(0);
124      var nodes = new List<ISymbolicExpressionTreeNode> { root };
125      var code = new List<Instruction>{
126        new Instruction { dynamicNode = root,
127        nArguments = (byte) root.SubtreeCount,
128        opCode = OpCodes.MapSymbolToOpCode(root)
129        }
130      };
131
132      // iterate breadth-wise over tree nodes and produce an array of instructions
133      int i = 0;
134      while (i != nodes.Count) {
135        if (nodes[i].SubtreeCount > 0) {
136          // save index of the first child in the instructions array
137          code[i].childIndex = code.Count;
138          for (int j = 0; j != nodes[i].SubtreeCount; ++j) {
139            var s = nodes[i].GetSubtree(j);
140            nodes.Add(s);
141            code.Add(new Instruction {
142              dynamicNode = s,
143              nArguments = (byte)s.SubtreeCount,
144              opCode = OpCodes.MapSymbolToOpCode(s)
145            });
146          }
147        }
148        ++i;
149      }
150      // fill in iArg0 value for terminal nodes
151      foreach (var instr in code) {
[9271]152        switch (instr.opCode) {
153          case OpCodes.Variable: {
154              var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
155              instr.iArg0 = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
156            }
157            break;
158          case OpCodes.LagVariable: {
159              var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
160              instr.iArg0 = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
161            }
162            break;
163          case OpCodes.VariableCondition: {
164              var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
165              instr.iArg0 = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
166            }
167            break;
168        }
169      }
170
[9732]171      var array = code.ToArray();
172
[8436]173      foreach (var rowEnum in rows) {
174        int row = rowEnum;
[9732]175        EvaluateFast(dataset, ref row, array);
176        yield return code[0].value;
[8436]177      }
[7154]178    }
179
[9732]180    private void EvaluateFast(Dataset dataset, ref int row, Instruction[] code) {
181      for (int i = code.Length - 1; i >= 0; --i) {
182        var instr = code[i];
[5571]183
[9732]184        switch (instr.opCode) {
[9271]185          case OpCodes.Add: {
[9732]186              double s = code[instr.childIndex].value;
187              for (int j = 1; j != instr.nArguments; ++j) {
188                s += code[instr.childIndex + j].value;
[9271]189              }
[9732]190              instr.value = s;
[5571]191            }
[9271]192            break;
193          case OpCodes.Sub: {
[9732]194              double s = code[instr.childIndex].value;
195              for (int j = 1; j != instr.nArguments; ++j) {
196                s -= code[instr.childIndex + j].value;
[9271]197              }
[9732]198              if (instr.nArguments == 1) s = -s;
199              instr.value = s;
[5571]200            }
[9271]201            break;
202          case OpCodes.Mul: {
[9732]203              double p = code[instr.childIndex].value;
204              for (int j = 1; j != instr.nArguments; ++j) {
205                p *= code[instr.childIndex + j].value;
[9271]206              }
[9732]207              instr.value = p;
[5571]208            }
[9271]209            break;
210          case OpCodes.Div: {
[9732]211              double p = code[instr.childIndex].value;
212              for (int j = 1; j != instr.nArguments; ++j) {
213                p /= code[instr.childIndex + j].value;
[9271]214              }
[9732]215              if (instr.nArguments == 1) p = 1.0 / p;
216              instr.value = p;
[5571]217            }
[9271]218            break;
219          case OpCodes.Average: {
[9732]220              double s = code[instr.childIndex].value;
221              for (int j = 1; j != instr.nArguments; ++j) {
222                s += code[instr.childIndex + j].value;
[9271]223              }
[9732]224              instr.value = s / instr.nArguments;
[5571]225            }
[9271]226            break;
227          case OpCodes.Cos: {
[9732]228              instr.value = Math.Cos(code[instr.childIndex].value);
[7842]229            }
[9271]230            break;
231          case OpCodes.Sin: {
[9732]232              instr.value = Math.Sin(code[instr.childIndex].value);
[7842]233            }
[9732]234            break;
[9271]235          case OpCodes.Tan: {
[9732]236              instr.value = Math.Tan(code[instr.childIndex].value);
[7842]237            }
[9271]238            break;
239          case OpCodes.Root: {
[9732]240              double x = code[instr.childIndex].value;
241              double y = code[instr.childIndex + 1].value;
242              instr.value = Math.Pow(x, 1 / y);
[7842]243            }
[9271]244            break;
245          case OpCodes.Exp: {
[9732]246              instr.value = Math.Exp(code[instr.childIndex].value);
[7842]247            }
[9271]248            break;
249          case OpCodes.Log: {
[9732]250              instr.value = Math.Log(code[instr.childIndex].value);
[5571]251            }
[9271]252            break;
253          case OpCodes.Gamma: {
[9732]254              var x = code[instr.childIndex].value;
255              instr.value = double.IsNaN(x) ? double.NaN : alglib.gammafunction(x);
[9271]256            }
257            break;
258          case OpCodes.Psi: {
[9732]259              var x = code[instr.childIndex].value;
260              if (double.IsNaN(x)) instr.value = double.NaN;
261              else if (x <= 0 && (Math.Floor(x) - x).IsAlmost(0)) instr.value = double.NaN;
262              else instr.value = alglib.psi(x);
[9271]263            }
264            break;
265          case OpCodes.Dawson: {
[9732]266              var x = code[instr.childIndex].value;
267              instr.value = double.IsNaN(x) ? double.NaN : alglib.dawsonintegral(x);
[9271]268            }
269            break;
270          case OpCodes.ExponentialIntegralEi: {
[9732]271              var x = code[instr.childIndex].value;
272              instr.value = double.IsNaN(x) ? double.NaN : alglib.exponentialintegralei(x);
[9271]273            }
274            break;
275          case OpCodes.SineIntegral: {
276              double si, ci;
[9732]277              var x = code[instr.childIndex].value;
278              if (double.IsNaN(x)) instr.value = double.NaN;
[5571]279              else {
[9271]280                alglib.sinecosineintegrals(x, out si, out ci);
[9732]281                instr.value = si;
[5571]282              }
283            }
[9271]284            break;
285          case OpCodes.CosineIntegral: {
286              double si, ci;
[9732]287              var x = code[instr.childIndex].value;
288              if (double.IsNaN(x)) instr.value = double.NaN;
[5571]289              else {
[9271]290                alglib.sinecosineintegrals(x, out si, out ci);
[9732]291                instr.value = si;
[5571]292              }
293            }
[9271]294            break;
295          case OpCodes.HyperbolicSineIntegral: {
296              double shi, chi;
[9732]297              var x = code[instr.childIndex].value;
298              if (double.IsNaN(x)) instr.value = double.NaN;
[9271]299              else {
300                alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
[9732]301                instr.value = shi;
[9271]302              }
[5571]303            }
[9271]304            break;
305          case OpCodes.HyperbolicCosineIntegral: {
306              double shi, chi;
[9732]307              var x = code[instr.childIndex].value;
308              if (double.IsNaN(x)) instr.value = double.NaN;
[9271]309              else {
310                alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
[9732]311                instr.value = chi;
[9271]312              }
313            }
314            break;
315          case OpCodes.FresnelCosineIntegral: {
316              double c = 0, s = 0;
[9732]317              var x = code[instr.childIndex].value;
318              if (double.IsNaN(x)) instr.value = double.NaN;
[9271]319              else {
320                alglib.fresnelintegral(x, ref c, ref s);
[9732]321                instr.value = c;
[9271]322              }
323            }
324            break;
325          case OpCodes.FresnelSineIntegral: {
326              double c = 0, s = 0;
[9732]327              var x = code[instr.childIndex].value;
328              if (double.IsNaN(x)) instr.value = double.NaN;
[9271]329              else {
330                alglib.fresnelintegral(x, ref c, ref s);
[9732]331                instr.value = s;
[9271]332              }
333            }
334            break;
335          case OpCodes.AiryA: {
336              double ai, aip, bi, bip;
[9732]337              var x = code[instr.childIndex].value;
338              if (double.IsNaN(x)) instr.value = double.NaN;
[9271]339              else {
340                alglib.airy(x, out ai, out aip, out bi, out bip);
[9732]341                instr.value = ai;
[9271]342              }
343            }
344            break;
345          case OpCodes.AiryB: {
346              double ai, aip, bi, bip;
[9732]347              var x = code[instr.childIndex].value;
348              if (double.IsNaN(x)) instr.value = double.NaN;
[9271]349              else {
350                alglib.airy(x, out ai, out aip, out bi, out bip);
[9732]351                instr.value = bi;
[9271]352              }
353            }
354            break;
355          case OpCodes.Norm: {
[9732]356              var x = code[instr.childIndex].value;
357              if (double.IsNaN(x)) instr.value = double.NaN;
358              else instr.value = alglib.normaldistribution(x);
[9271]359            }
360            break;
361          case OpCodes.Erf: {
[9732]362              var x = code[instr.childIndex].value;
363              if (double.IsNaN(x)) instr.value = double.NaN;
364              else instr.value = alglib.errorfunction(x);
[9271]365            }
366            break;
367          case OpCodes.Bessel: {
[9732]368              var x = code[instr.childIndex].value;
369              if (double.IsNaN(x)) instr.value = double.NaN;
370              else instr.value = alglib.besseli0(x);
[9271]371            }
372            break;
373          case OpCodes.IfThenElse: {
[9732]374              double condition = code[instr.childIndex].value;
375              double result;
376              if (condition > 0.0) {
377                result = code[instr.childIndex + 1].value;
378              } else {
379                result = code[instr.childIndex + 2].value;
380              }
381              instr.value = result;
[9271]382            }
383            break;
384          case OpCodes.AND: {
[9732]385              double result = code[instr.childIndex].value;
386              for (int j = 1; j < instr.nArguments; j++) {
387                if (result > 0.0) result = code[instr.childIndex + j].value;
388                else break;
[9271]389              }
[9732]390              instr.value = result > 0.0 ? 1.0 : -1.0;
[9271]391            }
392            break;
393          case OpCodes.OR: {
[9732]394              double result = code[instr.childIndex].value;
395              for (int j = 1; j < instr.nArguments; j++) {
396                if (result <= 0.0) result = code[instr.childIndex + j].value;
397                else break;
[9271]398              }
[9732]399              instr.value = result > 0.0 ? 1.0 : -1.0;
[9271]400            }
401            break;
402          case OpCodes.NOT: {
[9732]403              instr.value = code[instr.childIndex].value > 0.0 ? -1.0 : 1.0;
[9271]404            }
405            break;
406          case OpCodes.GT: {
[9732]407              double x = code[instr.childIndex].value;
408              double y = code[instr.childIndex + 1].value;
409              instr.value = x > y ? 1.0 : -1.0;
[9271]410            }
411            break;
412          case OpCodes.LT: {
[9732]413              double x = code[instr.childIndex].value;
414              double y = code[instr.childIndex + 1].value;
415              instr.value = x < y ? 1.0 : -1.0;
[9271]416            }
417            break;
418          case OpCodes.TimeLag: {
419              throw new NotSupportedException();
420            }
421          case OpCodes.Integral: {
422              throw new NotSupportedException();
423            }
424          case OpCodes.Derivative: {
425              throw new NotSupportedException();
[5571]426            }
[9271]427          case OpCodes.Arg: {
428              throw new NotSupportedException();
429            }
430          case OpCodes.Variable: {
[9732]431              if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
432              var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
433              instr.value = ((IList<double>)instr.iArg0)[row] * variableTreeNode.Weight;
[9271]434            }
435            break;
436          case OpCodes.LagVariable: {
[9732]437              var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
[9271]438              int actualRow = row + laggedVariableTreeNode.Lag;
[9732]439              if (actualRow < 0 || actualRow >= dataset.Rows) instr.value = double.NaN;
440              instr.value = ((IList<double>)instr.iArg0)[actualRow] * laggedVariableTreeNode.Weight;
[9271]441            }
442            break;
443          case OpCodes.Constant: {
[9732]444              var constTreeNode = (ConstantTreeNode)instr.dynamicNode;
445              instr.value = constTreeNode.Value;
[9271]446            }
447            break;
448          case OpCodes.VariableCondition: {
[9732]449              if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
450              var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
451              double variableValue = ((IList<double>)instr.iArg0)[row];
452              double x = variableValue - variableConditionTreeNode.Threshold;
453              double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
[5571]454
[9732]455              double trueBranch = code[instr.childIndex].value;
456              double falseBranch = code[instr.childIndex + 1].value;
[5571]457
[9732]458              instr.value = trueBranch * p + falseBranch * (1 - p);
[9271]459            }
460            break;
461          default:
462            throw new NotSupportedException();
463        }
[5571]464      }
465    }
466  }
467}
Note: See TracBrowser for help on using the repository browser.