Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/SymbolicRegressionConstantOptimizationEvaluator.cs @ 13670

Last change on this file since 13670 was 13670, checked in by mkommend, 8 years ago

#2584: Added parameter in constant optimization that determines whether variable weights should be modified.

File size: 24.1 KB
RevLine 
[6256]1#region License Information
2/* HeuristicLab
[12012]3 * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[6256]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
[8704]22using System;
[6256]23using System.Collections.Generic;
24using System.Linq;
[8704]25using AutoDiff;
[6256]26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32
33namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
[6555]34  [Item("Constant Optimization Evaluator", "Calculates Pearson R² of a symbolic regression solution and optimizes the constant used.")]
[6256]35  [StorableClass]
36  public class SymbolicRegressionConstantOptimizationEvaluator : SymbolicRegressionSingleObjectiveEvaluator {
37    private const string ConstantOptimizationIterationsParameterName = "ConstantOptimizationIterations";
38    private const string ConstantOptimizationImprovementParameterName = "ConstantOptimizationImprovement";
39    private const string ConstantOptimizationProbabilityParameterName = "ConstantOptimizationProbability";
40    private const string ConstantOptimizationRowsPercentageParameterName = "ConstantOptimizationRowsPercentage";
[8823]41    private const string UpdateConstantsInTreeParameterName = "UpdateConstantsInSymbolicExpressionTree";
[13670]42    private const string UpdateVariableWeightsParameterName = "Update Variable Weights";
[6256]43
44    public IFixedValueParameter<IntValue> ConstantOptimizationIterationsParameter {
45      get { return (IFixedValueParameter<IntValue>)Parameters[ConstantOptimizationIterationsParameterName]; }
46    }
47    public IFixedValueParameter<DoubleValue> ConstantOptimizationImprovementParameter {
48      get { return (IFixedValueParameter<DoubleValue>)Parameters[ConstantOptimizationImprovementParameterName]; }
49    }
50    public IFixedValueParameter<PercentValue> ConstantOptimizationProbabilityParameter {
51      get { return (IFixedValueParameter<PercentValue>)Parameters[ConstantOptimizationProbabilityParameterName]; }
52    }
53    public IFixedValueParameter<PercentValue> ConstantOptimizationRowsPercentageParameter {
54      get { return (IFixedValueParameter<PercentValue>)Parameters[ConstantOptimizationRowsPercentageParameterName]; }
55    }
[8823]56    public IFixedValueParameter<BoolValue> UpdateConstantsInTreeParameter {
57      get { return (IFixedValueParameter<BoolValue>)Parameters[UpdateConstantsInTreeParameterName]; }
58    }
[13670]59    public IFixedValueParameter<BoolValue> UpdateVariableWeightsParameter {
60      get { return (IFixedValueParameter<BoolValue>)Parameters[UpdateVariableWeightsParameterName]; }
61    }
[6256]62
[13670]63
[6256]64    public IntValue ConstantOptimizationIterations {
65      get { return ConstantOptimizationIterationsParameter.Value; }
66    }
67    public DoubleValue ConstantOptimizationImprovement {
68      get { return ConstantOptimizationImprovementParameter.Value; }
69    }
70    public PercentValue ConstantOptimizationProbability {
71      get { return ConstantOptimizationProbabilityParameter.Value; }
72    }
73    public PercentValue ConstantOptimizationRowsPercentage {
74      get { return ConstantOptimizationRowsPercentageParameter.Value; }
75    }
[8823]76    public bool UpdateConstantsInTree {
77      get { return UpdateConstantsInTreeParameter.Value.Value; }
78      set { UpdateConstantsInTreeParameter.Value.Value = value; }
79    }
[6256]80
[13670]81    public bool UpdateVariableWeights {
82      get { return UpdateVariableWeightsParameter.Value.Value; }
83      set { UpdateVariableWeightsParameter.Value.Value = value; }
84    }
85
[6256]86    public override bool Maximization {
87      get { return true; }
88    }
89
90    [StorableConstructor]
91    protected SymbolicRegressionConstantOptimizationEvaluator(bool deserializing) : base(deserializing) { }
92    protected SymbolicRegressionConstantOptimizationEvaluator(SymbolicRegressionConstantOptimizationEvaluator original, Cloner cloner)
93      : base(original, cloner) {
94    }
95    public SymbolicRegressionConstantOptimizationEvaluator()
96      : base() {
[8938]97      Parameters.Add(new FixedValueParameter<IntValue>(ConstantOptimizationIterationsParameterName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree (0 indicates other or default stopping criterion).", new IntValue(10), true));
[6256]98      Parameters.Add(new FixedValueParameter<DoubleValue>(ConstantOptimizationImprovementParameterName, "Determines the relative improvement which must be achieved in the constant optimization to continue with it (0 indicates other or default stopping criterion).", new DoubleValue(0), true));
99      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationProbabilityParameterName, "Determines the probability that the constants are optimized", new PercentValue(1), true));
100      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationRowsPercentageParameterName, "Determines the percentage of the rows which should be used for constant optimization", new PercentValue(1), true));
[8823]101      Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)));
[13670]102      Parameters.Add(new FixedValueParameter<BoolValue>(UpdateVariableWeightsParameterName, "Determines if the variable weights in the tree should be  optimized.", new BoolValue(true)));
[6256]103    }
104
105    public override IDeepCloneable Clone(Cloner cloner) {
106      return new SymbolicRegressionConstantOptimizationEvaluator(this, cloner);
107    }
108
[8823]109    [StorableHook(HookType.AfterDeserialization)]
110    private void AfterDeserialization() {
111      if (!Parameters.ContainsKey(UpdateConstantsInTreeParameterName))
112        Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)));
[13670]113      if (!Parameters.ContainsKey(UpdateVariableWeightsParameterName))
114        Parameters.Add(new FixedValueParameter<BoolValue>(UpdateVariableWeightsParameterName, "Determines if the variable weights in the tree should be  optimized.", new BoolValue(true)));
[8823]115    }
116
[10291]117    public override IOperation InstrumentedApply() {
[6256]118      var solution = SymbolicExpressionTreeParameter.ActualValue;
119      double quality;
120      if (RandomParameter.ActualValue.NextDouble() < ConstantOptimizationProbability.Value) {
121        IEnumerable<int> constantOptimizationRows = GenerateRowsToEvaluate(ConstantOptimizationRowsPercentage.Value);
122        quality = OptimizeConstants(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, ProblemDataParameter.ActualValue,
[13670]123           constantOptimizationRows, ApplyLinearScalingParameter.ActualValue.Value, ConstantOptimizationIterations.Value, updateVariableWeights: UpdateVariableWeights, lowerEstimationLimit: EstimationLimitsParameter.ActualValue.Lower, upperEstimationLimit: EstimationLimitsParameter.ActualValue.Upper, updateConstantsInTree: UpdateConstantsInTree);
[8938]124
[6256]125        if (ConstantOptimizationRowsPercentage.Value != RelativeNumberOfEvaluatedSamplesParameter.ActualValue.Value) {
126          var evaluationRows = GenerateRowsToEvaluate();
[8664]127          quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value);
[6256]128        }
129      } else {
130        var evaluationRows = GenerateRowsToEvaluate();
[8664]131        quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value);
[6256]132      }
133      QualityParameter.ActualValue = new DoubleValue(quality);
134
[10291]135      return base.InstrumentedApply();
[6256]136    }
137
138    public override double Evaluate(IExecutionContext context, ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows) {
139      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = context;
140      EstimationLimitsParameter.ExecutionContext = context;
[8664]141      ApplyLinearScalingParameter.ExecutionContext = context;
[6256]142
[9209]143      // Pearson R² evaluator is used on purpose instead of the const-opt evaluator,
144      // because Evaluate() is used to get the quality of evolved models on
145      // different partitions of the dataset (e.g., best validation model)
[8664]146      double r2 = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, tree, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, problemData, rows, ApplyLinearScalingParameter.ActualValue.Value);
[6256]147
148      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = null;
149      EstimationLimitsParameter.ExecutionContext = null;
[9209]150      ApplyLinearScalingParameter.ExecutionContext = null;
[6256]151
152      return r2;
153    }
154
[8823]155    #region derivations of functions
[8730]156    // create function factory for arctangent
157    private readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
[8823]158      eval: Math.Atan,
159      diff: x => 1 / (1 + x * x));
[8730]160    private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
[8823]161      eval: Math.Sin,
162      diff: Math.Cos);
[8730]163    private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
[8823]164       eval: Math.Cos,
165       diff: x => -Math.Sin(x));
[8730]166    private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
[8823]167      eval: Math.Tan,
168      diff: x => 1 + Math.Tan(x) * Math.Tan(x));
[8730]169    private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
[8823]170      eval: alglib.errorfunction,
171      diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
[8730]172    private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
[8823]173      eval: alglib.normaldistribution,
174      diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
175    #endregion
[8730]176
177
[13670]178    public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows, bool applyLinearScaling, int maxIterations, bool updateVariableWeights = true, double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue, bool updateConstantsInTree = true) {
[8704]179
180      List<AutoDiff.Variable> variables = new List<AutoDiff.Variable>();
181      List<AutoDiff.Variable> parameters = new List<AutoDiff.Variable>();
182      List<string> variableNames = new List<string>();
183
184      AutoDiff.Term func;
[13670]185      if (!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out func))
[8828]186        throw new NotSupportedException("Could not optimize constants of symbolic expression tree due to not supported symbols used in the tree.");
[8704]187      if (variableNames.Count == 0) return 0.0;
188
[13670]189      AutoDiff.IParametricCompiledTerm compiledFunc = func.Compile(variables.ToArray(), parameters.ToArray());
[8704]190
[13670]191      List<SymbolicExpressionTreeTerminalNode> terminalNodes = null;
192      if (updateVariableWeights)
193        terminalNodes = tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>().ToList();
194      else
195        terminalNodes = new List<SymbolicExpressionTreeTerminalNode>(tree.Root.IterateNodesPrefix().OfType<ConstantTreeNode>());
196
197      //extract inital constants
[8704]198      double[] c = new double[variables.Count];
199      {
200        c[0] = 0.0;
201        c[1] = 1.0;
202        int i = 2;
203        foreach (var node in terminalNodes) {
204          ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
205          VariableTreeNode variableTreeNode = node as VariableTreeNode;
206          if (constantTreeNode != null)
207            c[i++] = constantTreeNode.Value;
[13670]208          else if (updateVariableWeights && variableTreeNode != null)
[8704]209            c[i++] = variableTreeNode.Weight;
210        }
[6256]211      }
[8938]212      double[] originalConstants = (double[])c.Clone();
213      double originalQuality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
[6256]214
[8704]215      alglib.lsfitstate state;
216      alglib.lsfitreport rep;
217      int info;
[6256]218
[12509]219      IDataset ds = problemData.Dataset;
[8704]220      double[,] x = new double[rows.Count(), variableNames.Count];
221      int row = 0;
222      foreach (var r in rows) {
223        for (int col = 0; col < variableNames.Count; col++) {
224          x[row, col] = ds.GetDoubleValue(variableNames[col], r);
225        }
226        row++;
227      }
228      double[] y = ds.GetDoubleValues(problemData.TargetVariable, rows).ToArray();
229      int n = x.GetLength(0);
230      int m = x.GetLength(1);
231      int k = c.Length;
[6256]232
[8704]233      alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(compiledFunc);
234      alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(compiledFunc);
[6256]235
[8704]236      try {
237        alglib.lsfitcreatefg(x, y, c, n, m, k, false, out state);
[8938]238        alglib.lsfitsetcond(state, 0.0, 0.0, maxIterations);
239        //alglib.lsfitsetgradientcheck(state, 0.001);
[8704]240        alglib.lsfitfit(state, function_cx_1_func, function_cx_1_grad, null, null);
241        alglib.lsfitresults(state, out info, out c, out rep);
[6256]242      }
[8730]243      catch (ArithmeticException) {
[8984]244        return originalQuality;
[8730]245      }
[8704]246      catch (alglib.alglibexception) {
[8984]247        return originalQuality;
[8704]248      }
[8823]249
[8938]250      //info == -7  => constant optimization failed due to wrong gradient
[13670]251      if (info != -7) UpdateConstants(tree, c.Skip(2).ToArray(), updateVariableWeights);
[8938]252      var quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
253
[13670]254      if (!updateConstantsInTree) UpdateConstants(tree, originalConstants.Skip(2).ToArray(), updateVariableWeights);
[8938]255      if (originalQuality - quality > 0.001 || double.IsNaN(quality)) {
[13670]256        UpdateConstants(tree, originalConstants.Skip(2).ToArray(), updateVariableWeights);
[8938]257        return originalQuality;
[8704]258      }
[8938]259      return quality;
[6256]260    }
261
[13670]262    private static void UpdateConstants(ISymbolicExpressionTree tree, double[] constants, bool updateVariableWeights) {
[8938]263      int i = 0;
264      foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
265        ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
266        VariableTreeNode variableTreeNode = node as VariableTreeNode;
267        if (constantTreeNode != null)
268          constantTreeNode.Value = constants[i++];
[13670]269        else if (updateVariableWeights && variableTreeNode != null)
[8938]270          variableTreeNode.Weight = constants[i++];
271      }
272    }
273
[8704]274    private static alglib.ndimensional_pfunc CreatePFunc(AutoDiff.IParametricCompiledTerm compiledFunc) {
275      return (double[] c, double[] x, ref double func, object o) => {
276        func = compiledFunc.Evaluate(c, x);
277      };
278    }
[6256]279
[8704]280    private static alglib.ndimensional_pgrad CreatePGrad(AutoDiff.IParametricCompiledTerm compiledFunc) {
281      return (double[] c, double[] x, ref double func, double[] grad, object o) => {
282        var tupel = compiledFunc.Differentiate(c, x);
283        func = tupel.Item2;
284        Array.Copy(tupel.Item1, grad, grad.Length);
[6256]285      };
286    }
287
[13670]288    private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, List<AutoDiff.Variable> variables, List<AutoDiff.Variable> parameters, List<string> variableNames, bool updateVariableWeights, out AutoDiff.Term term) {
[8704]289      if (node.Symbol is Constant) {
290        var var = new AutoDiff.Variable();
291        variables.Add(var);
292        term = var;
293        return true;
294      }
295      if (node.Symbol is Variable) {
296        var varNode = node as VariableTreeNode;
297        var par = new AutoDiff.Variable();
298        parameters.Add(par);
299        variableNames.Add(varNode.VariableName);
[13670]300
301        if (updateVariableWeights) {
302          var w = new AutoDiff.Variable();
303          variables.Add(w);
304          term = AutoDiff.TermBuilder.Product(w, par);
305        } else {
306          term = par;
307        }
[8704]308        return true;
309      }
310      if (node.Symbol is Addition) {
311        List<AutoDiff.Term> terms = new List<Term>();
312        foreach (var subTree in node.Subtrees) {
313          AutoDiff.Term t;
[13670]314          if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, updateVariableWeights, out t)) {
[8704]315            term = null;
316            return false;
317          }
318          terms.Add(t);
319        }
320        term = AutoDiff.TermBuilder.Sum(terms);
321        return true;
322      }
[8823]323      if (node.Symbol is Subtraction) {
324        List<AutoDiff.Term> terms = new List<Term>();
325        for (int i = 0; i < node.SubtreeCount; i++) {
326          AutoDiff.Term t;
[13670]327          if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, variableNames, updateVariableWeights, out t)) {
[8823]328            term = null;
329            return false;
330          }
331          if (i > 0) t = -t;
332          terms.Add(t);
333        }
334        term = AutoDiff.TermBuilder.Sum(terms);
335        return true;
336      }
[8704]337      if (node.Symbol is Multiplication) {
338        AutoDiff.Term a, b;
[13670]339        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out a) ||
340          !TryTransformToAutoDiff(node.GetSubtree(1), variables, parameters, variableNames, updateVariableWeights, out b)) {
[8704]341          term = null;
342          return false;
343        } else {
344          List<AutoDiff.Term> factors = new List<Term>();
345          foreach (var subTree in node.Subtrees.Skip(2)) {
346            AutoDiff.Term f;
[13670]347            if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, updateVariableWeights, out f)) {
[8704]348              term = null;
349              return false;
350            }
351            factors.Add(f);
352          }
353          term = AutoDiff.TermBuilder.Product(a, b, factors.ToArray());
354          return true;
355        }
356      }
357      if (node.Symbol is Division) {
358        // only works for at least two subtrees
359        AutoDiff.Term a, b;
[13670]360        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out a) ||
361          !TryTransformToAutoDiff(node.GetSubtree(1), variables, parameters, variableNames, updateVariableWeights, out b)) {
[8704]362          term = null;
363          return false;
364        } else {
365          List<AutoDiff.Term> factors = new List<Term>();
366          foreach (var subTree in node.Subtrees.Skip(2)) {
367            AutoDiff.Term f;
[13670]368            if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, updateVariableWeights, out f)) {
[8704]369              term = null;
370              return false;
371            }
372            factors.Add(1.0 / f);
373          }
374          term = AutoDiff.TermBuilder.Product(a, 1.0 / b, factors.ToArray());
375          return true;
376        }
377      }
378      if (node.Symbol is Logarithm) {
379        AutoDiff.Term t;
[13670]380        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) {
[8704]381          term = null;
382          return false;
383        } else {
384          term = AutoDiff.TermBuilder.Log(t);
385          return true;
386        }
387      }
388      if (node.Symbol is Exponential) {
389        AutoDiff.Term t;
[13670]390        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) {
[8704]391          term = null;
392          return false;
393        } else {
394          term = AutoDiff.TermBuilder.Exp(t);
395          return true;
396        }
[11680]397      }
398      if (node.Symbol is Square) {
[8730]399        AutoDiff.Term t;
[13670]400        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) {
[8730]401          term = null;
402          return false;
403        } else {
[11680]404          term = AutoDiff.TermBuilder.Power(t, 2.0);
[8730]405          return true;
406        }
[11680]407      } if (node.Symbol is SquareRoot) {
[8730]408        AutoDiff.Term t;
[13670]409        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) {
[8730]410          term = null;
411          return false;
412        } else {
[11680]413          term = AutoDiff.TermBuilder.Power(t, 0.5);
[8730]414          return true;
415        }
[11680]416      } if (node.Symbol is Sine) {
[8730]417        AutoDiff.Term t;
[13670]418        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) {
[8730]419          term = null;
420          return false;
421        } else {
[11680]422          term = sin(t);
[8730]423          return true;
424        }
[11680]425      } if (node.Symbol is Cosine) {
[8730]426        AutoDiff.Term t;
[13670]427        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) {
[8730]428          term = null;
429          return false;
430        } else {
[11680]431          term = cos(t);
[8730]432          return true;
433        }
[11680]434      } if (node.Symbol is Tangent) {
435        AutoDiff.Term t;
[13670]436        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) {
[11680]437          term = null;
438          return false;
439        } else {
440          term = tan(t);
441          return true;
442        }
[8730]443      } if (node.Symbol is Erf) {
444        AutoDiff.Term t;
[13670]445        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) {
[8730]446          term = null;
447          return false;
448        } else {
449          term = erf(t);
450          return true;
451        }
452      } if (node.Symbol is Norm) {
453        AutoDiff.Term t;
[13670]454        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out t)) {
[8730]455          term = null;
456          return false;
457        } else {
458          term = norm(t);
459          return true;
460        }
461      }
[8704]462      if (node.Symbol is StartSymbol) {
463        var alpha = new AutoDiff.Variable();
464        var beta = new AutoDiff.Variable();
465        variables.Add(beta);
466        variables.Add(alpha);
467        AutoDiff.Term branchTerm;
[13670]468        if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, updateVariableWeights, out branchTerm)) {
[8704]469          term = branchTerm * alpha + beta;
470          return true;
471        } else {
472          term = null;
473          return false;
474        }
475      }
476      term = null;
477      return false;
478    }
[8730]479
480    public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) {
481      var containsUnknownSymbol = (
482        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
483        where
484         !(n.Symbol is Variable) &&
485         !(n.Symbol is Constant) &&
486         !(n.Symbol is Addition) &&
487         !(n.Symbol is Subtraction) &&
488         !(n.Symbol is Multiplication) &&
489         !(n.Symbol is Division) &&
490         !(n.Symbol is Logarithm) &&
491         !(n.Symbol is Exponential) &&
[11680]492         !(n.Symbol is SquareRoot) &&
493         !(n.Symbol is Square) &&
[8730]494         !(n.Symbol is Sine) &&
495         !(n.Symbol is Cosine) &&
496         !(n.Symbol is Tangent) &&
497         !(n.Symbol is Erf) &&
498         !(n.Symbol is Norm) &&
499         !(n.Symbol is StartSymbol)
500        select n).
501      Any();
502      return !containsUnknownSymbol;
503    }
[6256]504  }
505}
Note: See TracBrowser for help on using the repository browser.