Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
06/28/18 11:13:37 (7 years ago)
Author:
gkronber
Message:

#2522: merged trunk changes from r13402:15972 to branch resolving conflicts where necessary

Location:
branches/2522_RefactorPluginInfrastructure
Files:
4 edited

Legend:

Unmodified
Added
Removed
  • branches/2522_RefactorPluginInfrastructure

  • branches/2522_RefactorPluginInfrastructure/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression

  • branches/2522_RefactorPluginInfrastructure/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4

  • branches/2522_RefactorPluginInfrastructure/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/SymbolicRegressionConstantOptimizationEvaluator.cs

    r13300 r15973  
    11#region License Information
    22/* HeuristicLab
    3  * Copyright (C) 2002-2015 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    44 *
    55 * This file is part of HeuristicLab.
     
    2323using System.Collections.Generic;
    2424using System.Linq;
    25 using AutoDiff;
    2625using HeuristicLab.Common;
    2726using HeuristicLab.Core;
    2827using HeuristicLab.Data;
    2928using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
     29using HeuristicLab.Optimization;
    3030using HeuristicLab.Parameters;
    3131using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
     
    4040    private const string ConstantOptimizationRowsPercentageParameterName = "ConstantOptimizationRowsPercentage";
    4141    private const string UpdateConstantsInTreeParameterName = "UpdateConstantsInSymbolicExpressionTree";
     42    private const string UpdateVariableWeightsParameterName = "Update Variable Weights";
     43
     44    private const string FunctionEvaluationsResultParameterName = "Constants Optimization Function Evaluations";
     45    private const string GradientEvaluationsResultParameterName = "Constants Optimization Gradient Evaluations";
     46    private const string CountEvaluationsParameterName = "Count Function and Gradient Evaluations";
    4247
    4348    public IFixedValueParameter<IntValue> ConstantOptimizationIterationsParameter {
     
    5661      get { return (IFixedValueParameter<BoolValue>)Parameters[UpdateConstantsInTreeParameterName]; }
    5762    }
     63    public IFixedValueParameter<BoolValue> UpdateVariableWeightsParameter {
     64      get { return (IFixedValueParameter<BoolValue>)Parameters[UpdateVariableWeightsParameterName]; }
     65    }
     66
     67    public IResultParameter<IntValue> FunctionEvaluationsResultParameter {
     68      get { return (IResultParameter<IntValue>)Parameters[FunctionEvaluationsResultParameterName]; }
     69    }
     70    public IResultParameter<IntValue> GradientEvaluationsResultParameter {
     71      get { return (IResultParameter<IntValue>)Parameters[GradientEvaluationsResultParameterName]; }
     72    }
     73    public IFixedValueParameter<BoolValue> CountEvaluationsParameter {
     74      get { return (IFixedValueParameter<BoolValue>)Parameters[CountEvaluationsParameterName]; }
     75    }
     76
    5877
    5978    public IntValue ConstantOptimizationIterations {
     
    7291      get { return UpdateConstantsInTreeParameter.Value.Value; }
    7392      set { UpdateConstantsInTreeParameter.Value.Value = value; }
     93    }
     94
     95    public bool UpdateVariableWeights {
     96      get { return UpdateVariableWeightsParameter.Value.Value; }
     97      set { UpdateVariableWeightsParameter.Value.Value = value; }
     98    }
     99
     100    public bool CountEvaluations {
     101      get { return CountEvaluationsParameter.Value.Value; }
     102      set { CountEvaluationsParameter.Value.Value = value; }
    74103    }
    75104
     
    86115      : base() {
    87116      Parameters.Add(new FixedValueParameter<IntValue>(ConstantOptimizationIterationsParameterName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree (0 indicates other or default stopping criterion).", new IntValue(10), true));
    88       Parameters.Add(new FixedValueParameter<DoubleValue>(ConstantOptimizationImprovementParameterName, "Determines the relative improvement which must be achieved in the constant optimization to continue with it (0 indicates other or default stopping criterion).", new DoubleValue(0), true));
     117      Parameters.Add(new FixedValueParameter<DoubleValue>(ConstantOptimizationImprovementParameterName, "Determines the relative improvement which must be achieved in the constant optimization to continue with it (0 indicates other or default stopping criterion).", new DoubleValue(0), true) { Hidden = true });
    89118      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationProbabilityParameterName, "Determines the probability that the constants are optimized", new PercentValue(1), true));
    90119      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationRowsPercentageParameterName, "Determines the percentage of the rows which should be used for constant optimization", new PercentValue(1), true));
    91       Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)));
     120      Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)) { Hidden = true });
     121      Parameters.Add(new FixedValueParameter<BoolValue>(UpdateVariableWeightsParameterName, "Determines if the variable weights in the tree should be  optimized.", new BoolValue(true)) { Hidden = true });
     122
     123      Parameters.Add(new FixedValueParameter<BoolValue>(CountEvaluationsParameterName, "Determines if function and gradient evaluation should be counted.", new BoolValue(false)));
     124      Parameters.Add(new ResultParameter<IntValue>(FunctionEvaluationsResultParameterName, "The number of function evaluations performed by the constants optimization evaluator", "Results", new IntValue()));
     125      Parameters.Add(new ResultParameter<IntValue>(GradientEvaluationsResultParameterName, "The number of gradient evaluations performed by the constants optimization evaluator", "Results", new IntValue()));
    92126    }
    93127
     
    100134      if (!Parameters.ContainsKey(UpdateConstantsInTreeParameterName))
    101135        Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)));
    102     }
    103 
     136      if (!Parameters.ContainsKey(UpdateVariableWeightsParameterName))
     137        Parameters.Add(new FixedValueParameter<BoolValue>(UpdateVariableWeightsParameterName, "Determines if the variable weights in the tree should be  optimized.", new BoolValue(true)));
     138
     139      if (!Parameters.ContainsKey(CountEvaluationsParameterName))
     140        Parameters.Add(new FixedValueParameter<BoolValue>(CountEvaluationsParameterName, "Determines if function and gradient evaluation should be counted.", new BoolValue(false)));
     141
     142      if (!Parameters.ContainsKey(FunctionEvaluationsResultParameterName))
     143        Parameters.Add(new ResultParameter<IntValue>(FunctionEvaluationsResultParameterName, "The number of function evaluations performed by the constants optimization evaluator", "Results", new IntValue()));
     144      if (!Parameters.ContainsKey(GradientEvaluationsResultParameterName))
     145        Parameters.Add(new ResultParameter<IntValue>(GradientEvaluationsResultParameterName, "The number of gradient evaluations performed by the constants optimization evaluator", "Results", new IntValue()));
     146    }
     147
     148    private static readonly object locker = new object();
    104149    public override IOperation InstrumentedApply() {
    105150      var solution = SymbolicExpressionTreeParameter.ActualValue;
     
    107152      if (RandomParameter.ActualValue.NextDouble() < ConstantOptimizationProbability.Value) {
    108153        IEnumerable<int> constantOptimizationRows = GenerateRowsToEvaluate(ConstantOptimizationRowsPercentage.Value);
     154        var counter = new EvaluationsCounter();
    109155        quality = OptimizeConstants(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, ProblemDataParameter.ActualValue,
    110            constantOptimizationRows, ApplyLinearScalingParameter.ActualValue.Value, ConstantOptimizationIterations.Value,
    111            EstimationLimitsParameter.ActualValue.Upper, EstimationLimitsParameter.ActualValue.Lower, UpdateConstantsInTree);
     156           constantOptimizationRows, ApplyLinearScalingParameter.ActualValue.Value, ConstantOptimizationIterations.Value, updateVariableWeights: UpdateVariableWeights, lowerEstimationLimit: EstimationLimitsParameter.ActualValue.Lower, upperEstimationLimit: EstimationLimitsParameter.ActualValue.Upper, updateConstantsInTree: UpdateConstantsInTree, counter: counter);
    112157
    113158        if (ConstantOptimizationRowsPercentage.Value != RelativeNumberOfEvaluatedSamplesParameter.ActualValue.Value) {
     
    115160          quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value);
    116161        }
     162
     163        if (CountEvaluations) {
     164          lock (locker) {
     165            FunctionEvaluationsResultParameter.ActualValue.Value += counter.FunctionEvaluations;
     166            GradientEvaluationsResultParameter.ActualValue.Value += counter.GradientEvaluations;
     167          }
     168        }
     169
    117170      } else {
    118171        var evaluationRows = GenerateRowsToEvaluate();
     
    128181      EstimationLimitsParameter.ExecutionContext = context;
    129182      ApplyLinearScalingParameter.ExecutionContext = context;
     183      FunctionEvaluationsResultParameter.ExecutionContext = context;
     184      GradientEvaluationsResultParameter.ExecutionContext = context;
    130185
    131186      // Pearson R² evaluator is used on purpose instead of the const-opt evaluator,
     
    137192      EstimationLimitsParameter.ExecutionContext = null;
    138193      ApplyLinearScalingParameter.ExecutionContext = null;
     194      FunctionEvaluationsResultParameter.ExecutionContext = null;
     195      GradientEvaluationsResultParameter.ExecutionContext = null;
    139196
    140197      return r2;
    141198    }
    142199
    143     #region derivations of functions
    144     // create function factory for arctangent
    145     private readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
    146       eval: Math.Atan,
    147       diff: x => 1 / (1 + x * x));
    148     private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
    149       eval: Math.Sin,
    150       diff: Math.Cos);
    151     private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
    152        eval: Math.Cos,
    153        diff: x => -Math.Sin(x));
    154     private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
    155       eval: Math.Tan,
    156       diff: x => 1 + Math.Tan(x) * Math.Tan(x));
    157     private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
    158       eval: alglib.errorfunction,
    159       diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
    160     private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
    161       eval: alglib.normaldistribution,
    162       diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
    163     #endregion
    164 
    165 
    166     // TODO: swap positions of lowerEstimationLimit and upperEstimationLimit parameters
    167     public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, ISymbolicExpressionTree tree, IRegressionProblemData problemData,
    168       IEnumerable<int> rows, bool applyLinearScaling, int maxIterations, double upperEstimationLimit = double.MaxValue, double lowerEstimationLimit = double.MinValue, bool updateConstantsInTree = true) {
    169 
    170       List<AutoDiff.Variable> variables = new List<AutoDiff.Variable>();
    171       List<AutoDiff.Variable> parameters = new List<AutoDiff.Variable>();
    172       List<string> variableNames = new List<string>();
    173 
    174       AutoDiff.Term func;
    175       if (!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, variableNames, out func))
     200    public class EvaluationsCounter {
     201      public int FunctionEvaluations = 0;
     202      public int GradientEvaluations = 0;
     203    }
     204
     205    public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
     206      ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows, bool applyLinearScaling,
     207      int maxIterations, bool updateVariableWeights = true,
     208      double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
     209      bool updateConstantsInTree = true, Action<double[], double, object> iterationCallback = null, EvaluationsCounter counter = null) {
     210
     211      // numeric constants in the tree become variables for constant opt
     212      // variables in the tree become parameters (fixed values) for constant opt
     213      // for each parameter (variable in the original tree) we store the
     214      // variable name, variable value (for factor vars) and lag as a DataForVariable object.
     215      // A dictionary is used to find parameters
     216      double[] initialConstants;
     217      var parameters = new List<TreeToAutoDiffTermConverter.DataForVariable>();
     218
     219      TreeToAutoDiffTermConverter.ParametricFunction func;
     220      TreeToAutoDiffTermConverter.ParametricFunctionGradient func_grad;
     221      if (!TreeToAutoDiffTermConverter.TryConvertToAutoDiff(tree, updateVariableWeights, applyLinearScaling, out parameters, out initialConstants, out func, out func_grad))
    176222        throw new NotSupportedException("Could not optimize constants of symbolic expression tree due to not supported symbols used in the tree.");
    177       if (variableNames.Count == 0) return 0.0;
    178 
    179       AutoDiff.IParametricCompiledTerm compiledFunc = AutoDiff.TermUtils.Compile(func, variables.ToArray(), parameters.ToArray());
    180 
    181       List<SymbolicExpressionTreeTerminalNode> terminalNodes = tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>().ToList();
    182       double[] c = new double[variables.Count];
    183 
    184       {
     223      if (parameters.Count == 0) return 0.0; // gkronber: constant expressions always have a R² of 0.0
     224      var parameterEntries = parameters.ToArray(); // order of entries must be the same for x
     225
     226      //extract inital constants
     227      double[] c;
     228      if (applyLinearScaling) {
     229        c = new double[initialConstants.Length + 2];
    185230        c[0] = 0.0;
    186231        c[1] = 1.0;
    187         //extract inital constants
    188         int i = 2;
    189         foreach (var node in terminalNodes) {
    190           ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
    191           VariableTreeNode variableTreeNode = node as VariableTreeNode;
    192           if (constantTreeNode != null)
    193             c[i++] = constantTreeNode.Value;
    194           else if (variableTreeNode != null)
    195             c[i++] = variableTreeNode.Weight;
    196         }
    197       }
    198       double[] originalConstants = (double[])c.Clone();
     232        Array.Copy(initialConstants, 0, c, 2, initialConstants.Length);
     233      } else {
     234        c = (double[])initialConstants.Clone();
     235      }
     236
    199237      double originalQuality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
     238
     239      if (counter == null) counter = new EvaluationsCounter();
     240      var rowEvaluationsCounter = new EvaluationsCounter();
    200241
    201242      alglib.lsfitstate state;
    202243      alglib.lsfitreport rep;
    203       int info;
     244      int retVal;
    204245
    205246      IDataset ds = problemData.Dataset;
    206       double[,] x = new double[rows.Count(), variableNames.Count];
     247      double[,] x = new double[rows.Count(), parameters.Count];
    207248      int row = 0;
    208249      foreach (var r in rows) {
    209         for (int col = 0; col < variableNames.Count; col++) {
    210           x[row, col] = ds.GetDoubleValue(variableNames[col], r);
     250        int col = 0;
     251        foreach (var info in parameterEntries) {
     252          if (ds.VariableHasType<double>(info.variableName)) {
     253            x[row, col] = ds.GetDoubleValue(info.variableName, r + info.lag);
     254          } else if (ds.VariableHasType<string>(info.variableName)) {
     255            x[row, col] = ds.GetStringValue(info.variableName, r) == info.variableValue ? 1 : 0;
     256          } else throw new InvalidProgramException("found a variable of unknown type");
     257          col++;
    211258        }
    212259        row++;
     
    217264      int k = c.Length;
    218265
    219       alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(compiledFunc);
    220       alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(compiledFunc);
     266      alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(func);
     267      alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(func_grad);
     268      alglib.ndimensional_rep xrep = (p, f, obj) => iterationCallback(p, f, obj);
    221269
    222270      try {
    223271        alglib.lsfitcreatefg(x, y, c, n, m, k, false, out state);
    224272        alglib.lsfitsetcond(state, 0.0, 0.0, maxIterations);
     273        alglib.lsfitsetxrep(state, iterationCallback != null);
    225274        //alglib.lsfitsetgradientcheck(state, 0.001);
    226         alglib.lsfitfit(state, function_cx_1_func, function_cx_1_grad, null, null);
    227         alglib.lsfitresults(state, out info, out c, out rep);
    228       }
    229       catch (ArithmeticException) {
     275        alglib.lsfitfit(state, function_cx_1_func, function_cx_1_grad, xrep, rowEvaluationsCounter);
     276        alglib.lsfitresults(state, out retVal, out c, out rep);
     277      } catch (ArithmeticException) {
    230278        return originalQuality;
    231       }
    232       catch (alglib.alglibexception) {
     279      } catch (alglib.alglibexception) {
    233280        return originalQuality;
    234281      }
    235282
    236       //info == -7  => constant optimization failed due to wrong gradient
    237       if (info != -7) UpdateConstants(tree, c.Skip(2).ToArray());
     283      counter.FunctionEvaluations += rowEvaluationsCounter.FunctionEvaluations / n;
     284      counter.GradientEvaluations += rowEvaluationsCounter.GradientEvaluations / n;
     285
     286      //retVal == -7  => constant optimization failed due to wrong gradient
     287      if (retVal != -7) {
     288        if (applyLinearScaling) {
     289          var tmp = new double[c.Length - 2];
     290          Array.Copy(c, 2, tmp, 0, tmp.Length);
     291          UpdateConstants(tree, tmp, updateVariableWeights);
     292        } else UpdateConstants(tree, c, updateVariableWeights);
     293      }
    238294      var quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
    239295
    240       if (!updateConstantsInTree) UpdateConstants(tree, originalConstants.Skip(2).ToArray());
     296      if (!updateConstantsInTree) UpdateConstants(tree, initialConstants, updateVariableWeights);
     297
    241298      if (originalQuality - quality > 0.001 || double.IsNaN(quality)) {
    242         UpdateConstants(tree, originalConstants.Skip(2).ToArray());
     299        UpdateConstants(tree, initialConstants, updateVariableWeights);
    243300        return originalQuality;
    244301      }
     
    246303    }
    247304
    248     private static void UpdateConstants(ISymbolicExpressionTree tree, double[] constants) {
     305    private static void UpdateConstants(ISymbolicExpressionTree tree, double[] constants, bool updateVariableWeights) {
    249306      int i = 0;
    250307      foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
    251308        ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
    252         VariableTreeNode variableTreeNode = node as VariableTreeNode;
     309        VariableTreeNodeBase variableTreeNodeBase = node as VariableTreeNodeBase;
     310        FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode;
    253311        if (constantTreeNode != null)
    254312          constantTreeNode.Value = constants[i++];
    255         else if (variableTreeNode != null)
    256           variableTreeNode.Weight = constants[i++];
    257       }
    258     }
    259 
    260     private static alglib.ndimensional_pfunc CreatePFunc(AutoDiff.IParametricCompiledTerm compiledFunc) {
    261       return (double[] c, double[] x, ref double func, object o) => {
    262         func = compiledFunc.Evaluate(c, x);
     313        else if (updateVariableWeights && variableTreeNodeBase != null)
     314          variableTreeNodeBase.Weight = constants[i++];
     315        else if (factorVarTreeNode != null) {
     316          for (int j = 0; j < factorVarTreeNode.Weights.Length; j++)
     317            factorVarTreeNode.Weights[j] = constants[i++];
     318        }
     319      }
     320    }
     321
     322    private static alglib.ndimensional_pfunc CreatePFunc(TreeToAutoDiffTermConverter.ParametricFunction func) {
     323      return (double[] c, double[] x, ref double fx, object o) => {
     324        fx = func(c, x);
     325        var counter = (EvaluationsCounter)o;
     326        counter.FunctionEvaluations++;
    263327      };
    264328    }
    265329
    266     private static alglib.ndimensional_pgrad CreatePGrad(AutoDiff.IParametricCompiledTerm compiledFunc) {
    267       return (double[] c, double[] x, ref double func, double[] grad, object o) => {
    268         var tupel = compiledFunc.Differentiate(c, x);
    269         func = tupel.Item2;
    270         Array.Copy(tupel.Item1, grad, grad.Length);
     330    private static alglib.ndimensional_pgrad CreatePGrad(TreeToAutoDiffTermConverter.ParametricFunctionGradient func_grad) {
     331      return (double[] c, double[] x, ref double fx, double[] grad, object o) => {
     332        var tuple = func_grad(c, x);
     333        fx = tuple.Item2;
     334        Array.Copy(tuple.Item1, grad, grad.Length);
     335        var counter = (EvaluationsCounter)o;
     336        counter.GradientEvaluations++;
    271337      };
    272338    }
    273 
    274     private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, List<AutoDiff.Variable> variables, List<AutoDiff.Variable> parameters, List<string> variableNames, out AutoDiff.Term term) {
    275       if (node.Symbol is Constant) {
    276         var var = new AutoDiff.Variable();
    277         variables.Add(var);
    278         term = var;
    279         return true;
    280       }
    281       if (node.Symbol is Variable) {
    282         var varNode = node as VariableTreeNode;
    283         var par = new AutoDiff.Variable();
    284         parameters.Add(par);
    285         variableNames.Add(varNode.VariableName);
    286         var w = new AutoDiff.Variable();
    287         variables.Add(w);
    288         term = AutoDiff.TermBuilder.Product(w, par);
    289         return true;
    290       }
    291       if (node.Symbol is Addition) {
    292         List<AutoDiff.Term> terms = new List<Term>();
    293         foreach (var subTree in node.Subtrees) {
    294           AutoDiff.Term t;
    295           if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out t)) {
    296             term = null;
    297             return false;
    298           }
    299           terms.Add(t);
    300         }
    301         term = AutoDiff.TermBuilder.Sum(terms);
    302         return true;
    303       }
    304       if (node.Symbol is Subtraction) {
    305         List<AutoDiff.Term> terms = new List<Term>();
    306         for (int i = 0; i < node.SubtreeCount; i++) {
    307           AutoDiff.Term t;
    308           if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, variableNames, out t)) {
    309             term = null;
    310             return false;
    311           }
    312           if (i > 0) t = -t;
    313           terms.Add(t);
    314         }
    315         term = AutoDiff.TermBuilder.Sum(terms);
    316         return true;
    317       }
    318       if (node.Symbol is Multiplication) {
    319         AutoDiff.Term a, b;
    320         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out a) ||
    321           !TryTransformToAutoDiff(node.GetSubtree(1), variables, parameters, variableNames, out b)) {
    322           term = null;
    323           return false;
    324         } else {
    325           List<AutoDiff.Term> factors = new List<Term>();
    326           foreach (var subTree in node.Subtrees.Skip(2)) {
    327             AutoDiff.Term f;
    328             if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out f)) {
    329               term = null;
    330               return false;
    331             }
    332             factors.Add(f);
    333           }
    334           term = AutoDiff.TermBuilder.Product(a, b, factors.ToArray());
    335           return true;
    336         }
    337       }
    338       if (node.Symbol is Division) {
    339         // only works for at least two subtrees
    340         AutoDiff.Term a, b;
    341         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out a) ||
    342           !TryTransformToAutoDiff(node.GetSubtree(1), variables, parameters, variableNames, out b)) {
    343           term = null;
    344           return false;
    345         } else {
    346           List<AutoDiff.Term> factors = new List<Term>();
    347           foreach (var subTree in node.Subtrees.Skip(2)) {
    348             AutoDiff.Term f;
    349             if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out f)) {
    350               term = null;
    351               return false;
    352             }
    353             factors.Add(1.0 / f);
    354           }
    355           term = AutoDiff.TermBuilder.Product(a, 1.0 / b, factors.ToArray());
    356           return true;
    357         }
    358       }
    359       if (node.Symbol is Logarithm) {
    360         AutoDiff.Term t;
    361         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
    362           term = null;
    363           return false;
    364         } else {
    365           term = AutoDiff.TermBuilder.Log(t);
    366           return true;
    367         }
    368       }
    369       if (node.Symbol is Exponential) {
    370         AutoDiff.Term t;
    371         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
    372           term = null;
    373           return false;
    374         } else {
    375           term = AutoDiff.TermBuilder.Exp(t);
    376           return true;
    377         }
    378       }
    379       if (node.Symbol is Square) {
    380         AutoDiff.Term t;
    381         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
    382           term = null;
    383           return false;
    384         } else {
    385           term = AutoDiff.TermBuilder.Power(t, 2.0);
    386           return true;
    387         }
    388       } if (node.Symbol is SquareRoot) {
    389         AutoDiff.Term t;
    390         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
    391           term = null;
    392           return false;
    393         } else {
    394           term = AutoDiff.TermBuilder.Power(t, 0.5);
    395           return true;
    396         }
    397       } if (node.Symbol is Sine) {
    398         AutoDiff.Term t;
    399         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
    400           term = null;
    401           return false;
    402         } else {
    403           term = sin(t);
    404           return true;
    405         }
    406       } if (node.Symbol is Cosine) {
    407         AutoDiff.Term t;
    408         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
    409           term = null;
    410           return false;
    411         } else {
    412           term = cos(t);
    413           return true;
    414         }
    415       } if (node.Symbol is Tangent) {
    416         AutoDiff.Term t;
    417         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
    418           term = null;
    419           return false;
    420         } else {
    421           term = tan(t);
    422           return true;
    423         }
    424       } if (node.Symbol is Erf) {
    425         AutoDiff.Term t;
    426         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
    427           term = null;
    428           return false;
    429         } else {
    430           term = erf(t);
    431           return true;
    432         }
    433       } if (node.Symbol is Norm) {
    434         AutoDiff.Term t;
    435         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
    436           term = null;
    437           return false;
    438         } else {
    439           term = norm(t);
    440           return true;
    441         }
    442       }
    443       if (node.Symbol is StartSymbol) {
    444         var alpha = new AutoDiff.Variable();
    445         var beta = new AutoDiff.Variable();
    446         variables.Add(beta);
    447         variables.Add(alpha);
    448         AutoDiff.Term branchTerm;
    449         if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out branchTerm)) {
    450           term = branchTerm * alpha + beta;
    451           return true;
    452         } else {
    453           term = null;
    454           return false;
    455         }
    456       }
    457       term = null;
    458       return false;
    459     }
    460 
    461339    public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) {
    462       var containsUnknownSymbol = (
    463         from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
    464         where
    465          !(n.Symbol is Variable) &&
    466          !(n.Symbol is Constant) &&
    467          !(n.Symbol is Addition) &&
    468          !(n.Symbol is Subtraction) &&
    469          !(n.Symbol is Multiplication) &&
    470          !(n.Symbol is Division) &&
    471          !(n.Symbol is Logarithm) &&
    472          !(n.Symbol is Exponential) &&
    473          !(n.Symbol is SquareRoot) &&
    474          !(n.Symbol is Square) &&
    475          !(n.Symbol is Sine) &&
    476          !(n.Symbol is Cosine) &&
    477          !(n.Symbol is Tangent) &&
    478          !(n.Symbol is Erf) &&
    479          !(n.Symbol is Norm) &&
    480          !(n.Symbol is StartSymbol)
    481         select n).
    482       Any();
    483       return !containsUnknownSymbol;
     340      return TreeToAutoDiffTermConverter.IsCompatible(tree);
    484341    }
    485342  }
Note: See TracChangeset for help on using the changeset viewer.