Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.MultiVariate.TimeSeriesPrognosis/3.3/Symbolic/SymbolicTimeSeriesExpressionInterpreter.cs @ 4113

Last change on this file since 4113 was 4113, checked in by gkronber, 14 years ago

Added plugin for time series prognosis. #1081

File size: 16.4 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using System.Collections.Generic;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Symbols;
29using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
30using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Compiler;
31using HeuristicLab.Problems.DataAnalysis.Symbolic;
32using HeuristicLab.Problems.DataAnalysis.MultiVariate.TimeSeriesPrognosis.Symbolic.Interfaces;
33using HeuristicLab.Problems.DataAnalysis.MultiVariate.TimeSeriesPrognosis.Symbolic.Symbols;
34
35namespace HeuristicLab.Problems.DataAnalysis.MultiVariate.TimeSeriesPrognosis {
36  [StorableClass]
37  [Item("SymbolicTimeSeriesExpressionInterpreter", "Interpreter for symbolic expression trees representing time series forecast models.")]
38  public class SymbolicTimeSeriesExpressionInterpreter : NamedItem, ISymbolicTimeSeriesExpressionInterpreter {
39    private class OpCodes {
40      public const byte Add = 1;
41      public const byte Sub = 2;
42      public const byte Mul = 3;
43      public const byte Div = 4;
44
45      public const byte Sin = 5;
46      public const byte Cos = 6;
47      public const byte Tan = 7;
48
49      public const byte Log = 8;
50      public const byte Exp = 9;
51
52      public const byte IfThenElse = 10;
53
54      public const byte GT = 11;
55      public const byte LT = 12;
56
57      public const byte AND = 13;
58      public const byte OR = 14;
59      public const byte NOT = 15;
60
61
62      public const byte Average = 16;
63
64      public const byte Call = 17;
65
66      public const byte Variable = 18;
67      public const byte LagVariable = 19;
68      public const byte Constant = 20;
69      public const byte Arg = 21;
70      public const byte Differential = 22;
71      public const byte Integral = 23;
72      public const byte MovingAverage = 24;
73    }
74
75    private Dictionary<Type, byte> symbolToOpcode = new Dictionary<Type, byte>() {
76      { typeof(Addition), OpCodes.Add },
77      { typeof(Subtraction), OpCodes.Sub },
78      { typeof(Multiplication), OpCodes.Mul },
79      { typeof(Division), OpCodes.Div },
80      { typeof(Sine), OpCodes.Sin },
81      { typeof(Cosine), OpCodes.Cos },
82      { typeof(Tangent), OpCodes.Tan },
83      { typeof(Logarithm), OpCodes.Log },
84      { typeof(Exponential), OpCodes.Exp },
85      { typeof(IfThenElse), OpCodes.IfThenElse },
86      { typeof(GreaterThan), OpCodes.GT },
87      { typeof(LessThan), OpCodes.LT },
88      { typeof(And), OpCodes.AND },
89      { typeof(Or), OpCodes.OR },
90      { typeof(Not), OpCodes.NOT},
91      { typeof(Average), OpCodes.Average},
92      { typeof(InvokeFunction), OpCodes.Call },
93      { typeof(HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols.Variable), OpCodes.Variable },
94      { typeof(LaggedVariable), OpCodes.LagVariable },
95      { typeof(IntegratedVariable), OpCodes.Integral },
96      { typeof(DerivativeVariable), OpCodes.Differential },
97      { typeof(MovingAverage), OpCodes.MovingAverage },
98      { typeof(Constant), OpCodes.Constant },
99      { typeof(Argument), OpCodes.Arg },
100    };
101    private const int ARGUMENT_STACK_SIZE = 1024;
102
103    private Dataset dataset;
104    private int row;
105    private Instruction[] code;
106    private int pc;
107    private double[] argumentStack = new double[ARGUMENT_STACK_SIZE];
108    private int argStackPointer;
109    private Dictionary<int, double[]> estimatedTargetVariableValues;
110    private int currentPredictionHorizon;
111
112    public override bool CanChangeName {
113      get { return false; }
114    }
115    public override bool CanChangeDescription {
116      get { return false; }
117    }
118
119    public SymbolicTimeSeriesExpressionInterpreter()
120      : base() {
121    }
122    #region ITimeSeriesExpressionInterpreter Members
123
124    public IEnumerable<double[]> GetSymbolicExpressionTreeValues(SymbolicExpressionTree tree, Dataset dataset, IEnumerable<string> targetVariables, IEnumerable<int> rows, int predictionHorizon) {
125      this.dataset = dataset;
126      List<int> targetVariableIndexes = new List<int>();
127      estimatedTargetVariableValues = new Dictionary<int, double[]>();
128      foreach (string targetVariable in targetVariables) {
129        int index = dataset.GetVariableIndex(targetVariable);
130        targetVariableIndexes.Add(index);
131        estimatedTargetVariableValues.Add(index, new double[predictionHorizon]);
132      }
133      var compiler = new SymbolicExpressionTreeCompiler();
134      compiler.AddInstructionPostProcessingHook(PostProcessInstruction);
135      code = compiler.Compile(tree, MapSymbolToOpCode);
136
137      foreach (var row in rows) {
138        ResetVariableValues(dataset, row);
139        for (int step = 0; step < predictionHorizon; step++) {
140          this.row = row + step;
141          this.currentPredictionHorizon = step;
142          pc = 0;
143          argStackPointer = 0;
144          double[] estimatedValues = new double[tree.Root.SubTrees[0].SubTrees.Count];
145          int component = 0;
146          foreach (int targetVariableIndex in targetVariableIndexes) {
147            double estimatedValue = Evaluate();
148            estimatedTargetVariableValues[targetVariableIndex][step] = estimatedValue;
149            estimatedValues[component] = estimatedValue;
150            component++;
151          }
152          yield return estimatedValues;
153        }
154      }
155    }
156
157    public IEnumerable<double[]> GetScaledSymbolicExpressionTreeValues(SymbolicExpressionTree tree, Dataset dataset, IEnumerable<string> targetVariables, IEnumerable<int> rows, int predictionHorizon, double[] beta, double[] alpha) {
158      this.dataset = dataset;
159      List<int> targetVariableIndexes = new List<int>();
160      estimatedTargetVariableValues = new Dictionary<int, double[]>();
161      foreach (string targetVariable in targetVariables) {
162        int index = dataset.GetVariableIndex(targetVariable);
163        targetVariableIndexes.Add(index);
164        estimatedTargetVariableValues.Add(index, new double[predictionHorizon]);
165      }
166      var compiler = new SymbolicExpressionTreeCompiler();
167      compiler.AddInstructionPostProcessingHook(PostProcessInstruction);
168      code = compiler.Compile(tree, MapSymbolToOpCode);
169
170      foreach (var row in rows) {
171        ResetVariableValues(dataset, row);
172        for (int step = 0; step < predictionHorizon; step++) {
173          this.row = row + step;
174          this.currentPredictionHorizon = step;
175          pc = 0;
176          argStackPointer = 0;
177          double[] estimatedValues = new double[tree.Root.SubTrees[0].SubTrees.Count];
178          int component = 0;
179          foreach (int targetVariableIndex in targetVariableIndexes) {
180            double estimatedValue = Evaluate() * beta[component] + alpha[component];
181            estimatedTargetVariableValues[targetVariableIndex][step] = estimatedValue;
182            estimatedValues[component] = estimatedValue;
183            component++;
184          }
185          yield return estimatedValues;
186        }
187      }
188    }
189
190    #endregion
191
192    private void ResetVariableValues(Dataset dataset, int start) {
193      foreach (var pair in estimatedTargetVariableValues) {
194        int targetVariableIndex = pair.Key;
195        double[] values = pair.Value;
196        for (int i = 0; i < values.Length; i++) {
197          values[i] = dataset[start + i, targetVariableIndex];
198        }
199      }
200    }
201
202    private Instruction PostProcessInstruction(Instruction instr) {
203      if (instr.opCode == OpCodes.Variable) {
204        var variableTreeNode = instr.dynamicNode as VariableTreeNode;
205        instr.iArg0 = (ushort)dataset.GetVariableIndex(variableTreeNode.VariableName);
206      } else if (instr.opCode == OpCodes.LagVariable) {
207        var variableTreeNode = instr.dynamicNode as LaggedVariableTreeNode;
208        instr.iArg0 = (ushort)dataset.GetVariableIndex(variableTreeNode.VariableName);
209      }
210      return instr;
211    }
212
213    private byte MapSymbolToOpCode(SymbolicExpressionTreeNode treeNode) {
214      if (symbolToOpcode.ContainsKey(treeNode.Symbol.GetType()))
215        return symbolToOpcode[treeNode.Symbol.GetType()];
216      else
217        throw new NotSupportedException("Symbol: " + treeNode.Symbol);
218    }
219
220    private double Evaluate() {
221      Instruction currentInstr = code[pc++];
222      switch (currentInstr.opCode) {
223        case OpCodes.Add: {
224            double s = Evaluate();
225            for (int i = 1; i < currentInstr.nArguments; i++) {
226              s += Evaluate();
227            }
228            return s;
229          }
230        case OpCodes.Sub: {
231            double s = Evaluate();
232            for (int i = 1; i < currentInstr.nArguments; i++) {
233              s -= Evaluate();
234            }
235            if (currentInstr.nArguments == 1) s = -s;
236            return s;
237          }
238        case OpCodes.Mul: {
239            double p = Evaluate();
240            for (int i = 1; i < currentInstr.nArguments; i++) {
241              p *= Evaluate();
242            }
243            return p;
244          }
245        case OpCodes.Div: {
246            double p = Evaluate();
247            for (int i = 1; i < currentInstr.nArguments; i++) {
248              p /= Evaluate();
249            }
250            if (currentInstr.nArguments == 1) p = 1.0 / p;
251            return p;
252          }
253        case OpCodes.Average: {
254            double sum = Evaluate();
255            for (int i = 1; i < currentInstr.nArguments; i++) {
256              sum += Evaluate();
257            }
258            return sum / currentInstr.nArguments;
259          }
260        case OpCodes.Cos: {
261            return Math.Cos(Evaluate());
262          }
263        case OpCodes.Sin: {
264            return Math.Sin(Evaluate());
265          }
266        case OpCodes.Tan: {
267            return Math.Tan(Evaluate());
268          }
269        case OpCodes.Exp: {
270            return Math.Exp(Evaluate());
271          }
272        case OpCodes.Log: {
273            return Math.Log(Evaluate());
274          }
275        case OpCodes.IfThenElse: {
276            double condition = Evaluate();
277            double result;
278            if (condition > 0.0) {
279              result = Evaluate(); SkipBakedCode();
280            } else {
281              SkipBakedCode(); result = Evaluate();
282            }
283            return result;
284          }
285        case OpCodes.AND: {
286            double result = Evaluate();
287            for (int i = 1; i < currentInstr.nArguments; i++) {
288              if (result <= 0.0) SkipBakedCode();
289              else {
290                result = Evaluate();
291              }
292            }
293            return result <= 0.0 ? -1.0 : 1.0;
294          }
295        case OpCodes.OR: {
296            double result = Evaluate();
297            for (int i = 1; i < currentInstr.nArguments; i++) {
298              if (result > 0.0) SkipBakedCode();
299              else {
300                result = Evaluate();
301              }
302            }
303            return result > 0.0 ? 1.0 : -1.0;
304          }
305        case OpCodes.NOT: {
306            return -Evaluate();
307          }
308        case OpCodes.GT: {
309            double x = Evaluate();
310            double y = Evaluate();
311            if (x > y) return 1.0;
312            else return -1.0;
313          }
314        case OpCodes.LT: {
315            double x = Evaluate();
316            double y = Evaluate();
317            if (x < y) return 1.0;
318            else return -1.0;
319          }
320        case OpCodes.Call: {
321            // evaluate sub-trees
322            // push on argStack in reverse order
323            for (int i = 0; i < currentInstr.nArguments; i++) {
324              argumentStack[argStackPointer + currentInstr.nArguments - i] = Evaluate();
325            }
326            argStackPointer += currentInstr.nArguments;
327
328            // save the pc
329            int nextPc = pc;
330            // set pc to start of function 
331            pc = currentInstr.iArg0;
332            // evaluate the function
333            double v = Evaluate();
334
335            // decrease the argument stack pointer by the number of arguments pushed
336            // to set the argStackPointer back to the original location
337            argStackPointer -= currentInstr.nArguments;
338
339            // restore the pc => evaluation will continue at point after my subtrees 
340            pc = nextPc;
341            return v;
342          }
343        case OpCodes.Arg: {
344            return argumentStack[argStackPointer - currentInstr.iArg0];
345          }
346        case OpCodes.Variable: {
347            var variableTreeNode = currentInstr.dynamicNode as VariableTreeNode;
348            return dataset[row, currentInstr.iArg0] * variableTreeNode.Weight;
349          }
350        case OpCodes.LagVariable: {
351            var lagVariableTreeNode = currentInstr.dynamicNode as LaggedVariableTreeNode;
352            int actualRow = row + lagVariableTreeNode.Lag;
353            if (actualRow < 0 || actualRow >= dataset.Rows)
354              return double.NaN;
355            return GetVariableValue(currentInstr.iArg0, lagVariableTreeNode.Lag) * lagVariableTreeNode.Weight;
356          }
357        case OpCodes.MovingAverage: {
358            var movingAvgTreeNode = currentInstr.dynamicNode as MovingAverageTreeNode;
359            if (row + movingAvgTreeNode.MinTimeOffset < 0 || row + movingAvgTreeNode.MaxTimeOffset >= dataset.Rows)
360              return double.NaN;
361            double sum = 0.0;
362            for (int relativeRow = movingAvgTreeNode.MinTimeOffset; relativeRow < movingAvgTreeNode.MaxTimeOffset; relativeRow++) {
363              sum += GetVariableValue(currentInstr.iArg0, relativeRow) * movingAvgTreeNode.Weight;
364            }
365            return sum / (movingAvgTreeNode.MaxTimeOffset - movingAvgTreeNode.MinTimeOffset);
366          }
367        case OpCodes.Differential: {
368            var diffTreeNode = currentInstr.dynamicNode as DerivativeVariableTreeNode;
369            if (row + diffTreeNode.Lag - 2 < 0 || row + diffTreeNode.Lag >= dataset.Rows)
370              return double.NaN;
371            double y_0 = GetVariableValue(currentInstr.iArg0, diffTreeNode.Lag) * diffTreeNode.Weight;
372            double y_1 = GetVariableValue(currentInstr.iArg0, diffTreeNode.Lag - 1) * diffTreeNode.Weight;
373            double y_2 = GetVariableValue(currentInstr.iArg0, diffTreeNode.Lag - 2) * diffTreeNode.Weight;
374            return (3 * y_0 - 4 * y_1 + 3 * y_2) / 2;
375          }
376        case OpCodes.Integral: {
377            var integralVarTreeNode = currentInstr.dynamicNode as IntegratedVariableTreeNode;
378            if (row + integralVarTreeNode.MinTimeOffset < 0 || row + integralVarTreeNode.MaxTimeOffset >= dataset.Rows)
379              return double.NaN;
380            double sum = 0;
381            for (int relativeRow = integralVarTreeNode.MinTimeOffset; relativeRow < integralVarTreeNode.MaxTimeOffset; relativeRow++) {
382              sum += GetVariableValue(currentInstr.iArg0, relativeRow) * integralVarTreeNode.Weight;
383            }
384            return sum;
385          }
386        case OpCodes.Constant: {
387            var constTreeNode = currentInstr.dynamicNode as ConstantTreeNode;
388            return constTreeNode.Value;
389          }
390        default: throw new NotSupportedException();
391      }
392    }
393
394    private double GetVariableValue(int variableIndex, int timeoffset) {
395      if (estimatedTargetVariableValues.ContainsKey(variableIndex) &&
396                      currentPredictionHorizon + timeoffset >= 0) {
397        return estimatedTargetVariableValues[variableIndex][currentPredictionHorizon + timeoffset];
398      } else {
399        return dataset[row + timeoffset, variableIndex];
400      }
401    }
402
403    // skips a whole branch
404    protected void SkipBakedCode() {
405      int i = 1;
406      while (i > 0) {
407        i += code[pc++].nArguments;
408        i--;
409      }
410    }
411  }
412}
413
Note: See TracBrowser for help on using the repository browser.