Free cookie consent management tool by TermsFeed Policy Generator

source: branches/DataAnalysis/HeuristicLab.Problems.DataAnalysis.MultiVariate.TimeSeriesPrognosis/3.3/Symbolic/SymbolicTimeSeriesExpressionInterpreter.cs @ 4475

Last change on this file since 4475 was 4475, checked in by gkronber, 13 years ago

Fixed bugs in time series prognosis classes #1142.

File size: 16.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using System.Collections.Generic;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Symbols;
29using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
30using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Compiler;
31using HeuristicLab.Problems.DataAnalysis.Symbolic;
32using HeuristicLab.Problems.DataAnalysis.MultiVariate.TimeSeriesPrognosis.Symbolic.Interfaces;
33using HeuristicLab.Problems.DataAnalysis.MultiVariate.TimeSeriesPrognosis.Symbolic.Symbols;
34
35namespace HeuristicLab.Problems.DataAnalysis.MultiVariate.TimeSeriesPrognosis {
36  [StorableClass]
37  [Item("SymbolicTimeSeriesExpressionInterpreter", "Interpreter for symbolic expression trees representing time series forecast models.")]
38  public class SymbolicTimeSeriesExpressionInterpreter : NamedItem, ISymbolicTimeSeriesExpressionInterpreter {
39    private class OpCodes {
40      public const byte Add = 1;
41      public const byte Sub = 2;
42      public const byte Mul = 3;
43      public const byte Div = 4;
44
45      public const byte Sin = 5;
46      public const byte Cos = 6;
47      public const byte Tan = 7;
48
49      public const byte Log = 8;
50      public const byte Exp = 9;
51
52      public const byte IfThenElse = 10;
53
54      public const byte GT = 11;
55      public const byte LT = 12;
56
57      public const byte AND = 13;
58      public const byte OR = 14;
59      public const byte NOT = 15;
60
61
62      public const byte Average = 16;
63
64      public const byte Call = 17;
65
66      public const byte Variable = 18;
67      public const byte LagVariable = 19;
68      public const byte Constant = 20;
69      public const byte Arg = 21;
70      public const byte Differential = 22;
71      public const byte Integral = 23;
72      public const byte MovingAverage = 24;
73    }
74
75    private Dictionary<Type, byte> symbolToOpcode = new Dictionary<Type, byte>() {
76      { typeof(Addition), OpCodes.Add },
77      { typeof(Subtraction), OpCodes.Sub },
78      { typeof(Multiplication), OpCodes.Mul },
79      { typeof(Division), OpCodes.Div },
80      { typeof(Sine), OpCodes.Sin },
81      { typeof(Cosine), OpCodes.Cos },
82      { typeof(Tangent), OpCodes.Tan },
83      { typeof(Logarithm), OpCodes.Log },
84      { typeof(Exponential), OpCodes.Exp },
85      { typeof(IfThenElse), OpCodes.IfThenElse },
86      { typeof(GreaterThan), OpCodes.GT },
87      { typeof(LessThan), OpCodes.LT },
88      { typeof(And), OpCodes.AND },
89      { typeof(Or), OpCodes.OR },
90      { typeof(Not), OpCodes.NOT},
91      { typeof(Average), OpCodes.Average},
92      { typeof(InvokeFunction), OpCodes.Call },
93      { typeof(HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols.Variable), OpCodes.Variable },
94      { typeof(LaggedVariable), OpCodes.LagVariable },
95      { typeof(IntegratedVariable), OpCodes.Integral },
96      { typeof(DerivativeVariable), OpCodes.Differential },
97      { typeof(MovingAverage), OpCodes.MovingAverage },
98      { typeof(Constant), OpCodes.Constant },
99      { typeof(Argument), OpCodes.Arg },
100    };
101    private const int ARGUMENT_STACK_SIZE = 1024;
102
103    private Dataset dataset;
104    private int row;
105    private Instruction[] code;
106    private int pc;
107    private double[] argumentStack = new double[ARGUMENT_STACK_SIZE];
108    private int argStackPointer;
109    private Dictionary<int, double[]> estimatedTargetVariableValues;
110    private int currentPredictionHorizon;
111
112    public override bool CanChangeName {
113      get { return false; }
114    }
115    public override bool CanChangeDescription {
116      get { return false; }
117    }
118
119    public SymbolicTimeSeriesExpressionInterpreter()
120      : base() {
121    }
122    #region ITimeSeriesExpressionInterpreter Members
123
124    public IEnumerable<double[]> GetSymbolicExpressionTreeValues(SymbolicExpressionTree tree, Dataset dataset, IEnumerable<string> targetVariables, IEnumerable<int> rows, int predictionHorizon) {
125      this.dataset = dataset;
126      List<int> targetVariableIndexes = new List<int>();
127      estimatedTargetVariableValues = new Dictionary<int, double[]>();
128      foreach (string targetVariable in targetVariables) {
129        int index = dataset.GetVariableIndex(targetVariable);
130        targetVariableIndexes.Add(index);
131        estimatedTargetVariableValues.Add(index, new double[predictionHorizon]);
132      }
133      var compiler = new SymbolicExpressionTreeCompiler();
134      compiler.AddInstructionPostProcessingHook(PostProcessInstruction);
135      code = compiler.Compile(tree, MapSymbolToOpCode);
136
137      foreach (var row in rows) {
138        ResetVariableValues(dataset, row);
139        for (int step = 0; step < predictionHorizon; step++) {
140          this.row = row + step;
141          this.currentPredictionHorizon = step;
142          pc = 0;
143          argStackPointer = 0;
144          double[] estimatedValues = new double[tree.Root.SubTrees[0].SubTrees.Count];
145          int component = 0;
146          foreach (int targetVariableIndex in targetVariableIndexes) {
147            double estimatedValue = Evaluate();
148            estimatedTargetVariableValues[targetVariableIndex][step] = estimatedValue;
149            estimatedValues[component] = estimatedValue;
150            component++;
151          }
152          yield return estimatedValues;
153        }
154      }
155    }
156
157    public IEnumerable<double[]> GetScaledSymbolicExpressionTreeValues(SymbolicExpressionTree tree, Dataset dataset, IEnumerable<string> targetVariables, IEnumerable<int> rows, int predictionHorizon, double[] beta, double[] alpha) {
158      this.dataset = dataset;
159      List<int> targetVariableIndexes = new List<int>();
160      estimatedTargetVariableValues = new Dictionary<int, double[]>();
161      foreach (string targetVariable in targetVariables) {
162        int index = dataset.GetVariableIndex(targetVariable);
163        targetVariableIndexes.Add(index);
164        estimatedTargetVariableValues.Add(index, new double[predictionHorizon]);
165      }
166      var compiler = new SymbolicExpressionTreeCompiler();
167      compiler.AddInstructionPostProcessingHook(PostProcessInstruction);
168      code = compiler.Compile(tree, MapSymbolToOpCode);
169
170      foreach (var row in rows) {
171        ResetVariableValues(dataset, row);
172        for (int step = 0; step < predictionHorizon; step++) {
173          this.row = row + step;
174          this.currentPredictionHorizon = step;
175          pc = 0;
176          argStackPointer = 0;
177          double[] estimatedValues = new double[tree.Root.SubTrees[0].SubTrees.Count];
178          int component = 0;
179          foreach (int targetVariableIndex in targetVariableIndexes) {
180            double estimatedValue = Evaluate() * beta[component] + alpha[component];
181            estimatedTargetVariableValues[targetVariableIndex][step] = estimatedValue;
182            estimatedValues[component] = estimatedValue;
183            component++;
184          }
185          yield return estimatedValues;
186        }
187      }
188    }
189
190    #endregion
191
192    private void ResetVariableValues(Dataset dataset, int start) {
193      foreach (var pair in estimatedTargetVariableValues) {
194        int targetVariableIndex = pair.Key;
195        double[] values = pair.Value;
196        for (int i = 0; i < values.Length; i++) {
197          values[i] = dataset[start + i, targetVariableIndex];
198        }
199      }
200    }
201
202    private Instruction PostProcessInstruction(Instruction instr) {
203      if (instr.opCode == OpCodes.Variable || instr.opCode == OpCodes.LagVariable ||
204        instr.opCode == OpCodes.Integral || instr.opCode == OpCodes.MovingAverage || instr.opCode == OpCodes.Differential) {
205        var variableTreeNode = instr.dynamicNode as VariableTreeNode;
206        instr.iArg0 = (ushort)dataset.GetVariableIndex(variableTreeNode.VariableName);
207      }
208      return instr;
209    }
210
211    private byte MapSymbolToOpCode(SymbolicExpressionTreeNode treeNode) {
212      if (symbolToOpcode.ContainsKey(treeNode.Symbol.GetType()))
213        return symbolToOpcode[treeNode.Symbol.GetType()];
214      else
215        throw new NotSupportedException("Symbol: " + treeNode.Symbol);
216    }
217
218    private double Evaluate() {
219      Instruction currentInstr = code[pc++];
220      switch (currentInstr.opCode) {
221        case OpCodes.Add: {
222            double s = Evaluate();
223            for (int i = 1; i < currentInstr.nArguments; i++) {
224              s += Evaluate();
225            }
226            return s;
227          }
228        case OpCodes.Sub: {
229            double s = Evaluate();
230            for (int i = 1; i < currentInstr.nArguments; i++) {
231              s -= Evaluate();
232            }
233            if (currentInstr.nArguments == 1) s = -s;
234            return s;
235          }
236        case OpCodes.Mul: {
237            double p = Evaluate();
238            for (int i = 1; i < currentInstr.nArguments; i++) {
239              p *= Evaluate();
240            }
241            return p;
242          }
243        case OpCodes.Div: {
244            double p = Evaluate();
245            for (int i = 1; i < currentInstr.nArguments; i++) {
246              p /= Evaluate();
247            }
248            if (currentInstr.nArguments == 1) p = 1.0 / p;
249            return p;
250          }
251        case OpCodes.Average: {
252            double sum = Evaluate();
253            for (int i = 1; i < currentInstr.nArguments; i++) {
254              sum += Evaluate();
255            }
256            return sum / currentInstr.nArguments;
257          }
258        case OpCodes.Cos: {
259            return Math.Cos(Evaluate());
260          }
261        case OpCodes.Sin: {
262            return Math.Sin(Evaluate());
263          }
264        case OpCodes.Tan: {
265            return Math.Tan(Evaluate());
266          }
267        case OpCodes.Exp: {
268            return Math.Exp(Evaluate());
269          }
270        case OpCodes.Log: {
271            return Math.Log(Evaluate());
272          }
273        case OpCodes.IfThenElse: {
274            double condition = Evaluate();
275            double result;
276            if (condition > 0.0) {
277              result = Evaluate(); SkipBakedCode();
278            } else {
279              SkipBakedCode(); result = Evaluate();
280            }
281            return result;
282          }
283        case OpCodes.AND: {
284            double result = Evaluate();
285            for (int i = 1; i < currentInstr.nArguments; i++) {
286              if (result <= 0.0) SkipBakedCode();
287              else {
288                result = Evaluate();
289              }
290            }
291            return result <= 0.0 ? -1.0 : 1.0;
292          }
293        case OpCodes.OR: {
294            double result = Evaluate();
295            for (int i = 1; i < currentInstr.nArguments; i++) {
296              if (result > 0.0) SkipBakedCode();
297              else {
298                result = Evaluate();
299              }
300            }
301            return result > 0.0 ? 1.0 : -1.0;
302          }
303        case OpCodes.NOT: {
304            return -Evaluate();
305          }
306        case OpCodes.GT: {
307            double x = Evaluate();
308            double y = Evaluate();
309            if (x > y) return 1.0;
310            else return -1.0;
311          }
312        case OpCodes.LT: {
313            double x = Evaluate();
314            double y = Evaluate();
315            if (x < y) return 1.0;
316            else return -1.0;
317          }
318        case OpCodes.Call: {
319            // evaluate sub-trees
320            // push on argStack in reverse order
321            for (int i = 0; i < currentInstr.nArguments; i++) {
322              argumentStack[argStackPointer + currentInstr.nArguments - i] = Evaluate();
323            }
324            argStackPointer += currentInstr.nArguments;
325
326            // save the pc
327            int nextPc = pc;
328            // set pc to start of function 
329            pc = currentInstr.iArg0;
330            // evaluate the function
331            double v = Evaluate();
332
333            // decrease the argument stack pointer by the number of arguments pushed
334            // to set the argStackPointer back to the original location
335            argStackPointer -= currentInstr.nArguments;
336
337            // restore the pc => evaluation will continue at point after my subtrees 
338            pc = nextPc;
339            return v;
340          }
341        case OpCodes.Arg: {
342            return argumentStack[argStackPointer - currentInstr.iArg0];
343          }
344        case OpCodes.Variable: {
345            var variableTreeNode = currentInstr.dynamicNode as VariableTreeNode;
346            return dataset[row, currentInstr.iArg0] * variableTreeNode.Weight;
347          }
348        case OpCodes.LagVariable: {
349            var lagVariableTreeNode = currentInstr.dynamicNode as LaggedVariableTreeNode;
350            int actualRow = row + lagVariableTreeNode.Lag;
351            if (actualRow < 0 || actualRow >= dataset.Rows)
352              return double.NaN;
353            return GetVariableValue(currentInstr.iArg0, lagVariableTreeNode.Lag) * lagVariableTreeNode.Weight;
354          }
355        case OpCodes.MovingAverage: {
356            var movingAvgTreeNode = currentInstr.dynamicNode as MovingAverageTreeNode;
357            if (row + movingAvgTreeNode.MinTimeOffset < 0 || row + movingAvgTreeNode.MaxTimeOffset >= dataset.Rows)
358              return double.NaN;
359            double sum = 0.0;
360            for (int relativeRow = movingAvgTreeNode.MinTimeOffset; relativeRow < movingAvgTreeNode.MaxTimeOffset; relativeRow++) {
361              sum += GetVariableValue(currentInstr.iArg0, relativeRow);
362            }
363            return movingAvgTreeNode.Weight * sum / (movingAvgTreeNode.MaxTimeOffset - movingAvgTreeNode.MinTimeOffset);
364          }
365        case OpCodes.Differential: {
366            var diffTreeNode = currentInstr.dynamicNode as DerivativeVariableTreeNode;
367            if (row + diffTreeNode.Lag - 2 < 0 || row + diffTreeNode.Lag >= dataset.Rows)
368              return double.NaN;
369            double y_0 = GetVariableValue(currentInstr.iArg0, diffTreeNode.Lag);
370            double y_1 = GetVariableValue(currentInstr.iArg0, diffTreeNode.Lag - 1);
371            double y_2 = GetVariableValue(currentInstr.iArg0, diffTreeNode.Lag - 2);
372            return diffTreeNode.Weight * (3 * y_0 - 4 * y_1 + 3 * y_2) / 2;
373          }
374        case OpCodes.Integral: {
375            var integralVarTreeNode = currentInstr.dynamicNode as IntegratedVariableTreeNode;
376            if (row + integralVarTreeNode.MinTimeOffset < 0 || row + integralVarTreeNode.MaxTimeOffset >= dataset.Rows)
377              return double.NaN;
378            double sum = 0;
379            for (int relativeRow = integralVarTreeNode.MinTimeOffset; relativeRow < integralVarTreeNode.MaxTimeOffset; relativeRow++) {
380              sum += GetVariableValue(currentInstr.iArg0, relativeRow);
381            }
382            return integralVarTreeNode.Weight * sum;
383          }
384        case OpCodes.Constant: {
385            var constTreeNode = currentInstr.dynamicNode as ConstantTreeNode;
386            return constTreeNode.Value;
387          }
388        default: throw new NotSupportedException();
389      }
390    }
391
392    private double GetVariableValue(int variableIndex, int timeoffset) {
393      if (estimatedTargetVariableValues.ContainsKey(variableIndex) &&
394                      currentPredictionHorizon + timeoffset >= 0) {
395        return estimatedTargetVariableValues[variableIndex][currentPredictionHorizon + timeoffset];
396      } else {
397        return dataset[row + timeoffset, variableIndex];
398      }
399    }
400
401    // skips a whole branch
402    protected void SkipBakedCode() {
403      int i = 1;
404      while (i > 0) {
405        i += code[pc++].nArguments;
406        i--;
407      }
408    }
409  }
410}
411
Note: See TracBrowser for help on using the repository browser.