Free cookie consent management tool by TermsFeed Policy Generator

source: branches/DataAnalysis/HeuristicLab.Problems.DataAnalysis.MultiVariate.TimeSeriesPrognosis/3.3/Symbolic/SymbolicTimeSeriesExpressionInterpreter.cs @ 6231

Last change on this file since 6231 was 5275, checked in by gkronber, 14 years ago

Merged changes from trunk to data analysis exploration branch and added fractional distance metric evaluator. #1142

File size: 16.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using System.Collections.Generic;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Symbols;
29using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
30using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Compiler;
31using HeuristicLab.Problems.DataAnalysis.Symbolic;
32using HeuristicLab.Problems.DataAnalysis.MultiVariate.TimeSeriesPrognosis.Symbolic.Interfaces;
33using HeuristicLab.Problems.DataAnalysis.MultiVariate.TimeSeriesPrognosis.Symbolic.Symbols;
34
35namespace HeuristicLab.Problems.DataAnalysis.MultiVariate.TimeSeriesPrognosis {
36  [StorableClass]
37  [Item("SymbolicTimeSeriesExpressionInterpreter", "Interpreter for symbolic expression trees representing time series forecast models.")]
38  public class SymbolicTimeSeriesExpressionInterpreter : NamedItem, ISymbolicTimeSeriesExpressionInterpreter {
39    private class OpCodes {
40      public const byte Add = 1;
41      public const byte Sub = 2;
42      public const byte Mul = 3;
43      public const byte Div = 4;
44
45      public const byte Sin = 5;
46      public const byte Cos = 6;
47      public const byte Tan = 7;
48
49      public const byte Log = 8;
50      public const byte Exp = 9;
51
52      public const byte IfThenElse = 10;
53
54      public const byte GT = 11;
55      public const byte LT = 12;
56
57      public const byte AND = 13;
58      public const byte OR = 14;
59      public const byte NOT = 15;
60
61
62      public const byte Average = 16;
63
64      public const byte Call = 17;
65
66      public const byte Variable = 18;
67      public const byte LagVariable = 19;
68      public const byte Constant = 20;
69      public const byte Arg = 21;
70      public const byte Differential = 22;
71      public const byte Integral = 23;
72      public const byte MovingAverage = 24;
73    }
74
75    private Dictionary<Type, byte> symbolToOpcode = new Dictionary<Type, byte>() {
76      { typeof(Addition), OpCodes.Add },
77      { typeof(Subtraction), OpCodes.Sub },
78      { typeof(Multiplication), OpCodes.Mul },
79      { typeof(Division), OpCodes.Div },
80      { typeof(Sine), OpCodes.Sin },
81      { typeof(Cosine), OpCodes.Cos },
82      { typeof(Tangent), OpCodes.Tan },
83      { typeof(Logarithm), OpCodes.Log },
84      { typeof(Exponential), OpCodes.Exp },
85      { typeof(IfThenElse), OpCodes.IfThenElse },
86      { typeof(GreaterThan), OpCodes.GT },
87      { typeof(LessThan), OpCodes.LT },
88      { typeof(And), OpCodes.AND },
89      { typeof(Or), OpCodes.OR },
90      { typeof(Not), OpCodes.NOT},
91      { typeof(Average), OpCodes.Average},
92      { typeof(InvokeFunction), OpCodes.Call },
93      { typeof(HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols.Variable), OpCodes.Variable },
94      { typeof(LaggedVariable), OpCodes.LagVariable },
95      { typeof(IntegratedVariable), OpCodes.Integral },
96      { typeof(DerivativeVariable), OpCodes.Differential },
97      { typeof(MovingAverage), OpCodes.MovingAverage },
98      { typeof(Constant), OpCodes.Constant },
99      { typeof(Argument), OpCodes.Arg },
100    };
101    private const int ARGUMENT_STACK_SIZE = 1024;
102
103    private Dataset dataset;
104    private int row;
105    private Instruction[] code;
106    private int pc;
107    private double[] argumentStack = new double[ARGUMENT_STACK_SIZE];
108    private int argStackPointer;
109    private Dictionary<int, double[]> estimatedTargetVariableValues;
110    private int currentPredictionHorizon;
111
112    public override bool CanChangeName {
113      get { return false; }
114    }
115    public override bool CanChangeDescription {
116      get { return false; }
117    }
118    [StorableConstructor]
119    protected SymbolicTimeSeriesExpressionInterpreter(bool deserializing) : base(deserializing) { }
120    protected SymbolicTimeSeriesExpressionInterpreter(SymbolicTimeSeriesExpressionInterpreter original, Cloner cloner)
121      : base(original, cloner) {
122    }
123    public SymbolicTimeSeriesExpressionInterpreter()
124      : base() {
125    }
126    public override IDeepCloneable Clone(Cloner cloner) {
127      return new SymbolicTimeSeriesExpressionInterpreter(this, cloner);
128    }
129    #region ITimeSeriesExpressionInterpreter Members
130
131    public IEnumerable<double[]> GetSymbolicExpressionTreeValues(SymbolicExpressionTree tree, Dataset dataset, IEnumerable<string> targetVariables, IEnumerable<int> rows, int predictionHorizon) {
132      this.dataset = dataset;
133      List<int> targetVariableIndexes = new List<int>();
134      estimatedTargetVariableValues = new Dictionary<int, double[]>();
135      foreach (string targetVariable in targetVariables) {
136        int index = dataset.GetVariableIndex(targetVariable);
137        targetVariableIndexes.Add(index);
138        estimatedTargetVariableValues.Add(index, new double[predictionHorizon]);
139      }
140      var compiler = new SymbolicExpressionTreeCompiler();
141      compiler.AddInstructionPostProcessingHook(PostProcessInstruction);
142      code = compiler.Compile(tree, MapSymbolToOpCode);
143
144      foreach (var row in rows) {
145        // ResetVariableValues(dataset, row);
146        for (int step = 0; step < predictionHorizon; step++) {
147          this.row = row + step;
148          this.currentPredictionHorizon = step;
149          pc = 0;
150          argStackPointer = 0;
151          double[] estimatedValues = new double[tree.Root.SubTrees[0].SubTrees.Count];
152          int component = 0;
153          foreach (int targetVariableIndex in targetVariableIndexes) {
154            double estimatedValue = Evaluate();
155            estimatedTargetVariableValues[targetVariableIndex][step] = estimatedValue;
156            estimatedValues[component] = estimatedValue;
157            component++;
158          }
159          yield return estimatedValues;
160        }
161      }
162    }
163
164    public IEnumerable<double[]> GetScaledSymbolicExpressionTreeValues(SymbolicExpressionTree tree, Dataset dataset, IEnumerable<string> targetVariables, IEnumerable<int> rows, int predictionHorizon, double[] beta, double[] alpha) {
165      this.dataset = dataset;
166      List<int> targetVariableIndexes = new List<int>();
167      estimatedTargetVariableValues = new Dictionary<int, double[]>();
168      foreach (string targetVariable in targetVariables) {
169        int index = dataset.GetVariableIndex(targetVariable);
170        targetVariableIndexes.Add(index);
171        estimatedTargetVariableValues.Add(index, new double[predictionHorizon]);
172      }
173      var compiler = new SymbolicExpressionTreeCompiler();
174      compiler.AddInstructionPostProcessingHook(PostProcessInstruction);
175      code = compiler.Compile(tree, MapSymbolToOpCode);
176
177      foreach (var row in rows) {
178        // ResetVariableValues(dataset, row);
179        for (int step = 0; step < predictionHorizon; step++) {
180          this.row = row + step;
181          this.currentPredictionHorizon = step;
182          pc = 0;
183          argStackPointer = 0;
184          double[] estimatedValues = new double[tree.Root.SubTrees[0].SubTrees.Count];
185          int component = 0;
186          foreach (int targetVariableIndex in targetVariableIndexes) {
187            double estimatedValue = Evaluate() * beta[component] + alpha[component];
188            estimatedTargetVariableValues[targetVariableIndex][step] = estimatedValue;
189            estimatedValues[component] = estimatedValue;
190            component++;
191          }
192          yield return estimatedValues;
193        }
194      }
195    }
196
197    #endregion
198
199    //private void ResetVariableValues(Dataset dataset, int start) {
200    //  foreach (var pair in estimatedTargetVariableValues) {
201    //    int targetVariableIndex = pair.Key;
202    //    double[] values = pair.Value;
203    //    for (int i = 0; i < values.Length; i++) {
204    //      values[i] = dataset[start + i, targetVariableIndex];
205    //    }
206    //  }
207    //}
208
209    private Instruction PostProcessInstruction(Instruction instr) {
210      if (instr.opCode == OpCodes.Variable || instr.opCode == OpCodes.LagVariable ||
211        instr.opCode == OpCodes.Integral || instr.opCode == OpCodes.MovingAverage || instr.opCode == OpCodes.Differential) {
212        var variableTreeNode = instr.dynamicNode as VariableTreeNode;
213        instr.iArg0 = (ushort)dataset.GetVariableIndex(variableTreeNode.VariableName);
214      }
215      return instr;
216    }
217
218    private byte MapSymbolToOpCode(SymbolicExpressionTreeNode treeNode) {
219      if (symbolToOpcode.ContainsKey(treeNode.Symbol.GetType()))
220        return symbolToOpcode[treeNode.Symbol.GetType()];
221      else
222        throw new NotSupportedException("Symbol: " + treeNode.Symbol);
223    }
224
225    private double Evaluate() {
226      Instruction currentInstr = code[pc++];
227      switch (currentInstr.opCode) {
228        case OpCodes.Add: {
229            double s = Evaluate();
230            for (int i = 1; i < currentInstr.nArguments; i++) {
231              s += Evaluate();
232            }
233            return s;
234          }
235        case OpCodes.Sub: {
236            double s = Evaluate();
237            for (int i = 1; i < currentInstr.nArguments; i++) {
238              s -= Evaluate();
239            }
240            if (currentInstr.nArguments == 1) s = -s;
241            return s;
242          }
243        case OpCodes.Mul: {
244            double p = Evaluate();
245            for (int i = 1; i < currentInstr.nArguments; i++) {
246              p *= Evaluate();
247            }
248            return p;
249          }
250        case OpCodes.Div: {
251            double p = Evaluate();
252            for (int i = 1; i < currentInstr.nArguments; i++) {
253              p /= Evaluate();
254            }
255            if (currentInstr.nArguments == 1) p = 1.0 / p;
256            return p;
257          }
258        case OpCodes.Average: {
259            double sum = Evaluate();
260            for (int i = 1; i < currentInstr.nArguments; i++) {
261              sum += Evaluate();
262            }
263            return sum / currentInstr.nArguments;
264          }
265        case OpCodes.Cos: {
266            return Math.Cos(Evaluate());
267          }
268        case OpCodes.Sin: {
269            return Math.Sin(Evaluate());
270          }
271        case OpCodes.Tan: {
272            return Math.Tan(Evaluate());
273          }
274        case OpCodes.Exp: {
275            return Math.Exp(Evaluate());
276          }
277        case OpCodes.Log: {
278            return Math.Log(Evaluate());
279          }
280        case OpCodes.IfThenElse: {
281            double condition = Evaluate();
282            double result;
283            if (condition > 0.0) {
284              result = Evaluate(); SkipBakedCode();
285            } else {
286              SkipBakedCode(); result = Evaluate();
287            }
288            return result;
289          }
290        case OpCodes.AND: {
291            double result = Evaluate();
292            for (int i = 1; i < currentInstr.nArguments; i++) {
293              if (result <= 0.0) SkipBakedCode();
294              else {
295                result = Evaluate();
296              }
297            }
298            return result <= 0.0 ? -1.0 : 1.0;
299          }
300        case OpCodes.OR: {
301            double result = Evaluate();
302            for (int i = 1; i < currentInstr.nArguments; i++) {
303              if (result > 0.0) SkipBakedCode();
304              else {
305                result = Evaluate();
306              }
307            }
308            return result > 0.0 ? 1.0 : -1.0;
309          }
310        case OpCodes.NOT: {
311            return -Evaluate();
312          }
313        case OpCodes.GT: {
314            double x = Evaluate();
315            double y = Evaluate();
316            if (x > y) return 1.0;
317            else return -1.0;
318          }
319        case OpCodes.LT: {
320            double x = Evaluate();
321            double y = Evaluate();
322            if (x < y) return 1.0;
323            else return -1.0;
324          }
325        case OpCodes.Call: {
326            // evaluate sub-trees
327            // push on argStack in reverse order
328            for (int i = 0; i < currentInstr.nArguments; i++) {
329              argumentStack[argStackPointer + currentInstr.nArguments - i] = Evaluate();
330            }
331            argStackPointer += currentInstr.nArguments;
332
333            // save the pc
334            int nextPc = pc;
335            // set pc to start of function 
336            pc = currentInstr.iArg0;
337            // evaluate the function
338            double v = Evaluate();
339
340            // decrease the argument stack pointer by the number of arguments pushed
341            // to set the argStackPointer back to the original location
342            argStackPointer -= currentInstr.nArguments;
343
344            // restore the pc => evaluation will continue at point after my subtrees 
345            pc = nextPc;
346            return v;
347          }
348        case OpCodes.Arg: {
349            return argumentStack[argStackPointer - currentInstr.iArg0];
350          }
351        case OpCodes.Variable: {
352            var variableTreeNode = currentInstr.dynamicNode as VariableTreeNode;
353            return dataset[row, currentInstr.iArg0] * variableTreeNode.Weight;
354          }
355        case OpCodes.LagVariable: {
356            var lagVariableTreeNode = currentInstr.dynamicNode as LaggedVariableTreeNode;
357            int actualRow = row + lagVariableTreeNode.Lag;
358            if (actualRow < 0 || actualRow >= dataset.Rows + currentPredictionHorizon)
359              return double.NaN;
360            return GetVariableValue(currentInstr.iArg0, lagVariableTreeNode.Lag) * lagVariableTreeNode.Weight;
361          }
362        case OpCodes.MovingAverage: {
363            var movingAvgTreeNode = currentInstr.dynamicNode as MovingAverageTreeNode;
364            if (row + movingAvgTreeNode.MinTimeOffset < 0 || row + movingAvgTreeNode.MaxTimeOffset >= dataset.Rows + currentPredictionHorizon)
365              return double.NaN;
366            double sum = 0.0;
367            for (int relativeRow = movingAvgTreeNode.MinTimeOffset; relativeRow < movingAvgTreeNode.MaxTimeOffset; relativeRow++) {
368              sum += GetVariableValue(currentInstr.iArg0, relativeRow);
369            }
370            return movingAvgTreeNode.Weight * sum / (movingAvgTreeNode.MaxTimeOffset - movingAvgTreeNode.MinTimeOffset);
371          }
372        case OpCodes.Differential: {
373            var diffTreeNode = currentInstr.dynamicNode as DerivativeVariableTreeNode;
374            if (row + diffTreeNode.Lag - 2 < 0 || row + diffTreeNode.Lag >= dataset.Rows + currentPredictionHorizon)
375              return double.NaN;
376            double y_0 = GetVariableValue(currentInstr.iArg0, diffTreeNode.Lag);
377            double y_1 = GetVariableValue(currentInstr.iArg0, diffTreeNode.Lag - 1);
378            double y_2 = GetVariableValue(currentInstr.iArg0, diffTreeNode.Lag - 2);
379            return diffTreeNode.Weight * (y_0 - 4 * y_1 + 3 * y_2) / 2;
380          }
381        case OpCodes.Integral: {
382            var integralVarTreeNode = currentInstr.dynamicNode as IntegratedVariableTreeNode;
383            if (row + integralVarTreeNode.MinTimeOffset < 0 || row + integralVarTreeNode.MaxTimeOffset >= dataset.Rows + currentPredictionHorizon)
384              return double.NaN;
385            double sum = 0;
386            for (int relativeRow = integralVarTreeNode.MinTimeOffset; relativeRow < integralVarTreeNode.MaxTimeOffset; relativeRow++) {
387              sum += GetVariableValue(currentInstr.iArg0, relativeRow);
388            }
389            return integralVarTreeNode.Weight * sum;
390          }
391        case OpCodes.Constant: {
392            var constTreeNode = currentInstr.dynamicNode as ConstantTreeNode;
393            return constTreeNode.Value;
394          }
395        default: throw new NotSupportedException();
396      }
397    }
398
399    private double GetVariableValue(int variableIndex, int timeoffset) {
400      if (currentPredictionHorizon + timeoffset >= 0) {
401        double[] values;
402        estimatedTargetVariableValues.TryGetValue(variableIndex, out values);
403        if (values != null) {
404          return values[currentPredictionHorizon + timeoffset];
405        }
406      }
407      if (row + timeoffset < 0 || row + timeoffset >= dataset.Rows) {
408        return double.NaN;
409      } else {
410        return dataset[row + timeoffset, variableIndex];
411      }
412    }
413
414    // skips a whole branch
415    protected void SkipBakedCode() {
416      int i = 1;
417      while (i > 0) {
418        i += code[pc++].nArguments;
419        i--;
420      }
421    }
422  }
423}
424
Note: See TracBrowser for help on using the repository browser.