Changeset 16463


Ignore:
Timestamp:
12/28/18 17:57:13 (4 months ago)
Author:
mkommend
Message:

#2974: Adapted tree to autodiff converter.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/2974_Constants_Optimization/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/TreeToAutoDiffTermConverter.cs

    r16461 r16463  
    133133    public static bool TryConvertToAutoDiff(ISymbolicExpressionTree tree, bool makeVariableWeightsVariable, bool addLinearScalingTerms, Dictionary<DataForVariable, AutoDiff.Variable> parameters,
    134134      out ParametricFunction func,
    135       out ParametricFunctionGradient func_grad
     135      out ParametricFunctionGradient func_grad,
     136      out double[] initialConstants
    136137  ) {
    137 
    138138      // use a transformator object which holds the state (variable list, parameter list, ...) for recursive transformation of the tree
    139139      var transformator = new TreeToAutoDiffTermConverter(makeVariableWeightsVariable, parameters);
    140140      AutoDiff.Term term;
    141141      try {
    142         term = transformator.ConvertToAutoDiff(tree.Root.GetSubtree(0));
     142
    143143
    144144        if (addLinearScalingTerms) {
     
    146146          var alpha = new AutoDiff.Variable();
    147147          var beta = new AutoDiff.Variable();
    148           transformator.variables.Insert(0, alpha);
    149           transformator.variables.Insert(0, beta);
    150 
     148          transformator.variables.Add(beta);
     149          transformator.variables.Add(alpha);
     150          term = transformator.ConvertToAutoDiff(tree.Root.GetSubtree(0));
    151151          term = term * alpha + beta;
    152         }
    153        
    154         var compiledTerm = term.Compile(transformator.variables.ToArray(), parameters.Select(kvp => kvp.Value).ToArray());
    155 
    156 
     152        } else {
     153          term = transformator.ConvertToAutoDiff(tree.Root.GetSubtree(0));
     154        }
     155
     156        var compiledTerm = term.Compile(transformator.variables.ToArray(), parameters.Values.ToArray());
    157157        func = (vars, @params) => compiledTerm.Evaluate(vars, @params);
    158158        func_grad = (vars, @params) => compiledTerm.Differentiate(vars, @params);
     159        initialConstants = transformator.initialConstants.ToArray();
     160
    159161        return true;
    160162      } catch (ConversionException) {
    161163        func = null;
    162164        func_grad = null;
     165        initialConstants = null;
    163166      }
    164167      return false;
Note: See TracChangeset for help on using the changeset viewer.