Changeset 16457


Ignore:
Timestamp:
12/28/18 10:18:39 (8 weeks ago)
Author:
mkommend
Message:

#2974: Extracted linear scaling terms in auto diff converter.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/2974_Constants_Optimization/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/TreeToAutoDiffTermConverter.cs

    r16360 r16457  
    9898
    9999      // use a transformator object which holds the state (variable list, parameter list, ...) for recursive transformation of the tree
    100       var transformator = new TreeToAutoDiffTermConverter(makeVariableWeightsVariable, addLinearScalingTerms);
     100      var transformator = new TreeToAutoDiffTermConverter(makeVariableWeightsVariable);
    101101      AutoDiff.Term term;
    102102      try {
    103103        term = transformator.ConvertToAutoDiff(tree.Root.GetSubtree(0));
     104
     105        if (addLinearScalingTerms) {
     106          // scaling variables α, β are given at the beginning of the parameter vector
     107          var alpha = new AutoDiff.Variable();
     108          var beta = new AutoDiff.Variable();
     109          transformator.variables.Insert(0, alpha);
     110          transformator.variables.Insert(0, beta);
     111
     112          term = term * alpha + beta;
     113        }
     114
    104115        var parameterEntries = transformator.parameters.ToArray(); // guarantee same order for keys and values
    105116        var compiledTerm = term.Compile(transformator.variables.ToArray(),
    106117          parameterEntries.Select(kvp => kvp.Value).ToArray());
     118
    107119        parameters = new List<DataForVariable>(parameterEntries.Select(kvp => kvp.Key));
    108120        initialConstants = transformator.initialConstants.ToArray();
     
    111123        return true;
    112124      } catch (ConversionException) {
     125        parameters = null;
     126        initialConstants = null;
    113127        func = null;
    114128        func_grad = null;
    115         parameters = null;
    116         initialConstants = null;
    117129      }
    118130      return false;
     
    122134    private readonly
    123135    List<double> initialConstants;
    124     private readonly Dictionary<DataForVariable, AutoDiff.Variable> parameters;
     136    private Dictionary<DataForVariable, AutoDiff.Variable> parameters;
    125137    private readonly List<AutoDiff.Variable> variables;
    126138    private readonly bool makeVariableWeightsVariable;
    127     private readonly bool addLinearScalingTerms;
    128 
    129     private TreeToAutoDiffTermConverter(bool makeVariableWeightsVariable, bool addLinearScalingTerms) {
     139
     140    private TreeToAutoDiffTermConverter(bool makeVariableWeightsVariable) {
    130141      this.makeVariableWeightsVariable = makeVariableWeightsVariable;
    131       this.addLinearScalingTerms = addLinearScalingTerms;
    132142      this.initialConstants = new List<double>();
    133143      this.parameters = new Dictionary<DataForVariable, AutoDiff.Variable>();
     
    249259      if (node.Symbol is CubeRoot) {
    250260        return AutoDiff.TermBuilder.Power(
    251           ConvertToAutoDiff(node.GetSubtree(0)), 1.0/3.0);
     261          ConvertToAutoDiff(node.GetSubtree(0)), 1.0 / 3.0);
    252262      }
    253263      if (node.Symbol is Sine) {
     
    272282      }
    273283      if (node.Symbol is StartSymbol) {
    274         if (addLinearScalingTerms) {
    275           // scaling variables α, β are given at the beginning of the parameter vector
    276           var alpha = new AutoDiff.Variable();
    277           var beta = new AutoDiff.Variable();
    278           variables.Add(beta);
    279           variables.Add(alpha);
    280           var t = ConvertToAutoDiff(node.GetSubtree(0));
    281           return t * alpha + beta;
    282         } else return ConvertToAutoDiff(node.GetSubtree(0));
     284        return ConvertToAutoDiff(node.GetSubtree(0));
    283285      }
    284286      throw new ConversionException();
Note: See TracChangeset for help on using the changeset viewer.