Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeInterpreter.cs @ 14634

Last change on this file since 14634 was 14345, checked in by gkronber, 8 years ago

#2690: implemented methods to generate symbolic expression tree solutions for decision tree models (random forest and gradient boosted) as well as views which make it possible to inspect each of the individual trees in a GBT and RF solution

File size: 22.0 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Data;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Parameters;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30
31namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
32  [StorableClass]
33  [Item("SymbolicDataAnalysisExpressionTreeInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.")]
34  public class SymbolicDataAnalysisExpressionTreeInterpreter : ParameterizedNamedItem,
35    ISymbolicDataAnalysisExpressionTreeInterpreter {
36    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
37    private const string CheckExpressionsWithIntervalArithmeticParameterDescription = "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.";
38    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
39
40    public override bool CanChangeName {
41      get { return false; }
42    }
43
44    public override bool CanChangeDescription {
45      get { return false; }
46    }
47
48    #region parameter properties
49    public IFixedValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
50      get { return (IFixedValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
51    }
52
53    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
54      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
55    }
56    #endregion
57
58    #region properties
59    public bool CheckExpressionsWithIntervalArithmetic {
60      get { return CheckExpressionsWithIntervalArithmeticParameter.Value.Value; }
61      set { CheckExpressionsWithIntervalArithmeticParameter.Value.Value = value; }
62    }
63
64    public int EvaluatedSolutions {
65      get { return EvaluatedSolutionsParameter.Value.Value; }
66      set { EvaluatedSolutionsParameter.Value.Value = value; }
67    }
68    #endregion
69
70    [StorableConstructor]
71    protected SymbolicDataAnalysisExpressionTreeInterpreter(bool deserializing) : base(deserializing) { }
72
73    protected SymbolicDataAnalysisExpressionTreeInterpreter(SymbolicDataAnalysisExpressionTreeInterpreter original,
74      Cloner cloner)
75      : base(original, cloner) { }
76
77    public override IDeepCloneable Clone(Cloner cloner) {
78      return new SymbolicDataAnalysisExpressionTreeInterpreter(this, cloner);
79    }
80
81    public SymbolicDataAnalysisExpressionTreeInterpreter()
82      : base("SymbolicDataAnalysisExpressionTreeInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.") {
83      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
84      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
85    }
86
87    protected SymbolicDataAnalysisExpressionTreeInterpreter(string name, string description)
88      : base(name, description) {
89      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
90      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
91    }
92
93    [StorableHook(HookType.AfterDeserialization)]
94    private void AfterDeserialization() {
95      var evaluatedSolutions = new IntValue(0);
96      var checkExpressionsWithIntervalArithmetic = new BoolValue(false);
97      if (Parameters.ContainsKey(EvaluatedSolutionsParameterName)) {
98        var evaluatedSolutionsParameter = (IValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName];
99        evaluatedSolutions = evaluatedSolutionsParameter.Value;
100        Parameters.Remove(EvaluatedSolutionsParameterName);
101      }
102      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", evaluatedSolutions));
103      if (Parameters.ContainsKey(CheckExpressionsWithIntervalArithmeticParameterName)) {
104        var checkExpressionsWithIntervalArithmeticParameter = (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName];
105        Parameters.Remove(CheckExpressionsWithIntervalArithmeticParameterName);
106        checkExpressionsWithIntervalArithmetic = checkExpressionsWithIntervalArithmeticParameter.Value;
107      }
108      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, CheckExpressionsWithIntervalArithmeticParameterDescription, checkExpressionsWithIntervalArithmetic));
109    }
110
111    #region IStatefulItem
112    public void InitializeState() {
113      EvaluatedSolutions = 0;
114    }
115
116    public void ClearState() { }
117    #endregion
118
119    private readonly object syncRoot = new object();
120    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset,
121      IEnumerable<int> rows) {
122      if (CheckExpressionsWithIntervalArithmetic) {
123        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
124      }
125
126      lock (syncRoot) {
127        EvaluatedSolutions++; // increment the evaluated solutions counter
128      }
129      var state = PrepareInterpreterState(tree, dataset);
130
131      foreach (var rowEnum in rows) {
132        int row = rowEnum;
133        yield return Evaluate(dataset, ref row, state);
134        state.Reset();
135      }
136    }
137
138    private static InterpreterState PrepareInterpreterState(ISymbolicExpressionTree tree, IDataset dataset) {
139      Instruction[] code = SymbolicExpressionTreeCompiler.Compile(tree, OpCodes.MapSymbolToOpCode);
140      int necessaryArgStackSize = 0;
141      foreach (Instruction instr in code) {
142        if (instr.opCode == OpCodes.Variable) {
143          var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
144          instr.data = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
145        } else if (instr.opCode == OpCodes.LagVariable) {
146          var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
147          instr.data = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
148        } else if (instr.opCode == OpCodes.VariableCondition) {
149          var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
150          instr.data = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
151        } else if (instr.opCode == OpCodes.Call) {
152          necessaryArgStackSize += instr.nArguments + 1;
153        }
154      }
155      return new InterpreterState(code, necessaryArgStackSize);
156    }
157
158    public virtual double Evaluate(IDataset dataset, ref int row, InterpreterState state) {
159      Instruction currentInstr = state.NextInstruction();
160      switch (currentInstr.opCode) {
161        case OpCodes.Add: {
162            double s = Evaluate(dataset, ref row, state);
163            for (int i = 1; i < currentInstr.nArguments; i++) {
164              s += Evaluate(dataset, ref row, state);
165            }
166            return s;
167          }
168        case OpCodes.Sub: {
169            double s = Evaluate(dataset, ref row, state);
170            for (int i = 1; i < currentInstr.nArguments; i++) {
171              s -= Evaluate(dataset, ref row, state);
172            }
173            if (currentInstr.nArguments == 1) { s = -s; }
174            return s;
175          }
176        case OpCodes.Mul: {
177            double p = Evaluate(dataset, ref row, state);
178            for (int i = 1; i < currentInstr.nArguments; i++) {
179              p *= Evaluate(dataset, ref row, state);
180            }
181            return p;
182          }
183        case OpCodes.Div: {
184            double p = Evaluate(dataset, ref row, state);
185            for (int i = 1; i < currentInstr.nArguments; i++) {
186              p /= Evaluate(dataset, ref row, state);
187            }
188            if (currentInstr.nArguments == 1) { p = 1.0 / p; }
189            return p;
190          }
191        case OpCodes.Average: {
192            double sum = Evaluate(dataset, ref row, state);
193            for (int i = 1; i < currentInstr.nArguments; i++) {
194              sum += Evaluate(dataset, ref row, state);
195            }
196            return sum / currentInstr.nArguments;
197          }
198        case OpCodes.Cos: {
199            return Math.Cos(Evaluate(dataset, ref row, state));
200          }
201        case OpCodes.Sin: {
202            return Math.Sin(Evaluate(dataset, ref row, state));
203          }
204        case OpCodes.Tan: {
205            return Math.Tan(Evaluate(dataset, ref row, state));
206          }
207        case OpCodes.Square: {
208            return Math.Pow(Evaluate(dataset, ref row, state), 2);
209          }
210        case OpCodes.Power: {
211            double x = Evaluate(dataset, ref row, state);
212            double y = Math.Round(Evaluate(dataset, ref row, state));
213            return Math.Pow(x, y);
214          }
215        case OpCodes.SquareRoot: {
216            return Math.Sqrt(Evaluate(dataset, ref row, state));
217          }
218        case OpCodes.Root: {
219            double x = Evaluate(dataset, ref row, state);
220            double y = Math.Round(Evaluate(dataset, ref row, state));
221            return Math.Pow(x, 1 / y);
222          }
223        case OpCodes.Exp: {
224            return Math.Exp(Evaluate(dataset, ref row, state));
225          }
226        case OpCodes.Log: {
227            return Math.Log(Evaluate(dataset, ref row, state));
228          }
229        case OpCodes.Gamma: {
230            var x = Evaluate(dataset, ref row, state);
231            if (double.IsNaN(x)) { return double.NaN; } else { return alglib.gammafunction(x); }
232          }
233        case OpCodes.Psi: {
234            var x = Evaluate(dataset, ref row, state);
235            if (double.IsNaN(x)) return double.NaN;
236            else if (x <= 0 && (Math.Floor(x) - x).IsAlmost(0)) return double.NaN;
237            return alglib.psi(x);
238          }
239        case OpCodes.Dawson: {
240            var x = Evaluate(dataset, ref row, state);
241            if (double.IsNaN(x)) { return double.NaN; }
242            return alglib.dawsonintegral(x);
243          }
244        case OpCodes.ExponentialIntegralEi: {
245            var x = Evaluate(dataset, ref row, state);
246            if (double.IsNaN(x)) { return double.NaN; }
247            return alglib.exponentialintegralei(x);
248          }
249        case OpCodes.SineIntegral: {
250            double si, ci;
251            var x = Evaluate(dataset, ref row, state);
252            if (double.IsNaN(x)) return double.NaN;
253            else {
254              alglib.sinecosineintegrals(x, out si, out ci);
255              return si;
256            }
257          }
258        case OpCodes.CosineIntegral: {
259            double si, ci;
260            var x = Evaluate(dataset, ref row, state);
261            if (double.IsNaN(x)) return double.NaN;
262            else {
263              alglib.sinecosineintegrals(x, out si, out ci);
264              return ci;
265            }
266          }
267        case OpCodes.HyperbolicSineIntegral: {
268            double shi, chi;
269            var x = Evaluate(dataset, ref row, state);
270            if (double.IsNaN(x)) return double.NaN;
271            else {
272              alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
273              return shi;
274            }
275          }
276        case OpCodes.HyperbolicCosineIntegral: {
277            double shi, chi;
278            var x = Evaluate(dataset, ref row, state);
279            if (double.IsNaN(x)) return double.NaN;
280            else {
281              alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
282              return chi;
283            }
284          }
285        case OpCodes.FresnelCosineIntegral: {
286            double c = 0, s = 0;
287            var x = Evaluate(dataset, ref row, state);
288            if (double.IsNaN(x)) return double.NaN;
289            else {
290              alglib.fresnelintegral(x, ref c, ref s);
291              return c;
292            }
293          }
294        case OpCodes.FresnelSineIntegral: {
295            double c = 0, s = 0;
296            var x = Evaluate(dataset, ref row, state);
297            if (double.IsNaN(x)) return double.NaN;
298            else {
299              alglib.fresnelintegral(x, ref c, ref s);
300              return s;
301            }
302          }
303        case OpCodes.AiryA: {
304            double ai, aip, bi, bip;
305            var x = Evaluate(dataset, ref row, state);
306            if (double.IsNaN(x)) return double.NaN;
307            else {
308              alglib.airy(x, out ai, out aip, out bi, out bip);
309              return ai;
310            }
311          }
312        case OpCodes.AiryB: {
313            double ai, aip, bi, bip;
314            var x = Evaluate(dataset, ref row, state);
315            if (double.IsNaN(x)) return double.NaN;
316            else {
317              alglib.airy(x, out ai, out aip, out bi, out bip);
318              return bi;
319            }
320          }
321        case OpCodes.Norm: {
322            var x = Evaluate(dataset, ref row, state);
323            if (double.IsNaN(x)) return double.NaN;
324            else return alglib.normaldistribution(x);
325          }
326        case OpCodes.Erf: {
327            var x = Evaluate(dataset, ref row, state);
328            if (double.IsNaN(x)) return double.NaN;
329            else return alglib.errorfunction(x);
330          }
331        case OpCodes.Bessel: {
332            var x = Evaluate(dataset, ref row, state);
333            if (double.IsNaN(x)) return double.NaN;
334            else return alglib.besseli0(x);
335          }
336        case OpCodes.IfThenElse: {
337            double condition = Evaluate(dataset, ref row, state);
338            double result;
339            if (condition > 0.0) {
340              result = Evaluate(dataset, ref row, state); state.SkipInstructions();
341            } else {
342              state.SkipInstructions(); result = Evaluate(dataset, ref row, state);
343            }
344            return result;
345          }
346        case OpCodes.AND: {
347            double result = Evaluate(dataset, ref row, state);
348            for (int i = 1; i < currentInstr.nArguments; i++) {
349              if (result > 0.0) result = Evaluate(dataset, ref row, state);
350              else {
351                state.SkipInstructions();
352              }
353            }
354            return result > 0.0 ? 1.0 : -1.0;
355          }
356        case OpCodes.OR: {
357            double result = Evaluate(dataset, ref row, state);
358            for (int i = 1; i < currentInstr.nArguments; i++) {
359              if (result <= 0.0) result = Evaluate(dataset, ref row, state);
360              else {
361                state.SkipInstructions();
362              }
363            }
364            return result > 0.0 ? 1.0 : -1.0;
365          }
366        case OpCodes.NOT: {
367            return Evaluate(dataset, ref row, state) > 0.0 ? -1.0 : 1.0;
368          }
369        case OpCodes.XOR: {
370            //mkommend: XOR on multiple inputs is defined as true if the number of positive signals is odd
371            // this is equal to a consecutive execution of binary XOR operations.
372            int positiveSignals = 0;
373            for (int i = 0; i < currentInstr.nArguments; i++) {
374              if (Evaluate(dataset, ref row, state) > 0.0) { positiveSignals++; }
375            }
376            return positiveSignals % 2 != 0 ? 1.0 : -1.0;
377          }
378        case OpCodes.GT: {
379            double x = Evaluate(dataset, ref row, state);
380            double y = Evaluate(dataset, ref row, state);
381            if (x > y) { return 1.0; } else { return -1.0; }
382          }
383        case OpCodes.LT: {
384            double x = Evaluate(dataset, ref row, state);
385            double y = Evaluate(dataset, ref row, state);
386            if (x < y) { return 1.0; } else { return -1.0; }
387          }
388        case OpCodes.TimeLag: {
389            var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
390            row += timeLagTreeNode.Lag;
391            double result = Evaluate(dataset, ref row, state);
392            row -= timeLagTreeNode.Lag;
393            return result;
394          }
395        case OpCodes.Integral: {
396            int savedPc = state.ProgramCounter;
397            var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
398            double sum = 0.0;
399            for (int i = 0; i < Math.Abs(timeLagTreeNode.Lag); i++) {
400              row += Math.Sign(timeLagTreeNode.Lag);
401              sum += Evaluate(dataset, ref row, state);
402              state.ProgramCounter = savedPc;
403            }
404            row -= timeLagTreeNode.Lag;
405            sum += Evaluate(dataset, ref row, state);
406            return sum;
407          }
408
409        //mkommend: derivate calculation taken from:
410        //http://www.holoborodko.com/pavel/numerical-methods/numerical-derivative/smooth-low-noise-differentiators/
411        //one sided smooth differentiatior, N = 4
412        // y' = 1/8h (f_i + 2f_i-1, -2 f_i-3 - f_i-4)
413        case OpCodes.Derivative: {
414            int savedPc = state.ProgramCounter;
415            double f_0 = Evaluate(dataset, ref row, state); row--;
416            state.ProgramCounter = savedPc;
417            double f_1 = Evaluate(dataset, ref row, state); row -= 2;
418            state.ProgramCounter = savedPc;
419            double f_3 = Evaluate(dataset, ref row, state); row--;
420            state.ProgramCounter = savedPc;
421            double f_4 = Evaluate(dataset, ref row, state);
422            row += 4;
423
424            return (f_0 + 2 * f_1 - 2 * f_3 - f_4) / 8; // h = 1
425          }
426        case OpCodes.Call: {
427            // evaluate sub-trees
428            double[] argValues = new double[currentInstr.nArguments];
429            for (int i = 0; i < currentInstr.nArguments; i++) {
430              argValues[i] = Evaluate(dataset, ref row, state);
431            }
432            // push on argument values on stack
433            state.CreateStackFrame(argValues);
434
435            // save the pc
436            int savedPc = state.ProgramCounter;
437            // set pc to start of function 
438            state.ProgramCounter = (ushort)currentInstr.data;
439            // evaluate the function
440            double v = Evaluate(dataset, ref row, state);
441
442            // delete the stack frame
443            state.RemoveStackFrame();
444
445            // restore the pc => evaluation will continue at point after my subtrees 
446            state.ProgramCounter = savedPc;
447            return v;
448          }
449        case OpCodes.Arg: {
450            return state.GetStackFrameValue((ushort)currentInstr.data);
451          }
452        case OpCodes.Variable: {
453            if (row < 0 || row >= dataset.Rows) return double.NaN;
454            var variableTreeNode = (VariableTreeNode)currentInstr.dynamicNode;
455            return ((IList<double>)currentInstr.data)[row] * variableTreeNode.Weight;
456          }
457        case OpCodes.LagVariable: {
458            var laggedVariableTreeNode = (LaggedVariableTreeNode)currentInstr.dynamicNode;
459            int actualRow = row + laggedVariableTreeNode.Lag;
460            if (actualRow < 0 || actualRow >= dataset.Rows) { return double.NaN; }
461            return ((IList<double>)currentInstr.data)[actualRow] * laggedVariableTreeNode.Weight;
462          }
463        case OpCodes.Constant: {
464            var constTreeNode = (ConstantTreeNode)currentInstr.dynamicNode;
465            return constTreeNode.Value;
466          }
467
468        //mkommend: this symbol uses the logistic function f(x) = 1 / (1 + e^(-alpha * x) )
469        //to determine the relative amounts of the true and false branch see http://en.wikipedia.org/wiki/Logistic_function
470        case OpCodes.VariableCondition: {
471            if (row < 0 || row >= dataset.Rows) return double.NaN;
472            var variableConditionTreeNode = (VariableConditionTreeNode)currentInstr.dynamicNode;
473            if (!variableConditionTreeNode.Symbol.IgnoreSlope) {
474              double variableValue = ((IList<double>)currentInstr.data)[row];
475              double x = variableValue - variableConditionTreeNode.Threshold;
476              double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
477
478              double trueBranch = Evaluate(dataset, ref row, state);
479              double falseBranch = Evaluate(dataset, ref row, state);
480
481              return trueBranch * p + falseBranch * (1 - p);
482            } else {
483              // strict threshold
484              double variableValue = ((IList<double>)currentInstr.data)[row];
485              if (variableValue <= variableConditionTreeNode.Threshold) {
486                var left = Evaluate(dataset, ref row, state);
487                state.SkipInstructions();
488                return left;
489              } else {
490                state.SkipInstructions();
491                return Evaluate(dataset, ref row, state);
492              }
493            }
494          }
495        default:
496          throw new NotSupportedException();
497      }
498    }
499  }
500}
Note: See TracBrowser for help on using the repository browser.