Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
11/15/12 16:47:25 (12 years ago)
Author:
mkommend
Message:

#1763: merged changes from trunk into the tree simplifier branch.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/HeuristicLab.TreeSimplifier/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/SymbolicRegressionConstantOptimizationEvaluator.cs

    r8053 r8915  
    2020#endregion
    2121
     22using System;
    2223using System.Collections.Generic;
    2324using System.Linq;
     25using AutoDiff;
    2426using HeuristicLab.Common;
    2527using HeuristicLab.Core;
     
    3739    private const string ConstantOptimizationProbabilityParameterName = "ConstantOptimizationProbability";
    3840    private const string ConstantOptimizationRowsPercentageParameterName = "ConstantOptimizationRowsPercentage";
     41    private const string UpdateConstantsInTreeParameterName = "UpdateConstantsInSymbolicExpressionTree";
    3942
    4043    private const string EvaluatedTreesResultName = "EvaluatedTrees";
     
    6063      get { return (IFixedValueParameter<PercentValue>)Parameters[ConstantOptimizationRowsPercentageParameterName]; }
    6164    }
     65    public IFixedValueParameter<BoolValue> UpdateConstantsInTreeParameter {
     66      get { return (IFixedValueParameter<BoolValue>)Parameters[UpdateConstantsInTreeParameterName]; }
     67    }
    6268
    6369    public IntValue ConstantOptimizationIterations {
     
    7278    public PercentValue ConstantOptimizationRowsPercentage {
    7379      get { return ConstantOptimizationRowsPercentageParameter.Value; }
     80    }
     81    public bool UpdateConstantsInTree {
     82      get { return UpdateConstantsInTreeParameter.Value.Value; }
     83      set { UpdateConstantsInTreeParameter.Value.Value = value; }
    7484    }
    7585
     
    8999      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationProbabilityParameterName, "Determines the probability that the constants are optimized", new PercentValue(1), true));
    90100      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationRowsPercentageParameterName, "Determines the percentage of the rows which should be used for constant optimization", new PercentValue(1), true));
     101      Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)));
    91102
    92103      Parameters.Add(new LookupParameter<IntValue>(EvaluatedTreesResultName));
     
    98109    }
    99110
     111    [StorableHook(HookType.AfterDeserialization)]
     112    private void AfterDeserialization() {
     113      if (!Parameters.ContainsKey(UpdateConstantsInTreeParameterName))
     114        Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)));
     115    }
     116
    100117    public override IOperation Apply() {
    101118      AddResults();
    102       int seed = RandomParameter.ActualValue.Next();
    103119      var solution = SymbolicExpressionTreeParameter.ActualValue;
    104120      double quality;
     
    106122        IEnumerable<int> constantOptimizationRows = GenerateRowsToEvaluate(ConstantOptimizationRowsPercentage.Value);
    107123        quality = OptimizeConstants(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, ProblemDataParameter.ActualValue,
    108            constantOptimizationRows, ConstantOptimizationImprovement.Value, ConstantOptimizationIterations.Value, 0.001,
    109            EstimationLimitsParameter.ActualValue.Upper, EstimationLimitsParameter.ActualValue.Lower,
     124           constantOptimizationRows, ApplyLinearScalingParameter.ActualValue.Value, ConstantOptimizationIterations.Value,
     125           EstimationLimitsParameter.ActualValue.Upper, EstimationLimitsParameter.ActualValue.Lower, UpdateConstantsInTree,
    110126          EvaluatedTreesParameter.ActualValue, EvaluatedTreeNodesParameter.ActualValue);
    111127        if (ConstantOptimizationRowsPercentage.Value != RelativeNumberOfEvaluatedSamplesParameter.ActualValue.Value) {
    112128          var evaluationRows = GenerateRowsToEvaluate();
    113           quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows);
     129          quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value);
    114130        }
    115131      } else {
    116132        var evaluationRows = GenerateRowsToEvaluate();
    117         quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows);
     133        quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value);
    118134      }
    119135      QualityParameter.ActualValue = new DoubleValue(quality);
    120       EvaluatedTreesParameter.ActualValue.Value += 1;
    121       EvaluatedTreeNodesParameter.ActualValue.Value += solution.Length;
     136      lock (locker) {
     137        EvaluatedTreesParameter.ActualValue.Value += 1;
     138        EvaluatedTreeNodesParameter.ActualValue.Value += solution.Length;
     139      }
    122140
    123141      if (Successor != null)
     
    127145    }
    128146
     147    private object locker = new object();
    129148    private void AddResults() {
    130       if (EvaluatedTreesParameter.ActualValue == null) {
    131         var scope = ExecutionContext.Scope;
    132         while (scope.Parent != null)
    133           scope = scope.Parent;
    134         scope.Variables.Add(new Core.Variable(EvaluatedTreesResultName, new IntValue()));
    135       }
    136       if (EvaluatedTreeNodesParameter.ActualValue == null) {
    137         var scope = ExecutionContext.Scope;
    138         while (scope.Parent != null)
    139           scope = scope.Parent;
    140         scope.Variables.Add(new Core.Variable(EvaluatedTreeNodesResultName, new IntValue()));
     149      lock (locker) {
     150        if (EvaluatedTreesParameter.ActualValue == null) {
     151          var scope = ExecutionContext.Scope;
     152          while (scope.Parent != null)
     153            scope = scope.Parent;
     154          scope.Variables.Add(new Core.Variable(EvaluatedTreesResultName, new IntValue()));
     155        }
     156        if (EvaluatedTreeNodesParameter.ActualValue == null) {
     157          var scope = ExecutionContext.Scope;
     158          while (scope.Parent != null)
     159            scope = scope.Parent;
     160          scope.Variables.Add(new Core.Variable(EvaluatedTreeNodesResultName, new IntValue()));
     161        }
    141162      }
    142163    }
     
    145166      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = context;
    146167      EstimationLimitsParameter.ExecutionContext = context;
    147 
    148       double r2 = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, tree, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, problemData, rows);
     168      ApplyLinearScalingParameter.ExecutionContext = context;
     169
     170      double r2 = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, tree, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, problemData, rows, ApplyLinearScalingParameter.ActualValue.Value);
    149171
    150172      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = null;
    151173      EstimationLimitsParameter.ExecutionContext = null;
     174      ApplyLinearScalingParameter.ExecutionContext = context;
    152175
    153176      return r2;
    154177    }
    155178
     179    #region derivations of functions
     180    // create function factory for arctangent
     181    private readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
     182      eval: Math.Atan,
     183      diff: x => 1 / (1 + x * x));
     184    private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
     185      eval: Math.Sin,
     186      diff: Math.Cos);
     187    private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
     188       eval: Math.Cos,
     189       diff: x => -Math.Sin(x));
     190    private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
     191      eval: Math.Tan,
     192      diff: x => 1 + Math.Tan(x) * Math.Tan(x));
     193    private static readonly Func<Term, UnaryFunc> square = UnaryFunc.Factory(
     194       eval: x => x * x,
     195       diff: x => 2 * x);
     196    private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
     197      eval: alglib.errorfunction,
     198      diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
     199    private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
     200      eval: alglib.normaldistribution,
     201      diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
     202    #endregion
     203
     204
    156205    public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, ISymbolicExpressionTree tree, IRegressionProblemData problemData,
    157       IEnumerable<int> rows, double improvement, int iterations, double differentialStep, double upperEstimationLimit = double.MaxValue, double lowerEstimationLimit = double.MinValue, IntValue evaluatedTrees = null, IntValue evaluatedTreeNodes = null) {
     206      IEnumerable<int> rows, bool applyLinearScaling, int maxIterations, double upperEstimationLimit = double.MaxValue, double lowerEstimationLimit = double.MinValue, bool updateConstantsInTree = true, IntValue evaluatedTrees = null, IntValue evaluatedTreeNodes = null) {
     207
     208      List<AutoDiff.Variable> variables = new List<AutoDiff.Variable>();
     209      List<AutoDiff.Variable> parameters = new List<AutoDiff.Variable>();
     210      List<string> variableNames = new List<string>();
     211
     212      AutoDiff.Term func;
     213      if (!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, variableNames, out func))
     214        throw new NotSupportedException("Could not optimize constants of symbolic expression tree due to not supported symbols used in the tree.");
     215      if (variableNames.Count == 0) return 0.0;
     216
     217      AutoDiff.IParametricCompiledTerm compiledFunc = AutoDiff.TermUtils.Compile(func, variables.ToArray(), parameters.ToArray());
     218
    158219      List<SymbolicExpressionTreeTerminalNode> terminalNodes = tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>().ToList();
    159       double[] c = new double[terminalNodes.Count];
    160       int treeLength = tree.Length;
    161 
    162       //extract inital constants
    163       for (int i = 0; i < terminalNodes.Count; i++) {
    164         ConstantTreeNode constantTreeNode = terminalNodes[i] as ConstantTreeNode;
    165         if (constantTreeNode != null) c[i] = constantTreeNode.Value;
    166         VariableTreeNode variableTreeNode = terminalNodes[i] as VariableTreeNode;
    167         if (variableTreeNode != null) c[i] = variableTreeNode.Weight;
    168       }
    169 
    170       double epsg = 0;
    171       double epsf = improvement;
    172       double epsx = 0;
    173       int maxits = iterations;
    174       double diffstep = differentialStep;
    175 
    176       alglib.minlmstate state;
    177       alglib.minlmreport report;
    178 
    179       alglib.minlmcreatev(1, c, diffstep, out state);
    180       alglib.minlmsetcond(state, epsg, epsf, epsx, maxits);
    181       alglib.minlmoptimize(state, CreateCallBack(interpreter, tree, problemData, rows, upperEstimationLimit, lowerEstimationLimit, treeLength, evaluatedTrees, evaluatedTreeNodes), null, terminalNodes);
    182       alglib.minlmresults(state, out c, out report);
    183 
    184       for (int i = 0; i < c.Length; i++) {
    185         ConstantTreeNode constantTreeNode = terminalNodes[i] as ConstantTreeNode;
    186         if (constantTreeNode != null) constantTreeNode.Value = c[i];
    187         VariableTreeNode variableTreeNode = terminalNodes[i] as VariableTreeNode;
    188         if (variableTreeNode != null) variableTreeNode.Weight = c[i];
    189       }
    190 
    191       return (state.fi[0] - 1) * -1;
    192     }
    193 
    194     private static alglib.ndimensional_fvec CreateCallBack(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows, double upperEstimationLimit, double lowerEstimationLimit, int treeLength, IntValue evaluatedTrees = null, IntValue evaluatedTreeNodes = null) {
    195       return (double[] arg, double[] fi, object obj) => {
    196         // update constants of tree
    197         List<SymbolicExpressionTreeTerminalNode> terminalNodes = (List<SymbolicExpressionTreeTerminalNode>)obj;
    198         for (int i = 0; i < terminalNodes.Count; i++) {
    199           ConstantTreeNode constantTreeNode = terminalNodes[i] as ConstantTreeNode;
    200           if (constantTreeNode != null) constantTreeNode.Value = arg[i];
    201           VariableTreeNode variableTreeNode = terminalNodes[i] as VariableTreeNode;
    202           if (variableTreeNode != null) variableTreeNode.Weight = arg[i];
    203         }
    204 
    205         double quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows);
    206 
    207         fi[0] = 1 - quality;
    208         if (evaluatedTrees != null) evaluatedTrees.Value++;
    209         if (evaluatedTreeNodes != null) evaluatedTreeNodes.Value += treeLength;
     220      double[] c = new double[variables.Count];
     221
     222      {
     223        c[0] = 0.0;
     224        c[1] = 1.0;
     225        //extract inital constants
     226        int i = 2;
     227        foreach (var node in terminalNodes) {
     228          ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
     229          VariableTreeNode variableTreeNode = node as VariableTreeNode;
     230          if (constantTreeNode != null)
     231            c[i++] = constantTreeNode.Value;
     232          else if (variableTreeNode != null)
     233            c[i++] = variableTreeNode.Weight;
     234        }
     235      }
     236
     237      alglib.lsfitstate state;
     238      alglib.lsfitreport rep;
     239      int info;
     240
     241      Dataset ds = problemData.Dataset;
     242      double[,] x = new double[rows.Count(), variableNames.Count];
     243      int row = 0;
     244      foreach (var r in rows) {
     245        for (int col = 0; col < variableNames.Count; col++) {
     246          x[row, col] = ds.GetDoubleValue(variableNames[col], r);
     247        }
     248        row++;
     249      }
     250      double[] y = ds.GetDoubleValues(problemData.TargetVariable, rows).ToArray();
     251      int n = x.GetLength(0);
     252      int m = x.GetLength(1);
     253      int k = c.Length;
     254
     255      alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(compiledFunc);
     256      alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(compiledFunc);
     257
     258      try {
     259        alglib.lsfitcreatefg(x, y, c, n, m, k, false, out state);
     260        alglib.lsfitsetcond(state, 0, 0, maxIterations);
     261        alglib.lsfitfit(state, function_cx_1_func, function_cx_1_grad, null, null);
     262        alglib.lsfitresults(state, out info, out c, out rep);
     263
     264      }
     265      catch (ArithmeticException) {
     266        return 0.0;
     267      }
     268      catch (alglib.alglibexception) {
     269        return 0.0;
     270      }
     271      var newTree = tree;
     272      if (!updateConstantsInTree) newTree = (ISymbolicExpressionTree)tree.Clone();
     273      {
     274        // only when no error occurred
     275        // set constants in tree
     276        int i = 2;
     277        foreach (var node in newTree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
     278          ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
     279          VariableTreeNode variableTreeNode = node as VariableTreeNode;
     280          if (constantTreeNode != null)
     281            constantTreeNode.Value = c[i++];
     282          else if (variableTreeNode != null)
     283            variableTreeNode.Weight = c[i++];
     284        }
     285
     286      }
     287      return SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, newTree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
     288    }
     289
     290    private static alglib.ndimensional_pfunc CreatePFunc(AutoDiff.IParametricCompiledTerm compiledFunc) {
     291      return (double[] c, double[] x, ref double func, object o) => {
     292        func = compiledFunc.Evaluate(c, x);
    210293      };
    211294    }
    212295
     296    private static alglib.ndimensional_pgrad CreatePGrad(AutoDiff.IParametricCompiledTerm compiledFunc) {
     297      return (double[] c, double[] x, ref double func, double[] grad, object o) => {
     298        var tupel = compiledFunc.Differentiate(c, x);
     299        func = tupel.Item2;
     300        Array.Copy(tupel.Item1, grad, grad.Length);
     301      };
     302    }
     303
     304    private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, List<AutoDiff.Variable> variables, List<AutoDiff.Variable> parameters, List<string> variableNames, out AutoDiff.Term term) {
     305      if (node.Symbol is Constant) {
     306        var var = new AutoDiff.Variable();
     307        variables.Add(var);
     308        term = var;
     309        return true;
     310      }
     311      if (node.Symbol is Variable) {
     312        var varNode = node as VariableTreeNode;
     313        var par = new AutoDiff.Variable();
     314        parameters.Add(par);
     315        variableNames.Add(varNode.VariableName);
     316        var w = new AutoDiff.Variable();
     317        variables.Add(w);
     318        term = AutoDiff.TermBuilder.Product(w, par);
     319        return true;
     320      }
     321      if (node.Symbol is Addition) {
     322        List<AutoDiff.Term> terms = new List<Term>();
     323        foreach (var subTree in node.Subtrees) {
     324          AutoDiff.Term t;
     325          if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out t)) {
     326            term = null;
     327            return false;
     328          }
     329          terms.Add(t);
     330        }
     331        term = AutoDiff.TermBuilder.Sum(terms);
     332        return true;
     333      }
     334      if (node.Symbol is Subtraction) {
     335        List<AutoDiff.Term> terms = new List<Term>();
     336        for (int i = 0; i < node.SubtreeCount; i++) {
     337          AutoDiff.Term t;
     338          if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, variableNames, out t)) {
     339            term = null;
     340            return false;
     341          }
     342          if (i > 0) t = -t;
     343          terms.Add(t);
     344        }
     345        term = AutoDiff.TermBuilder.Sum(terms);
     346        return true;
     347      }
     348      if (node.Symbol is Multiplication) {
     349        AutoDiff.Term a, b;
     350        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out a) ||
     351          !TryTransformToAutoDiff(node.GetSubtree(1), variables, parameters, variableNames, out b)) {
     352          term = null;
     353          return false;
     354        } else {
     355          List<AutoDiff.Term> factors = new List<Term>();
     356          foreach (var subTree in node.Subtrees.Skip(2)) {
     357            AutoDiff.Term f;
     358            if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out f)) {
     359              term = null;
     360              return false;
     361            }
     362            factors.Add(f);
     363          }
     364          term = AutoDiff.TermBuilder.Product(a, b, factors.ToArray());
     365          return true;
     366        }
     367      }
     368      if (node.Symbol is Division) {
     369        // only works for at least two subtrees
     370        AutoDiff.Term a, b;
     371        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out a) ||
     372          !TryTransformToAutoDiff(node.GetSubtree(1), variables, parameters, variableNames, out b)) {
     373          term = null;
     374          return false;
     375        } else {
     376          List<AutoDiff.Term> factors = new List<Term>();
     377          foreach (var subTree in node.Subtrees.Skip(2)) {
     378            AutoDiff.Term f;
     379            if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out f)) {
     380              term = null;
     381              return false;
     382            }
     383            factors.Add(1.0 / f);
     384          }
     385          term = AutoDiff.TermBuilder.Product(a, 1.0 / b, factors.ToArray());
     386          return true;
     387        }
     388      }
     389      if (node.Symbol is Logarithm) {
     390        AutoDiff.Term t;
     391        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
     392          term = null;
     393          return false;
     394        } else {
     395          term = AutoDiff.TermBuilder.Log(t);
     396          return true;
     397        }
     398      }
     399      if (node.Symbol is Exponential) {
     400        AutoDiff.Term t;
     401        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
     402          term = null;
     403          return false;
     404        } else {
     405          term = AutoDiff.TermBuilder.Exp(t);
     406          return true;
     407        }
     408      } if (node.Symbol is Sine) {
     409        AutoDiff.Term t;
     410        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
     411          term = null;
     412          return false;
     413        } else {
     414          term = sin(t);
     415          return true;
     416        }
     417      } if (node.Symbol is Cosine) {
     418        AutoDiff.Term t;
     419        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
     420          term = null;
     421          return false;
     422        } else {
     423          term = cos(t);
     424          return true;
     425        }
     426      } if (node.Symbol is Tangent) {
     427        AutoDiff.Term t;
     428        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
     429          term = null;
     430          return false;
     431        } else {
     432          term = tan(t);
     433          return true;
     434        }
     435      }
     436      if (node.Symbol is Square) {
     437        AutoDiff.Term t;
     438        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
     439          term = null;
     440          return false;
     441        } else {
     442          term = square(t);
     443          return true;
     444        }
     445      } if (node.Symbol is Erf) {
     446        AutoDiff.Term t;
     447        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
     448          term = null;
     449          return false;
     450        } else {
     451          term = erf(t);
     452          return true;
     453        }
     454      } if (node.Symbol is Norm) {
     455        AutoDiff.Term t;
     456        if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
     457          term = null;
     458          return false;
     459        } else {
     460          term = norm(t);
     461          return true;
     462        }
     463      }
     464      if (node.Symbol is StartSymbol) {
     465        var alpha = new AutoDiff.Variable();
     466        var beta = new AutoDiff.Variable();
     467        variables.Add(beta);
     468        variables.Add(alpha);
     469        AutoDiff.Term branchTerm;
     470        if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out branchTerm)) {
     471          term = branchTerm * alpha + beta;
     472          return true;
     473        } else {
     474          term = null;
     475          return false;
     476        }
     477      }
     478      term = null;
     479      return false;
     480    }
     481
     482    public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) {
     483      var containsUnknownSymbol = (
     484        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
     485        where
     486         !(n.Symbol is Variable) &&
     487         !(n.Symbol is Constant) &&
     488         !(n.Symbol is Addition) &&
     489         !(n.Symbol is Subtraction) &&
     490         !(n.Symbol is Multiplication) &&
     491         !(n.Symbol is Division) &&
     492         !(n.Symbol is Logarithm) &&
     493         !(n.Symbol is Exponential) &&
     494         !(n.Symbol is Sine) &&
     495         !(n.Symbol is Cosine) &&
     496         !(n.Symbol is Tangent) &&
     497         !(n.Symbol is Square) &&
     498         !(n.Symbol is Erf) &&
     499         !(n.Symbol is Norm) &&
     500         !(n.Symbol is StartSymbol)
     501        select n).
     502      Any();
     503      return !containsUnknownSymbol;
     504    }
    213505  }
    214506}
Note: See TracChangeset for help on using the changeset viewer.