Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Interpreter/SymbolicDataAnalysisExpressionTreeLinearInterpreter.cs @ 14345

Last change on this file since 14345 was 14345, checked in by gkronber, 7 years ago

#2690: implemented methods to generate symbolic expression tree solutions for decision tree models (random forest and gradient boosted) as well as views which make it possible to inspect each of the individual trees in a GBT and RF solution

File size: 20.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
33  [StorableClass]
34  [Item("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Fast linear (non-recursive) interpreter for symbolic expression trees. Does not support ADFs.")]
35  public sealed class SymbolicDataAnalysisExpressionTreeLinearInterpreter : ParameterizedNamedItem, ISymbolicDataAnalysisExpressionTreeInterpreter {
36    private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
37    private const string CheckExpressionsWithIntervalArithmeticParameterDescription = "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.";
38    private const string EvaluatedSolutionsParameterName = "EvaluatedSolutions";
39
40    private readonly SymbolicDataAnalysisExpressionTreeInterpreter interpreter;
41
42    public override bool CanChangeName {
43      get { return false; }
44    }
45
46    public override bool CanChangeDescription {
47      get { return false; }
48    }
49
50    #region parameter properties
51    public IFixedValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
52      get { return (IFixedValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
53    }
54
55    public IFixedValueParameter<IntValue> EvaluatedSolutionsParameter {
56      get { return (IFixedValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName]; }
57    }
58    #endregion
59
60    #region properties
61    public bool CheckExpressionsWithIntervalArithmetic {
62      get { return CheckExpressionsWithIntervalArithmeticParameter.Value.Value; }
63      set { CheckExpressionsWithIntervalArithmeticParameter.Value.Value = value; }
64    }
65    public int EvaluatedSolutions {
66      get { return EvaluatedSolutionsParameter.Value.Value; }
67      set { EvaluatedSolutionsParameter.Value.Value = value; }
68    }
69    #endregion
70
71    [StorableConstructor]
72    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(bool deserializing)
73      : base(deserializing) {
74      interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
75    }
76
77    private SymbolicDataAnalysisExpressionTreeLinearInterpreter(SymbolicDataAnalysisExpressionTreeLinearInterpreter original, Cloner cloner)
78      : base(original, cloner) {
79      interpreter = cloner.Clone(original.interpreter);
80    }
81
82    public override IDeepCloneable Clone(Cloner cloner) {
83      return new SymbolicDataAnalysisExpressionTreeLinearInterpreter(this, cloner);
84    }
85
86    public SymbolicDataAnalysisExpressionTreeLinearInterpreter()
87      : base("SymbolicDataAnalysisExpressionTreeLinearInterpreter", "Linear (non-recursive) interpreter for symbolic expression trees (does not support ADFs).") {
88      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, CheckExpressionsWithIntervalArithmeticParameterDescription, new BoolValue(false)));
89      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
90      interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
91    }
92
93    public SymbolicDataAnalysisExpressionTreeLinearInterpreter(string name, string description)
94      : base(name, description) {
95      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, CheckExpressionsWithIntervalArithmeticParameterDescription, new BoolValue(false)));
96      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", new IntValue(0)));
97      interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
98    }
99
100    [StorableHook(HookType.AfterDeserialization)]
101    private void AfterDeserialization() {
102      var evaluatedSolutions = new IntValue(0);
103      var checkExpressionsWithIntervalArithmetic = new BoolValue(false);
104      if (Parameters.ContainsKey(EvaluatedSolutionsParameterName)) {
105        var evaluatedSolutionsParameter = (IValueParameter<IntValue>)Parameters[EvaluatedSolutionsParameterName];
106        evaluatedSolutions = evaluatedSolutionsParameter.Value;
107        Parameters.Remove(EvaluatedSolutionsParameterName);
108      }
109      Parameters.Add(new FixedValueParameter<IntValue>(EvaluatedSolutionsParameterName, "A counter for the total number of solutions the interpreter has evaluated", evaluatedSolutions));
110      if (Parameters.ContainsKey(CheckExpressionsWithIntervalArithmeticParameterName)) {
111        var checkExpressionsWithIntervalArithmeticParameter = (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName];
112        Parameters.Remove(CheckExpressionsWithIntervalArithmeticParameterName);
113        checkExpressionsWithIntervalArithmetic = checkExpressionsWithIntervalArithmeticParameter.Value;
114      }
115      Parameters.Add(new FixedValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, CheckExpressionsWithIntervalArithmeticParameterDescription, checkExpressionsWithIntervalArithmetic));
116    }
117
118    #region IStatefulItem
119    public void InitializeState() {
120      EvaluatedSolutions = 0;
121    }
122
123    public void ClearState() { }
124    #endregion
125
126    private readonly object syncRoot = new object();
127    public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, IDataset dataset, IEnumerable<int> rows) {
128      if (!rows.Any()) return Enumerable.Empty<double>();
129      if (CheckExpressionsWithIntervalArithmetic)
130        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
131
132      lock (syncRoot) {
133        EvaluatedSolutions++; // increment the evaluated solutions counter
134      }
135
136      var code = SymbolicExpressionTreeLinearCompiler.Compile(tree, OpCodes.MapSymbolToOpCode);
137      PrepareInstructions(code, dataset);
138      return rows.Select(row => Evaluate(dataset, row, code));
139    }
140
141    private double Evaluate(IDataset dataset, int row, LinearInstruction[] code) {
142      for (int i = code.Length - 1; i >= 0; --i) {
143        if (code[i].skip) continue;
144        #region opcode if
145        var instr = code[i];
146        if (instr.opCode == OpCodes.Variable) {
147          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
148          else {
149            var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
150            instr.value = ((IList<double>)instr.data)[row] * variableTreeNode.Weight;
151          }
152        } else if (instr.opCode == OpCodes.LagVariable) {
153          var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
154          int actualRow = row + laggedVariableTreeNode.Lag;
155          if (actualRow < 0 || actualRow >= dataset.Rows)
156            instr.value = double.NaN;
157          else
158            instr.value = ((IList<double>)instr.data)[actualRow] * laggedVariableTreeNode.Weight;
159        } else if (instr.opCode == OpCodes.VariableCondition) {
160          if (row < 0 || row >= dataset.Rows) instr.value = double.NaN;
161          var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
162          if (!variableConditionTreeNode.Symbol.IgnoreSlope) {
163            double variableValue = ((IList<double>)instr.data)[row];
164            double x = variableValue - variableConditionTreeNode.Threshold;
165            double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
166
167            double trueBranch = code[instr.childIndex].value;
168            double falseBranch = code[instr.childIndex + 1].value;
169
170            instr.value = trueBranch * p + falseBranch * (1 - p);
171          } else {
172            double variableValue = ((IList<double>)instr.data)[row];
173            if (variableValue <= variableConditionTreeNode.Threshold) {
174              instr.value = code[instr.childIndex].value;
175            } else {
176              instr.value = code[instr.childIndex + 1].value;
177            }
178          }
179        } else if (instr.opCode == OpCodes.Add) {
180          double s = code[instr.childIndex].value;
181          for (int j = 1; j != instr.nArguments; ++j) {
182            s += code[instr.childIndex + j].value;
183          }
184          instr.value = s;
185        } else if (instr.opCode == OpCodes.Sub) {
186          double s = code[instr.childIndex].value;
187          for (int j = 1; j != instr.nArguments; ++j) {
188            s -= code[instr.childIndex + j].value;
189          }
190          if (instr.nArguments == 1) s = -s;
191          instr.value = s;
192        } else if (instr.opCode == OpCodes.Mul) {
193          double p = code[instr.childIndex].value;
194          for (int j = 1; j != instr.nArguments; ++j) {
195            p *= code[instr.childIndex + j].value;
196          }
197          instr.value = p;
198        } else if (instr.opCode == OpCodes.Div) {
199          double p = code[instr.childIndex].value;
200          for (int j = 1; j != instr.nArguments; ++j) {
201            p /= code[instr.childIndex + j].value;
202          }
203          if (instr.nArguments == 1) p = 1.0 / p;
204          instr.value = p;
205        } else if (instr.opCode == OpCodes.Average) {
206          double s = code[instr.childIndex].value;
207          for (int j = 1; j != instr.nArguments; ++j) {
208            s += code[instr.childIndex + j].value;
209          }
210          instr.value = s / instr.nArguments;
211        } else if (instr.opCode == OpCodes.Cos) {
212          instr.value = Math.Cos(code[instr.childIndex].value);
213        } else if (instr.opCode == OpCodes.Sin) {
214          instr.value = Math.Sin(code[instr.childIndex].value);
215        } else if (instr.opCode == OpCodes.Tan) {
216          instr.value = Math.Tan(code[instr.childIndex].value);
217        } else if (instr.opCode == OpCodes.Square) {
218          instr.value = Math.Pow(code[instr.childIndex].value, 2);
219        } else if (instr.opCode == OpCodes.Power) {
220          double x = code[instr.childIndex].value;
221          double y = Math.Round(code[instr.childIndex + 1].value);
222          instr.value = Math.Pow(x, y);
223        } else if (instr.opCode == OpCodes.SquareRoot) {
224          instr.value = Math.Sqrt(code[instr.childIndex].value);
225        } else if (instr.opCode == OpCodes.Root) {
226          double x = code[instr.childIndex].value;
227          double y = Math.Round(code[instr.childIndex + 1].value);
228          instr.value = Math.Pow(x, 1 / y);
229        } else if (instr.opCode == OpCodes.Exp) {
230          instr.value = Math.Exp(code[instr.childIndex].value);
231        } else if (instr.opCode == OpCodes.Log) {
232          instr.value = Math.Log(code[instr.childIndex].value);
233        } else if (instr.opCode == OpCodes.Gamma) {
234          var x = code[instr.childIndex].value;
235          instr.value = double.IsNaN(x) ? double.NaN : alglib.gammafunction(x);
236        } else if (instr.opCode == OpCodes.Psi) {
237          var x = code[instr.childIndex].value;
238          if (double.IsNaN(x)) instr.value = double.NaN;
239          else if (x <= 0 && (Math.Floor(x) - x).IsAlmost(0)) instr.value = double.NaN;
240          else instr.value = alglib.psi(x);
241        } else if (instr.opCode == OpCodes.Dawson) {
242          var x = code[instr.childIndex].value;
243          instr.value = double.IsNaN(x) ? double.NaN : alglib.dawsonintegral(x);
244        } else if (instr.opCode == OpCodes.ExponentialIntegralEi) {
245          var x = code[instr.childIndex].value;
246          instr.value = double.IsNaN(x) ? double.NaN : alglib.exponentialintegralei(x);
247        } else if (instr.opCode == OpCodes.SineIntegral) {
248          double si, ci;
249          var x = code[instr.childIndex].value;
250          if (double.IsNaN(x)) instr.value = double.NaN;
251          else {
252            alglib.sinecosineintegrals(x, out si, out ci);
253            instr.value = si;
254          }
255        } else if (instr.opCode == OpCodes.CosineIntegral) {
256          double si, ci;
257          var x = code[instr.childIndex].value;
258          if (double.IsNaN(x)) instr.value = double.NaN;
259          else {
260            alglib.sinecosineintegrals(x, out si, out ci);
261            instr.value = ci;
262          }
263        } else if (instr.opCode == OpCodes.HyperbolicSineIntegral) {
264          double shi, chi;
265          var x = code[instr.childIndex].value;
266          if (double.IsNaN(x)) instr.value = double.NaN;
267          else {
268            alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
269            instr.value = shi;
270          }
271        } else if (instr.opCode == OpCodes.HyperbolicCosineIntegral) {
272          double shi, chi;
273          var x = code[instr.childIndex].value;
274          if (double.IsNaN(x)) instr.value = double.NaN;
275          else {
276            alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
277            instr.value = chi;
278          }
279        } else if (instr.opCode == OpCodes.FresnelCosineIntegral) {
280          double c = 0, s = 0;
281          var x = code[instr.childIndex].value;
282          if (double.IsNaN(x)) instr.value = double.NaN;
283          else {
284            alglib.fresnelintegral(x, ref c, ref s);
285            instr.value = c;
286          }
287        } else if (instr.opCode == OpCodes.FresnelSineIntegral) {
288          double c = 0, s = 0;
289          var x = code[instr.childIndex].value;
290          if (double.IsNaN(x)) instr.value = double.NaN;
291          else {
292            alglib.fresnelintegral(x, ref c, ref s);
293            instr.value = s;
294          }
295        } else if (instr.opCode == OpCodes.AiryA) {
296          double ai, aip, bi, bip;
297          var x = code[instr.childIndex].value;
298          if (double.IsNaN(x)) instr.value = double.NaN;
299          else {
300            alglib.airy(x, out ai, out aip, out bi, out bip);
301            instr.value = ai;
302          }
303        } else if (instr.opCode == OpCodes.AiryB) {
304          double ai, aip, bi, bip;
305          var x = code[instr.childIndex].value;
306          if (double.IsNaN(x)) instr.value = double.NaN;
307          else {
308            alglib.airy(x, out ai, out aip, out bi, out bip);
309            instr.value = bi;
310          }
311        } else if (instr.opCode == OpCodes.Norm) {
312          var x = code[instr.childIndex].value;
313          if (double.IsNaN(x)) instr.value = double.NaN;
314          else instr.value = alglib.normaldistribution(x);
315        } else if (instr.opCode == OpCodes.Erf) {
316          var x = code[instr.childIndex].value;
317          if (double.IsNaN(x)) instr.value = double.NaN;
318          else instr.value = alglib.errorfunction(x);
319        } else if (instr.opCode == OpCodes.Bessel) {
320          var x = code[instr.childIndex].value;
321          if (double.IsNaN(x)) instr.value = double.NaN;
322          else instr.value = alglib.besseli0(x);
323        } else if (instr.opCode == OpCodes.IfThenElse) {
324          double condition = code[instr.childIndex].value;
325          double result;
326          if (condition > 0.0) {
327            result = code[instr.childIndex + 1].value;
328          } else {
329            result = code[instr.childIndex + 2].value;
330          }
331          instr.value = result;
332        } else if (instr.opCode == OpCodes.AND) {
333          double result = code[instr.childIndex].value;
334          for (int j = 1; j < instr.nArguments; j++) {
335            if (result > 0.0) result = code[instr.childIndex + j].value;
336            else break;
337          }
338          instr.value = result > 0.0 ? 1.0 : -1.0;
339        } else if (instr.opCode == OpCodes.OR) {
340          double result = code[instr.childIndex].value;
341          for (int j = 1; j < instr.nArguments; j++) {
342            if (result <= 0.0) result = code[instr.childIndex + j].value;
343            else break;
344          }
345          instr.value = result > 0.0 ? 1.0 : -1.0;
346        } else if (instr.opCode == OpCodes.NOT) {
347          instr.value = code[instr.childIndex].value > 0.0 ? -1.0 : 1.0;
348        } else if (instr.opCode == OpCodes.XOR) {
349          int positiveSignals = 0;
350          for (int j = 0; j < instr.nArguments; j++) {
351            if (code[instr.childIndex + j].value > 0.0) positiveSignals++;
352          }
353          instr.value = positiveSignals % 2 != 0 ? 1.0 : -1.0;
354        } else if (instr.opCode == OpCodes.GT) {
355          double x = code[instr.childIndex].value;
356          double y = code[instr.childIndex + 1].value;
357          instr.value = x > y ? 1.0 : -1.0;
358        } else if (instr.opCode == OpCodes.LT) {
359          double x = code[instr.childIndex].value;
360          double y = code[instr.childIndex + 1].value;
361          instr.value = x < y ? 1.0 : -1.0;
362        } else if (instr.opCode == OpCodes.TimeLag || instr.opCode == OpCodes.Derivative || instr.opCode == OpCodes.Integral) {
363          var state = (InterpreterState)instr.data;
364          state.Reset();
365          instr.value = interpreter.Evaluate(dataset, ref row, state);
366        } else {
367          var errorText = string.Format("The {0} symbol is not supported by the linear interpreter. To support this symbol, please use the SymbolicDataAnalysisExpressionTreeInterpreter.", instr.dynamicNode.Symbol.Name);
368          throw new NotSupportedException(errorText);
369        }
370        #endregion
371      }
372      return code[0].value;
373    }
374
375    private static LinearInstruction[] GetPrefixSequence(LinearInstruction[] code, int startIndex) {
376      var s = new Stack<int>();
377      var list = new List<LinearInstruction>();
378      s.Push(startIndex);
379      while (s.Any()) {
380        int i = s.Pop();
381        var instr = code[i];
382        // push instructions in reverse execution order
383        for (int j = instr.nArguments - 1; j >= 0; j--) s.Push(instr.childIndex + j);
384        list.Add(instr);
385      }
386      return list.ToArray();
387    }
388
389    public static void PrepareInstructions(LinearInstruction[] code, IDataset dataset) {
390      for (int i = 0; i != code.Length; ++i) {
391        var instr = code[i];
392        #region opcode switch
393        switch (instr.opCode) {
394          case OpCodes.Constant: {
395              var constTreeNode = (ConstantTreeNode)instr.dynamicNode;
396              instr.value = constTreeNode.Value;
397              instr.skip = true; // the value is already set so this instruction should be skipped in the evaluation phase
398            }
399            break;
400          case OpCodes.Variable: {
401              var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
402              instr.data = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
403            }
404            break;
405          case OpCodes.LagVariable: {
406              var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
407              instr.data = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
408            }
409            break;
410          case OpCodes.VariableCondition: {
411              var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
412              instr.data = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
413            }
414            break;
415          case OpCodes.TimeLag:
416          case OpCodes.Integral:
417          case OpCodes.Derivative: {
418              var seq = GetPrefixSequence(code, i);
419              var interpreterState = new InterpreterState(seq, 0);
420              instr.data = interpreterState;
421              for (int j = 1; j != seq.Length; ++j)
422                seq[j].skip = true;
423              break;
424            }
425        }
426        #endregion
427      }
428    }
429  }
430}
Note: See TracBrowser for help on using the repository browser.