Free cookie consent management tool by TermsFeed Policy Generator

source: branches/symbreg-factors-2650/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/SymbolicRegressionConstantOptimizationEvaluator.cs @ 14402

Last change on this file since 14402 was 14402, checked in by gkronber, 8 years ago

#2650: fixed a bug in constants optimizer in relation to lagged variables

File size: 28.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Diagnostics.Contracts;
25using System.Linq;
26using AutoDiff;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
31using HeuristicLab.Parameters;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33
34namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
35  [Item("Constant Optimization Evaluator", "Calculates Pearson R² of a symbolic regression solution and optimizes the constant used.")]
36  [StorableClass]
37  public class SymbolicRegressionConstantOptimizationEvaluator : SymbolicRegressionSingleObjectiveEvaluator {
38    private const string ConstantOptimizationIterationsParameterName = "ConstantOptimizationIterations";
39    private const string ConstantOptimizationImprovementParameterName = "ConstantOptimizationImprovement";
40    private const string ConstantOptimizationProbabilityParameterName = "ConstantOptimizationProbability";
41    private const string ConstantOptimizationRowsPercentageParameterName = "ConstantOptimizationRowsPercentage";
42    private const string UpdateConstantsInTreeParameterName = "UpdateConstantsInSymbolicExpressionTree";
43    private const string UpdateVariableWeightsParameterName = "Update Variable Weights";
44
45    public IFixedValueParameter<IntValue> ConstantOptimizationIterationsParameter {
46      get { return (IFixedValueParameter<IntValue>)Parameters[ConstantOptimizationIterationsParameterName]; }
47    }
48    public IFixedValueParameter<DoubleValue> ConstantOptimizationImprovementParameter {
49      get { return (IFixedValueParameter<DoubleValue>)Parameters[ConstantOptimizationImprovementParameterName]; }
50    }
51    public IFixedValueParameter<PercentValue> ConstantOptimizationProbabilityParameter {
52      get { return (IFixedValueParameter<PercentValue>)Parameters[ConstantOptimizationProbabilityParameterName]; }
53    }
54    public IFixedValueParameter<PercentValue> ConstantOptimizationRowsPercentageParameter {
55      get { return (IFixedValueParameter<PercentValue>)Parameters[ConstantOptimizationRowsPercentageParameterName]; }
56    }
57    public IFixedValueParameter<BoolValue> UpdateConstantsInTreeParameter {
58      get { return (IFixedValueParameter<BoolValue>)Parameters[UpdateConstantsInTreeParameterName]; }
59    }
60    public IFixedValueParameter<BoolValue> UpdateVariableWeightsParameter {
61      get { return (IFixedValueParameter<BoolValue>)Parameters[UpdateVariableWeightsParameterName]; }
62    }
63
64
65    public IntValue ConstantOptimizationIterations {
66      get { return ConstantOptimizationIterationsParameter.Value; }
67    }
68    public DoubleValue ConstantOptimizationImprovement {
69      get { return ConstantOptimizationImprovementParameter.Value; }
70    }
71    public PercentValue ConstantOptimizationProbability {
72      get { return ConstantOptimizationProbabilityParameter.Value; }
73    }
74    public PercentValue ConstantOptimizationRowsPercentage {
75      get { return ConstantOptimizationRowsPercentageParameter.Value; }
76    }
77    public bool UpdateConstantsInTree {
78      get { return UpdateConstantsInTreeParameter.Value.Value; }
79      set { UpdateConstantsInTreeParameter.Value.Value = value; }
80    }
81
82    public bool UpdateVariableWeights {
83      get { return UpdateVariableWeightsParameter.Value.Value; }
84      set { UpdateVariableWeightsParameter.Value.Value = value; }
85    }
86
87    public override bool Maximization {
88      get { return true; }
89    }
90
91    [StorableConstructor]
92    protected SymbolicRegressionConstantOptimizationEvaluator(bool deserializing) : base(deserializing) { }
93    protected SymbolicRegressionConstantOptimizationEvaluator(SymbolicRegressionConstantOptimizationEvaluator original, Cloner cloner)
94      : base(original, cloner) {
95    }
96    public SymbolicRegressionConstantOptimizationEvaluator()
97      : base() {
98      Parameters.Add(new FixedValueParameter<IntValue>(ConstantOptimizationIterationsParameterName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree (0 indicates other or default stopping criterion).", new IntValue(10), true));
99      Parameters.Add(new FixedValueParameter<DoubleValue>(ConstantOptimizationImprovementParameterName, "Determines the relative improvement which must be achieved in the constant optimization to continue with it (0 indicates other or default stopping criterion).", new DoubleValue(0), true) { Hidden = true });
100      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationProbabilityParameterName, "Determines the probability that the constants are optimized", new PercentValue(1), true));
101      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationRowsPercentageParameterName, "Determines the percentage of the rows which should be used for constant optimization", new PercentValue(1), true));
102      Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)) { Hidden = true });
103      Parameters.Add(new FixedValueParameter<BoolValue>(UpdateVariableWeightsParameterName, "Determines if the variable weights in the tree should be  optimized.", new BoolValue(true)) { Hidden = true });
104    }
105
106    public override IDeepCloneable Clone(Cloner cloner) {
107      return new SymbolicRegressionConstantOptimizationEvaluator(this, cloner);
108    }
109
110    [StorableHook(HookType.AfterDeserialization)]
111    private void AfterDeserialization() {
112      if (!Parameters.ContainsKey(UpdateConstantsInTreeParameterName))
113        Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)));
114      if (!Parameters.ContainsKey(UpdateVariableWeightsParameterName))
115        Parameters.Add(new FixedValueParameter<BoolValue>(UpdateVariableWeightsParameterName, "Determines if the variable weights in the tree should be  optimized.", new BoolValue(true)));
116    }
117
118    public override IOperation InstrumentedApply() {
119      var solution = SymbolicExpressionTreeParameter.ActualValue;
120      double quality;
121      if (RandomParameter.ActualValue.NextDouble() < ConstantOptimizationProbability.Value) {
122        IEnumerable<int> constantOptimizationRows = GenerateRowsToEvaluate(ConstantOptimizationRowsPercentage.Value);
123        quality = OptimizeConstants(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, ProblemDataParameter.ActualValue,
124           constantOptimizationRows, ApplyLinearScalingParameter.ActualValue.Value, ConstantOptimizationIterations.Value, updateVariableWeights: UpdateVariableWeights, lowerEstimationLimit: EstimationLimitsParameter.ActualValue.Lower, upperEstimationLimit: EstimationLimitsParameter.ActualValue.Upper, updateConstantsInTree: UpdateConstantsInTree);
125
126        if (ConstantOptimizationRowsPercentage.Value != RelativeNumberOfEvaluatedSamplesParameter.ActualValue.Value) {
127          var evaluationRows = GenerateRowsToEvaluate();
128          quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value);
129        }
130      } else {
131        var evaluationRows = GenerateRowsToEvaluate();
132        quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value);
133      }
134      QualityParameter.ActualValue = new DoubleValue(quality);
135
136      return base.InstrumentedApply();
137    }
138
139    public override double Evaluate(IExecutionContext context, ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows) {
140      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = context;
141      EstimationLimitsParameter.ExecutionContext = context;
142      ApplyLinearScalingParameter.ExecutionContext = context;
143
144      // Pearson R² evaluator is used on purpose instead of the const-opt evaluator,
145      // because Evaluate() is used to get the quality of evolved models on
146      // different partitions of the dataset (e.g., best validation model)
147      double r2 = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, tree, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, problemData, rows, ApplyLinearScalingParameter.ActualValue.Value);
148
149      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = null;
150      EstimationLimitsParameter.ExecutionContext = null;
151      ApplyLinearScalingParameter.ExecutionContext = null;
152
153      return r2;
154    }
155
156    #region derivations of functions
157    // create function factory for arctangent
158    private readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
159      eval: Math.Atan,
160      diff: x => 1 / (1 + x * x));
161    private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
162      eval: Math.Sin,
163      diff: Math.Cos);
164    private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
165       eval: Math.Cos,
166       diff: x => -Math.Sin(x));
167    private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
168      eval: Math.Tan,
169      diff: x => 1 + Math.Tan(x) * Math.Tan(x));
170    private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
171      eval: alglib.errorfunction,
172      diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
173    private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
174      eval: alglib.normaldistribution,
175      diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
176    #endregion
177
178
179    public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows, bool applyLinearScaling, int maxIterations, bool updateVariableWeights = true, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue, bool updateConstantsInTree = true) {
180
181      List<AutoDiff.Variable> variables = new List<AutoDiff.Variable>();
182      List<AutoDiff.Variable> parameters = new List<AutoDiff.Variable>();
183      List<string> variableNames = new List<string>();
184      List<string> categoricalVariableValues = new List<string>();
185      List<int> lags = new List<int>();
186
187      AutoDiff.Term func;
188      if (!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out func))
189        throw new NotSupportedException("Could not optimize constants of symbolic expression tree due to not supported symbols used in the tree.");
190      if (variableNames.Count == 0) return 0.0; // gkronber: constant expressions always have a R² of 0.0
191
192      AutoDiff.IParametricCompiledTerm compiledFunc = func.Compile(variables.ToArray(), parameters.ToArray());
193
194      List<SymbolicExpressionTreeTerminalNode> terminalNodes = null; // gkronber only used for extraction of initial constants
195      if (updateVariableWeights)
196        terminalNodes = tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>().ToList();
197      else
198        terminalNodes = new List<SymbolicExpressionTreeTerminalNode>
199          (tree.Root.IterateNodesPrefix()
200          .OfType<SymbolicExpressionTreeTerminalNode>()
201          .Where(node => node is ConstantTreeNode || node is FactorVariableTreeNode));
202
203      //extract inital constants
204      double[] c = new double[variables.Count];
205      {
206        c[0] = 0.0;
207        c[1] = 1.0;
208        int i = 2;
209        foreach (var node in terminalNodes) {
210          ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
211          VariableTreeNode variableTreeNode = node as VariableTreeNode;
212          BinaryFactorVariableTreeNode binFactorVarTreeNode = node as BinaryFactorVariableTreeNode;
213          FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode;
214          if (constantTreeNode != null)
215            c[i++] = constantTreeNode.Value;
216          else if (updateVariableWeights && variableTreeNode != null)
217            c[i++] = variableTreeNode.Weight;
218          else if (updateVariableWeights && binFactorVarTreeNode != null)
219            c[i++] = binFactorVarTreeNode.Weight;
220          else if (factorVarTreeNode != null) {
221            // gkronber: a factorVariableTreeNode holds a category-specific constant therefore we can consider factors to be the same as constants
222            foreach (var w in factorVarTreeNode.Weights) c[i++] = w;
223          }
224        }
225      }
226      double[] originalConstants = (double[])c.Clone();
227      double originalQuality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
228
229      alglib.lsfitstate state;
230      alglib.lsfitreport rep;
231      int info;
232
233      IDataset ds = problemData.Dataset;
234      double[,] x = new double[rows.Count(), variableNames.Count];
235      int row = 0;
236      foreach (var r in rows) {
237        for (int col = 0; col < variableNames.Count; col++) {
238          int lag = lags[col];
239          if (ds.VariableHasType<double>(variableNames[col])) {
240            x[row, col] = ds.GetDoubleValue(variableNames[col], r + lag);
241          } else if (ds.VariableHasType<string>(variableNames[col])) {
242            x[row, col] = ds.GetStringValue(variableNames[col], r) == categoricalVariableValues[col] ? 1 : 0;
243          } else throw new InvalidProgramException("found a variable of unknown type");
244        }
245        row++;
246      }
247      double[] y = ds.GetDoubleValues(problemData.TargetVariable, rows).ToArray();
248      int n = x.GetLength(0);
249      int m = x.GetLength(1);
250      int k = c.Length;
251
252      alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(compiledFunc);
253      alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(compiledFunc);
254
255      try {
256        alglib.lsfitcreatefg(x, y, c, n, m, k, false, out state);
257        alglib.lsfitsetcond(state, 0.0, 0.0, maxIterations);
258        //alglib.lsfitsetgradientcheck(state, 0.001);
259        alglib.lsfitfit(state, function_cx_1_func, function_cx_1_grad, null, null);
260        alglib.lsfitresults(state, out info, out c, out rep);
261      } catch (ArithmeticException) {
262        return originalQuality;
263      } catch (alglib.alglibexception) {
264        return originalQuality;
265      }
266
267      //info == -7  => constant optimization failed due to wrong gradient
268      if (info != -7) UpdateConstants(tree, c.Skip(2).ToArray(), updateVariableWeights);
269      var quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
270
271      if (!updateConstantsInTree) UpdateConstants(tree, originalConstants.Skip(2).ToArray(), updateVariableWeights);
272      if (originalQuality - quality > 0.001 || double.IsNaN(quality)) {
273        UpdateConstants(tree, originalConstants.Skip(2).ToArray(), updateVariableWeights);
274        return originalQuality;
275      }
276      return quality;
277    }
278
279    private static void UpdateConstants(ISymbolicExpressionTree tree, double[] constants, bool updateVariableWeights) {
280      int i = 0;
281      foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
282        ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
283        VariableTreeNode variableTreeNode = node as VariableTreeNode;
284        BinaryFactorVariableTreeNode binFactorVarTreeNode = node as BinaryFactorVariableTreeNode;
285        FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode;
286        if (constantTreeNode != null)
287          constantTreeNode.Value = constants[i++];
288        else if (updateVariableWeights && variableTreeNode != null)
289          variableTreeNode.Weight = constants[i++];
290        else if (updateVariableWeights && binFactorVarTreeNode != null)
291          binFactorVarTreeNode.Weight = constants[i++];
292        else if (factorVarTreeNode != null) {
293          for (int j = 0; j < factorVarTreeNode.Weights.Length; j++)
294            factorVarTreeNode.Weights[j] = constants[i++];
295        }
296      }
297    }
298
299    private static alglib.ndimensional_pfunc CreatePFunc(AutoDiff.IParametricCompiledTerm compiledFunc) {
300      return (double[] c, double[] x, ref double func, object o) => {
301        func = compiledFunc.Evaluate(c, x);
302      };
303    }
304
305    private static alglib.ndimensional_pgrad CreatePGrad(AutoDiff.IParametricCompiledTerm compiledFunc) {
306      return (double[] c, double[] x, ref double func, double[] grad, object o) => {
307        var tupel = compiledFunc.Differentiate(c, x);
308        func = tupel.Item2;
309        Array.Copy(tupel.Item1, grad, grad.Length);
310      };
311    }
312
313    private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, List<AutoDiff.Variable> variables, List<AutoDiff.Variable> parameters,
314      List<string> variableNames, List<int> lags, List<string> categoricalVariableValues, bool updateVariableWeights, out AutoDiff.Term term) {
315      if (node.Symbol is Constant) {
316        var var = new AutoDiff.Variable();
317        variables.Add(var);
318        term = var;
319        return true;
320      }
321      if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) {
322        var varNode = node as VariableTreeNodeBase;
323        var factorVarNode = node as BinaryFactorVariableTreeNode;
324        // factor variable values are only 0 or 1 and set in x accordingly
325        var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;
326        var par = FindOrCreateParameter(varNode.VariableName, varValue, parameters, variableNames, categoricalVariableValues);
327        lags.Add(0);
328
329        if (updateVariableWeights) {
330          var w = new AutoDiff.Variable();
331          variables.Add(w);
332          term = AutoDiff.TermBuilder.Product(w, par);
333        } else {
334          term = varNode.Weight * par;
335        }
336        return true;
337      }
338      if (node.Symbol is FactorVariable) {
339        var factorVarNode = node as FactorVariableTreeNode;
340        var products = new List<Term>();
341        foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
342          var par = FindOrCreateParameter(factorVarNode.VariableName, variableValue, parameters, variableNames, categoricalVariableValues);
343          lags.Add(0);
344
345          var wVar = new AutoDiff.Variable();
346          variables.Add(wVar);
347
348          products.Add(AutoDiff.TermBuilder.Product(wVar, par));
349        }
350        term = AutoDiff.TermBuilder.Sum(products);
351        return true;
352      }
353      if (node.Symbol is LaggedVariable) {
354        var varNode = node as LaggedVariableTreeNode;
355        var par = new AutoDiff.Variable();
356        parameters.Add(par);
357        variableNames.Add(varNode.VariableName);
358        lags.Add(varNode.Lag);
359
360        if (updateVariableWeights) {
361          var w = new AutoDiff.Variable();
362          variables.Add(w);
363          term = AutoDiff.TermBuilder.Product(w, par);
364        } else {
365          term = varNode.Weight * par;
366        }
367        return true;
368      }
369      if (node.Symbol is Addition) {
370        List<AutoDiff.Term> terms = new List<Term>();
371        foreach (var subTree in node.Subtrees) {
372          AutoDiff.Term t;
373          if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out t)) {
374            term = null;
375            return false;
376          }
377          terms.Add(t);
378        }
379        term = AutoDiff.TermBuilder.Sum(terms);
380        return true;
381      }
382      if (node.Symbol is Subtraction) {
383        List<AutoDiff.Term> terms = new List<Term>();
384        for (int i = 0; i < node.SubtreeCount; i++) {
385          AutoDiff.Term t;
386          if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out t)) {
387            term = null;
388            return false;
389          }
390          if (i > 0) t = -t;
391          terms.Add(t);
392        }
393        if (terms.Count == 1) term = -terms[0];
394        else term = AutoDiff.TermBuilder.Sum(terms);
395        return true;
396      }
397      if (node.Symbol is Multiplication) {
398        List<AutoDiff.Term> terms = new List<Term>();
399        foreach (var subTree in node.Subtrees) {
400          AutoDiff.Term t;
401          if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out t)) {
402            term = null;
403            return false;
404          }
405          terms.Add(t);
406        }
407        if (terms.Count == 1) term = terms[0];
408        else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, b));
409        return true;
410
411      }
412      if (node.Symbol is Division) {
413        List<AutoDiff.Term> terms = new List<Term>();
414        foreach (var subTree in node.Subtrees) {
415          AutoDiff.Term t;
416          if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out t)) {
417            term = null;
418            return false;
419          }
420          terms.Add(t);
421        }
422        if (terms.Count == 1) term = 1.0 / terms[0];
423        else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b));
424        return true;
425      }
426      if (node.Symbol is Logarithm) {
427        AutoDiff.Term t;
428        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out t)) {
429          term = null;
430          return false;
431        } else {
432          term = AutoDiff.TermBuilder.Log(t);
433          return true;
434        }
435      }
436      if (node.Symbol is Exponential) {
437        AutoDiff.Term t;
438        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out t)) {
439          term = null;
440          return false;
441        } else {
442          term = AutoDiff.TermBuilder.Exp(t);
443          return true;
444        }
445      }
446      if (node.Symbol is Square) {
447        AutoDiff.Term t;
448        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out t)) {
449          term = null;
450          return false;
451        } else {
452          term = AutoDiff.TermBuilder.Power(t, 2.0);
453          return true;
454        }
455      }
456      if (node.Symbol is SquareRoot) {
457        AutoDiff.Term t;
458        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out t)) {
459          term = null;
460          return false;
461        } else {
462          term = AutoDiff.TermBuilder.Power(t, 0.5);
463          return true;
464        }
465      }
466      if (node.Symbol is Sine) {
467        AutoDiff.Term t;
468        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out t)) {
469          term = null;
470          return false;
471        } else {
472          term = sin(t);
473          return true;
474        }
475      }
476      if (node.Symbol is Cosine) {
477        AutoDiff.Term t;
478        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out t)) {
479          term = null;
480          return false;
481        } else {
482          term = cos(t);
483          return true;
484        }
485      }
486      if (node.Symbol is Tangent) {
487        AutoDiff.Term t;
488        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out t)) {
489          term = null;
490          return false;
491        } else {
492          term = tan(t);
493          return true;
494        }
495      }
496      if (node.Symbol is Erf) {
497        AutoDiff.Term t;
498        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out t)) {
499          term = null;
500          return false;
501        } else {
502          term = erf(t);
503          return true;
504        }
505      }
506      if (node.Symbol is Norm) {
507        AutoDiff.Term t;
508        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out t)) {
509          term = null;
510          return false;
511        } else {
512          term = norm(t);
513          return true;
514        }
515      }
516      if (node.Symbol is StartSymbol) {
517        var alpha = new AutoDiff.Variable();
518        var beta = new AutoDiff.Variable();
519        variables.Add(beta);
520        variables.Add(alpha);
521        AutoDiff.Term branchTerm;
522        if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, lags, categoricalVariableValues, updateVariableWeights, out branchTerm)) {
523          term = branchTerm * alpha + beta;
524          return true;
525        } else {
526          term = null;
527          return false;
528        }
529      }
530      term = null;
531      return false;
532    }
533
534    // for each factor variable value we need a parameter which represents a binary indicator for that variable & value combination
535    // each binary indicator is only necessary once. So we only create a parameter if this combination is not yet available
536    private static Term FindOrCreateParameter(string varName, string varValue,
537      List<AutoDiff.Variable> parameters, List<string> variableNames, List<string> variableValues) {
538      Contract.Assert(variableNames.Count == variableValues.Count);
539      int idx = -1;
540      for (int i = 0; i < variableNames.Count; i++) {
541        if (variableNames[i] == varName && variableValues[i] == varValue) {
542          idx = i;
543          break;
544        }
545      }
546
547      AutoDiff.Variable par = null;
548      if (idx == -1) {
549        // not found -> create new parameter and entries in names and values lists
550        par = new AutoDiff.Variable();
551        parameters.Add(par);
552        variableNames.Add(varName);
553        variableValues.Add(varValue);
554      } else {
555        par = parameters[idx];
556      }
557      return par;
558    }
559
560    public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) {
561      var containsUnknownSymbol = (
562        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
563        where
564         !(n.Symbol is Variable) &&
565         !(n.Symbol is BinaryFactorVariable) &&
566         !(n.Symbol is FactorVariable) &&
567         !(n.Symbol is LaggedVariable) &&
568         !(n.Symbol is Constant) &&
569         !(n.Symbol is Addition) &&
570         !(n.Symbol is Subtraction) &&
571         !(n.Symbol is Multiplication) &&
572         !(n.Symbol is Division) &&
573         !(n.Symbol is Logarithm) &&
574         !(n.Symbol is Exponential) &&
575         !(n.Symbol is SquareRoot) &&
576         !(n.Symbol is Square) &&
577         !(n.Symbol is Sine) &&
578         !(n.Symbol is Cosine) &&
579         !(n.Symbol is Tangent) &&
580         !(n.Symbol is Erf) &&
581         !(n.Symbol is Norm) &&
582         !(n.Symbol is StartSymbol)
583        select n).
584      Any();
585      return !containsUnknownSymbol;
586    }
587  }
588}
Note: See TracBrowser for help on using the repository browser.