Free cookie consent management tool by TermsFeed Policy Generator

source: branches/symbreg-factors-2650/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/SymbolicRegressionConstantOptimizationEvaluator.cs @ 14251

Last change on this file since 14251 was 14251, checked in by gkronber, 8 years ago

#2650:

  • extended non-linear regression to work with factors
  • fixed bugs in constants optimizer and tree interpreter
  • improved simplification of factor variables
  • added support for factors to ERC view
  • added support for factors to solution comparison view
  • activated view for all factors
File size: 26.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2016 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using AutoDiff;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32
33namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
34  [Item("Constant Optimization Evaluator", "Calculates Pearson R² of a symbolic regression solution and optimizes the constant used.")]
35  [StorableClass]
36  public class SymbolicRegressionConstantOptimizationEvaluator : SymbolicRegressionSingleObjectiveEvaluator {
37    private const string ConstantOptimizationIterationsParameterName = "ConstantOptimizationIterations";
38    private const string ConstantOptimizationImprovementParameterName = "ConstantOptimizationImprovement";
39    private const string ConstantOptimizationProbabilityParameterName = "ConstantOptimizationProbability";
40    private const string ConstantOptimizationRowsPercentageParameterName = "ConstantOptimizationRowsPercentage";
41    private const string UpdateConstantsInTreeParameterName = "UpdateConstantsInSymbolicExpressionTree";
42    private const string UpdateVariableWeightsParameterName = "Update Variable Weights";
43
44    public IFixedValueParameter<IntValue> ConstantOptimizationIterationsParameter {
45      get { return (IFixedValueParameter<IntValue>)Parameters[ConstantOptimizationIterationsParameterName]; }
46    }
47    public IFixedValueParameter<DoubleValue> ConstantOptimizationImprovementParameter {
48      get { return (IFixedValueParameter<DoubleValue>)Parameters[ConstantOptimizationImprovementParameterName]; }
49    }
50    public IFixedValueParameter<PercentValue> ConstantOptimizationProbabilityParameter {
51      get { return (IFixedValueParameter<PercentValue>)Parameters[ConstantOptimizationProbabilityParameterName]; }
52    }
53    public IFixedValueParameter<PercentValue> ConstantOptimizationRowsPercentageParameter {
54      get { return (IFixedValueParameter<PercentValue>)Parameters[ConstantOptimizationRowsPercentageParameterName]; }
55    }
56    public IFixedValueParameter<BoolValue> UpdateConstantsInTreeParameter {
57      get { return (IFixedValueParameter<BoolValue>)Parameters[UpdateConstantsInTreeParameterName]; }
58    }
59    public IFixedValueParameter<BoolValue> UpdateVariableWeightsParameter {
60      get { return (IFixedValueParameter<BoolValue>)Parameters[UpdateVariableWeightsParameterName]; }
61    }
62
63
64    public IntValue ConstantOptimizationIterations {
65      get { return ConstantOptimizationIterationsParameter.Value; }
66    }
67    public DoubleValue ConstantOptimizationImprovement {
68      get { return ConstantOptimizationImprovementParameter.Value; }
69    }
70    public PercentValue ConstantOptimizationProbability {
71      get { return ConstantOptimizationProbabilityParameter.Value; }
72    }
73    public PercentValue ConstantOptimizationRowsPercentage {
74      get { return ConstantOptimizationRowsPercentageParameter.Value; }
75    }
76    public bool UpdateConstantsInTree {
77      get { return UpdateConstantsInTreeParameter.Value.Value; }
78      set { UpdateConstantsInTreeParameter.Value.Value = value; }
79    }
80
81    public bool UpdateVariableWeights {
82      get { return UpdateVariableWeightsParameter.Value.Value; }
83      set { UpdateVariableWeightsParameter.Value.Value = value; }
84    }
85
86    public override bool Maximization {
87      get { return true; }
88    }
89
90    [StorableConstructor]
91    protected SymbolicRegressionConstantOptimizationEvaluator(bool deserializing) : base(deserializing) { }
92    protected SymbolicRegressionConstantOptimizationEvaluator(SymbolicRegressionConstantOptimizationEvaluator original, Cloner cloner)
93      : base(original, cloner) {
94    }
95    public SymbolicRegressionConstantOptimizationEvaluator()
96      : base() {
97      Parameters.Add(new FixedValueParameter<IntValue>(ConstantOptimizationIterationsParameterName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree (0 indicates other or default stopping criterion).", new IntValue(10), true));
98      Parameters.Add(new FixedValueParameter<DoubleValue>(ConstantOptimizationImprovementParameterName, "Determines the relative improvement which must be achieved in the constant optimization to continue with it (0 indicates other or default stopping criterion).", new DoubleValue(0), true) { Hidden = true });
99      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationProbabilityParameterName, "Determines the probability that the constants are optimized", new PercentValue(1), true));
100      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationRowsPercentageParameterName, "Determines the percentage of the rows which should be used for constant optimization", new PercentValue(1), true));
101      Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)) { Hidden = true });
102      Parameters.Add(new FixedValueParameter<BoolValue>(UpdateVariableWeightsParameterName, "Determines if the variable weights in the tree should be  optimized.", new BoolValue(true)) { Hidden = true });
103    }
104
105    public override IDeepCloneable Clone(Cloner cloner) {
106      return new SymbolicRegressionConstantOptimizationEvaluator(this, cloner);
107    }
108
109    [StorableHook(HookType.AfterDeserialization)]
110    private void AfterDeserialization() {
111      if (!Parameters.ContainsKey(UpdateConstantsInTreeParameterName))
112        Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)));
113      if (!Parameters.ContainsKey(UpdateVariableWeightsParameterName))
114        Parameters.Add(new FixedValueParameter<BoolValue>(UpdateVariableWeightsParameterName, "Determines if the variable weights in the tree should be  optimized.", new BoolValue(true)));
115    }
116
117    public override IOperation InstrumentedApply() {
118      var solution = SymbolicExpressionTreeParameter.ActualValue;
119      double quality;
120      if (RandomParameter.ActualValue.NextDouble() < ConstantOptimizationProbability.Value) {
121        IEnumerable<int> constantOptimizationRows = GenerateRowsToEvaluate(ConstantOptimizationRowsPercentage.Value);
122        quality = OptimizeConstants(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, ProblemDataParameter.ActualValue,
123           constantOptimizationRows, ApplyLinearScalingParameter.ActualValue.Value, ConstantOptimizationIterations.Value, updateVariableWeights: UpdateVariableWeights, lowerEstimationLimit: EstimationLimitsParameter.ActualValue.Lower, upperEstimationLimit: EstimationLimitsParameter.ActualValue.Upper, updateConstantsInTree: UpdateConstantsInTree);
124
125        if (ConstantOptimizationRowsPercentage.Value != RelativeNumberOfEvaluatedSamplesParameter.ActualValue.Value) {
126          var evaluationRows = GenerateRowsToEvaluate();
127          quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value);
128        }
129      } else {
130        var evaluationRows = GenerateRowsToEvaluate();
131        quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value);
132      }
133      QualityParameter.ActualValue = new DoubleValue(quality);
134
135      return base.InstrumentedApply();
136    }
137
138    public override double Evaluate(IExecutionContext context, ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows) {
139      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = context;
140      EstimationLimitsParameter.ExecutionContext = context;
141      ApplyLinearScalingParameter.ExecutionContext = context;
142
143      // Pearson R² evaluator is used on purpose instead of the const-opt evaluator,
144      // because Evaluate() is used to get the quality of evolved models on
145      // different partitions of the dataset (e.g., best validation model)
146      double r2 = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, tree, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, problemData, rows, ApplyLinearScalingParameter.ActualValue.Value);
147
148      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = null;
149      EstimationLimitsParameter.ExecutionContext = null;
150      ApplyLinearScalingParameter.ExecutionContext = null;
151
152      return r2;
153    }
154
155    #region derivations of functions
156    // create function factory for arctangent
157    private readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
158      eval: Math.Atan,
159      diff: x => 1 / (1 + x * x));
160    private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
161      eval: Math.Sin,
162      diff: Math.Cos);
163    private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
164       eval: Math.Cos,
165       diff: x => -Math.Sin(x));
166    private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
167      eval: Math.Tan,
168      diff: x => 1 + Math.Tan(x) * Math.Tan(x));
169    private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
170      eval: alglib.errorfunction,
171      diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
172    private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
173      eval: alglib.normaldistribution,
174      diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
175    #endregion
176
177
178    public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows, bool applyLinearScaling, int maxIterations, bool updateVariableWeights = true, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue, bool updateConstantsInTree = true) {
179
180      List<AutoDiff.Variable> variables = new List<AutoDiff.Variable>();
181      List<AutoDiff.Variable> parameters = new List<AutoDiff.Variable>();
182      List<string> variableNames = new List<string>();
183      List<string> categoricalVariableValues = new List<string>();
184
185      AutoDiff.Term func;
186      if (!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out func))
187        throw new NotSupportedException("Could not optimize constants of symbolic expression tree due to not supported symbols used in the tree.");
188      if (variableNames.Count == 0) return 0.0; // gkronber: constant expressions always have a R² of 0.0
189
190      AutoDiff.IParametricCompiledTerm compiledFunc = func.Compile(variables.ToArray(), parameters.ToArray());
191
192      List<SymbolicExpressionTreeTerminalNode> terminalNodes = null; // gkronber only used for extraction of initial constants
193      if (updateVariableWeights)
194        terminalNodes = tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>().ToList();
195      else
196        terminalNodes = new List<SymbolicExpressionTreeTerminalNode>
197          (tree.Root.IterateNodesPrefix()
198          .OfType<SymbolicExpressionTreeTerminalNode>()
199          .Where(node => node is ConstantTreeNode || node is FactorVariableTreeNode));
200
201      //extract inital constants
202      double[] c = new double[variables.Count];
203      {
204        c[0] = 0.0;
205        c[1] = 1.0;
206        int i = 2;
207        foreach (var node in terminalNodes) {
208          ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
209          VariableTreeNode variableTreeNode = node as VariableTreeNode;
210          BinaryFactorVariableTreeNode binFactorVarTreeNode = node as BinaryFactorVariableTreeNode;
211          FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode;
212          if (constantTreeNode != null)
213            c[i++] = constantTreeNode.Value;
214          else if (updateVariableWeights && variableTreeNode != null)
215            c[i++] = variableTreeNode.Weight;
216          else if (updateVariableWeights && binFactorVarTreeNode != null)
217            c[i++] = binFactorVarTreeNode.Weight;
218          else if (factorVarTreeNode != null) {
219            // gkronber: a factorVariableTreeNode holds a category-specific constant therefore we can consider factors to be the same as constants
220            foreach (var w in factorVarTreeNode.Weights) c[i++] = w;
221          }
222        }
223      }
224      double[] originalConstants = (double[])c.Clone();
225      double originalQuality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
226
227      alglib.lsfitstate state;
228      alglib.lsfitreport rep;
229      int info;
230
231      IDataset ds = problemData.Dataset;
232      double[,] x = new double[rows.Count(), variableNames.Count];
233      int row = 0;
234      foreach (var r in rows) {
235        for (int col = 0; col < variableNames.Count; col++) {
236          if (ds.VariableHasType<double>(variableNames[col])) {
237            x[row, col] = ds.GetDoubleValue(variableNames[col], r);
238          } else if (ds.VariableHasType<string>(variableNames[col])) {
239            x[row, col] = ds.GetStringValue(variableNames[col], r) == categoricalVariableValues[col] ? 1 : 0;
240          } else throw new InvalidProgramException("found a variable of unknown type");
241        }
242        row++;
243      }
244      double[] y = ds.GetDoubleValues(problemData.TargetVariable, rows).ToArray();
245      int n = x.GetLength(0);
246      int m = x.GetLength(1);
247      int k = c.Length;
248
249      alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(compiledFunc);
250      alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(compiledFunc);
251
252      try {
253        alglib.lsfitcreatefg(x, y, c, n, m, k, false, out state);
254        alglib.lsfitsetcond(state, 0.0, 0.0, maxIterations);
255        //alglib.lsfitsetgradientcheck(state, 0.001);
256        alglib.lsfitfit(state, function_cx_1_func, function_cx_1_grad, null, null);
257        alglib.lsfitresults(state, out info, out c, out rep);
258      }
259      catch (ArithmeticException) {
260        return originalQuality;
261      }
262      catch (alglib.alglibexception) {
263        return originalQuality;
264      }
265
266      //info == -7  => constant optimization failed due to wrong gradient
267      if (info != -7) UpdateConstants(tree, c.Skip(2).ToArray(), updateVariableWeights);
268      var quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
269
270      if (!updateConstantsInTree) UpdateConstants(tree, originalConstants.Skip(2).ToArray(), updateVariableWeights);
271      if (originalQuality - quality > 0.001 || double.IsNaN(quality)) {
272        UpdateConstants(tree, originalConstants.Skip(2).ToArray(), updateVariableWeights);
273        return originalQuality;
274      }
275      return quality;
276    }
277
278    private static void UpdateConstants(ISymbolicExpressionTree tree, double[] constants, bool updateVariableWeights) {
279      int i = 0;
280      foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
281        ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
282        VariableTreeNode variableTreeNode = node as VariableTreeNode;
283        BinaryFactorVariableTreeNode binFactorVarTreeNode = node as BinaryFactorVariableTreeNode;
284        FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode;
285        if (constantTreeNode != null)
286          constantTreeNode.Value = constants[i++];
287        else if (updateVariableWeights && variableTreeNode != null)
288          variableTreeNode.Weight = constants[i++];
289        else if (updateVariableWeights && binFactorVarTreeNode != null)
290          binFactorVarTreeNode.Weight = constants[i++];
291        else if (factorVarTreeNode != null) {
292          for (int j = 0; j < factorVarTreeNode.Weights.Length; j++)
293            factorVarTreeNode.Weights[j] = constants[i++];
294        }
295      }
296    }
297
298    private static alglib.ndimensional_pfunc CreatePFunc(AutoDiff.IParametricCompiledTerm compiledFunc) {
299      return (double[] c, double[] x, ref double func, object o) => {
300        func = compiledFunc.Evaluate(c, x);
301      };
302    }
303
304    private static alglib.ndimensional_pgrad CreatePGrad(AutoDiff.IParametricCompiledTerm compiledFunc) {
305      return (double[] c, double[] x, ref double func, double[] grad, object o) => {
306        var tupel = compiledFunc.Differentiate(c, x);
307        func = tupel.Item2;
308        Array.Copy(tupel.Item1, grad, grad.Length);
309      };
310    }
311
312    private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, List<AutoDiff.Variable> variables, List<AutoDiff.Variable> parameters,
313      List<string> variableNames, List<string> categoricalVariableValues, bool updateVariableWeights, out AutoDiff.Term term) {
314      if (node.Symbol is Constant) {
315        var var = new AutoDiff.Variable();
316        variables.Add(var);
317        term = var;
318        return true;
319      }
320      if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) {
321        var varNode = node as VariableTreeNodeBase;
322        var factorVarNode = node as BinaryFactorVariableTreeNode;
323        // factor variable values are only 0 or 1 and set in x accordingly
324        var par = new AutoDiff.Variable();
325        parameters.Add(par);
326        variableNames.Add(varNode.VariableName);
327        categoricalVariableValues.Add(factorVarNode != null ? factorVarNode.VariableValue : string.Empty);
328
329        if (updateVariableWeights) {
330          var w = new AutoDiff.Variable();
331          variables.Add(w);
332          term = AutoDiff.TermBuilder.Product(w, par);
333        } else {
334          term = par;
335        }
336        return true;
337      }
338      if (node.Symbol is FactorVariable) {
339        var factorVarNode = node as FactorVariableTreeNode;
340        var products = new List<Term>();
341        foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
342          var par = new AutoDiff.Variable();
343          parameters.Add(par);
344          variableNames.Add(factorVarNode.VariableName);
345          categoricalVariableValues.Add(variableValue);
346
347          var wVar = new AutoDiff.Variable();
348          variables.Add(wVar);
349
350          products.Add(AutoDiff.TermBuilder.Product(wVar, par));
351        }
352        term = AutoDiff.TermBuilder.Sum(products);
353        return true;
354      }
355      if (node.Symbol is Addition) {
356        List<AutoDiff.Term> terms = new List<Term>();
357        foreach (var subTree in node.Subtrees) {
358          AutoDiff.Term t;
359          if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
360            term = null;
361            return false;
362          }
363          terms.Add(t);
364        }
365        term = AutoDiff.TermBuilder.Sum(terms);
366        return true;
367      }
368      if (node.Symbol is Subtraction) {
369        List<AutoDiff.Term> terms = new List<Term>();
370        for (int i = 0; i < node.SubtreeCount; i++) {
371          AutoDiff.Term t;
372          if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
373            term = null;
374            return false;
375          }
376          if (i > 0) t = -t;
377          terms.Add(t);
378        }
379        if (terms.Count == 1) term = -terms[0];
380        else term = AutoDiff.TermBuilder.Sum(terms);
381        return true;
382      }
383      if (node.Symbol is Multiplication) {
384        List<AutoDiff.Term> terms = new List<Term>();
385        foreach (var subTree in node.Subtrees) {
386          AutoDiff.Term t;
387          if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
388            term = null;
389            return false;
390          }
391          terms.Add(t);
392        }
393        if (terms.Count == 1) term = terms[0];
394        else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, b));
395        return true;
396
397      }
398      if (node.Symbol is Division) {
399        List<AutoDiff.Term> terms = new List<Term>();
400        foreach (var subTree in node.Subtrees) {
401          AutoDiff.Term t;
402          if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
403            term = null;
404            return false;
405          }
406          terms.Add(t);
407        }
408        if (terms.Count == 1) term = 1.0 / terms[0];
409        else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b));
410        return true;
411      }
412      if (node.Symbol is Logarithm) {
413        AutoDiff.Term t;
414        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
415          term = null;
416          return false;
417        } else {
418          term = AutoDiff.TermBuilder.Log(t);
419          return true;
420        }
421      }
422      if (node.Symbol is Exponential) {
423        AutoDiff.Term t;
424        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
425          term = null;
426          return false;
427        } else {
428          term = AutoDiff.TermBuilder.Exp(t);
429          return true;
430        }
431      }
432      if (node.Symbol is Square) {
433        AutoDiff.Term t;
434        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
435          term = null;
436          return false;
437        } else {
438          term = AutoDiff.TermBuilder.Power(t, 2.0);
439          return true;
440        }
441      }
442      if (node.Symbol is SquareRoot) {
443        AutoDiff.Term t;
444        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
445          term = null;
446          return false;
447        } else {
448          term = AutoDiff.TermBuilder.Power(t, 0.5);
449          return true;
450        }
451      }
452      if (node.Symbol is Sine) {
453        AutoDiff.Term t;
454        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
455          term = null;
456          return false;
457        } else {
458          term = sin(t);
459          return true;
460        }
461      }
462      if (node.Symbol is Cosine) {
463        AutoDiff.Term t;
464        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
465          term = null;
466          return false;
467        } else {
468          term = cos(t);
469          return true;
470        }
471      }
472      if (node.Symbol is Tangent) {
473        AutoDiff.Term t;
474        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
475          term = null;
476          return false;
477        } else {
478          term = tan(t);
479          return true;
480        }
481      }
482      if (node.Symbol is Erf) {
483        AutoDiff.Term t;
484        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
485          term = null;
486          return false;
487        } else {
488          term = erf(t);
489          return true;
490        }
491      }
492      if (node.Symbol is Norm) {
493        AutoDiff.Term t;
494        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out t)) {
495          term = null;
496          return false;
497        } else {
498          term = norm(t);
499          return true;
500        }
501      }
502      if (node.Symbol is StartSymbol) {
503        var alpha = new AutoDiff.Variable();
504        var beta = new AutoDiff.Variable();
505        variables.Add(beta);
506        variables.Add(alpha);
507        AutoDiff.Term branchTerm;
508        if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, categoricalVariableValues, updateVariableWeights, out branchTerm)) {
509          term = branchTerm * alpha + beta;
510          return true;
511        } else {
512          term = null;
513          return false;
514        }
515      }
516      term = null;
517      return false;
518    }
519
520    public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) {
521      var containsUnknownSymbol = (
522        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
523        where
524         !(n.Symbol is Variable) &&
525         !(n.Symbol is BinaryFactorVariable) &&
526         !(n.Symbol is FactorVariable) &&
527         !(n.Symbol is Constant) &&
528         !(n.Symbol is Addition) &&
529         !(n.Symbol is Subtraction) &&
530         !(n.Symbol is Multiplication) &&
531         !(n.Symbol is Division) &&
532         !(n.Symbol is Logarithm) &&
533         !(n.Symbol is Exponential) &&
534         !(n.Symbol is SquareRoot) &&
535         !(n.Symbol is Square) &&
536         !(n.Symbol is Sine) &&
537         !(n.Symbol is Cosine) &&
538         !(n.Symbol is Tangent) &&
539         !(n.Symbol is Erf) &&
540         !(n.Symbol is Norm) &&
541         !(n.Symbol is StartSymbol)
542        select n).
543      Any();
544      return !containsUnknownSymbol;
545    }
546  }
547}
Note: See TracBrowser for help on using the repository browser.