Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/19/20 19:07:40 (4 years ago)
Author:
fbaching
Message:

#1837: merged changes from trunk

  • apply changes from Attic release to all SlidingWindow specific code files (replace StorableClass with StorableType)
File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/1837_Sliding Window GP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/SymbolicRegressionConstantOptimizationEvaluator.cs

    r10291 r17687  
    11#region License Information
    22/* HeuristicLab
    3  * Copyright (C) 2002-2013 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    44 *
    55 * This file is part of HeuristicLab.
     
    2323using System.Collections.Generic;
    2424using System.Linq;
    25 using AutoDiff;
     25using HEAL.Attic;
    2626using HeuristicLab.Common;
    2727using HeuristicLab.Core;
    2828using HeuristicLab.Data;
    2929using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
     30using HeuristicLab.Optimization;
    3031using HeuristicLab.Parameters;
    31 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3232
    3333namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
    3434  [Item("Constant Optimization Evaluator", "Calculates Pearson R² of a symbolic regression solution and optimizes the constant used.")]
    35   [StorableClass]
     35  [StorableType("24B68851-036D-4446-BD6F-3823E9028FF4")]
    3636  public class SymbolicRegressionConstantOptimizationEvaluator : SymbolicRegressionSingleObjectiveEvaluator {
    3737    private const string ConstantOptimizationIterationsParameterName = "ConstantOptimizationIterations";
     
    4040    private const string ConstantOptimizationRowsPercentageParameterName = "ConstantOptimizationRowsPercentage";
    4141    private const string UpdateConstantsInTreeParameterName = "UpdateConstantsInSymbolicExpressionTree";
     42    private const string UpdateVariableWeightsParameterName = "Update Variable Weights";
     43
     44    private const string FunctionEvaluationsResultParameterName = "Constants Optimization Function Evaluations";
     45    private const string GradientEvaluationsResultParameterName = "Constants Optimization Gradient Evaluations";
     46    private const string CountEvaluationsParameterName = "Count Function and Gradient Evaluations";
    4247
    4348    public IFixedValueParameter<IntValue> ConstantOptimizationIterationsParameter {
     
    5661      get { return (IFixedValueParameter<BoolValue>)Parameters[UpdateConstantsInTreeParameterName]; }
    5762    }
     63    public IFixedValueParameter<BoolValue> UpdateVariableWeightsParameter {
     64      get { return (IFixedValueParameter<BoolValue>)Parameters[UpdateVariableWeightsParameterName]; }
     65    }
     66
     67    public IResultParameter<IntValue> FunctionEvaluationsResultParameter {
     68      get { return (IResultParameter<IntValue>)Parameters[FunctionEvaluationsResultParameterName]; }
     69    }
     70    public IResultParameter<IntValue> GradientEvaluationsResultParameter {
     71      get { return (IResultParameter<IntValue>)Parameters[GradientEvaluationsResultParameterName]; }
     72    }
     73    public IFixedValueParameter<BoolValue> CountEvaluationsParameter {
     74      get { return (IFixedValueParameter<BoolValue>)Parameters[CountEvaluationsParameterName]; }
     75    }
     76
    5877
    5978    public IntValue ConstantOptimizationIterations {
     
    7493    }
    7594
     95    public bool UpdateVariableWeights {
     96      get { return UpdateVariableWeightsParameter.Value.Value; }
     97      set { UpdateVariableWeightsParameter.Value.Value = value; }
     98    }
     99
     100    public bool CountEvaluations {
     101      get { return CountEvaluationsParameter.Value.Value; }
     102      set { CountEvaluationsParameter.Value.Value = value; }
     103    }
     104
    76105    public override bool Maximization {
    77106      get { return true; }
     
    79108
    80109    [StorableConstructor]
    81     protected SymbolicRegressionConstantOptimizationEvaluator(bool deserializing) : base(deserializing) { }
     110    protected SymbolicRegressionConstantOptimizationEvaluator(StorableConstructorFlag _) : base(_) { }
    82111    protected SymbolicRegressionConstantOptimizationEvaluator(SymbolicRegressionConstantOptimizationEvaluator original, Cloner cloner)
    83112      : base(original, cloner) {
     
    85114    public SymbolicRegressionConstantOptimizationEvaluator()
    86115      : base() {
    87       Parameters.Add(new FixedValueParameter<IntValue>(ConstantOptimizationIterationsParameterName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree (0 indicates other or default stopping criterion).", new IntValue(10), true));
    88       Parameters.Add(new FixedValueParameter<DoubleValue>(ConstantOptimizationImprovementParameterName, "Determines the relative improvement which must be achieved in the constant optimization to continue with it (0 indicates other or default stopping criterion).", new DoubleValue(0), true));
    89       Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationProbabilityParameterName, "Determines the probability that the constants are optimized", new PercentValue(1), true));
    90       Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationRowsPercentageParameterName, "Determines the percentage of the rows which should be used for constant optimization", new PercentValue(1), true));
    91       Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)));
     116      Parameters.Add(new FixedValueParameter<IntValue>(ConstantOptimizationIterationsParameterName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree (0 indicates other or default stopping criterion).", new IntValue(10)));
     117      Parameters.Add(new FixedValueParameter<DoubleValue>(ConstantOptimizationImprovementParameterName, "Determines the relative improvement which must be achieved in the constant optimization to continue with it (0 indicates other or default stopping criterion).", new DoubleValue(0)) { Hidden = true });
     118      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationProbabilityParameterName, "Determines the probability that the constants are optimized", new PercentValue(1)));
     119      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationRowsPercentageParameterName, "Determines the percentage of the rows which should be used for constant optimization", new PercentValue(1)));
     120      Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)) { Hidden = true });
     121      Parameters.Add(new FixedValueParameter<BoolValue>(UpdateVariableWeightsParameterName, "Determines if the variable weights in the tree should be  optimized.", new BoolValue(true)) { Hidden = true });
     122
     123      Parameters.Add(new FixedValueParameter<BoolValue>(CountEvaluationsParameterName, "Determines if function and gradient evaluation should be counted.", new BoolValue(false)));
     124      Parameters.Add(new ResultParameter<IntValue>(FunctionEvaluationsResultParameterName, "The number of function evaluations performed by the constants optimization evaluator", "Results", new IntValue()));
     125      Parameters.Add(new ResultParameter<IntValue>(GradientEvaluationsResultParameterName, "The number of gradient evaluations performed by the constants optimization evaluator", "Results", new IntValue()));
    92126    }
    93127
     
    100134      if (!Parameters.ContainsKey(UpdateConstantsInTreeParameterName))
    101135        Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)));
    102     }
    103 
     136      if (!Parameters.ContainsKey(UpdateVariableWeightsParameterName))
     137        Parameters.Add(new FixedValueParameter<BoolValue>(UpdateVariableWeightsParameterName, "Determines if the variable weights in the tree should be  optimized.", new BoolValue(true)));
     138
     139      if (!Parameters.ContainsKey(CountEvaluationsParameterName))
     140        Parameters.Add(new FixedValueParameter<BoolValue>(CountEvaluationsParameterName, "Determines if function and gradient evaluation should be counted.", new BoolValue(false)));
     141
     142      if (!Parameters.ContainsKey(FunctionEvaluationsResultParameterName))
     143        Parameters.Add(new ResultParameter<IntValue>(FunctionEvaluationsResultParameterName, "The number of function evaluations performed by the constants optimization evaluator", "Results", new IntValue()));
     144      if (!Parameters.ContainsKey(GradientEvaluationsResultParameterName))
     145        Parameters.Add(new ResultParameter<IntValue>(GradientEvaluationsResultParameterName, "The number of gradient evaluations performed by the constants optimization evaluator", "Results", new IntValue()));
     146    }
     147
     148    private static readonly object locker = new object();
    104149    public override IOperation InstrumentedApply() {
    105150      var solution = SymbolicExpressionTreeParameter.ActualValue;
     
    107152      if (RandomParameter.ActualValue.NextDouble() < ConstantOptimizationProbability.Value) {
    108153        IEnumerable<int> constantOptimizationRows = GenerateRowsToEvaluate(ConstantOptimizationRowsPercentage.Value);
     154        var counter = new EvaluationsCounter();
    109155        quality = OptimizeConstants(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, ProblemDataParameter.ActualValue,
    110            constantOptimizationRows, ApplyLinearScalingParameter.ActualValue.Value, ConstantOptimizationIterations.Value,
    111            EstimationLimitsParameter.ActualValue.Upper, EstimationLimitsParameter.ActualValue.Lower, UpdateConstantsInTree);
     156           constantOptimizationRows, ApplyLinearScalingParameter.ActualValue.Value, ConstantOptimizationIterations.Value, updateVariableWeights: UpdateVariableWeights, lowerEstimationLimit: EstimationLimitsParameter.ActualValue.Lower, upperEstimationLimit: EstimationLimitsParameter.ActualValue.Upper, updateConstantsInTree: UpdateConstantsInTree, counter: counter);
    112157
    113158        if (ConstantOptimizationRowsPercentage.Value != RelativeNumberOfEvaluatedSamplesParameter.ActualValue.Value) {
     
    115160          quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value);
    116161        }
     162
     163        if (CountEvaluations) {
     164          lock (locker) {
     165            FunctionEvaluationsResultParameter.ActualValue.Value += counter.FunctionEvaluations;
     166            GradientEvaluationsResultParameter.ActualValue.Value += counter.GradientEvaluations;
     167          }
     168        }
     169
    117170      } else {
    118171        var evaluationRows = GenerateRowsToEvaluate();
     
    128181      EstimationLimitsParameter.ExecutionContext = context;
    129182      ApplyLinearScalingParameter.ExecutionContext = context;
     183      FunctionEvaluationsResultParameter.ExecutionContext = context;
     184      GradientEvaluationsResultParameter.ExecutionContext = context;
    130185
    131186      // Pearson R² evaluator is used on purpose instead of the const-opt evaluator,
     
    137192      EstimationLimitsParameter.ExecutionContext = null;
    138193      ApplyLinearScalingParameter.ExecutionContext = null;
     194      FunctionEvaluationsResultParameter.ExecutionContext = null;
     195      GradientEvaluationsResultParameter.ExecutionContext = null;
    139196
    140197      return r2;
    141198    }
    142199
    143     #region derivations of functions
    144     // create function factory for arctangent
    145     private readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
    146       eval: Math.Atan,
    147       diff: x => 1 / (1 + x * x));
    148     private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
    149       eval: Math.Sin,
    150       diff: Math.Cos);
    151     private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
    152        eval: Math.Cos,
    153        diff: x => -Math.Sin(x));
    154     private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
    155       eval: Math.Tan,
    156       diff: x => 1 + Math.Tan(x) * Math.Tan(x));
    157     private static readonly Func<Term, UnaryFunc> square = UnaryFunc.Factory(
    158        eval: x => x * x,
    159        diff: x => 2 * x);
    160     private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
    161       eval: alglib.errorfunction,
    162       diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
    163     private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
    164       eval: alglib.normaldistribution,
    165       diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
    166     #endregion
    167 
    168 
    169     public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, ISymbolicExpressionTree tree, IRegressionProblemData problemData,
    170       IEnumerable<int> rows, bool applyLinearScaling, int maxIterations, double upperEstimationLimit = double.MaxValue, double lowerEstimationLimit = double.MinValue, bool updateConstantsInTree = true) {
    171 
    172       List<AutoDiff.Variable> variables = new List<AutoDiff.Variable>();
    173       List<AutoDiff.Variable> parameters = new List<AutoDiff.Variable>();
    174       List<string> variableNames = new List<string>();
    175 
    176       AutoDiff.Term func;
    177       if (!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, variableNames, out func))
     200    public class EvaluationsCounter {
     201      public int FunctionEvaluations = 0;
     202      public int GradientEvaluations = 0;
     203    }
     204
     205    public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
     206      ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows, bool applyLinearScaling,
     207      int maxIterations, bool updateVariableWeights = true,
     208      double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
     209      bool updateConstantsInTree = true, Action<double[], double, object> iterationCallback = null, EvaluationsCounter counter = null) {
     210
     211      // numeric constants in the tree become variables for constant opt
     212      // variables in the tree become parameters (fixed values) for constant opt
     213      // for each parameter (variable in the original tree) we store the
     214      // variable name, variable value (for factor vars) and lag as a DataForVariable object.
     215      // A dictionary is used to find parameters
     216      double[] initialConstants;
     217      var parameters = new List<TreeToAutoDiffTermConverter.DataForVariable>();
     218
     219      TreeToAutoDiffTermConverter.ParametricFunction func;
     220      TreeToAutoDiffTermConverter.ParametricFunctionGradient func_grad;
     221      if (!TreeToAutoDiffTermConverter.TryConvertToAutoDiff(tree, updateVariableWeights, applyLinearScaling, out parameters, out initialConstants, out func, out func_grad))
    178222        throw new NotSupportedException("Could not optimize constants of symbolic expression tree due to not supported symbols used in the tree.");
    179       if (variableNames.Count == 0) return 0.0;
    180 
    181       AutoDiff.IParametricCompiledTerm compiledFunc = AutoDiff.TermUtils.Compile(func, variables.ToArray(), parameters.ToArray());
    182 
    183       List<SymbolicExpressionTreeTerminalNode> terminalNodes = tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>().ToList();
    184       double[] c = new double[variables.Count];
    185 
    186       {
     223      if (parameters.Count == 0) return 0.0; // gkronber: constant expressions always have a R² of 0.0
     224      var parameterEntries = parameters.ToArray(); // order of entries must be the same for x
     225
     226      //extract inital constants
     227      double[] c;
     228      if (applyLinearScaling) {
     229        c = new double[initialConstants.Length + 2];
    187230        c[0] = 0.0;
    188231        c[1] = 1.0;
    189         //extract inital constants
    190         int i = 2;
    191         foreach (var node in terminalNodes) {
    192           ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
    193           VariableTreeNode variableTreeNode = node as VariableTreeNode;
    194           if (constantTreeNode != null)
    195             c[i++] = constantTreeNode.Value;
    196           else if (variableTreeNode != null)
    197             c[i++] = variableTreeNode.Weight;
    198         }
    199       }
    200       double[] originalConstants = (double[])c.Clone();
     232        Array.Copy(initialConstants, 0, c, 2, initialConstants.Length);
     233      } else {
     234        c = (double[])initialConstants.Clone();
     235      }
     236
    201237      double originalQuality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
     238
     239      if (counter == null) counter = new EvaluationsCounter();
     240      var rowEvaluationsCounter = new EvaluationsCounter();
    202241
    203242      alglib.lsfitstate state;
    204243      alglib.lsfitreport rep;
    205       int info;
    206 
    207       Dataset ds = problemData.Dataset;
    208       double[,] x = new double[rows.Count(), variableNames.Count];
     244      int retVal;
     245
     246      IDataset ds = problemData.Dataset;
     247      double[,] x = new double[rows.Count(), parameters.Count];
    209248      int row = 0;
    210249      foreach (var r in rows) {
    211         for (int col = 0; col < variableNames.Count; col++) {
    212           x[row, col] = ds.GetDoubleValue(variableNames[col], r);
     250        int col = 0;
     251        foreach (var info in parameterEntries) {
     252          if (ds.VariableHasType<double>(info.variableName)) {
     253            x[row, col] = ds.GetDoubleValue(info.variableName, r + info.lag);
     254          } else if (ds.VariableHasType<string>(info.variableName)) {
     255            x[row, col] = ds.GetStringValue(info.variableName, r) == info.variableValue ? 1 : 0;
     256          } else throw new InvalidProgramException("found a variable of unknown type");
     257          col++;
    213258        }
    214259        row++;
     
    219264      int k = c.Length;
    220265
    221       alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(compiledFunc);
    222       alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(compiledFunc);
     266      alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(func);
     267      alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(func_grad);
     268      alglib.ndimensional_rep xrep = (p, f, obj) => iterationCallback(p, f, obj);
    223269
    224270      try {
    225271        alglib.lsfitcreatefg(x, y, c, n, m, k, false, out state);
    226272        alglib.lsfitsetcond(state, 0.0, 0.0, maxIterations);
     273        alglib.lsfitsetxrep(state, iterationCallback != null);
    227274        //alglib.lsfitsetgradientcheck(state, 0.001);
    228         alglib.lsfitfit(state, function_cx_1_func, function_cx_1_grad, null, null);
    229         alglib.lsfitresults(state, out info, out c, out rep);
    230       }
    231       catch (ArithmeticException) {
     275        alglib.lsfitfit(state, function_cx_1_func, function_cx_1_grad, xrep, rowEvaluationsCounter);
     276        alglib.lsfitresults(state, out retVal, out c, out rep);
     277      } catch (ArithmeticException) {
    232278        return originalQuality;
    233       }
    234       catch (alglib.alglibexception) {
     279      } catch (alglib.alglibexception) {
    235280        return originalQuality;
    236281      }
    237282
    238       //info == -7  => constant optimization failed due to wrong gradient
    239       if (info != -7) UpdateConstants(tree, c.Skip(2).ToArray());
     283      counter.FunctionEvaluations += rowEvaluationsCounter.FunctionEvaluations / n;
     284      counter.GradientEvaluations += rowEvaluationsCounter.GradientEvaluations / n;
     285
     286      //retVal == -7  => constant optimization failed due to wrong gradient
     287      if (retVal != -7) {
     288        if (applyLinearScaling) {
     289          var tmp = new double[c.Length - 2];
     290          Array.Copy(c, 2, tmp, 0, tmp.Length);
     291          UpdateConstants(tree, tmp, updateVariableWeights);
     292        } else UpdateConstants(tree, c, updateVariableWeights);
     293      }
    240294      var quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling);
    241295
    242       if (!updateConstantsInTree) UpdateConstants(tree, originalConstants.Skip(2).ToArray());
     296      if (!updateConstantsInTree) UpdateConstants(tree, initialConstants, updateVariableWeights);
     297
    243298      if (originalQuality - quality > 0.001 || double.IsNaN(quality)) {
    244         UpdateConstants(tree, originalConstants.Skip(2).ToArray());
     299        UpdateConstants(tree, initialConstants, updateVariableWeights);
    245300        return originalQuality;
    246301      }
     
    248303    }
    249304
    250     private static void UpdateConstants(ISymbolicExpressionTree tree, double[] constants) {
     305    private static void UpdateConstants(ISymbolicExpressionTree tree, double[] constants, bool updateVariableWeights) {
    251306      int i = 0;
    252307      foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) {
    253308        ConstantTreeNode constantTreeNode = node as ConstantTreeNode;
    254         VariableTreeNode variableTreeNode = node as VariableTreeNode;
     309        VariableTreeNodeBase variableTreeNodeBase = node as VariableTreeNodeBase;
     310        FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode;
    255311        if (constantTreeNode != null)
    256312          constantTreeNode.Value = constants[i++];
    257         else if (variableTreeNode != null)
    258           variableTreeNode.Weight = constants[i++];
    259       }
    260     }
    261 
    262     private static alglib.ndimensional_pfunc CreatePFunc(AutoDiff.IParametricCompiledTerm compiledFunc) {
    263       return (double[] c, double[] x, ref double func, object o) => {
    264         func = compiledFunc.Evaluate(c, x);
     313        else if (updateVariableWeights && variableTreeNodeBase != null)
     314          variableTreeNodeBase.Weight = constants[i++];
     315        else if (factorVarTreeNode != null) {
     316          for (int j = 0; j < factorVarTreeNode.Weights.Length; j++)
     317            factorVarTreeNode.Weights[j] = constants[i++];
     318        }
     319      }
     320    }
     321
     322    private static alglib.ndimensional_pfunc CreatePFunc(TreeToAutoDiffTermConverter.ParametricFunction func) {
     323      return (double[] c, double[] x, ref double fx, object o) => {
     324        fx = func(c, x);
     325        var counter = (EvaluationsCounter)o;
     326        counter.FunctionEvaluations++;
    265327      };
    266328    }
    267329
    268     private static alglib.ndimensional_pgrad CreatePGrad(AutoDiff.IParametricCompiledTerm compiledFunc) {
    269       return (double[] c, double[] x, ref double func, double[] grad, object o) => {
    270         var tupel = compiledFunc.Differentiate(c, x);
    271         func = tupel.Item2;
    272         Array.Copy(tupel.Item1, grad, grad.Length);
     330    private static alglib.ndimensional_pgrad CreatePGrad(TreeToAutoDiffTermConverter.ParametricFunctionGradient func_grad) {
     331      return (double[] c, double[] x, ref double fx, double[] grad, object o) => {
     332        var tuple = func_grad(c, x);
     333        fx = tuple.Item2;
     334        Array.Copy(tuple.Item1, grad, grad.Length);
     335        var counter = (EvaluationsCounter)o;
     336        counter.GradientEvaluations++;
    273337      };
    274338    }
    275 
    276     private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, List<AutoDiff.Variable> variables, List<AutoDiff.Variable> parameters, List<string> variableNames, out AutoDiff.Term term) {
    277       if (node.Symbol is Constant) {
    278         var var = new AutoDiff.Variable();
    279         variables.Add(var);
    280         term = var;
    281         return true;
    282       }
    283       if (node.Symbol is Variable) {
    284         var varNode = node as VariableTreeNode;
    285         var par = new AutoDiff.Variable();
    286         parameters.Add(par);
    287         variableNames.Add(varNode.VariableName);
    288         var w = new AutoDiff.Variable();
    289         variables.Add(w);
    290         term = AutoDiff.TermBuilder.Product(w, par);
    291         return true;
    292       }
    293       if (node.Symbol is Addition) {
    294         List<AutoDiff.Term> terms = new List<Term>();
    295         foreach (var subTree in node.Subtrees) {
    296           AutoDiff.Term t;
    297           if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out t)) {
    298             term = null;
    299             return false;
    300           }
    301           terms.Add(t);
    302         }
    303         term = AutoDiff.TermBuilder.Sum(terms);
    304         return true;
    305       }
    306       if (node.Symbol is Subtraction) {
    307         List<AutoDiff.Term> terms = new List<Term>();
    308         for (int i = 0; i < node.SubtreeCount; i++) {
    309           AutoDiff.Term t;
    310           if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, variableNames, out t)) {
    311             term = null;
    312             return false;
    313           }
    314           if (i > 0) t = -t;
    315           terms.Add(t);
    316         }
    317         term = AutoDiff.TermBuilder.Sum(terms);
    318         return true;
    319       }
    320       if (node.Symbol is Multiplication) {
    321         AutoDiff.Term a, b;
    322         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out a) ||
    323           !TryTransformToAutoDiff(node.GetSubtree(1), variables, parameters, variableNames, out b)) {
    324           term = null;
    325           return false;
    326         } else {
    327           List<AutoDiff.Term> factors = new List<Term>();
    328           foreach (var subTree in node.Subtrees.Skip(2)) {
    329             AutoDiff.Term f;
    330             if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out f)) {
    331               term = null;
    332               return false;
    333             }
    334             factors.Add(f);
    335           }
    336           term = AutoDiff.TermBuilder.Product(a, b, factors.ToArray());
    337           return true;
    338         }
    339       }
    340       if (node.Symbol is Division) {
    341         // only works for at least two subtrees
    342         AutoDiff.Term a, b;
    343         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out a) ||
    344           !TryTransformToAutoDiff(node.GetSubtree(1), variables, parameters, variableNames, out b)) {
    345           term = null;
    346           return false;
    347         } else {
    348           List<AutoDiff.Term> factors = new List<Term>();
    349           foreach (var subTree in node.Subtrees.Skip(2)) {
    350             AutoDiff.Term f;
    351             if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out f)) {
    352               term = null;
    353               return false;
    354             }
    355             factors.Add(1.0 / f);
    356           }
    357           term = AutoDiff.TermBuilder.Product(a, 1.0 / b, factors.ToArray());
    358           return true;
    359         }
    360       }
    361       if (node.Symbol is Logarithm) {
    362         AutoDiff.Term t;
    363         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
    364           term = null;
    365           return false;
    366         } else {
    367           term = AutoDiff.TermBuilder.Log(t);
    368           return true;
    369         }
    370       }
    371       if (node.Symbol is Exponential) {
    372         AutoDiff.Term t;
    373         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
    374           term = null;
    375           return false;
    376         } else {
    377           term = AutoDiff.TermBuilder.Exp(t);
    378           return true;
    379         }
    380       } if (node.Symbol is Sine) {
    381         AutoDiff.Term t;
    382         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
    383           term = null;
    384           return false;
    385         } else {
    386           term = sin(t);
    387           return true;
    388         }
    389       } if (node.Symbol is Cosine) {
    390         AutoDiff.Term t;
    391         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
    392           term = null;
    393           return false;
    394         } else {
    395           term = cos(t);
    396           return true;
    397         }
    398       } if (node.Symbol is Tangent) {
    399         AutoDiff.Term t;
    400         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
    401           term = null;
    402           return false;
    403         } else {
    404           term = tan(t);
    405           return true;
    406         }
    407       }
    408       if (node.Symbol is Square) {
    409         AutoDiff.Term t;
    410         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
    411           term = null;
    412           return false;
    413         } else {
    414           term = square(t);
    415           return true;
    416         }
    417       } if (node.Symbol is Erf) {
    418         AutoDiff.Term t;
    419         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
    420           term = null;
    421           return false;
    422         } else {
    423           term = erf(t);
    424           return true;
    425         }
    426       } if (node.Symbol is Norm) {
    427         AutoDiff.Term t;
    428         if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {
    429           term = null;
    430           return false;
    431         } else {
    432           term = norm(t);
    433           return true;
    434         }
    435       }
    436       if (node.Symbol is StartSymbol) {
    437         var alpha = new AutoDiff.Variable();
    438         var beta = new AutoDiff.Variable();
    439         variables.Add(beta);
    440         variables.Add(alpha);
    441         AutoDiff.Term branchTerm;
    442         if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out branchTerm)) {
    443           term = branchTerm * alpha + beta;
    444           return true;
    445         } else {
    446           term = null;
    447           return false;
    448         }
    449       }
    450       term = null;
    451       return false;
    452     }
    453 
    454339    public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) {
    455       var containsUnknownSymbol = (
    456         from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
    457         where
    458          !(n.Symbol is Variable) &&
    459          !(n.Symbol is Constant) &&
    460          !(n.Symbol is Addition) &&
    461          !(n.Symbol is Subtraction) &&
    462          !(n.Symbol is Multiplication) &&
    463          !(n.Symbol is Division) &&
    464          !(n.Symbol is Logarithm) &&
    465          !(n.Symbol is Exponential) &&
    466          !(n.Symbol is Sine) &&
    467          !(n.Symbol is Cosine) &&
    468          !(n.Symbol is Tangent) &&
    469          !(n.Symbol is Square) &&
    470          !(n.Symbol is Erf) &&
    471          !(n.Symbol is Norm) &&
    472          !(n.Symbol is StartSymbol)
    473         select n).
    474       Any();
    475       return !containsUnknownSymbol;
     340      return TreeToAutoDiffTermConverter.IsCompatible(tree);
    476341    }
    477342  }
Note: See TracChangeset for help on using the changeset viewer.