Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/08/17 19:51:31 (7 years ago)
Author:
gkronber
Message:

#2789: worked on nonlinear regression with constraints

File:
1 moved

Legend:

Unmodified
Added
Removed
  • branches/MathNetNumerics-Exploration-2789/HeuristicLab.Algorithms.DataAnalysis.Experimental/TreeToDiffSharpConverter.cs

    r15312 r15313  
    2424using System.Linq;
    2525using System.Runtime.Serialization;
    26 using AutoDiff;
    2726using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
     27using HeuristicLab.Problems.DataAnalysis.Symbolic;
     28using DiffSharp.Interop.Float64;
     29using System.Linq.Expressions;
     30using System.Reflection;
    2831
    2932namespace HeuristicLab.Algorithms.DataAnalysis.Experimental {
    30   public class TreeToAutoDiffTermConverter {
    31     public delegate double ParametricFunction(double[] vars, double[] @params);
    32 
    33     public delegate Tuple<double[], double> ParametricFunctionGradient(double[] vars, double[] @params);
     33  public class TreeToDiffSharpConverter {
     34    public delegate double ParametricFunction(double[] vars);
     35
     36    public delegate Tuple<double[], double> ParametricFunctionGradient(double[] vars);
    3437
    3538    #region helper class
     
    5962    #endregion
    6063
    61     #region derivations of functions
    62     // create function factory for arctangent
    63     private static readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory(
    64       eval: Math.Atan,
    65       diff: x => 1 / (1 + x * x));
    66 
    67     private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory(
    68       eval: Math.Sin,
    69       diff: Math.Cos);
    70 
    71     private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory(
    72       eval: Math.Cos,
    73       diff: x => -Math.Sin(x));
    74 
    75     private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory(
    76       eval: Math.Tan,
    77       diff: x => 1 + Math.Tan(x) * Math.Tan(x));
    78 
    79     private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory(
    80       eval: alglib.errorfunction,
    81       diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI));
    82 
    83     private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory(
    84       eval: alglib.normaldistribution,
    85       diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI));
    86 
    87     #endregion
    88 
    89     public static bool TryConvertToAutoDiff(ISymbolicExpressionTree tree, bool makeVariableWeightsVariable,
     64
     65    public static bool TryConvertToDiffSharp(ISymbolicExpressionTree tree, bool makeVariableWeightsVariable,
    9066      out List<DataForVariable> parameters, out double[] initialConstants,
    91       out ParametricFunction func,
    92       out ParametricFunctionGradient func_grad,
    93       out ParametricFunctionGradient func_grad_for_vars) {
     67      out Func<DV, D> func) {
    9468
    9569      // use a transformator object which holds the state (variable list, parameter list, ...) for recursive transformation of the tree
    96       var transformator = new TreeToAutoDiffTermConverter(makeVariableWeightsVariable);
    97       AutoDiff.Term term;
     70      var transformator = new TreeToDiffSharpConverter(makeVariableWeightsVariable);
    9871      try {
    99         term = transformator.ConvertToAutoDiff(tree.Root.GetSubtree(0));
     72
     73        // the list of variable names represents the names for dv[0] ... dv[d-1] where d is the number of input variables
     74        // the remaining entries of d represent the parameter values
     75        transformator.ExtractParameters(tree.Root.GetSubtree(0));
     76
     77        var lambda = transformator.CreateDelegate(tree, transformator.parameters);
     78        func = lambda.Compile();
     79
    10080        var parameterEntries = transformator.parameters.ToArray(); // guarantee same order for keys and values
    101         var compiledTerm = term.Compile(
    102           transformator.variables.ToArray(),
    103           parameterEntries.Select(kvp => kvp.Value).ToArray());
    104        
    10581        parameters = new List<DataForVariable>(parameterEntries.Select(kvp => kvp.Key));
    10682        initialConstants = transformator.initialConstants.ToArray();
    107         func = (vars, @params) => compiledTerm.Evaluate(vars, @params);
    108         func_grad = (vars, @params) => compiledTerm.Differentiate(vars, @params);
    109         func_grad_for_vars = (vars, @params) => compiledTerm.Differentiate(@params,vars);
    11083        return true;
    11184      } catch (ConversionException) {
    11285        func = null;
    113         func_grad = null;
    114         func_grad_for_vars = null;
    11586        parameters = null;
    11687        initialConstants = null;
     
    11990    }
    12091
     92    public Expression<Func<DV, D>> CreateDelegate(ISymbolicExpressionTree tree, Dictionary<DataForVariable, int> parameters) {
     93      paramIdx = parameters.Count; // first non-variable parameter
     94      var dv = Expression.Parameter(typeof(DV));
     95      var expr = MakeExpr(tree.Root.GetSubtree(0), parameters, dv);
     96      var lambda = Expression.Lambda<Func<DV, D>>(expr, dv);
     97      return lambda;
     98    }
     99
    121100    // state for recursive transformation of trees
    122     private readonly
    123     List<double> initialConstants;
    124     private readonly Dictionary<DataForVariable, AutoDiff.Variable> parameters;
    125     private readonly List<AutoDiff.Variable> variables;
     101    private readonly List<double> initialConstants;
     102    private readonly Dictionary<DataForVariable, int> parameters;
    126103    private readonly bool makeVariableWeightsVariable;
    127 
    128     private TreeToAutoDiffTermConverter(bool makeVariableWeightsVariable) {
     104    private int paramIdx;
     105
     106    private TreeToDiffSharpConverter(bool makeVariableWeightsVariable) {
    129107      this.makeVariableWeightsVariable = makeVariableWeightsVariable;
    130108      this.initialConstants = new List<double>();
    131       this.parameters = new Dictionary<DataForVariable, AutoDiff.Variable>();
    132       this.variables = new List<AutoDiff.Variable>();
    133     }
    134 
    135     private AutoDiff.Term ConvertToAutoDiff(ISymbolicExpressionTreeNode node) {
    136       if (node.Symbol is Constant) {
     109      this.parameters = new Dictionary<DataForVariable, int>();
     110    }
     111
     112    private void ExtractParameters(ISymbolicExpressionTreeNode node) {
     113      if (node.Symbol is HeuristicLab.Problems.DataAnalysis.Symbolic.Constant) {
    137114        initialConstants.Add(((ConstantTreeNode)node).Value);
    138         var var = new AutoDiff.Variable();
    139         variables.Add(var);
    140         return var;
    141       }
    142       if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) {
     115      } else if (node.Symbol is HeuristicLab.Problems.DataAnalysis.Symbolic.Variable || node.Symbol is BinaryFactorVariable) {
     116        var varNode = node as VariableTreeNodeBase;
     117        var factorVarNode = node as BinaryFactorVariableTreeNode;
     118        // factor variable values are only 0 or 1 and set in x accordingly
     119        var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty;
     120        FindOrCreateParameter(parameters, varNode.VariableName, varValue);
     121
     122        if (makeVariableWeightsVariable) {
     123          initialConstants.Add(varNode.Weight);
     124        }
     125      } else if (node.Symbol is FactorVariable) {
     126        var factorVarNode = node as FactorVariableTreeNode;
     127        var products = new List<D>();
     128        foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
     129          FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);
     130
     131          initialConstants.Add(factorVarNode.GetValue(variableValue));
     132        }
     133      } else if (node.Symbol is LaggedVariable) {
     134        var varNode = node as LaggedVariableTreeNode;
     135        FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag);
     136
     137        if (makeVariableWeightsVariable) {
     138          initialConstants.Add(varNode.Weight);
     139        }
     140      } else if (node.Symbol is Addition) {
     141        foreach (var subTree in node.Subtrees) {
     142          ExtractParameters(subTree);
     143        }
     144      } else if (node.Symbol is Subtraction) {
     145        for (int i = 0; i < node.SubtreeCount; i++) {
     146          ExtractParameters(node.GetSubtree(i));
     147        }
     148      } else if (node.Symbol is Multiplication) {
     149        foreach (var subTree in node.Subtrees) {
     150          ExtractParameters(subTree);
     151        }
     152      } else if (node.Symbol is Division) {
     153        foreach (var subTree in node.Subtrees) {
     154          ExtractParameters(subTree);
     155        }
     156      } else if (node.Symbol is Logarithm) {
     157        ExtractParameters(node.GetSubtree(0));
     158      } else if (node.Symbol is Exponential) {
     159        ExtractParameters(node.GetSubtree(0));
     160      } else if (node.Symbol is Square) {
     161        ExtractParameters(node.GetSubtree(0));
     162      } else if (node.Symbol is SquareRoot) {
     163        ExtractParameters(node.GetSubtree(0));
     164      } else if (node.Symbol is Sine) {
     165        ExtractParameters(node.GetSubtree(0));
     166      } else if (node.Symbol is Cosine) {
     167        ExtractParameters(node.GetSubtree(0));
     168      } else if (node.Symbol is Tangent) {
     169        ExtractParameters(node.GetSubtree(0));
     170      } else if (node.Symbol is StartSymbol) {
     171        ExtractParameters(node.GetSubtree(0));
     172      } else throw new ConversionException();
     173    }
     174
     175    private Func<DV, D> CreateDiffSharpFunc(ISymbolicExpressionTreeNode node, Dictionary<DataForVariable, int> parameters) {
     176      this.paramIdx = parameters.Count; // first idx of non-variable parameter     
     177      var f = CreateDiffSharpFunc(node, parameters);
     178      return (DV paramValues) => f(paramValues);
     179    }
     180
     181    private static readonly MethodInfo DvIndexer = typeof(DV).GetMethod("get_Item", new[] { typeof(int) });
     182    private static readonly MethodInfo d_Add_d = typeof(D).GetMethod("op_Addition", new[] { typeof(D), typeof(D) });
     183    private static readonly MethodInfo d_Neg = typeof(D).GetMethod("Neg", new[] { typeof(D) });
     184    private static readonly MethodInfo d_Mul_d = typeof(D).GetMethod("op_Multiply", new[] { typeof(D), typeof(D) });
     185    private static readonly MethodInfo d_Mul_f = typeof(D).GetMethod("op_Multiply", new[] { typeof(D), typeof(double) });
     186    private static readonly MethodInfo d_Div_d = typeof(D).GetMethod("op_Division", new[] { typeof(D), typeof(D) });
     187    private static readonly MethodInfo f_Div_d = typeof(D).GetMethod("op_Division", new[] { typeof(double), typeof(D) });
     188    private static readonly MethodInfo d_Sub_d = typeof(D).GetMethod("op_Subtraction", new[] { typeof(D), typeof(D) });
     189    private static readonly MethodInfo d_Pow_f = typeof(D).GetMethod("Pow", new[] { typeof(D), typeof(double) });
     190    private static readonly MethodInfo d_Log = typeof(D).GetMethod("Log", new[] { typeof(D) });
     191    private static readonly MethodInfo d_Exp = typeof(D).GetMethod("Exp", new[] { typeof(D) });
     192
     193
     194
     195    private Expression MakeExpr(ISymbolicExpressionTreeNode node, Dictionary<DataForVariable, int> parameters, ParameterExpression dv) {
     196      if (node.Symbol is HeuristicLab.Problems.DataAnalysis.Symbolic.Constant) {
     197        return Expression.Call(dv, DvIndexer, Expression.Constant(paramIdx++));
     198      }
     199      if (node.Symbol is HeuristicLab.Problems.DataAnalysis.Symbolic.Variable || node.Symbol is BinaryFactorVariable) {
    143200        var varNode = node as VariableTreeNodeBase;
    144201        var factorVarNode = node as BinaryFactorVariableTreeNode;
     
    148205
    149206        if (makeVariableWeightsVariable) {
    150           initialConstants.Add(varNode.Weight);
    151           var w = new AutoDiff.Variable();
    152           variables.Add(w);
    153           return AutoDiff.TermBuilder.Product(w, par);
     207          var w = Expression.Call(dv, DvIndexer, Expression.Constant(paramIdx++));
     208          var v = Expression.Call(dv, DvIndexer, Expression.Constant(par));
     209          return Expression.Call(d_Mul_d, w, v);
    154210        } else {
    155           return varNode.Weight * par;
     211          var w = Expression.Constant(varNode.Weight);
     212          var v = Expression.Call(dv, DvIndexer, Expression.Constant(par));
     213          return Expression.Call(d_Mul_f, v, w);
    156214        }
    157215      }
    158216      if (node.Symbol is FactorVariable) {
    159217        var factorVarNode = node as FactorVariableTreeNode;
    160         var products = new List<Term>();
    161         foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) {
     218        var products = new List<D>();
     219        var firstValue = factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName).First();
     220        var parForFirstValue = FindOrCreateParameter(parameters, factorVarNode.VariableName, firstValue);
     221        var weightForFirstValue = Expression.Call(dv, DvIndexer, Expression.Constant(paramIdx++));
     222        var valForFirstValue = Expression.Call(dv, DvIndexer, Expression.Constant(parForFirstValue));
     223        var res = Expression.Call(d_Mul_d, weightForFirstValue, valForFirstValue);
     224
     225        foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName).Skip(1)) {
    162226          var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue);
    163 
    164           initialConstants.Add(factorVarNode.GetValue(variableValue));
    165           var wVar = new AutoDiff.Variable();
    166           variables.Add(wVar);
    167 
    168           products.Add(AutoDiff.TermBuilder.Product(wVar, par));
    169         }
    170         return AutoDiff.TermBuilder.Sum(products);
    171       }
    172       if (node.Symbol is LaggedVariable) {
    173         var varNode = node as LaggedVariableTreeNode;
    174         var par = FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag);
    175 
    176         if (makeVariableWeightsVariable) {
    177           initialConstants.Add(varNode.Weight);
    178           var w = new AutoDiff.Variable();
    179           variables.Add(w);
    180           return AutoDiff.TermBuilder.Product(w, par);
     227     
     228          var weight = Expression.Call(dv, DvIndexer, Expression.Constant(paramIdx++));
     229          var v = Expression.Call(dv, DvIndexer, Expression.Constant(par));
     230
     231          res = Expression.Call(d_Add_d, res, Expression.Call(d_Mul_d, weight, v));
     232        }
     233        return res;
     234      }
     235      // if (node.Symbol is LaggedVariable) {
     236      //   var varNode = node as LaggedVariableTreeNode;
     237      //   var par = FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag);
     238      //
     239      //   if (makeVariableWeightsVariable) {
     240      //     initialConstants.Add(varNode.Weight);
     241      //     var w = paramValues[paramIdx++];
     242      //     return w * paramValues[par];
     243      //   } else {
     244      //     return varNode.Weight * paramValues[par];
     245      //   }
     246      // }
     247      if (node.Symbol is Addition) {
     248        var f = MakeExpr(node.Subtrees.First(), parameters, dv);
     249
     250        foreach (var subTree in node.Subtrees.Skip(1)) {
     251          f = Expression.Call(d_Add_d, f, MakeExpr(subTree, parameters, dv));
     252        }
     253        return f;
     254      }
     255      if (node.Symbol is Subtraction) {
     256        if (node.SubtreeCount == 1) {
     257          return Expression.Call(d_Neg, MakeExpr(node.Subtrees.First(), parameters, dv));
    181258        } else {
    182           return varNode.Weight * par;
    183         }
    184       }
    185       if (node.Symbol is Addition) {
    186         List<AutoDiff.Term> terms = new List<Term>();
    187         foreach (var subTree in node.Subtrees) {
    188           terms.Add(ConvertToAutoDiff(subTree));
    189         }
    190         return AutoDiff.TermBuilder.Sum(terms);
    191       }
    192       if (node.Symbol is Subtraction) {
    193         List<AutoDiff.Term> terms = new List<Term>();
    194         for (int i = 0; i < node.SubtreeCount; i++) {
    195           AutoDiff.Term t = ConvertToAutoDiff(node.GetSubtree(i));
    196           if (i > 0) t = -t;
    197           terms.Add(t);
    198         }
    199         if (terms.Count == 1) return -terms[0];
    200         else return AutoDiff.TermBuilder.Sum(terms);
     259          var f = MakeExpr(node.Subtrees.First(), parameters, dv);
     260
     261          foreach (var subTree in node.Subtrees.Skip(1)) {
     262            f = Expression.Call(d_Sub_d, f, MakeExpr(subTree, parameters, dv));
     263          }
     264          return f;
     265        }
    201266      }
    202267      if (node.Symbol is Multiplication) {
    203         List<AutoDiff.Term> terms = new List<Term>();
    204         foreach (var subTree in node.Subtrees) {
    205           terms.Add(ConvertToAutoDiff(subTree));
    206         }
    207         if (terms.Count == 1) return terms[0];
    208         else return terms.Aggregate((a, b) => new AutoDiff.Product(a, b));
     268        var f = MakeExpr(node.Subtrees.First(), parameters, dv);
     269        foreach (var subTree in node.Subtrees.Skip(1)) {
     270          f = Expression.Call(d_Mul_d, f, MakeExpr(subTree, parameters, dv));
     271        }
     272        return f;
    209273      }
    210274      if (node.Symbol is Division) {
    211         List<AutoDiff.Term> terms = new List<Term>();
    212         foreach (var subTree in node.Subtrees) {
    213           terms.Add(ConvertToAutoDiff(subTree));
    214         }
    215         if (terms.Count == 1) return 1.0 / terms[0];
    216         else return terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b));
     275        if (node.SubtreeCount == 1) {
     276          return Expression.Call(f_Div_d, Expression.Constant(1.0), MakeExpr(node.Subtrees.First(), parameters, dv));
     277        } else {
     278          var f = MakeExpr(node.Subtrees.First(), parameters, dv);
     279
     280          foreach (var subTree in node.Subtrees.Skip(1)) {
     281            f = Expression.Call(d_Div_d, f, MakeExpr(subTree, parameters, dv));
     282          }
     283          return f;
     284        }
    217285      }
    218286      if (node.Symbol is Logarithm) {
    219         return AutoDiff.TermBuilder.Log(
    220           ConvertToAutoDiff(node.GetSubtree(0)));
     287        return Expression.Call(d_Log, MakeExpr(node.GetSubtree(0), parameters, dv));
    221288      }
    222289      if (node.Symbol is Exponential) {
    223         return AutoDiff.TermBuilder.Exp(
    224           ConvertToAutoDiff(node.GetSubtree(0)));
     290        return Expression.Call(d_Exp, MakeExpr(node.GetSubtree(0), parameters, dv));
    225291      }
    226292      if (node.Symbol is Square) {
    227         return AutoDiff.TermBuilder.Power(
    228           ConvertToAutoDiff(node.GetSubtree(0)), 2.0);
     293        return Expression.Call(d_Pow_f, MakeExpr(node.GetSubtree(0), parameters, dv), Expression.Constant(2.0));
    229294      }
    230295      if (node.Symbol is SquareRoot) {
    231         return AutoDiff.TermBuilder.Power(
    232           ConvertToAutoDiff(node.GetSubtree(0)), 0.5);
    233       }
    234       if (node.Symbol is Sine) {
    235         return sin(
    236           ConvertToAutoDiff(node.GetSubtree(0)));
    237       }
    238       if (node.Symbol is Cosine) {
    239         return cos(
    240           ConvertToAutoDiff(node.GetSubtree(0)));
    241       }
    242       if (node.Symbol is Tangent) {
    243         return tan(
    244           ConvertToAutoDiff(node.GetSubtree(0)));
    245       }
    246       if (node.Symbol is Erf) {
    247         return erf(
    248           ConvertToAutoDiff(node.GetSubtree(0)));
    249       }
    250       if (node.Symbol is Norm) {
    251         return norm(
    252           ConvertToAutoDiff(node.GetSubtree(0)));
    253       }
     296        return Expression.Call(d_Pow_f, MakeExpr(node.GetSubtree(0), parameters, dv), Expression.Constant(0.5));
     297      }
     298      // if (node.Symbol is Sine) {
     299      //   return AD.Sin(CreateDiffSharpFunc(node.GetSubtree(0), parameters, paramValues));
     300      // }
     301      // if (node.Symbol is Cosine) {
     302      //   return AD.Cos(CreateDiffSharpFunc(node.GetSubtree(0), parameters, paramValues));
     303      // }
     304      // if (node.Symbol is Tangent) {
     305      //   return AD.Tan(CreateDiffSharpFunc(node.GetSubtree(0), parameters, paramValues));
     306      // }
    254307      if (node.Symbol is StartSymbol) {
    255         var alpha = new AutoDiff.Variable();
    256         var beta = new AutoDiff.Variable();
    257         variables.Add(beta);
    258         variables.Add(alpha);
    259         return ConvertToAutoDiff(node.GetSubtree(0)) * alpha + beta;
     308        var alpha = Expression.Call(dv, DvIndexer, Expression.Constant(paramIdx++));
     309        var beta = Expression.Call(dv, DvIndexer, Expression.Constant(paramIdx++));
     310
     311        return Expression.Call(d_Add_d, beta,
     312          Expression.Call(d_Mul_d, alpha, MakeExpr(node.GetSubtree(0), parameters, dv)));
    260313      }
    261314      throw new ConversionException();
     
    265318    // for each factor variable value we need a parameter which represents a binary indicator for that variable & value combination
    266319    // each binary indicator is only necessary once. So we only create a parameter if this combination is not yet available.
    267     private static Term FindOrCreateParameter(Dictionary<DataForVariable, AutoDiff.Variable> parameters,
     320    private static int FindOrCreateParameter(Dictionary<DataForVariable, int> parameters,
    268321      string varName, string varValue = "", int lag = 0) {
    269322      var data = new DataForVariable(varName, varValue, lag);
    270 
    271       AutoDiff.Variable par = null;
    272       if (!parameters.TryGetValue(data, out par)) {
    273         // not found -> create new parameter and entries in names and values lists
    274         par = new AutoDiff.Variable();
    275         parameters.Add(data, par);
    276       }
    277       return par;
     323      int idx = -1;
     324      if (parameters.TryGetValue(data, out idx)) return idx;
     325      else parameters[data] = parameters.Count;
     326      return idx;
    278327    }
    279328
     
    282331        from n in tree.Root.GetSubtree(0).IterateNodesPrefix()
    283332        where
    284           !(n.Symbol is Variable) &&
     333          !(n.Symbol is HeuristicLab.Problems.DataAnalysis.Symbolic.Variable) &&
    285334          !(n.Symbol is BinaryFactorVariable) &&
    286335          !(n.Symbol is FactorVariable) &&
    287336          !(n.Symbol is LaggedVariable) &&
    288           !(n.Symbol is Constant) &&
     337          !(n.Symbol is HeuristicLab.Problems.DataAnalysis.Symbolic.Constant) &&
    289338          !(n.Symbol is Addition) &&
    290339          !(n.Symbol is Subtraction) &&
     
    298347          !(n.Symbol is Cosine) &&
    299348          !(n.Symbol is Tangent) &&
    300           !(n.Symbol is Erf) &&
    301           !(n.Symbol is Norm) &&
    302349          !(n.Symbol is StartSymbol)
    303350        select n).Any();
Note: See TracChangeset for help on using the changeset viewer.