Free cookie consent management tool by TermsFeed Policy Generator

Changeset 14756 for branches


Ignore:
Timestamp:
03/16/17 09:25:47 (8 years ago)
Author:
gkronber
Message:

#2650 improved code for handling variables in the constant optimizer by using a dictionary

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/symbreg-factors-2650/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/SymbolicRegressionConstantOptimizationEvaluator.cs

    r14402 r14756  
    2222using System;
    2323using System.Collections.Generic;
    24 using System.Diagnostics.Contracts;
    2524using System.Linq;
    2625using AutoDiff;
     
    110109    [StorableHook(HookType.AfterDeserialization)]
    111110    private void AfterDeserialization() {
    112       if (!Parameters.ContainsKey(UpdateConstantsInTreeParameterName))
     111      if(!Parameters.ContainsKey(UpdateConstantsInTreeParameterName))
    113112        Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)));
    114       if (!Parameters.ContainsKey(UpdateVariableWeightsParameterName))
     113      if(!Parameters.ContainsKey(UpdateVariableWeightsParameterName))
    115114        Parameters.Add(new FixedValueParameter<BoolValue>(UpdateVariableWeightsParameterName, "Determines if the variable weights in the tree should be  optimized.", new BoolValue(true)));
    116115    }
     
    119118      var solution = SymbolicExpressionTreeParameter.ActualValue;
    120119      double quality;
    121       if (RandomParameter.ActualValue.NextDouble() < ConstantOptimizationProbability.Value) {
     120      if(RandomParameter.ActualValue.NextDouble() < ConstantOptimizationProbability.Value) {
    122121        IEnumerable<int> constantOptimizationRows = GenerateRowsToEvaluate(ConstantOptimizationRowsPercentage.Value);
    123122        quality = OptimizeConstants(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, ProblemDataParameter.ActualValue,
    124123           constantOptimizationRows, ApplyLinearScalingParameter.ActualValue.Value, ConstantOptimizationIterations.Value, updateVariableWeights: UpdateVariableWeights, lowerEstimationLimit: EstimationLimitsParameter.ActualValue.Lower, upperEstimationLimit: EstimationLimitsParameter.ActualValue.Upper, updateConstantsInTree: UpdateConstantsInTree);
    125124
    126         if (ConstantOptimizationRowsPercentage.Value != RelativeNumberOfEvaluatedSamplesParameter.ActualValue.Value) {
     125        if(ConstantOptimizationRowsPercentage.Value != RelativeNumberOfEvaluatedSamplesParameter.ActualValue.Value) {
    127126          var evaluationRows = GenerateRowsToEvaluate();
    128127          quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value);
     
    177176
    178177
    179     public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows, bool applyLinearScaling, int maxIterations, bool updateVariableWeights = true, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue, bool updateConstantsInTree = true) {
    180 
    181       List<AutoDiff.Variable> variables = new List<AutoDiff.Variable>();
    182       List<AutoDiff.Variable> parameters = new List<AutoDiff.Variable>();
    183       List<string> variableNames = new List<string>();
    184       List<string> categoricalVariableValues = new List<string>();
    185       List<int> lags = new List<int>();
     178
     179    public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
     180      ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows, bool applyLinearScaling,
     181      int maxIterations, bool updateVariableWeights = true,
     182      double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
     183      bool updateConstantsInTree = true) {
     184
     185      // numeric constants in the tree become variables for constant opt
     186      // variables in the tree become parameters (fixed values) for constant opt
     187      // for each parameter (variable in the original tree) we store the
     188      // variable name, variable value (for factor vars) and lag as a DataForVariable object.
     189      // A dictionary is used to find parameters
     190      var variables = new List<AutoDiff.Variable>();
     191      var parameters = new Dictionary<DataForVariable, AutoDiff.Variable>();
     192      //List<string> variableNames = new List<string>();
     193      //List<string> categoricalVariableValues = new List<string>();
     194      //List<int> lags = new List<int>();
    186195
    187196      AutoDiff.Term func;
    188       if (!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out func))
     197      if(!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, updateVariableWeights, out func))
    189198        throw new NotSupportedException("Could not optimize constants of symbolic expression tree due to not supported symbols used in the tree.");
    190       if (variableNames.Count == 0) return 0.0; // gkronber: constant expressions always have a R² of 0.0
    191 
    192       AutoDiff.IParametricCompiledTerm compiledFunc = func.Compile(variables.ToArray(), parameters.ToArray());
     199      if(parameters.Count == 0) return 0.0; // gkronber: constant expressions always have a R² of 0.0
     200
     201      var parameterEntries = parameters.ToArray(); // order of entries must be the same for x
     202      AutoDiff.IParametricCompiledTerm compiledFunc = func.Compile(variables.ToArray(), parameterEntries.Select(kvp => kvp.Value).ToArray());
    193203
    194204      List<SymbolicExpressionTreeTerminalNode> terminalNodes = null; // gkronber only used for extraction of initial constants
    195       if (updateVariableWeights)
     205      if(updateVariableWeights)
    196206        terminalNodes = tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>().ToList();
    197207      else
     
    207217        c[1] = 1.0;
    208218        int i = 2;
    209         foreach (var node in terminalNodes) {
     219        foreach(var node in terminalNodes) {
    210220          ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
    211221          VariableTreeNode variableTreeNode = node as VariableTreeNode;
    212222          BinaryFactorVariableTreeNode binFactorVarTreeNode = node as BinaryFactorVariableTreeNode;
    213223          FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode;
    214           if (constantTreeNode != null)
     224          if(constantTreeNode != null)
    215225            c[i++] = constantTreeNode.Value;
    216           else if (updateVariableWeights && variableTreeNode != null)
     226          else if(updateVariableWeights && variableTreeNode != null)
    217227            c[i++] = variableTreeNode.Weight;
    218           else if (updateVariableWeights && binFactorVarTreeNode != null)
     228          else if(updateVariableWeights && binFactorVarTreeNode != null)
    219229            c[i++] = binFactorVarTreeNode.Weight;
    220           else if (factorVarTreeNode != null) {
     230          else if(factorVarTreeNode != null) {
    221231            // gkronber: a factorVariableTreeNode holds a category-specific constant therefore we can consider factors to be the same as constants
    222             foreach (var w in factorVarTreeNode.Weights) c[i++] = w;
     232            foreach(var w in factorVarTreeNode.Weights) c[i++] = w;
    223233          }
    224234        }
     
    229239      alglib.lsfitstate state;
    230240      alglib.lsfitreport rep;
    231       int info;
     241      int retVal;
    232242
    233243      IDataset ds = problemData.Dataset;
    234       double[,] x = new double[rows.Count(), variableNames.Count];
     244      double[,] x = new double[rows.Count(), parameters.Count];
    235245      int row = 0;
    236       foreach (var r in rows) {
    237         for (int col = 0; col < variableNames.Count; col++) {
    238           int lag = lags[col];
    239           if (ds.VariableHasType<double>(variableNames[col])) {
    240             x[row, col] = ds.GetDoubleValue(variableNames[col], r + lag);
    241           } else if (ds.VariableHasType<string>(variableNames[col])) {
    242             x[row, col] = ds.GetStringValue(variableNames[col], r) == categoricalVariableValues[col] ? 1 : 0;
     246      foreach(var r in rows) {
     247        int col = 0;
     248        foreach(var kvp in parameterEntries) {
     249          var info = kvp.Key;
     250          int lag = info.lag;
     251          if(ds.VariableHasType<double>(info.variableName)) {
     252            x[row, col] = ds.GetDoubleValue(info.variableName, r + lag);
     253          } else if(ds.VariableHasType<string>(info.variableName)) {
     254            x[row, col] = ds.GetStringValue(info.variableName, r) == info.variableValue ? 1 : 0;
    243255          } else throw new InvalidProgramException("found a variable of unknown type");
     256          col++;
    244257        }
    245258        row++;
     
    258271        //alglib.lsfitsetgradientcheck(state, 0.001);
    259272        alglib.lsfitfit(state, function_cx_1_func, function_cx_1_grad, null, null);
    260         alglib.lsfitresults(state, out info, out c, out rep);
    261       } catch (ArithmeticException) {
     273        alglib.lsfitresults(state, out retVal, out c, out rep);
     274      } catch(ArithmeticException) {
    262275        return originalQuality;
    263       } catch (alglib.alglibexception) {
     276      } catch(alglib.alglibexception) {
    264277        return originalQuality;
    265278      }
    266279
    267       //info == -7  => constant optimization failed due to wrong gradient
    268       if (info != -7) UpdateConstants(tree, c.Skip(2).ToArray(), updateVariableWeights);
     280      //retVal == -7  => constant optimization failed due to wrong gradient
     281      if(retVal != -7) UpdateConstants(tree, c.Skip(2).ToArray(), updateVariableWeights);
    269282      var quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
    270283
    271       if (!updateConstantsInTree) UpdateConstants(tree, originalConstants.Skip(2).ToArray(), updateVariableWeights);
    272       if (originalQuality - quality > 0.001 || double.IsNaN(quality)) {
     284      if(!updateConstantsInTree) UpdateConstants(tree, originalConstants.Skip(2).ToArray(), updateVariableWeights);
     285      if(originalQuality - quality > 0.001 || double.IsNaN(quality)) {
    273286        UpdateConstants(tree, originalConstants.Skip(2).ToArray(), updateVariableWeights);
    274287        return originalQuality;
     
    279292    private static void UpdateConstants(ISymbolicExpressionTree tree, double[] constants, bool updateVariableWeights) {
    280293      int i = 0;
    281       foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
     294      foreach(var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
    282295        ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
    283296        VariableTreeNode variableTreeNode = node as VariableTreeNode;
    284297        BinaryFactorVariableTreeNode binFactorVarTreeNode = node as BinaryFactorVariableTreeNode;
    285298        FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode;
    286         if (constantTreeNode != null)
     299        if(constantTreeNode != null)
    287300          constantTreeNode.Value = constants[i++];
    288         else if (updateVariableWeights && variableTreeNode != null)
     301        else if(updateVariableWeights && variableTreeNode != null)
    289302          variableTreeNode.Weight = constants[i++];
    290         else if (updateVariableWeights && binFactorVarTreeNode != null)
     303        else if(updateVariableWeights && binFactorVarTreeNode != null)
    291304          binFactorVarTreeNode.Weight = constants[i++];
    292         else if (factorVarTreeNode != null) {
    293           for (int j = 0; j < factorVarTreeNode.Weights.Length; j++)
     305        else if(factorVarTreeNode != null) {
     306          for(int j = 0; j < factorVarTreeNode.Weights.Length; j++)
    294307            factorVarTreeNode.Weights[j] = constants[i++];
    295308        }
     
    311324    }
    312325
    313     private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, List<AutoDiff.Variable> variables, List<AutoDiff.Variable> parameters,
    314       List<string> variableNames, List<int> lags, List<string> categoricalVariableValues, bool updateVariableWeights, out AutoDiff.Term term) {
    315       if (node.Symbol is Constant) {
     326    private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node,
     327      List<AutoDiff.Variable> variables, Dictionary<DataForVariable, AutoDiff.Variable> parameters,
     328      bool updateVariableWeights, out AutoDiff.Term term) {
     329      if(node.Symbol is Constant) {
    316330        var var = new AutoDiff.Variable();
    317331        variables.Add(var);
     
    319333        return true;
    320334      }
    321       if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) {
     335      if(node.Symbol is Variable || node.Symbol is BinaryFactorVariable) {
    322336        var varNode = node as VariableTreeNodeBase;
    323337        var factorVarNode = node as BinaryFactorVariableTreeNode;
    324338        // factor variable values are only 0 or 1 and set in x accordingly
    325339        var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;
    326         var par = FindOrCreateParameter(varNode.VariableName, varValue, parameters, variableNames, categoricalVariableValues);
    327         lags.Add(0);
    328 
    329         if (updateVariableWeights) {
     340        var par = FindOrCreateParameter(parameters, varNode.VariableName, varValue);
     341
     342        if(updateVariableWeights) {
    330343          var w = new AutoDiff.Variable();
    331344          variables.Add(w);
     
    336349        return true;
    337350      }
    338       if (node.Symbol is FactorVariable) {
     351      if(node.Symbol is FactorVariable) {
    339352        var factorVarNode = node as FactorVariableTreeNode;
    340353        var products = new List<Term>();
    341         foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
    342           var par = FindOrCreateParameter(factorVarNode.VariableName, variableValue, parameters, variableNames, categoricalVariableValues);
    343           lags.Add(0);
     354        foreach(var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
     355          var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);
    344356
    345357          var wVar = new AutoDiff.Variable();
     
    351363        return true;
    352364      }
    353       if (node.Symbol is LaggedVariable) {
     365      if(node.Symbol is LaggedVariable) {
    354366        var varNode = node as LaggedVariableTreeNode;
    355         var par = new AutoDiff.Variable();
    356         parameters.Add(par);
    357         variableNames.Add(varNode.VariableName);
    358         lags.Add(varNode.Lag);
    359 
    360         if (updateVariableWeights) {
     367        var par = FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag);
     368
     369        if(updateVariableWeights) {
    361370          var w = new AutoDiff.Variable();
    362371          variables.Add(w);
     
    367376        return true;
    368377      }
    369       if (node.Symbol is Addition) {
     378      if(node.Symbol is Addition) {
    370379        List<AutoDiff.Term> terms = new List<Term>();
    371         foreach (var subTree in node.Subtrees) {
     380        foreach(var subTree in node.Subtrees) {
    372381          AutoDiff.Term t;
    373           if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out t)) {
     382          if(!TryTransformToAutoDiff(subTree, variables, parameters, updateVariableWeights, out t)) {
    374383            term = null;
    375384            return false;
     
    380389        return true;
    381390      }
    382       if (node.Symbol is Subtraction) {
     391      if(node.Symbol is Subtraction) {
    383392        List<AutoDiff.Term> terms = new List<Term>();
    384         for (int i = 0; i < node.SubtreeCount; i++) {
     393        for(int i = 0; i < node.SubtreeCount; i++) {
    385394          AutoDiff.Term t;
    386           if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out t)) {
     395          if(!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, updateVariableWeights, out t)) {
    387396            term = null;
    388397            return false;
    389398          }
    390           if (i > 0) t = -t;
     399          if(i > 0) t = -t;
    391400          terms.Add(t);
    392401        }
    393         if (terms.Count == 1) term = -terms[0];
     402        if(terms.Count == 1) term = -terms[0];
    394403        else term = AutoDiff.TermBuilder.Sum(terms);
    395404        return true;
    396405      }
    397       if (node.Symbol is Multiplication) {
     406      if(node.Symbol is Multiplication) {
    398407        List<AutoDiff.Term> terms = new List<Term>();
    399         foreach (var subTree in node.Subtrees) {
     408        foreach(var subTree in node.Subtrees) {
    400409          AutoDiff.Term t;
    401           if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out t)) {
     410          if(!TryTransformToAutoDiff(subTree, variables, parameters, updateVariableWeights, out t)) {
    402411            term = null;
    403412            return false;
     
    405414          terms.Add(t);
    406415        }
    407         if (terms.Count == 1) term = terms[0];
     416        if(terms.Count == 1) term = terms[0];
    408417        else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, b));
    409418        return true;
    410419
    411420      }
    412       if (node.Symbol is Division) {
     421      if(node.Symbol is Division) {
    413422        List<AutoDiff.Term> terms = new List<Term>();
    414         foreach (var subTree in node.Subtrees) {
     423        foreach(var subTree in node.Subtrees) {
    415424          AutoDiff.Term t;
    416           if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out t)) {
     425          if(!TryTransformToAutoDiff(subTree, variables, parameters, updateVariableWeights, out t)) {
    417426            term = null;
    418427            return false;
     
    420429          terms.Add(t);
    421430        }
    422         if (terms.Count == 1) term = 1.0 / terms[0];
     431        if(terms.Count == 1) term = 1.0 / terms[0];
    423432        else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b));
    424433        return true;
    425434      }
    426       if (node.Symbol is Logarithm) {
    427         AutoDiff.Term t;
    428         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out t)) {
     435      if(node.Symbol is Logarithm) {
     436        AutoDiff.Term t;
     437        if(!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    429438          term = null;
    430439          return false;
     
    434443        }
    435444      }
    436       if (node.Symbol is Exponential) {
    437         AutoDiff.Term t;
    438         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out t)) {
     445      if(node.Symbol is Exponential) {
     446        AutoDiff.Term t;
     447        if(!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    439448          term = null;
    440449          return false;
     
    444453        }
    445454      }
    446       if (node.Symbol is Square) {
    447         AutoDiff.Term t;
    448         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out t)) {
     455      if(node.Symbol is Square) {
     456        AutoDiff.Term t;
     457        if(!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    449458          term = null;
    450459          return false;
     
    454463        }
    455464      }
    456       if (node.Symbol is SquareRoot) {
    457         AutoDiff.Term t;
    458         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out t)) {
     465      if(node.Symbol is SquareRoot) {
     466        AutoDiff.Term t;
     467        if(!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    459468          term = null;
    460469          return false;
     
    464473        }
    465474      }
    466       if (node.Symbol is Sine) {
    467         AutoDiff.Term t;
    468         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out t)) {
     475      if(node.Symbol is Sine) {
     476        AutoDiff.Term t;
     477        if(!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    469478          term = null;
    470479          return false;
     
    474483        }
    475484      }
    476       if (node.Symbol is Cosine) {
    477         AutoDiff.Term t;
    478         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out t)) {
     485      if(node.Symbol is Cosine) {
     486        AutoDiff.Term t;
     487        if(!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    479488          term = null;
    480489          return false;
     
    484493        }
    485494      }
    486       if (node.Symbol is Tangent) {
    487         AutoDiff.Term t;
    488         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out t)) {
     495      if(node.Symbol is Tangent) {
     496        AutoDiff.Term t;
     497        if(!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    489498          term = null;
    490499          return false;
     
    494503        }
    495504      }
    496       if (node.Symbol is Erf) {
    497         AutoDiff.Term t;
    498         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out t)) {
     505      if(node.Symbol is Erf) {
     506        AutoDiff.Term t;
     507        if(!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    499508          term = null;
    500509          return false;
     
    504513        }
    505514      }
    506       if (node.Symbol is Norm) {
    507         AutoDiff.Term t;
    508         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out t)) {
     515      if(node.Symbol is Norm) {
     516        AutoDiff.Term t;
     517        if(!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
    509518          term = null;
    510519          return false;
     
    514523        }
    515524      }
    516       if (node.Symbol is StartSymbol) {
     525      if(node.Symbol is StartSymbol) {
    517526        var alpha = new AutoDiff.Variable();
    518527        var beta = new AutoDiff.Variable();
     
    520529        variables.Add(alpha);
    521530        AutoDiff.Term branchTerm;
    522         if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out branchTerm)) {
     531        if(TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out branchTerm)) {
    523532          term = branchTerm * alpha + beta;
    524533          return true;
     
    534543    // for each factor variable value we need a parameter which represents a binary indicator for that variable & value combination
    535544    // each binary indicator is only necessary once. So we only create a parameter if this combination is not yet available
    536     private static Term FindOrCreateParameter(string varName, string varValue,
    537       List<AutoDiff.Variable> parameters, List<string> variableNames, List<string> variableValues) {
    538       Contract.Assert(variableNames.Count == variableValues.Count);
    539       int idx = -1;
    540       for (int i = 0; i < variableNames.Count; i++) {
    541         if (variableNames[i] == varName && variableValues[i] == varValue) {
    542           idx = i;
    543           break;
    544         }
    545       }
     545    private static Term FindOrCreateParameter(Dictionary<DataForVariable, AutoDiff.Variable> parameters,
     546      string varName, string varValue = "", int lag = 0) {
     547      var data = new DataForVariable(varName, varValue, lag);
    546548
    547549      AutoDiff.Variable par = null;
    548       if (idx == -1) {
     550      if(!parameters.TryGetValue(data, out par)) {
    549551        // not found -> create new parameter and entries in names and values lists
    550552        par = new AutoDiff.Variable();
    551         parameters.Add(par);
    552         variableNames.Add(varName);
    553         variableValues.Add(varValue);
    554       } else {
    555         par = parameters[idx];
     553        parameters.Add(data, par);
    556554      }
    557555      return par;
     
    585583      return !containsUnknownSymbol;
    586584    }
     585
     586
     587    #region helper class
     588    private class DataForVariable {
     589      public readonly string variableName;
     590      public readonly string variableValue; // for factor vars
     591      public readonly int lag;
     592
     593      public DataForVariable(string varName, string varValue, int lag) {
     594        this.variableName = varName;
     595        this.variableValue = varValue;
     596        this.lag = lag;
     597      }
     598
     599      public override bool Equals(object obj) {
     600        var other = obj as DataForVariable;
     601        if(other == null) return false;
     602        return other.variableName.Equals(this.variableName) &&
     603               other.variableValue.Equals(this.variableValue) &&
     604               other.lag == this.lag;
     605      }
     606
     607      public override int GetHashCode() {
     608        return variableName.GetHashCode() ^ variableValue.GetHashCode() ^ lag;
     609      }
     610    }
     611    #endregion
    587612  }
    588613}
Note: See TracChangeset for help on using the changeset viewer.