Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
06/12/12 10:31:56 (12 years ago)
Author:
mkommend
Message:

#1081: Improved performance of time series prognosis.

Location:
branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis.Symbolic.TimeSeriesPrognosis/3.4
Files:
7 edited
2 copied

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis.Symbolic.TimeSeriesPrognosis/3.4/HeuristicLab.Problems.DataAnalysis.Symbolic.TimeSeriesPrognosis-3.4.csproj

    r7886 r7989  
    123123  </ItemGroup>
    124124  <ItemGroup>
     125    <Compile Include="Interfaces\ISymbolicTimeSeriesPrognogisExpressionTreeInterpreter.cs" />
    125126    <Compile Include="Interfaces\ISymbolicTimeSeriesPrognosisInterpreterOperator.cs" />
    126127    <Compile Include="Interfaces\ISymbolicTimeSeriesPrognosisEvaluator.cs" />
     
    137138    <Compile Include="SingleObjective\SymbolicTimeSeriesPrognosisSingleObjectiveTrainingBestSolutionAnalyzer.cs" />
    138139    <Compile Include="SingleObjective\SymbolicTimeSeriesPrognosisSingleObjectiveValidationBestSolutionAnalyzer.cs" />
     140    <Compile Include="SymbolicTimeSeriesPrognosisExpressionTreeInterpreter.cs" />
    139141    <Compile Include="SymbolicTimeSeriesPrognosisModel.cs" />
    140142    <Compile Include="SymbolicTimeSeriesPrognosisSolution.cs" />
  • branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis.Symbolic.TimeSeriesPrognosis/3.4/Interfaces/ISymbolicTimeSeriesPrognogisExpressionTreeInterpreter.cs

    r7929 r7989  
    2121
    2222using System.Collections.Generic;
    23 using HeuristicLab.Core;
    2423using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    2524
    2625namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
    27   public interface ISymbolicTimeSeriesPrognosisExpressionTreeInterpreter : INamedItem {
    28     IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset,
    29                                                         string[] targetVariables, IEnumerable<int> rows);
    30     IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, string[] targetVariables, IEnumerable<int> rows, int horizon);
     26  public interface ISymbolicTimeSeriesPrognosisExpressionTreeInterpreter : ISymbolicDataAnalysisExpressionTreeInterpreter {
     27    string TargetVariable { get; set; }
     28    IEnumerable<IEnumerable<double>> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows, int horizon);
    3129  }
    3230}
  • branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis.Symbolic.TimeSeriesPrognosis/3.4/SingleObjective/SymbolicTimeSeriesPrognosisSingleObjectiveEvaluator.cs

    r7120 r7989  
    2121
    2222
     23using System;
     24using System.Collections.Generic;
    2325using HeuristicLab.Common;
    2426using HeuristicLab.Core;
     
    2729using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2830namespace HeuristicLab.Problems.DataAnalysis.Symbolic.TimeSeriesPrognosis {
     31  [StorableClass]
    2932  public abstract class SymbolicTimeSeriesPrognosisSingleObjectiveEvaluator : SymbolicDataAnalysisSingleObjectiveEvaluator<ITimeSeriesPrognosisProblemData>, ISymbolicTimeSeriesPrognosisSingleObjectiveEvaluator {
    3033    private const string HorizonParameterName = "Horizon";
     34    private const string ApplyLinearScalingParameterName = "ApplyLinearScaling";
    3135
     36    public IFixedValueParameter<BoolValue> ApplyLinearScalingParameter {
     37      get { return (IFixedValueParameter<BoolValue>)Parameters[ApplyLinearScalingParameterName]; }
     38    }
     39    public bool ApplyLinearScaling {
     40      get { return ApplyLinearScalingParameter.Value.Value; }
     41      set { ApplyLinearScalingParameter.Value.Value = value; }
     42    }
    3243    public IValueLookupParameter<IntValue> HorizonParameter {
    3344      get { return (IValueLookupParameter<IntValue>)Parameters[HorizonParameterName]; }
     
    4253    protected SymbolicTimeSeriesPrognosisSingleObjectiveEvaluator()
    4354      : base() {
     55      Parameters.Add(new FixedValueParameter<BoolValue>(ApplyLinearScalingParameterName, "Flag that indicates if the individual should be linearly scaled before evaluating.", new BoolValue(true)));
    4456      Parameters.Add(new ValueLookupParameter<IntValue>(HorizonParameterName, "The time interval for which the prognosis should be calculated.", new IntValue(1)));
     57      ApplyLinearScalingParameter.Hidden = true;
     58    }
     59
     60
     61    [ThreadStatic]
     62    private static double[] cache;
     63    protected static void CalculateWithScaling(IEnumerable<double> targetValues, IEnumerable<double> estimatedValues, IOnlineCalculator calculator, int maxRows) {
     64      if (cache == null || cache.GetLength(0) < maxRows) {
     65        cache = new double[maxRows];
     66      }
     67
     68      //calculate linear scaling
     69      //the static methods of the calculator could not be used as it performs a check if the enumerators have an equal amount of elements
     70      //this is not true if the cache is used
     71      int i = 0;
     72      var linearScalingCalculator = new OnlineLinearScalingParameterCalculator();
     73      var targetValuesEnumerator = targetValues.GetEnumerator();
     74      var estimatedValuesEnumerator = estimatedValues.GetEnumerator();
     75      while (targetValuesEnumerator.MoveNext() && estimatedValuesEnumerator.MoveNext()) {
     76        double target = targetValuesEnumerator.Current;
     77        double estimated = estimatedValuesEnumerator.Current;
     78        linearScalingCalculator.Add(estimated, target);
     79        cache[i] = estimated;
     80        i++;
     81      }
     82      double alpha = linearScalingCalculator.Alpha;
     83      double beta = linearScalingCalculator.Beta;
     84
     85      //calculate the quality by using the passed online calculator
     86      targetValuesEnumerator = targetValues.GetEnumerator();
     87      i = 0;
     88      while (targetValuesEnumerator.MoveNext()) {
     89        double target = targetValuesEnumerator.Current;
     90        double estimated = cache[i] * beta + alpha;
     91        calculator.Add(target, estimated);
     92        i++;
     93      }
    4594    }
    4695  }
  • branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis.Symbolic.TimeSeriesPrognosis/3.4/SingleObjective/SymbolicTimeSeriesPrognosisSingleObjectiveMeanSquaredErrorEvaluator.cs

    r7183 r7989  
    2222using System;
    2323using System.Collections.Generic;
    24 using System.Drawing.Printing;
    2524using System.Linq;
    2625using HeuristicLab.Common;
     
    5150      IEnumerable<int> rows = GenerateRowsToEvaluate();
    5251
    53       double quality = Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue,
    54         solution,
     52      var interpreter = (ISymbolicTimeSeriesPrognosisExpressionTreeInterpreter)SymbolicDataAnalysisTreeInterpreterParameter.ActualValue;
     53
     54      double quality = Calculate(interpreter, solution,
    5555        EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper,
    5656        ProblemDataParameter.ActualValue,
    57         rows, HorizonParameter.ActualValue.Value);
     57        rows, HorizonParameter.ActualValue.Value, ApplyLinearScaling);
    5858      QualityParameter.ActualValue = new DoubleValue(quality);
    5959
     
    6161    }
    6262
    63     public static double Calculate(ISymbolicTimeSeriesPrognosisExpressionTreeInterpreter interpreter, ISymbolicExpressionTree solution, double lowerEstimationLimit, double upperEstimationLimit, ITimeSeriesPrognosisProblemData problemData, IEnumerable<int> rows, int horizon) {
    64       double[] alpha;
    65       double[] beta;
    66       DetermineScalingFactors(solution, problemData, interpreter, rows, out alpha, out beta);
    67       var scaledSolution = Scale(solution, alpha, beta);
    68       string[] targetVariables = problemData.TargetVariables.ToArray();
    69       var meanSquaredErrorCalculators = Enumerable.Range(0, problemData.TargetVariables.Count())
    70         .Select(i => new OnlineMeanSquaredErrorCalculator()).ToArray();
     63    public static double Calculate(ISymbolicTimeSeriesPrognosisExpressionTreeInterpreter interpreter, ISymbolicExpressionTree solution, double lowerEstimationLimit, double upperEstimationLimit, ITimeSeriesPrognosisProblemData problemData, IEnumerable<int> rows, int horizon, bool applyLinearScaling) {
     64      IEnumerable<double> targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows.SelectMany(r => Enumerable.Range(r, horizon)));
     65      IEnumerable<double> estimatedValues = interpreter.GetSymbolicExpressionTreeValues(solution, problemData.Dataset, rows, horizon).SelectMany(x => x);
     66      IEnumerable<double> boundedEstimatedValues = estimatedValues.LimitToRange(lowerEstimationLimit, upperEstimationLimit);
     67      OnlineCalculatorError errorState;
    7168
    72       var allContinuationsEnumerator = interpreter.GetSymbolicExpressionTreeValues(scaledSolution, problemData.Dataset,
    73                                                                                   targetVariables,
    74                                                                                   rows, horizon).GetEnumerator();
    75       allContinuationsEnumerator.MoveNext();
    76       // foreach row
    77       foreach (var row in rows) {
    78         // foreach horizon
    79         for (int h = 0; h < horizon; h++) {
    80           // foreach component
    81           for (int i = 0; i < meanSquaredErrorCalculators.Length; i++) {
    82             double e = Math.Min(upperEstimationLimit, Math.Max(lowerEstimationLimit, allContinuationsEnumerator.Current));
    83             meanSquaredErrorCalculators[i].Add(problemData.Dataset.GetDoubleValue(targetVariables[i], row + h), e);
    84             if (meanSquaredErrorCalculators[i].ErrorState == OnlineCalculatorError.InvalidValueAdded)
    85               return double.MaxValue;
    86             allContinuationsEnumerator.MoveNext();
    87           }
    88         }
    89       }
    90       var meanCalculator = new OnlineMeanAndVarianceCalculator();
    91       foreach (var calc in meanSquaredErrorCalculators) {
    92         if (calc.ErrorState != OnlineCalculatorError.None) return double.MaxValue;
    93         meanCalculator.Add(calc.MeanSquaredError);
    94       }
     69      double mse;
     70      if (applyLinearScaling) {
     71        var mseCalculator = new OnlineMeanSquaredErrorCalculator();
     72        CalculateWithScaling(targetValues, boundedEstimatedValues, mseCalculator, problemData.Dataset.Rows);
     73        errorState = mseCalculator.ErrorState;
     74        mse = mseCalculator.MeanSquaredError;
     75      } else
     76        mse = OnlineMeanSquaredErrorCalculator.Calculate(targetValues, boundedEstimatedValues, out errorState);
    9577
    96       return meanCalculator.MeanErrorState == OnlineCalculatorError.None ? meanCalculator.Mean : double.MaxValue;
     78      if (errorState != OnlineCalculatorError.None) return Double.NaN;
     79      else return mse;
    9780    }
    9881
    99     private static ISymbolicExpressionTree Scale(ISymbolicExpressionTree solution, double[] alpha, double[] beta) {
    100       var clone = (ISymbolicExpressionTree)solution.Clone();
    101       int n = alpha.Length;
    102       for (int i = 0; i < n; i++) {
    103         var parent = clone.Root.GetSubtree(0);
    104         var rpb = clone.Root.GetSubtree(0).GetSubtree(i);
    105         var scaledRpb = MakeSum(
    106           MakeProduct(rpb,
    107             MakeConstant(beta[i], clone.Root.Grammar), clone.Root.Grammar),
    108             MakeConstant(alpha[i], clone.Root.Grammar), clone.Root.Grammar);
    109         parent.RemoveSubtree(i);
    110         parent.InsertSubtree(i, scaledRpb);
    111       }
    112       return clone;
    113     }
    114 
    115     private static ISymbolicExpressionTreeNode MakeSum(ISymbolicExpressionTreeNode a, ISymbolicExpressionTreeNode b, ISymbolicExpressionTreeGrammar grammar) {
    116       var sum = grammar.Symbols.Where(s => s is Addition).First().CreateTreeNode();
    117       sum.AddSubtree(a);
    118       sum.AddSubtree(b);
    119       return sum;
    120     }
    121 
    122     private static ISymbolicExpressionTreeNode MakeProduct(ISymbolicExpressionTreeNode a, ISymbolicExpressionTreeNode b, ISymbolicExpressionTreeGrammar grammar) {
    123       var prod = grammar.Symbols.Where(s => s is Multiplication).First().CreateTreeNode();
    124       prod.AddSubtree(a);
    125       prod.AddSubtree(b);
    126       return prod;
    127     }
    128 
    129     private static ISymbolicExpressionTreeNode MakeConstant(double c, ISymbolicExpressionTreeGrammar grammar) {
    130       var node = (ConstantTreeNode)grammar.Symbols.Where(s => s is Constant).First().CreateTreeNode();
    131       node.Value = c;
    132       return node;
    133     }
    134 
    135     private static void DetermineScalingFactors(ISymbolicExpressionTree solution, ITimeSeriesPrognosisProblemData problemData, ISymbolicTimeSeriesPrognosisExpressionTreeInterpreter interpreter, IEnumerable<int> rows, out double[] alpha, out double[] beta) {
    136       string[] targetVariables = problemData.TargetVariables.ToArray();
    137       int nComponents = targetVariables.Length;
    138       alpha = new double[nComponents];
    139       beta = new double[nComponents];
    140       var oneStepPredictionsEnumerator = interpreter.GetSymbolicExpressionTreeValues(solution, problemData.Dataset, targetVariables, rows).GetEnumerator();
    141       var scalingParameterCalculators =
    142         Enumerable.Repeat(0, nComponents).Select(x => new OnlineLinearScalingParameterCalculator()).ToArray();
    143       var targetValues = problemData.Dataset.GetVectorEnumerable(targetVariables, rows);
    144       var targetValueEnumerator = targetValues.GetEnumerator();
    145 
    146       var more = oneStepPredictionsEnumerator.MoveNext() & targetValueEnumerator.MoveNext();
    147       while (more) {
    148         for (int i = 0; i < nComponents; i++) {
    149           scalingParameterCalculators[i].Add(oneStepPredictionsEnumerator.Current, targetValueEnumerator.Current);
    150           more = oneStepPredictionsEnumerator.MoveNext() & targetValueEnumerator.MoveNext();
    151         }
    152       }
    153 
    154       for (int i = 0; i < nComponents; i++) {
    155         if (scalingParameterCalculators[i].ErrorState == OnlineCalculatorError.None) {
    156           alpha[i] = scalingParameterCalculators[i].Alpha;
    157           beta[i] = scalingParameterCalculators[i].Beta;
    158         } else {
    159           alpha[i] = 0.0;
    160           beta[i] = 1.0;
    161         }
    162       }
    163     }
    16482
    16583    public override double Evaluate(IExecutionContext context, ISymbolicExpressionTree tree, ITimeSeriesPrognosisProblemData problemData, IEnumerable<int> rows) {
     
    16886      HorizonParameter.ExecutionContext = context;
    16987
    170       double mse = Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, tree, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, problemData, rows, HorizonParameter.ActualValue.Value);
     88      double mse = Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue as ISymbolicTimeSeriesPrognosisExpressionTreeInterpreter, tree, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, problemData, rows, HorizonParameter.ActualValue.Value, ApplyLinearScaling);
    17189
    17290      HorizonParameter.ExecutionContext = null;
  • branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis.Symbolic.TimeSeriesPrognosis/3.4/SingleObjective/SymbolicTimeSeriesPrognosisSingleObjectiveProblem.cs

    r7843 r7989  
    5454      : base(new TimeSeriesPrognosisProblemData(), new SymbolicTimeSeriesPrognosisSingleObjectiveMeanSquaredErrorEvaluator(), new SymbolicDataAnalysisExpressionTreeCreator()) {
    5555      Parameters.Add(new FixedValueParameter<DoubleLimit>(EstimationLimitsParameterName, EstimationLimitsParameterDescription));
    56 
    5756      EstimationLimitsParameter.Hidden = true;
    5857
     
    6160      MaximumSymbolicExpressionTreeLength.Value = InitialMaximumTreeLength;
    6261
     62      var interpeter = new SymbolicTimeSeriesPrognosisExpressionTreeInterpreter();
     63      interpeter.TargetVariable = ProblemData.TargetVariable;
     64      SymbolicExpressionTreeInterpreter = interpeter;
     65
    6366      SymbolicExpressionTreeGrammarParameter.ValueChanged += (o, e) => ConfigureGrammarSymbols();
    64 
    6567      ConfigureGrammarSymbols();
    6668
     
    8890    protected override void OnProblemDataChanged() {
    8991      base.OnProblemDataChanged();
     92      var interpreter = SymbolicExpressionTreeInterpreter as ISymbolicTimeSeriesPrognosisExpressionTreeInterpreter;
     93      if (interpreter != null) {
     94        interpreter.TargetVariable = ProblemData.TargetVariable;
     95      }
    9096      UpdateEstimationLimits();
     97
    9198    }
    9299
     
    107114      }
    108115    }
    109 
    110116  }
    111117}
  • branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis.Symbolic.TimeSeriesPrognosis/3.4/SingleObjective/SymbolicTimeSeriesPrognosisSingleObjectiveTrainingBestSolutionAnalyzer.cs

    r7183 r7989  
    2020#endregion
    2121
    22 using System.Linq;
    2322using HeuristicLab.Common;
    2423using HeuristicLab.Core;
     
    7675
    7776    protected override ISymbolicTimeSeriesPrognosisSolution CreateSolution(ISymbolicExpressionTree bestTree, double bestQuality) {
    78       var model = new SymbolicTimeSeriesPrognosisModel((ISymbolicExpressionTree)bestTree.Clone(), SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, ProblemDataParameter.ActualValue.TargetVariables.ToArray(), EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper);
     77      var model = new SymbolicTimeSeriesPrognosisModel((ISymbolicExpressionTree)bestTree.Clone(), SymbolicDataAnalysisTreeInterpreterParameter.ActualValue as ISymbolicTimeSeriesPrognosisExpressionTreeInterpreter, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper);
    7978      if (ApplyLinearScaling.Value)
    8079        SymbolicTimeSeriesPrognosisModel.Scale(model, ProblemDataParameter.ActualValue);
  • branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis.Symbolic.TimeSeriesPrognosis/3.4/SingleObjective/SymbolicTimeSeriesPrognosisSingleObjectiveValidationBestSolutionAnalyzer.cs

    r7183 r7989  
    2020#endregion
    2121
    22 using System.Linq;
    2322using HeuristicLab.Common;
    2423using HeuristicLab.Core;
     
    6564
    6665    protected override ISymbolicTimeSeriesPrognosisSolution CreateSolution(ISymbolicExpressionTree bestTree, double bestQuality) {
    67       var model = new SymbolicTimeSeriesPrognosisModel((ISymbolicExpressionTree)bestTree.Clone(), SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, ProblemDataParameter.ActualValue.TargetVariables.ToArray(), EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper);
     66      var model = new SymbolicTimeSeriesPrognosisModel((ISymbolicExpressionTree)bestTree.Clone(), SymbolicDataAnalysisTreeInterpreterParameter.ActualValue as ISymbolicTimeSeriesPrognosisExpressionTreeInterpreter, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper);
    6867      if (ApplyLinearScaling.Value)
    6968        SymbolicTimeSeriesPrognosisModel.Scale(model, ProblemDataParameter.ActualValue);
  • branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis.Symbolic.TimeSeriesPrognosis/3.4/SymbolicTimeSeriesPrognosisExpressionTreeInterpreter.cs

    r7949 r7989  
    2222using System;
    2323using System.Collections.Generic;
     24using System.Linq;
    2425using HeuristicLab.Common;
    2526using HeuristicLab.Core;
     
    2829using HeuristicLab.Parameters;
    2930using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    30 using System.Linq;
    3131
    3232namespace HeuristicLab.Problems.DataAnalysis.Symbolic.TimeSeriesPrognosis {
    3333  [StorableClass]
    3434  [Item("SymbolicTimeSeriesPrognosisInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.")]
    35   public sealed class SymbolicTimeSeriesPrognosisInterpreter : ParameterizedNamedItem, ISymbolicTimeSeriesPrognosisExpressionTreeInterpreter {
    36     private const string CheckExpressionsWithIntervalArithmeticParameterName = "CheckExpressionsWithIntervalArithmetic";
    37     #region private classes
    38     private class InterpreterState {
    39       private double[] argumentStack;
    40       private int argumentStackPointer;
    41       private Instruction[] code;
    42       private int pc;
    43       public int ProgramCounter {
    44         get { return pc; }
    45         set { pc = value; }
    46       }
    47       internal InterpreterState(Instruction[] code, int argumentStackSize) {
    48         this.code = code;
    49         this.pc = 0;
    50         if (argumentStackSize > 0) {
    51           this.argumentStack = new double[argumentStackSize];
     35  public sealed class SymbolicTimeSeriesPrognosisExpressionTreeInterpreter : SymbolicDataAnalysisExpressionTreeInterpreter, ISymbolicTimeSeriesPrognosisExpressionTreeInterpreter {
     36    private const string TargetVariableParameterName = "TargetVariable";
     37
     38    public IFixedValueParameter<StringValue> TargetVariableParameter {
     39      get { return (IFixedValueParameter<StringValue>)Parameters[TargetVariableParameterName]; }
     40    }
     41
     42    public string TargetVariable {
     43      get { return TargetVariableParameter.Value.Value; }
     44      set { TargetVariableParameter.Value.Value = value; }
     45    }
     46
     47    [ThreadStatic]
     48    private static double[] targetVariableCache;
     49    [ThreadStatic]
     50    private static List<int> invalidateCacheIndexes;
     51
     52    [StorableConstructor]
     53    private SymbolicTimeSeriesPrognosisExpressionTreeInterpreter(bool deserializing) : base(deserializing) { }
     54    private SymbolicTimeSeriesPrognosisExpressionTreeInterpreter(SymbolicTimeSeriesPrognosisExpressionTreeInterpreter original, Cloner cloner) : base(original, cloner) { }
     55    public override IDeepCloneable Clone(Cloner cloner) {
     56      return new SymbolicTimeSeriesPrognosisExpressionTreeInterpreter(this, cloner);
     57    }
     58
     59    public SymbolicTimeSeriesPrognosisExpressionTreeInterpreter()
     60      : base("SymbolicTimeSeriesPrognosisInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.") {
     61      Parameters.Add(new FixedValueParameter<StringValue>(TargetVariableParameterName));
     62      TargetVariableParameter.Hidden = true;
     63    }
     64
     65    // for each row several (=#horizon) future predictions
     66    public IEnumerable<IEnumerable<double>> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows, int horizon) {
     67      return GetSymbolicExpressionTreeValues(tree, dataset, rows, rows.Select(row => horizon));
     68    }
     69
     70    public IEnumerable<IEnumerable<double>> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows, IEnumerable<int> horizons) {
     71      if (CheckExpressionsWithIntervalArithmetic.Value)
     72        throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
     73      if (targetVariableCache == null || targetVariableCache.GetLength(0) < dataset.Rows)
     74        targetVariableCache = dataset.GetDoubleValues(TargetVariable).ToArray();
     75      if (invalidateCacheIndexes == null)
     76        invalidateCacheIndexes = new List<int>(10);
     77
     78      string targetVariable = TargetVariable;
     79      EvaluatedSolutions.Value++; // increment the evaluated solutions counter
     80      var state = PrepareInterpreterState(tree, dataset, targetVariableCache);
     81      var rowsEnumerator = rows.GetEnumerator();
     82      var horizonsEnumerator = horizons.GetEnumerator();
     83
     84      // produce a n-step forecast for all rows
     85      while (rowsEnumerator.MoveNext() & horizonsEnumerator.MoveNext()) {
     86        int row = rowsEnumerator.Current;
     87        int horizon = horizonsEnumerator.Current;
     88
     89        double[] vProgs = new double[horizon];
     90        for (int i = 0; i < horizon; i++) {
     91          int localRow = i + row; // create a local variable for the ref parameter
     92          vProgs[i] = Evaluate(dataset, ref localRow, state);
     93          targetVariableCache[localRow] = vProgs[i];
     94          invalidateCacheIndexes.Add(localRow);
     95          state.Reset();
    5296        }
    53         this.argumentStackPointer = 0;
     97
     98        yield return vProgs;
     99
     100        int j = 0;
     101        foreach (var targetValue in dataset.GetDoubleValues(TargetVariable, invalidateCacheIndexes)) {
     102          targetVariableCache[invalidateCacheIndexes[j]] = targetValue;
     103          j++;
     104        }
     105        invalidateCacheIndexes.Clear();
    54106      }
    55107
    56       internal void Reset() {
    57         this.pc = 0;
    58         this.argumentStackPointer = 0;
    59       }
    60 
    61       internal Instruction NextInstruction() {
    62         return code[pc++];
    63       }
    64       private void Push(double val) {
    65         argumentStack[argumentStackPointer++] = val;
    66       }
    67       private double Pop() {
    68         return argumentStack[--argumentStackPointer];
    69       }
    70 
    71       internal void CreateStackFrame(double[] argValues) {
    72         // push in reverse order to make indexing easier
    73         for (int i = argValues.Length - 1; i >= 0; i--) {
    74           argumentStack[argumentStackPointer++] = argValues[i];
    75         }
    76         Push(argValues.Length);
    77       }
    78 
    79       internal void RemoveStackFrame() {
    80         int size = (int)Pop();
    81         argumentStackPointer -= size;
    82       }
    83 
    84       internal double GetStackFrameValue(ushort index) {
    85         // layout of stack:
    86         // [0]   <- argumentStackPointer
    87         // [StackFrameSize = N + 1]
    88         // [Arg0] <- argumentStackPointer - 2 - 0
    89         // [Arg1] <- argumentStackPointer - 2 - 1
    90         // [...]
    91         // [ArgN] <- argumentStackPointer - 2 - N
    92         // <Begin of stack frame>
    93         return argumentStack[argumentStackPointer - index - 2];
    94       }
    95     }
    96     private class OpCodes {
    97       public const byte Add = 1;
    98       public const byte Sub = 2;
    99       public const byte Mul = 3;
    100       public const byte Div = 4;
    101 
    102       public const byte Sin = 5;
    103       public const byte Cos = 6;
    104       public const byte Tan = 7;
    105 
    106       public const byte Log = 8;
    107       public const byte Exp = 9;
    108 
    109       public const byte IfThenElse = 10;
    110 
    111       public const byte GT = 11;
    112       public const byte LT = 12;
    113 
    114       public const byte AND = 13;
    115       public const byte OR = 14;
    116       public const byte NOT = 15;
    117 
    118 
    119       public const byte Average = 16;
    120 
    121       public const byte Call = 17;
    122 
    123       public const byte Variable = 18;
    124       public const byte LagVariable = 19;
    125       public const byte Constant = 20;
    126       public const byte Arg = 21;
    127 
    128       public const byte Power = 22;
    129       public const byte Root = 23;
    130       public const byte TimeLag = 24;
    131       public const byte Integral = 25;
    132       public const byte Derivative = 26;
    133 
    134       public const byte VariableCondition = 27;
    135     }
    136     #endregion
    137 
    138     private Dictionary<Type, byte> symbolToOpcode = new Dictionary<Type, byte>() {
    139       { typeof(Addition), OpCodes.Add },
    140       { typeof(Subtraction), OpCodes.Sub },
    141       { typeof(Multiplication), OpCodes.Mul },
    142       { typeof(Division), OpCodes.Div },
    143       { typeof(Sine), OpCodes.Sin },
    144       { typeof(Cosine), OpCodes.Cos },
    145       { typeof(Tangent), OpCodes.Tan },
    146       { typeof(Logarithm), OpCodes.Log },
    147       { typeof(Exponential), OpCodes.Exp },
    148       { typeof(IfThenElse), OpCodes.IfThenElse },
    149       { typeof(GreaterThan), OpCodes.GT },
    150       { typeof(LessThan), OpCodes.LT },
    151       { typeof(And), OpCodes.AND },
    152       { typeof(Or), OpCodes.OR },
    153       { typeof(Not), OpCodes.NOT},
    154       { typeof(Average), OpCodes.Average},
    155       { typeof(InvokeFunction), OpCodes.Call },
    156       { typeof(HeuristicLab.Problems.DataAnalysis.Symbolic.Variable), OpCodes.Variable },
    157       { typeof(LaggedVariable), OpCodes.LagVariable },
    158       { typeof(Constant), OpCodes.Constant },
    159       { typeof(Argument), OpCodes.Arg },
    160       { typeof(Power),OpCodes.Power},
    161       { typeof(Root),OpCodes.Root},
    162       { typeof(TimeLag), OpCodes.TimeLag},
    163       { typeof(Integral), OpCodes.Integral},
    164       { typeof(Derivative), OpCodes.Derivative},
    165       { typeof(VariableCondition),OpCodes.VariableCondition}
    166     };
    167 
    168     public override bool CanChangeName {
    169       get { return false; }
    170     }
    171     public override bool CanChangeDescription {
    172       get { return false; }
     108      if (rowsEnumerator.MoveNext() || horizonsEnumerator.MoveNext())
     109        throw new ArgumentException("Number of elements in rows and horizon enumerations doesn't match.");
    173110    }
    174111
    175     #region parameter properties
    176     public IValueParameter<BoolValue> CheckExpressionsWithIntervalArithmeticParameter {
    177       get { return (IValueParameter<BoolValue>)Parameters[CheckExpressionsWithIntervalArithmeticParameterName]; }
    178     }
    179     #endregion
    180 
    181     #region properties
    182     public BoolValue CheckExpressionsWithIntervalArithmetic {
    183       get { return CheckExpressionsWithIntervalArithmeticParameter.Value; }
    184       set { CheckExpressionsWithIntervalArithmeticParameter.Value = value; }
    185     }
    186 
    187     [Storable]
    188     private readonly string[] targetVariables;
    189     #endregion
    190 
    191 
    192     [StorableConstructor]
    193     private SymbolicTimeSeriesPrognosisInterpreter(bool deserializing) : base(deserializing) { }
    194     private SymbolicTimeSeriesPrognosisInterpreter(SymbolicTimeSeriesPrognosisInterpreter original, Cloner cloner)
    195       : base(original, cloner) {
    196       this.targetVariables = original.targetVariables;
    197     }
    198     public override IDeepCloneable Clone(Cloner cloner) {
    199       return new SymbolicTimeSeriesPrognosisInterpreter(this, cloner);
    200     }
    201 
    202     public SymbolicTimeSeriesPrognosisInterpreter(string[] targetVariables)
    203       : base("SymbolicTimeSeriesPrognosisInterpreter", "Interpreter for symbolic expression trees including automatically defined functions.") {
    204       Parameters.Add(new ValueParameter<BoolValue>(CheckExpressionsWithIntervalArithmeticParameterName, "Switch that determines if the interpreter checks the validity of expressions with interval arithmetic before evaluating the expression.", new BoolValue(false)));
    205       this.targetVariables = targetVariables;
    206     }
    207 
    208     public IEnumerable<double> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows) {
    209       throw new NotSupportedException();
    210     }
    211 
    212     // for each row for each target variable one prognosis (=enumerable of future values)
    213     public IEnumerable<IEnumerable<IEnumerable<double>>> GetSymbolicExpressionTreeValues(ISymbolicExpressionTree tree, Dataset dataset, IEnumerable<int> rows, int horizon) {
    214       if (CheckExpressionsWithIntervalArithmetic.Value)
    215         throw new NotSupportedException("Interval arithmetic is not yet supported in the symbolic data analysis interpreter.");
    216       var compiler = new SymbolicExpressionTreeCompiler();
    217       Instruction[] code = compiler.Compile(tree, MapSymbolToOpCode);
     112    private InterpreterState PrepareInterpreterState(ISymbolicExpressionTree tree, Dataset dataset, double[] targetVariableCache) {
     113      Instruction[] code = SymbolicExpressionTreeCompiler.Compile(tree, OpCodes.MapSymbolToOpCode);
    218114      int necessaryArgStackSize = 0;
    219       for (int i = 0; i < code.Length; i++) {
    220         Instruction instr = code[i];
     115      foreach (Instruction instr in code) {
    221116        if (instr.opCode == OpCodes.Variable) {
    222           var variableTreeNode = instr.dynamicNode as VariableTreeNode;
    223           instr.iArg0 = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
    224           code[i] = instr;
     117          var variableTreeNode = (VariableTreeNode)instr.dynamicNode;
     118          if (variableTreeNode.VariableName == TargetVariable)
     119            instr.iArg0 = targetVariableCache;
     120          else
     121            instr.iArg0 = dataset.GetReadOnlyDoubleValues(variableTreeNode.VariableName);
    225122        } else if (instr.opCode == OpCodes.LagVariable) {
    226           var laggedVariableTreeNode = instr.dynamicNode as LaggedVariableTreeNode;
     123          var laggedVariableTreeNode = (LaggedVariableTreeNode)instr.dynamicNode;
    227124          instr.iArg0 = dataset.GetReadOnlyDoubleValues(laggedVariableTreeNode.VariableName);
    228           code[i] = instr;
    229125        } else if (instr.opCode == OpCodes.VariableCondition) {
    230           var variableConditionTreeNode = instr.dynamicNode as VariableConditionTreeNode;
     126          var variableConditionTreeNode = (VariableConditionTreeNode)instr.dynamicNode;
    231127          instr.iArg0 = dataset.GetReadOnlyDoubleValues(variableConditionTreeNode.VariableName);
    232128        } else if (instr.opCode == OpCodes.Call) {
     
    234130        }
    235131      }
    236       var state = new InterpreterState(code, necessaryArgStackSize);
    237132
    238       int nComponents = tree.Root.GetSubtree(0).SubtreeCount;
    239       // produce a n-step forecast for each target variable for all rows
    240       var cachedPrognosedValues = new Dictionary<string, double[]>();
    241       foreach (var targetVariable in targetVariables)
    242         cachedPrognosedValues[targetVariable] = new double[horizon];
    243       foreach (var rowEnum in rows) {
    244         int row = rowEnum;
    245         List<double[]> vProgs = new List<double[]>();
    246         foreach (var horizonRow in Enumerable.Range(row, horizon)) {
    247           int localRow = horizonRow; // create a local variable for the ref parameter
    248           var vPrognosis = from i in Enumerable.Range(0, nComponents)
    249                            select Evaluate(dataset, ref localRow, row - 1, state, cachedPrognosedValues);
    250 
    251           var vPrognosisArr = vPrognosis.ToArray();
    252           vProgs.Add(vPrognosisArr);
    253           // set cachedValues for prognosis of future values
    254           for (int i = 0; i < vPrognosisArr.Length; i++)
    255             cachedPrognosedValues[targetVariables[i]][horizonRow - row] = vPrognosisArr[i];
    256 
    257           state.Reset();
    258         }
    259 
    260         yield return from component in Enumerable.Range(0, nComponents)
    261                      select from prognosisStep in Enumerable.Range(0, vProgs.Count)
    262                             select vProgs[prognosisStep][component];
    263       }
    264     }
    265 
    266     private double Evaluate(Dataset dataset, ref int row, int lastObservedRow, InterpreterState state, Dictionary<string, double[]> cachedPrognosedValues) {
    267       Instruction currentInstr = state.NextInstruction();
    268       switch (currentInstr.opCode) {
    269         case OpCodes.Add: {
    270             double s = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    271             for (int i = 1; i < currentInstr.nArguments; i++) {
    272               s += Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    273             }
    274             return s;
    275           }
    276         case OpCodes.Sub: {
    277             double s = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    278             for (int i = 1; i < currentInstr.nArguments; i++) {
    279               s -= Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    280             }
    281             if (currentInstr.nArguments == 1) s = -s;
    282             return s;
    283           }
    284         case OpCodes.Mul: {
    285             double p = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    286             for (int i = 1; i < currentInstr.nArguments; i++) {
    287               p *= Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    288             }
    289             return p;
    290           }
    291         case OpCodes.Div: {
    292             double p = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    293             for (int i = 1; i < currentInstr.nArguments; i++) {
    294               p /= Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    295             }
    296             if (currentInstr.nArguments == 1) p = 1.0 / p;
    297             return p;
    298           }
    299         case OpCodes.Average: {
    300             double sum = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    301             for (int i = 1; i < currentInstr.nArguments; i++) {
    302               sum += Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    303             }
    304             return sum / currentInstr.nArguments;
    305           }
    306         case OpCodes.Cos: {
    307             return Math.Cos(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
    308           }
    309         case OpCodes.Sin: {
    310             return Math.Sin(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
    311           }
    312         case OpCodes.Tan: {
    313             return Math.Tan(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
    314           }
    315         case OpCodes.Power: {
    316             double x = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    317             double y = Math.Round(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
    318             return Math.Pow(x, y);
    319           }
    320         case OpCodes.Root: {
    321             double x = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    322             double y = Math.Round(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
    323             return Math.Pow(x, 1 / y);
    324           }
    325         case OpCodes.Exp: {
    326             return Math.Exp(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
    327           }
    328         case OpCodes.Log: {
    329             return Math.Log(Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues));
    330           }
    331         case OpCodes.IfThenElse: {
    332             double condition = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    333             double result;
    334             if (condition > 0.0) {
    335               result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues); SkipInstructions(state);
    336             } else {
    337               SkipInstructions(state); result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    338             }
    339             return result;
    340           }
    341         case OpCodes.AND: {
    342             double result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    343             for (int i = 1; i < currentInstr.nArguments; i++) {
    344               if (result > 0.0) result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    345               else {
    346                 SkipInstructions(state);
    347               }
    348             }
    349             return result > 0.0 ? 1.0 : -1.0;
    350           }
    351         case OpCodes.OR: {
    352             double result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    353             for (int i = 1; i < currentInstr.nArguments; i++) {
    354               if (result <= 0.0) result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    355               else {
    356                 SkipInstructions(state);
    357               }
    358             }
    359             return result > 0.0 ? 1.0 : -1.0;
    360           }
    361         case OpCodes.NOT: {
    362             return Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues) > 0.0 ? -1.0 : 1.0;
    363           }
    364         case OpCodes.GT: {
    365             double x = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    366             double y = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    367             if (x > y) return 1.0;
    368             else return -1.0;
    369           }
    370         case OpCodes.LT: {
    371             double x = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    372             double y = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    373             if (x < y) return 1.0;
    374             else return -1.0;
    375           }
    376         case OpCodes.TimeLag: {
    377             var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
    378             row += timeLagTreeNode.Lag;
    379             double result = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    380             row -= timeLagTreeNode.Lag;
    381             return result;
    382           }
    383         case OpCodes.Integral: {
    384             int savedPc = state.ProgramCounter;
    385             var timeLagTreeNode = (LaggedTreeNode)currentInstr.dynamicNode;
    386             double sum = 0.0;
    387             for (int i = 0; i < Math.Abs(timeLagTreeNode.Lag); i++) {
    388               row += Math.Sign(timeLagTreeNode.Lag);
    389               sum += Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    390               state.ProgramCounter = savedPc;
    391             }
    392             row -= timeLagTreeNode.Lag;
    393             sum += Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    394             return sum;
    395           }
    396 
    397         //mkommend: derivate calculation taken from:
    398         //http://www.holoborodko.com/pavel/numerical-methods/numerical-derivative/smooth-low-noise-differentiators/
    399         //one sided smooth differentiatior, N = 4
    400         // y' = 1/8h (f_i + 2f_i-1, -2 f_i-3 - f_i-4)
    401         case OpCodes.Derivative: {
    402             int savedPc = state.ProgramCounter;
    403             double f_0 = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues); row--;
    404             state.ProgramCounter = savedPc;
    405             double f_1 = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues); row -= 2;
    406             state.ProgramCounter = savedPc;
    407             double f_3 = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues); row--;
    408             state.ProgramCounter = savedPc;
    409             double f_4 = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    410             row += 4;
    411 
    412             return (f_0 + 2 * f_1 - 2 * f_3 - f_4) / 8; // h = 1
    413           }
    414         case OpCodes.Call: {
    415             // evaluate sub-trees
    416             double[] argValues = new double[currentInstr.nArguments];
    417             for (int i = 0; i < currentInstr.nArguments; i++) {
    418               argValues[i] = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    419             }
    420             // push on argument values on stack
    421             state.CreateStackFrame(argValues);
    422 
    423             // save the pc
    424             int savedPc = state.ProgramCounter;
    425             // set pc to start of function 
    426             state.ProgramCounter = (ushort)currentInstr.iArg0;
    427             // evaluate the function
    428             double v = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    429 
    430             // delete the stack frame
    431             state.RemoveStackFrame();
    432 
    433             // restore the pc => evaluation will continue at point after my subtrees 
    434             state.ProgramCounter = savedPc;
    435             return v;
    436           }
    437         case OpCodes.Arg: {
    438             return state.GetStackFrameValue((ushort)currentInstr.iArg0);
    439           }
    440         case OpCodes.Variable: {
    441             if (row < 0 || row >= dataset.Rows)
    442               return double.NaN;
    443             var variableTreeNode = (VariableTreeNode)currentInstr.dynamicNode;
    444             if (row <= lastObservedRow) return ((IList<double>)currentInstr.iArg0)[row] * variableTreeNode.Weight;
    445             else return cachedPrognosedValues[variableTreeNode.VariableName][row - lastObservedRow - 1] * variableTreeNode.Weight;
    446           }
    447         case OpCodes.LagVariable: {
    448             var laggedVariableTreeNode = (LaggedVariableTreeNode)currentInstr.dynamicNode;
    449             int actualRow = row + laggedVariableTreeNode.Lag;
    450             if (actualRow < 0 || actualRow >= dataset.Rows)
    451               return double.NaN;
    452             if (actualRow <= lastObservedRow) return ((IList<double>)currentInstr.iArg0)[actualRow] * laggedVariableTreeNode.Weight;
    453             else return cachedPrognosedValues[laggedVariableTreeNode.VariableName][actualRow - lastObservedRow - 1] * laggedVariableTreeNode.Weight;
    454           }
    455         case OpCodes.Constant: {
    456             var constTreeNode = currentInstr.dynamicNode as ConstantTreeNode;
    457             return constTreeNode.Value;
    458           }
    459 
    460         //mkommend: this symbol uses the logistic function f(x) = 1 / (1 + e^(-alpha * x) )
    461         //to determine the relative amounts of the true and false branch see http://en.wikipedia.org/wiki/Logistic_function
    462         case OpCodes.VariableCondition: {
    463             if (row < 0 || row >= dataset.Rows)
    464               return double.NaN;
    465             var variableConditionTreeNode = (VariableConditionTreeNode)currentInstr.dynamicNode;
    466             double variableValue;
    467             if (row <= lastObservedRow)
    468               variableValue = ((IList<double>)currentInstr.iArg0)[row];
    469             else
    470               variableValue = cachedPrognosedValues[variableConditionTreeNode.VariableName][row - lastObservedRow - 1];
    471 
    472             double x = variableValue - variableConditionTreeNode.Threshold;
    473             double p = 1 / (1 + Math.Exp(-variableConditionTreeNode.Slope * x));
    474 
    475             double trueBranch = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    476             double falseBranch = Evaluate(dataset, ref row, lastObservedRow, state, cachedPrognosedValues);
    477 
    478             return trueBranch * p + falseBranch * (1 - p);
    479           }
    480         default: throw new NotSupportedException();
    481       }
    482     }
    483 
    484     private byte MapSymbolToOpCode(ISymbolicExpressionTreeNode treeNode) {
    485       if (symbolToOpcode.ContainsKey(treeNode.Symbol.GetType()))
    486         return symbolToOpcode[treeNode.Symbol.GetType()];
    487       else
    488         throw new NotSupportedException("Symbol: " + treeNode.Symbol);
    489     }
    490 
    491     // skips a whole branch
    492     private void SkipInstructions(InterpreterState state) {
    493       int i = 1;
    494       while (i > 0) {
    495         i += state.NextInstruction().nArguments;
    496         i--;
    497       }
     133      return new InterpreterState(code, necessaryArgStackSize);
    498134    }
    499135  }
  • branches/HeuristicLab.TimeSeries/HeuristicLab.Problems.DataAnalysis.Symbolic.TimeSeriesPrognosis/3.4/SymbolicTimeSeriesPrognosisModel.cs

    r7183 r7989  
    2020#endregion
    2121
    22 using System;
    2322using System.Collections.Generic;
    2423using System.Drawing;
     
    6463    #endregion
    6564
    66     [Storable]
    67     private string[] targetVariables;
    68 
    69 
    7065    [StorableConstructor]
    7166    protected SymbolicTimeSeriesPrognosisModel(bool deserializing) : base(deserializing) { }
     
    7469      this.symbolicExpressionTree = cloner.Clone(original.symbolicExpressionTree);
    7570      this.interpreter = cloner.Clone(original.interpreter);
    76       this.targetVariables = new string[original.targetVariables.Length];
    77       Array.Copy(original.targetVariables, this.targetVariables, this.targetVariables.Length);
    7871      this.lowerEstimationLimit = original.lowerEstimationLimit;
    7972      this.upperEstimationLimit = original.upperEstimationLimit;
    8073    }
    81     public SymbolicTimeSeriesPrognosisModel(ISymbolicExpressionTree tree, ISymbolicTimeSeriesPrognosisExpressionTreeInterpreter interpreter, IEnumerable<string> targetVariables, double lowerLimit = double.MinValue, double upperLimit = double.MaxValue)
     74    public SymbolicTimeSeriesPrognosisModel(ISymbolicExpressionTree tree, ISymbolicTimeSeriesPrognosisExpressionTreeInterpreter interpreter, double lowerLimit = double.MinValue, double upperLimit = double.MaxValue)
    8275      : base() {
    8376      this.name = ItemName;
    8477      this.description = ItemDescription;
    8578      this.symbolicExpressionTree = tree;
    86       this.interpreter = interpreter; this.targetVariables = targetVariables.ToArray();
     79      this.interpreter = interpreter;
    8780      this.lowerEstimationLimit = lowerLimit;
    8881      this.upperEstimationLimit = upperLimit;
     
    9386    }
    9487
    95     public IEnumerable<IEnumerable<IEnumerable<double>>> GetPrognosedValues(Dataset dataset, IEnumerable<int> rows, int horizon) {
    96       var enumerator =
    97         Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, dataset, targetVariables, rows, horizon).
    98           GetEnumerator();
    99       foreach (var r in rows) {
    100         var l = new List<double[]>();
    101         for (int h = 0; h < horizon; h++) {
    102           double[] components = new double[targetVariables.Length];
    103           for (int c = 0; c < components.Length; c++) {
    104             enumerator.MoveNext();
    105             components[c] = Math.Min(upperEstimationLimit, Math.Max(lowerEstimationLimit, enumerator.Current));
    106           }
    107           l.Add(components);
    108         }
    109         yield return l;
    110       }
     88    public IEnumerable<IEnumerable<double>> GetPrognosedValues(Dataset dataset, IEnumerable<int> rows, int horizon) {
     89      var estimatedValues = Interpreter.GetSymbolicExpressionTreeValues(SymbolicExpressionTree, dataset, rows, horizon);
     90      return estimatedValues.Select(predictionPerRow => predictionPerRow.LimitToRange(lowerEstimationLimit, upperEstimationLimit));
    11191    }
    11292
     
    120100    public static void Scale(SymbolicTimeSeriesPrognosisModel model, ITimeSeriesPrognosisProblemData problemData) {
    121101      var dataset = problemData.Dataset;
    122       var targetVariables = problemData.TargetVariables.ToArray();
     102      var targetVariable = problemData.TargetVariable;
    123103      var rows = problemData.TrainingIndizes;
    124       var estimatedValuesEnumerator = model.Interpreter.GetSymbolicExpressionTreeValues(model.SymbolicExpressionTree, dataset,
    125                                                                               targetVariables,
    126                                                                               rows).GetEnumerator();
    127       var scalingParameterCalculators =
    128         problemData.TargetVariables.Select(v => new OnlineLinearScalingParameterCalculator()).ToArray();
    129       var targetValuesEnumerator = problemData.Dataset.GetVectorEnumerable(targetVariables, rows).GetEnumerator();
     104      var estimatedValuesEnumerator = model.Interpreter.GetSymbolicExpressionTreeValues(model.SymbolicExpressionTree, dataset, rows);
     105      var targetValuesEnumerator = problemData.Dataset.GetDoubleValues(targetVariable, rows);
    130106
    131       var more = targetValuesEnumerator.MoveNext() & estimatedValuesEnumerator.MoveNext();
    132       // foreach row
    133       while (more) {
    134         // foreach component
    135         for (int i = 0; i < targetVariables.Length; i++) {
    136           scalingParameterCalculators[i].Add(estimatedValuesEnumerator.Current, targetValuesEnumerator.Current);
    137           more = estimatedValuesEnumerator.MoveNext() & targetValuesEnumerator.MoveNext();
     107      double alpha, beta;
     108      OnlineCalculatorError error;
     109      OnlineLinearScalingParameterCalculator.Calculate(estimatedValuesEnumerator, targetValuesEnumerator, out alpha, out beta, out error);
     110      if (error != OnlineCalculatorError.None) return;
     111
     112      ConstantTreeNode alphaTreeNode = null;
     113      ConstantTreeNode betaTreeNode = null;
     114      // check if model has been scaled previously by analyzing the structure of the tree
     115      var startNode = model.SymbolicExpressionTree.Root.GetSubtree(0);
     116      if (startNode.GetSubtree(0).Symbol is Addition) {
     117        var addNode = startNode.GetSubtree(0);
     118        if (addNode.SubtreeCount == 2 && addNode.GetSubtree(0).Symbol is Multiplication && addNode.GetSubtree(1).Symbol is Constant) {
     119          alphaTreeNode = addNode.GetSubtree(1) as ConstantTreeNode;
     120          var mulNode = addNode.GetSubtree(0);
     121          if (mulNode.SubtreeCount == 2 && mulNode.GetSubtree(1).Symbol is Constant) {
     122            betaTreeNode = mulNode.GetSubtree(1) as ConstantTreeNode;
     123          }
    138124        }
    139125      }
    140 
    141       for (int i = 0; i < targetVariables.Count(); i++) {
    142         if (scalingParameterCalculators[i].ErrorState != OnlineCalculatorError.None) continue;
    143         double alpha = scalingParameterCalculators[i].Alpha;
    144         double beta = scalingParameterCalculators[i].Beta;
    145         ConstantTreeNode alphaTreeNode = null;
    146         ConstantTreeNode betaTreeNode = null;
    147         // check if model has been scaled previously by analyzing the structure of the tree
    148         var startNode = model.SymbolicExpressionTree.Root.GetSubtree(0);
    149         if (startNode.GetSubtree(i).Symbol is Addition) {
    150           var addNode = startNode.GetSubtree(i);
    151           if (addNode.SubtreeCount == 2 && addNode.GetSubtree(0).Symbol is Multiplication &&
    152               addNode.GetSubtree(1).Symbol is Constant) {
    153             alphaTreeNode = addNode.GetSubtree(1) as ConstantTreeNode;
    154             var mulNode = addNode.GetSubtree(0);
    155             if (mulNode.SubtreeCount == 2 && mulNode.GetSubtree(1).Symbol is Constant) {
    156               betaTreeNode = mulNode.GetSubtree(1) as ConstantTreeNode;
    157             }
    158           }
    159         }
    160         // if tree structure matches the structure necessary for linear scaling then reuse the existing tree nodes
    161         if (alphaTreeNode != null && betaTreeNode != null) {
    162           betaTreeNode.Value *= beta;
    163           alphaTreeNode.Value *= beta;
    164           alphaTreeNode.Value += alpha;
    165         } else {
    166           var mainBranch = startNode.GetSubtree(i);
    167           startNode.RemoveSubtree(i);
    168           var scaledMainBranch = MakeSum(MakeProduct(mainBranch, beta), alpha);
    169           startNode.InsertSubtree(i, scaledMainBranch);
    170         }
    171       } // foreach
     126      // if tree structure matches the structure necessary for linear scaling then reuse the existing tree nodes
     127      if (alphaTreeNode != null && betaTreeNode != null) {
     128        betaTreeNode.Value *= beta;
     129        alphaTreeNode.Value *= beta;
     130        alphaTreeNode.Value += alpha;
     131      } else {
     132        var mainBranch = startNode.GetSubtree(0);
     133        startNode.RemoveSubtree(0);
     134        var scaledMainBranch = MakeSum(MakeProduct(mainBranch, beta), alpha);
     135        startNode.AddSubtree(scaledMainBranch);
     136      }
    172137    }
    173138
Note: See TracChangeset for help on using the changeset viewer.