Free cookie consent management tool by TermsFeed Policy Generator

source: branches/symbreg-factors-2650/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/SymbolicRegressionConstantOptimizationEvaluator.cs @ 14756

Last change on this file since 14756 was 14756, checked in by gkronber, 7 years ago

#2650 improved code for handling variables in the constant optimizer by using a dictionary

File size: 28.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using AutoDiff;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32
33namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
34  [Item("Constant Optimization Evaluator", "Calculates Pearson R² of a symbolic regression solution and optimizes the constant used.")]
35  [StorableClass]
36  public class SymbolicRegressionConstantOptimizationEvaluator : SymbolicRegressionSingleObjectiveEvaluator {
37    private const string ConstantOptimizationIterationsParameterName = "ConstantOptimizationIterations";
38    private const string ConstantOptimizationImprovementParameterName = "ConstantOptimizationImprovement";
39    private const string ConstantOptimizationProbabilityParameterName = "ConstantOptimizationProbability";
40    private const string ConstantOptimizationRowsPercentageParameterName = "ConstantOptimizationRowsPercentage";
41    private const string UpdateConstantsInTreeParameterName = "UpdateConstantsInSymbolicExpressionTree";
42    private const string UpdateVariableWeightsParameterName = "Update Variable Weights";
43
44    public IFixedValueParameter<IntValue> ConstantOptimizationIterationsParameter {
45      get { return (IFixedValueParameter<IntValue>)Parameters[ConstantOptimizationIterationsParameterName]; }
46    }
47    public IFixedValueParameter<DoubleValue> ConstantOptimizationImprovementParameter {
48      get { return (IFixedValueParameter<DoubleValue>)Parameters[ConstantOptimizationImprovementParameterName]; }
49    }
50    public IFixedValueParameter<PercentValue> ConstantOptimizationProbabilityParameter {
51      get { return (IFixedValueParameter<PercentValue>)Parameters[ConstantOptimizationProbabilityParameterName]; }
52    }
53    public IFixedValueParameter<PercentValue> ConstantOptimizationRowsPercentageParameter {
54      get { return (IFixedValueParameter<PercentValue>)Parameters[ConstantOptimizationRowsPercentageParameterName]; }
55    }
56    public IFixedValueParameter<BoolValue> UpdateConstantsInTreeParameter {
57      get { return (IFixedValueParameter<BoolValue>)Parameters[UpdateConstantsInTreeParameterName]; }
58    }
59    public IFixedValueParameter<BoolValue> UpdateVariableWeightsParameter {
60      get { return (IFixedValueParameter<BoolValue>)Parameters[UpdateVariableWeightsParameterName]; }
61    }
62
63
64    public IntValue ConstantOptimizationIterations {
65      get { return ConstantOptimizationIterationsParameter.Value; }
66    }
67    public DoubleValue ConstantOptimizationImprovement {
68      get { return ConstantOptimizationImprovementParameter.Value; }
69    }
70    public PercentValue ConstantOptimizationProbability {
71      get { return ConstantOptimizationProbabilityParameter.Value; }
72    }
73    public PercentValue ConstantOptimizationRowsPercentage {
74      get { return ConstantOptimizationRowsPercentageParameter.Value; }
75    }
76    public bool UpdateConstantsInTree {
77      get { return UpdateConstantsInTreeParameter.Value.Value; }
78      set { UpdateConstantsInTreeParameter.Value.Value = value; }
79    }
80
81    public bool UpdateVariableWeights {
82      get { return UpdateVariableWeightsParameter.Value.Value; }
83      set { UpdateVariableWeightsParameter.Value.Value = value; }
84    }
85
86    public override bool Maximization {
87      get { return true; }
88    }
89
90    [StorableConstructor]
91    protected SymbolicRegressionConstantOptimizationEvaluator(bool deserializing) : base(deserializing) { }
92    protected SymbolicRegressionConstantOptimizationEvaluator(SymbolicRegressionConstantOptimizationEvaluator original, Cloner cloner)
93      : base(original, cloner) {
94    }
95    public SymbolicRegressionConstantOptimizationEvaluator()
96      : base() {
97      Parameters.Add(new FixedValueParameter<IntValue>(ConstantOptimizationIterationsParameterName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree (0 indicates other or default stopping criterion).", new IntValue(10), true));
98      Parameters.Add(new FixedValueParameter<DoubleValue>(ConstantOptimizationImprovementParameterName, "Determines the relative improvement which must be achieved in the constant optimization to continue with it (0 indicates other or default stopping criterion).", new DoubleValue(0), true) { Hidden = true });
99      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationProbabilityParameterName, "Determines the probability that the constants are optimized", new PercentValue(1), true));
100      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationRowsPercentageParameterName, "Determines the percentage of the rows which should be used for constant optimization", new PercentValue(1), true));
101      Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)) { Hidden = true });
102      Parameters.Add(new FixedValueParameter<BoolValue>(UpdateVariableWeightsParameterName, "Determines if the variable weights in the tree should be  optimized.", new BoolValue(true)) { Hidden = true });
103    }
104
105    public override IDeepCloneable Clone(Cloner cloner) {
106      return new SymbolicRegressionConstantOptimizationEvaluator(this, cloner);
107    }
108
109    [StorableHook(HookType.AfterDeserialization)]
110    private void AfterDeserialization() {
111      if(!Parameters.ContainsKey(UpdateConstantsInTreeParameterName))
112        Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)));
113      if(!Parameters.ContainsKey(UpdateVariableWeightsParameterName))
114        Parameters.Add(new FixedValueParameter<BoolValue>(UpdateVariableWeightsParameterName, "Determines if the variable weights in the tree should be  optimized.", new BoolValue(true)));
115    }
116
117    public override IOperation InstrumentedApply() {
118      var solution = SymbolicExpressionTreeParameter.ActualValue;
119      double quality;
120      if(RandomParameter.ActualValue.NextDouble() < ConstantOptimizationProbability.Value) {
121        IEnumerable<int> constantOptimizationRows = GenerateRowsToEvaluate(ConstantOptimizationRowsPercentage.Value);
122        quality = OptimizeConstants(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, ProblemDataParameter.ActualValue,
123           constantOptimizationRows, ApplyLinearScalingParameter.ActualValue.Value, ConstantOptimizationIterations.Value, updateVariableWeights: UpdateVariableWeights, lowerEstimationLimit: EstimationLimitsParameter.ActualValue.Lower, upperEstimationLimit: EstimationLimitsParameter.ActualValue.Upper, updateConstantsInTree: UpdateConstantsInTree);
124
125        if(ConstantOptimizationRowsPercentage.Value != RelativeNumberOfEvaluatedSamplesParameter.ActualValue.Value) {
126          var evaluationRows = GenerateRowsToEvaluate();
127          quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value);
128        }
129      } else {
130        var evaluationRows = GenerateRowsToEvaluate();
131        quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value);
132      }
133      QualityParameter.ActualValue = new DoubleValue(quality);
134
135      return base.InstrumentedApply();
136    }
137
138    public override double Evaluate(IExecutionContext context, ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows) {
139      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = context;
140      EstimationLimitsParameter.ExecutionContext = context;
141      ApplyLinearScalingParameter.ExecutionContext = context;
142
143      // Pearson R² evaluator is used on purpose instead of the const-opt evaluator,
144      // because Evaluate() is used to get the quality of evolved models on
145      // different partitions of the dataset (e.g., best validation model)
146      double r2 = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, tree, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, problemData, rows, ApplyLinearScalingParameter.ActualValue.Value);
147
148      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = null;
149      EstimationLimitsParameter.ExecutionContext = null;
150      ApplyLinearScalingParameter.ExecutionContext = null;
151
152      return r2;
153    }
154
155    #region derivations of functions
156    // create function factory for arctangent
157    private readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
158      eval: Math.Atan,
159      diff: x => 1 / (1 + x * x));
160    private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
161      eval: Math.Sin,
162      diff: Math.Cos);
163    private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
164       eval: Math.Cos,
165       diff: x => -Math.Sin(x));
166    private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
167      eval: Math.Tan,
168      diff: x => 1 + Math.Tan(x) * Math.Tan(x));
169    private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
170      eval: alglib.errorfunction,
171      diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
172    private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
173      eval: alglib.normaldistribution,
174      diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
175    #endregion
176
177
178
179    public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
180      ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows, bool applyLinearScaling,
181      int maxIterations, bool updateVariableWeights = true,
182      double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
183      bool updateConstantsInTree = true) {
184
185      // numeric constants in the tree become variables for constant opt
186      // variables in the tree become parameters (fixed values) for constant opt
187      // for each parameter (variable in the original tree) we store the
188      // variable name, variable value (for factor vars) and lag as a DataForVariable object.
189      // A dictionary is used to find parameters
190      var variables = new List<AutoDiff.Variable>();
191      var parameters = new Dictionary<DataForVariable, AutoDiff.Variable>();
192      //List<string> variableNames = new List<string>();
193      //List<string> categoricalVariableValues = new List<string>();
194      //List<int> lags = new List<int>();
195
196      AutoDiff.Term func;
197      if(!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, updateVariableWeights, out func))
198        throw new NotSupportedException("Could not optimize constants of symbolic expression tree due to not supported symbols used in the tree.");
199      if(parameters.Count == 0) return 0.0; // gkronber: constant expressions always have a R² of 0.0
200
201      var parameterEntries = parameters.ToArray(); // order of entries must be the same for x
202      AutoDiff.IParametricCompiledTerm compiledFunc = func.Compile(variables.ToArray(), parameterEntries.Select(kvp => kvp.Value).ToArray());
203
204      List<SymbolicExpressionTreeTerminalNode> terminalNodes = null; // gkronber only used for extraction of initial constants
205      if(updateVariableWeights)
206        terminalNodes = tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>().ToList();
207      else
208        terminalNodes = new List<SymbolicExpressionTreeTerminalNode>
209          (tree.Root.IterateNodesPrefix()
210          .OfType<SymbolicExpressionTreeTerminalNode>()
211          .Where(node => node is ConstantTreeNode || node is FactorVariableTreeNode));
212
213      //extract inital constants
214      double[] c = new double[variables.Count];
215      {
216        c[0] = 0.0;
217        c[1] = 1.0;
218        int i = 2;
219        foreach(var node in terminalNodes) {
220          ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
221          VariableTreeNode variableTreeNode = node as VariableTreeNode;
222          BinaryFactorVariableTreeNode binFactorVarTreeNode = node as BinaryFactorVariableTreeNode;
223          FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode;
224          if(constantTreeNode != null)
225            c[i++] = constantTreeNode.Value;
226          else if(updateVariableWeights && variableTreeNode != null)
227            c[i++] = variableTreeNode.Weight;
228          else if(updateVariableWeights && binFactorVarTreeNode != null)
229            c[i++] = binFactorVarTreeNode.Weight;
230          else if(factorVarTreeNode != null) {
231            // gkronber: a factorVariableTreeNode holds a category-specific constant therefore we can consider factors to be the same as constants
232            foreach(var w in factorVarTreeNode.Weights) c[i++] = w;
233          }
234        }
235      }
236      double[] originalConstants = (double[])c.Clone();
237      double originalQuality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
238
239      alglib.lsfitstate state;
240      alglib.lsfitreport rep;
241      int retVal;
242
243      IDataset ds = problemData.Dataset;
244      double[,] x = new double[rows.Count(), parameters.Count];
245      int row = 0;
246      foreach(var r in rows) {
247        int col = 0;
248        foreach(var kvp in parameterEntries) {
249          var info = kvp.Key;
250          int lag = info.lag;
251          if(ds.VariableHasType<double>(info.variableName)) {
252            x[row, col] = ds.GetDoubleValue(info.variableName, r + lag);
253          } else if(ds.VariableHasType<string>(info.variableName)) {
254            x[row, col] = ds.GetStringValue(info.variableName, r) == info.variableValue ? 1 : 0;
255          } else throw new InvalidProgramException("found a variable of unknown type");
256          col++;
257        }
258        row++;
259      }
260      double[] y = ds.GetDoubleValues(problemData.TargetVariable, rows).ToArray();
261      int n = x.GetLength(0);
262      int m = x.GetLength(1);
263      int k = c.Length;
264
265      alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(compiledFunc);
266      alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(compiledFunc);
267
268      try {
269        alglib.lsfitcreatefg(x, y, c, n, m, k, false, out state);
270        alglib.lsfitsetcond(state, 0.0, 0.0, maxIterations);
271        //alglib.lsfitsetgradientcheck(state, 0.001);
272        alglib.lsfitfit(state, function_cx_1_func, function_cx_1_grad, null, null);
273        alglib.lsfitresults(state, out retVal, out c, out rep);
274      } catch(ArithmeticException) {
275        return originalQuality;
276      } catch(alglib.alglibexception) {
277        return originalQuality;
278      }
279
280      //retVal == -7  => constant optimization failed due to wrong gradient
281      if(retVal != -7) UpdateConstants(tree, c.Skip(2).ToArray(), updateVariableWeights);
282      var quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
283
284      if(!updateConstantsInTree) UpdateConstants(tree, originalConstants.Skip(2).ToArray(), updateVariableWeights);
285      if(originalQuality - quality > 0.001 || double.IsNaN(quality)) {
286        UpdateConstants(tree, originalConstants.Skip(2).ToArray(), updateVariableWeights);
287        return originalQuality;
288      }
289      return quality;
290    }
291
292    private static void UpdateConstants(ISymbolicExpressionTree tree, double[] constants, bool updateVariableWeights) {
293      int i = 0;
294      foreach(var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
295        ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
296        VariableTreeNode variableTreeNode = node as VariableTreeNode;
297        BinaryFactorVariableTreeNode binFactorVarTreeNode = node as BinaryFactorVariableTreeNode;
298        FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode;
299        if(constantTreeNode != null)
300          constantTreeNode.Value = constants[i++];
301        else if(updateVariableWeights && variableTreeNode != null)
302          variableTreeNode.Weight = constants[i++];
303        else if(updateVariableWeights && binFactorVarTreeNode != null)
304          binFactorVarTreeNode.Weight = constants[i++];
305        else if(factorVarTreeNode != null) {
306          for(int j = 0; j < factorVarTreeNode.Weights.Length; j++)
307            factorVarTreeNode.Weights[j] = constants[i++];
308        }
309      }
310    }
311
312    private static alglib.ndimensional_pfunc CreatePFunc(AutoDiff.IParametricCompiledTerm compiledFunc) {
313      return (double[] c, double[] x, ref double func, object o) => {
314        func = compiledFunc.Evaluate(c, x);
315      };
316    }
317
318    private static alglib.ndimensional_pgrad CreatePGrad(AutoDiff.IParametricCompiledTerm compiledFunc) {
319      return (double[] c, double[] x, ref double func, double[] grad, object o) => {
320        var tupel = compiledFunc.Differentiate(c, x);
321        func = tupel.Item2;
322        Array.Copy(tupel.Item1, grad, grad.Length);
323      };
324    }
325
326    private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node,
327      List<AutoDiff.Variable> variables, Dictionary<DataForVariable, AutoDiff.Variable> parameters,
328      bool updateVariableWeights, out AutoDiff.Term term) {
329      if(node.Symbol is Constant) {
330        var var = new AutoDiff.Variable();
331        variables.Add(var);
332        term = var;
333        return true;
334      }
335      if(node.Symbol is Variable || node.Symbol is BinaryFactorVariable) {
336        var varNode = node as VariableTreeNodeBase;
337        var factorVarNode = node as BinaryFactorVariableTreeNode;
338        // factor variable values are only 0 or 1 and set in x accordingly
339        var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;
340        var par = FindOrCreateParameter(parameters, varNode.VariableName, varValue);
341
342        if(updateVariableWeights) {
343          var w = new AutoDiff.Variable();
344          variables.Add(w);
345          term = AutoDiff.TermBuilder.Product(w, par);
346        } else {
347          term = varNode.Weight * par;
348        }
349        return true;
350      }
351      if(node.Symbol is FactorVariable) {
352        var factorVarNode = node as FactorVariableTreeNode;
353        var products = new List<Term>();
354        foreach(var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
355          var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);
356
357          var wVar = new AutoDiff.Variable();
358          variables.Add(wVar);
359
360          products.Add(AutoDiff.TermBuilder.Product(wVar, par));
361        }
362        term = AutoDiff.TermBuilder.Sum(products);
363        return true;
364      }
365      if(node.Symbol is LaggedVariable) {
366        var varNode = node as LaggedVariableTreeNode;
367        var par = FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag);
368
369        if(updateVariableWeights) {
370          var w = new AutoDiff.Variable();
371          variables.Add(w);
372          term = AutoDiff.TermBuilder.Product(w, par);
373        } else {
374          term = varNode.Weight * par;
375        }
376        return true;
377      }
378      if(node.Symbol is Addition) {
379        List<AutoDiff.Term> terms = new List<Term>();
380        foreach(var subTree in node.Subtrees) {
381          AutoDiff.Term t;
382          if(!TryTransformToAutoDiff(subTree, variables, parameters, updateVariableWeights, out t)) {
383            term = null;
384            return false;
385          }
386          terms.Add(t);
387        }
388        term = AutoDiff.TermBuilder.Sum(terms);
389        return true;
390      }
391      if(node.Symbol is Subtraction) {
392        List<AutoDiff.Term> terms = new List<Term>();
393        for(int i = 0; i < node.SubtreeCount; i++) {
394          AutoDiff.Term t;
395          if(!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, updateVariableWeights, out t)) {
396            term = null;
397            return false;
398          }
399          if(i > 0) t = -t;
400          terms.Add(t);
401        }
402        if(terms.Count == 1) term = -terms[0];
403        else term = AutoDiff.TermBuilder.Sum(terms);
404        return true;
405      }
406      if(node.Symbol is Multiplication) {
407        List<AutoDiff.Term> terms = new List<Term>();
408        foreach(var subTree in node.Subtrees) {
409          AutoDiff.Term t;
410          if(!TryTransformToAutoDiff(subTree, variables, parameters, updateVariableWeights, out t)) {
411            term = null;
412            return false;
413          }
414          terms.Add(t);
415        }
416        if(terms.Count == 1) term = terms[0];
417        else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, b));
418        return true;
419
420      }
421      if(node.Symbol is Division) {
422        List<AutoDiff.Term> terms = new List<Term>();
423        foreach(var subTree in node.Subtrees) {
424          AutoDiff.Term t;
425          if(!TryTransformToAutoDiff(subTree, variables, parameters, updateVariableWeights, out t)) {
426            term = null;
427            return false;
428          }
429          terms.Add(t);
430        }
431        if(terms.Count == 1) term = 1.0 / terms[0];
432        else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b));
433        return true;
434      }
435      if(node.Symbol is Logarithm) {
436        AutoDiff.Term t;
437        if(!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
438          term = null;
439          return false;
440        } else {
441          term = AutoDiff.TermBuilder.Log(t);
442          return true;
443        }
444      }
445      if(node.Symbol is Exponential) {
446        AutoDiff.Term t;
447        if(!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
448          term = null;
449          return false;
450        } else {
451          term = AutoDiff.TermBuilder.Exp(t);
452          return true;
453        }
454      }
455      if(node.Symbol is Square) {
456        AutoDiff.Term t;
457        if(!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
458          term = null;
459          return false;
460        } else {
461          term = AutoDiff.TermBuilder.Power(t, 2.0);
462          return true;
463        }
464      }
465      if(node.Symbol is SquareRoot) {
466        AutoDiff.Term t;
467        if(!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
468          term = null;
469          return false;
470        } else {
471          term = AutoDiff.TermBuilder.Power(t, 0.5);
472          return true;
473        }
474      }
475      if(node.Symbol is Sine) {
476        AutoDiff.Term t;
477        if(!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
478          term = null;
479          return false;
480        } else {
481          term = sin(t);
482          return true;
483        }
484      }
485      if(node.Symbol is Cosine) {
486        AutoDiff.Term t;
487        if(!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
488          term = null;
489          return false;
490        } else {
491          term = cos(t);
492          return true;
493        }
494      }
495      if(node.Symbol is Tangent) {
496        AutoDiff.Term t;
497        if(!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
498          term = null;
499          return false;
500        } else {
501          term = tan(t);
502          return true;
503        }
504      }
505      if(node.Symbol is Erf) {
506        AutoDiff.Term t;
507        if(!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
508          term = null;
509          return false;
510        } else {
511          term = erf(t);
512          return true;
513        }
514      }
515      if(node.Symbol is Norm) {
516        AutoDiff.Term t;
517        if(!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out t)) {
518          term = null;
519          return false;
520        } else {
521          term = norm(t);
522          return true;
523        }
524      }
525      if(node.Symbol is StartSymbol) {
526        var alpha = new AutoDiff.Variable();
527        var beta = new AutoDiff.Variable();
528        variables.Add(beta);
529        variables.Add(alpha);
530        AutoDiff.Term branchTerm;
531        if(TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, updateVariableWeights, out branchTerm)) {
532          term = branchTerm * alpha + beta;
533          return true;
534        } else {
535          term = null;
536          return false;
537        }
538      }
539      term = null;
540      return false;
541    }
542
543    // for each factor variable value we need a parameter which represents a binary indicator for that variable & value combination
544    // each binary indicator is only necessary once. So we only create a parameter if this combination is not yet available
545    private static Term FindOrCreateParameter(Dictionary<DataForVariable, AutoDiff.Variable> parameters,
546      string varName, string varValue = "", int lag = 0) {
547      var data = new DataForVariable(varName, varValue, lag);
548
549      AutoDiff.Variable par = null;
550      if(!parameters.TryGetValue(data, out par)) {
551        // not found -> create new parameter and entries in names and values lists
552        par = new AutoDiff.Variable();
553        parameters.Add(data, par);
554      }
555      return par;
556    }
557
558    public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) {
559      var containsUnknownSymbol = (
560        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
561        where
562         !(n.Symbol is Variable) &&
563         !(n.Symbol is BinaryFactorVariable) &&
564         !(n.Symbol is FactorVariable) &&
565         !(n.Symbol is LaggedVariable) &&
566         !(n.Symbol is Constant) &&
567         !(n.Symbol is Addition) &&
568         !(n.Symbol is Subtraction) &&
569         !(n.Symbol is Multiplication) &&
570         !(n.Symbol is Division) &&
571         !(n.Symbol is Logarithm) &&
572         !(n.Symbol is Exponential) &&
573         !(n.Symbol is SquareRoot) &&
574         !(n.Symbol is Square) &&
575         !(n.Symbol is Sine) &&
576         !(n.Symbol is Cosine) &&
577         !(n.Symbol is Tangent) &&
578         !(n.Symbol is Erf) &&
579         !(n.Symbol is Norm) &&
580         !(n.Symbol is StartSymbol)
581        select n).
582      Any();
583      return !containsUnknownSymbol;
584    }
585
586
587    #region helper class
588    private class DataForVariable {
589      public readonly string variableName;
590      public readonly string variableValue; // for factor vars
591      public readonly int lag;
592
593      public DataForVariable(string varName, string varValue, int lag) {
594        this.variableName = varName;
595        this.variableValue = varValue;
596        this.lag = lag;
597      }
598
599      public override bool Equals(object obj) {
600        var other = obj as DataForVariable;
601        if(other == null) return false;
602        return other.variableName.Equals(this.variableName) &&
603               other.variableValue.Equals(this.variableValue) &&
604               other.lag == this.lag;
605      }
606
607      public override int GetHashCode() {
608        return variableName.GetHashCode() ^ variableValue.GetHashCode() ^ lag;
609      }
610    }
611    #endregion
612  }
613}
Note: See TracBrowser for help on using the repository browser.