Changeset 15313 for branches/MathNetNumerics-Exploration-2789/HeuristicLab.Algorithms.DataAnalysis.Experimental/TreeToDiffSharpConverter.cs
- Timestamp:
- 08/08/17 19:51:31 (7 years ago)
- File:
-
- 1 moved
Legend:
- Unmodified
- Added
- Removed
-
branches/MathNetNumerics-Exploration-2789/HeuristicLab.Algorithms.DataAnalysis.Experimental/TreeToDiffSharpConverter.cs
r15312 r15313 24 24 using System.Linq; 25 25 using System.Runtime.Serialization; 26 using AutoDiff;27 26 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; 27 using HeuristicLab.Problems.DataAnalysis.Symbolic; 28 using DiffSharp.Interop.Float64; 29 using System.Linq.Expressions; 30 using System.Reflection; 28 31 29 32 namespace HeuristicLab.Algorithms.DataAnalysis.Experimental { 30 public class TreeTo AutoDiffTermConverter {31 public delegate double ParametricFunction(double[] vars , double[] @params);32 33 public delegate Tuple<double[], double> ParametricFunctionGradient(double[] vars , double[] @params);33 public class TreeToDiffSharpConverter { 34 public delegate double ParametricFunction(double[] vars); 35 36 public delegate Tuple<double[], double> ParametricFunctionGradient(double[] vars); 34 37 35 38 #region helper class … … 59 62 #endregion 60 63 61 #region derivations of functions 62 // create function factory for arctangent 63 private static readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory( 64 eval: Math.Atan, 65 diff: x => 1 / (1 + x * x)); 66 67 private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory( 68 eval: Math.Sin, 69 diff: Math.Cos); 70 71 private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory( 72 eval: Math.Cos, 73 diff: x => -Math.Sin(x)); 74 75 private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory( 76 eval: Math.Tan, 77 diff: x => 1 + Math.Tan(x) * Math.Tan(x)); 78 79 private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory( 80 eval: alglib.errorfunction, 81 diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI)); 82 83 private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory( 84 eval: alglib.normaldistribution, 85 diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI)); 86 87 #endregion 88 89 public static bool TryConvertToAutoDiff(ISymbolicExpressionTree tree, bool makeVariableWeightsVariable, 64 65 public static bool TryConvertToDiffSharp(ISymbolicExpressionTree tree, bool makeVariableWeightsVariable, 90 66 out List<DataForVariable> parameters, out double[] initialConstants, 91 out ParametricFunction func, 92 out ParametricFunctionGradient func_grad, 93 out ParametricFunctionGradient func_grad_for_vars) { 67 out Func<DV, D> func) { 94 68 95 69 // use a transformator object which holds the state (variable list, parameter list, ...) for recursive transformation of the tree 96 var transformator = new TreeToAutoDiffTermConverter(makeVariableWeightsVariable); 97 AutoDiff.Term term; 70 var transformator = new TreeToDiffSharpConverter(makeVariableWeightsVariable); 98 71 try { 99 term = transformator.ConvertToAutoDiff(tree.Root.GetSubtree(0)); 72 73 // the list of variable names represents the names for dv[0] ... dv[d-1] where d is the number of input variables 74 // the remaining entries of d represent the parameter values 75 transformator.ExtractParameters(tree.Root.GetSubtree(0)); 76 77 var lambda = transformator.CreateDelegate(tree, transformator.parameters); 78 func = lambda.Compile(); 79 100 80 var parameterEntries = transformator.parameters.ToArray(); // guarantee same order for keys and values 101 var compiledTerm = term.Compile(102 transformator.variables.ToArray(),103 parameterEntries.Select(kvp => kvp.Value).ToArray());104 105 81 parameters = new List<DataForVariable>(parameterEntries.Select(kvp => kvp.Key)); 106 82 initialConstants = transformator.initialConstants.ToArray(); 107 func = (vars, @params) => compiledTerm.Evaluate(vars, @params);108 func_grad = (vars, @params) => compiledTerm.Differentiate(vars, @params);109 func_grad_for_vars = (vars, @params) => compiledTerm.Differentiate(@params,vars);110 83 return true; 111 84 } catch (ConversionException) { 112 85 func = null; 113 func_grad = null;114 func_grad_for_vars = null;115 86 parameters = null; 116 87 initialConstants = null; … … 119 90 } 120 91 92 public Expression<Func<DV, D>> CreateDelegate(ISymbolicExpressionTree tree, Dictionary<DataForVariable, int> parameters) { 93 paramIdx = parameters.Count; // first non-variable parameter 94 var dv = Expression.Parameter(typeof(DV)); 95 var expr = MakeExpr(tree.Root.GetSubtree(0), parameters, dv); 96 var lambda = Expression.Lambda<Func<DV, D>>(expr, dv); 97 return lambda; 98 } 99 121 100 // state for recursive transformation of trees 122 private readonly 123 List<double> initialConstants; 124 private readonly Dictionary<DataForVariable, AutoDiff.Variable> parameters; 125 private readonly List<AutoDiff.Variable> variables; 101 private readonly List<double> initialConstants; 102 private readonly Dictionary<DataForVariable, int> parameters; 126 103 private readonly bool makeVariableWeightsVariable; 127 128 private TreeToAutoDiffTermConverter(bool makeVariableWeightsVariable) { 104 private int paramIdx; 105 106 private TreeToDiffSharpConverter(bool makeVariableWeightsVariable) { 129 107 this.makeVariableWeightsVariable = makeVariableWeightsVariable; 130 108 this.initialConstants = new List<double>(); 131 this.parameters = new Dictionary<DataForVariable, AutoDiff.Variable>(); 132 this.variables = new List<AutoDiff.Variable>(); 133 } 134 135 private AutoDiff.Term ConvertToAutoDiff(ISymbolicExpressionTreeNode node) { 136 if (node.Symbol is Constant) { 109 this.parameters = new Dictionary<DataForVariable, int>(); 110 } 111 112 private void ExtractParameters(ISymbolicExpressionTreeNode node) { 113 if (node.Symbol is HeuristicLab.Problems.DataAnalysis.Symbolic.Constant) { 137 114 initialConstants.Add(((ConstantTreeNode)node).Value); 138 var var = new AutoDiff.Variable(); 139 variables.Add(var); 140 return var; 141 } 142 if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) { 115 } else if (node.Symbol is HeuristicLab.Problems.DataAnalysis.Symbolic.Variable || node.Symbol is BinaryFactorVariable) { 116 var varNode = node as VariableTreeNodeBase; 117 var factorVarNode = node as BinaryFactorVariableTreeNode; 118 // factor variable values are only 0 or 1 and set in x accordingly 119 var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty; 120 FindOrCreateParameter(parameters, varNode.VariableName, varValue); 121 122 if (makeVariableWeightsVariable) { 123 initialConstants.Add(varNode.Weight); 124 } 125 } else if (node.Symbol is FactorVariable) { 126 var factorVarNode = node as FactorVariableTreeNode; 127 var products = new List<D>(); 128 foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) { 129 FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue); 130 131 initialConstants.Add(factorVarNode.GetValue(variableValue)); 132 } 133 } else if (node.Symbol is LaggedVariable) { 134 var varNode = node as LaggedVariableTreeNode; 135 FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag); 136 137 if (makeVariableWeightsVariable) { 138 initialConstants.Add(varNode.Weight); 139 } 140 } else if (node.Symbol is Addition) { 141 foreach (var subTree in node.Subtrees) { 142 ExtractParameters(subTree); 143 } 144 } else if (node.Symbol is Subtraction) { 145 for (int i = 0; i < node.SubtreeCount; i++) { 146 ExtractParameters(node.GetSubtree(i)); 147 } 148 } else if (node.Symbol is Multiplication) { 149 foreach (var subTree in node.Subtrees) { 150 ExtractParameters(subTree); 151 } 152 } else if (node.Symbol is Division) { 153 foreach (var subTree in node.Subtrees) { 154 ExtractParameters(subTree); 155 } 156 } else if (node.Symbol is Logarithm) { 157 ExtractParameters(node.GetSubtree(0)); 158 } else if (node.Symbol is Exponential) { 159 ExtractParameters(node.GetSubtree(0)); 160 } else if (node.Symbol is Square) { 161 ExtractParameters(node.GetSubtree(0)); 162 } else if (node.Symbol is SquareRoot) { 163 ExtractParameters(node.GetSubtree(0)); 164 } else if (node.Symbol is Sine) { 165 ExtractParameters(node.GetSubtree(0)); 166 } else if (node.Symbol is Cosine) { 167 ExtractParameters(node.GetSubtree(0)); 168 } else if (node.Symbol is Tangent) { 169 ExtractParameters(node.GetSubtree(0)); 170 } else if (node.Symbol is StartSymbol) { 171 ExtractParameters(node.GetSubtree(0)); 172 } else throw new ConversionException(); 173 } 174 175 private Func<DV, D> CreateDiffSharpFunc(ISymbolicExpressionTreeNode node, Dictionary<DataForVariable, int> parameters) { 176 this.paramIdx = parameters.Count; // first idx of non-variable parameter 177 var f = CreateDiffSharpFunc(node, parameters); 178 return (DV paramValues) => f(paramValues); 179 } 180 181 private static readonly MethodInfo DvIndexer = typeof(DV).GetMethod("get_Item", new[] { typeof(int) }); 182 private static readonly MethodInfo d_Add_d = typeof(D).GetMethod("op_Addition", new[] { typeof(D), typeof(D) }); 183 private static readonly MethodInfo d_Neg = typeof(D).GetMethod("Neg", new[] { typeof(D) }); 184 private static readonly MethodInfo d_Mul_d = typeof(D).GetMethod("op_Multiply", new[] { typeof(D), typeof(D) }); 185 private static readonly MethodInfo d_Mul_f = typeof(D).GetMethod("op_Multiply", new[] { typeof(D), typeof(double) }); 186 private static readonly MethodInfo d_Div_d = typeof(D).GetMethod("op_Division", new[] { typeof(D), typeof(D) }); 187 private static readonly MethodInfo f_Div_d = typeof(D).GetMethod("op_Division", new[] { typeof(double), typeof(D) }); 188 private static readonly MethodInfo d_Sub_d = typeof(D).GetMethod("op_Subtraction", new[] { typeof(D), typeof(D) }); 189 private static readonly MethodInfo d_Pow_f = typeof(D).GetMethod("Pow", new[] { typeof(D), typeof(double) }); 190 private static readonly MethodInfo d_Log = typeof(D).GetMethod("Log", new[] { typeof(D) }); 191 private static readonly MethodInfo d_Exp = typeof(D).GetMethod("Exp", new[] { typeof(D) }); 192 193 194 195 private Expression MakeExpr(ISymbolicExpressionTreeNode node, Dictionary<DataForVariable, int> parameters, ParameterExpression dv) { 196 if (node.Symbol is HeuristicLab.Problems.DataAnalysis.Symbolic.Constant) { 197 return Expression.Call(dv, DvIndexer, Expression.Constant(paramIdx++)); 198 } 199 if (node.Symbol is HeuristicLab.Problems.DataAnalysis.Symbolic.Variable || node.Symbol is BinaryFactorVariable) { 143 200 var varNode = node as VariableTreeNodeBase; 144 201 var factorVarNode = node as BinaryFactorVariableTreeNode; … … 148 205 149 206 if (makeVariableWeightsVariable) { 150 initialConstants.Add(varNode.Weight); 151 var w = new AutoDiff.Variable(); 152 variables.Add(w); 153 return AutoDiff.TermBuilder.Product(w, par); 207 var w = Expression.Call(dv, DvIndexer, Expression.Constant(paramIdx++)); 208 var v = Expression.Call(dv, DvIndexer, Expression.Constant(par)); 209 return Expression.Call(d_Mul_d, w, v); 154 210 } else { 155 return varNode.Weight * par; 211 var w = Expression.Constant(varNode.Weight); 212 var v = Expression.Call(dv, DvIndexer, Expression.Constant(par)); 213 return Expression.Call(d_Mul_f, v, w); 156 214 } 157 215 } 158 216 if (node.Symbol is FactorVariable) { 159 217 var factorVarNode = node as FactorVariableTreeNode; 160 var products = new List<Term>(); 161 foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName)) { 218 var products = new List<D>(); 219 var firstValue = factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName).First(); 220 var parForFirstValue = FindOrCreateParameter(parameters, factorVarNode.VariableName, firstValue); 221 var weightForFirstValue = Expression.Call(dv, DvIndexer, Expression.Constant(paramIdx++)); 222 var valForFirstValue = Expression.Call(dv, DvIndexer, Expression.Constant(parForFirstValue)); 223 var res = Expression.Call(d_Mul_d, weightForFirstValue, valForFirstValue); 224 225 foreach (var variableValue in factorVarNode.Symbol.GetVariableValues(factorVarNode.VariableName).Skip(1)) { 162 226 var par = FindOrCreateParameter(parameters, factorVarNode.VariableName, variableValue); 163 164 initialConstants.Add(factorVarNode.GetValue(variableValue)); 165 var wVar = new AutoDiff.Variable(); 166 variables.Add(wVar); 167 168 products.Add(AutoDiff.TermBuilder.Product(wVar, par)); 169 } 170 return AutoDiff.TermBuilder.Sum(products); 171 } 172 if (node.Symbol is LaggedVariable) { 173 var varNode = node as LaggedVariableTreeNode; 174 var par = FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag); 175 176 if (makeVariableWeightsVariable) { 177 initialConstants.Add(varNode.Weight); 178 var w = new AutoDiff.Variable(); 179 variables.Add(w); 180 return AutoDiff.TermBuilder.Product(w, par); 227 228 var weight = Expression.Call(dv, DvIndexer, Expression.Constant(paramIdx++)); 229 var v = Expression.Call(dv, DvIndexer, Expression.Constant(par)); 230 231 res = Expression.Call(d_Add_d, res, Expression.Call(d_Mul_d, weight, v)); 232 } 233 return res; 234 } 235 // if (node.Symbol is LaggedVariable) { 236 // var varNode = node as LaggedVariableTreeNode; 237 // var par = FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag); 238 // 239 // if (makeVariableWeightsVariable) { 240 // initialConstants.Add(varNode.Weight); 241 // var w = paramValues[paramIdx++]; 242 // return w * paramValues[par]; 243 // } else { 244 // return varNode.Weight * paramValues[par]; 245 // } 246 // } 247 if (node.Symbol is Addition) { 248 var f = MakeExpr(node.Subtrees.First(), parameters, dv); 249 250 foreach (var subTree in node.Subtrees.Skip(1)) { 251 f = Expression.Call(d_Add_d, f, MakeExpr(subTree, parameters, dv)); 252 } 253 return f; 254 } 255 if (node.Symbol is Subtraction) { 256 if (node.SubtreeCount == 1) { 257 return Expression.Call(d_Neg, MakeExpr(node.Subtrees.First(), parameters, dv)); 181 258 } else { 182 return varNode.Weight * par; 183 } 184 } 185 if (node.Symbol is Addition) { 186 List<AutoDiff.Term> terms = new List<Term>(); 187 foreach (var subTree in node.Subtrees) { 188 terms.Add(ConvertToAutoDiff(subTree)); 189 } 190 return AutoDiff.TermBuilder.Sum(terms); 191 } 192 if (node.Symbol is Subtraction) { 193 List<AutoDiff.Term> terms = new List<Term>(); 194 for (int i = 0; i < node.SubtreeCount; i++) { 195 AutoDiff.Term t = ConvertToAutoDiff(node.GetSubtree(i)); 196 if (i > 0) t = -t; 197 terms.Add(t); 198 } 199 if (terms.Count == 1) return -terms[0]; 200 else return AutoDiff.TermBuilder.Sum(terms); 259 var f = MakeExpr(node.Subtrees.First(), parameters, dv); 260 261 foreach (var subTree in node.Subtrees.Skip(1)) { 262 f = Expression.Call(d_Sub_d, f, MakeExpr(subTree, parameters, dv)); 263 } 264 return f; 265 } 201 266 } 202 267 if (node.Symbol is Multiplication) { 203 List<AutoDiff.Term> terms = new List<Term>(); 204 foreach (var subTree in node.Subtrees) { 205 terms.Add(ConvertToAutoDiff(subTree)); 206 } 207 if (terms.Count == 1) return terms[0]; 208 else return terms.Aggregate((a, b) => new AutoDiff.Product(a, b)); 268 var f = MakeExpr(node.Subtrees.First(), parameters, dv); 269 foreach (var subTree in node.Subtrees.Skip(1)) { 270 f = Expression.Call(d_Mul_d, f, MakeExpr(subTree, parameters, dv)); 271 } 272 return f; 209 273 } 210 274 if (node.Symbol is Division) { 211 List<AutoDiff.Term> terms = new List<Term>(); 212 foreach (var subTree in node.Subtrees) { 213 terms.Add(ConvertToAutoDiff(subTree)); 214 } 215 if (terms.Count == 1) return 1.0 / terms[0]; 216 else return terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b)); 275 if (node.SubtreeCount == 1) { 276 return Expression.Call(f_Div_d, Expression.Constant(1.0), MakeExpr(node.Subtrees.First(), parameters, dv)); 277 } else { 278 var f = MakeExpr(node.Subtrees.First(), parameters, dv); 279 280 foreach (var subTree in node.Subtrees.Skip(1)) { 281 f = Expression.Call(d_Div_d, f, MakeExpr(subTree, parameters, dv)); 282 } 283 return f; 284 } 217 285 } 218 286 if (node.Symbol is Logarithm) { 219 return AutoDiff.TermBuilder.Log( 220 ConvertToAutoDiff(node.GetSubtree(0))); 287 return Expression.Call(d_Log, MakeExpr(node.GetSubtree(0), parameters, dv)); 221 288 } 222 289 if (node.Symbol is Exponential) { 223 return AutoDiff.TermBuilder.Exp( 224 ConvertToAutoDiff(node.GetSubtree(0))); 290 return Expression.Call(d_Exp, MakeExpr(node.GetSubtree(0), parameters, dv)); 225 291 } 226 292 if (node.Symbol is Square) { 227 return AutoDiff.TermBuilder.Power( 228 ConvertToAutoDiff(node.GetSubtree(0)), 2.0); 293 return Expression.Call(d_Pow_f, MakeExpr(node.GetSubtree(0), parameters, dv), Expression.Constant(2.0)); 229 294 } 230 295 if (node.Symbol is SquareRoot) { 231 return AutoDiff.TermBuilder.Power( 232 ConvertToAutoDiff(node.GetSubtree(0)), 0.5); 233 } 234 if (node.Symbol is Sine) { 235 return sin( 236 ConvertToAutoDiff(node.GetSubtree(0))); 237 } 238 if (node.Symbol is Cosine) { 239 return cos( 240 ConvertToAutoDiff(node.GetSubtree(0))); 241 } 242 if (node.Symbol is Tangent) { 243 return tan( 244 ConvertToAutoDiff(node.GetSubtree(0))); 245 } 246 if (node.Symbol is Erf) { 247 return erf( 248 ConvertToAutoDiff(node.GetSubtree(0))); 249 } 250 if (node.Symbol is Norm) { 251 return norm( 252 ConvertToAutoDiff(node.GetSubtree(0))); 253 } 296 return Expression.Call(d_Pow_f, MakeExpr(node.GetSubtree(0), parameters, dv), Expression.Constant(0.5)); 297 } 298 // if (node.Symbol is Sine) { 299 // return AD.Sin(CreateDiffSharpFunc(node.GetSubtree(0), parameters, paramValues)); 300 // } 301 // if (node.Symbol is Cosine) { 302 // return AD.Cos(CreateDiffSharpFunc(node.GetSubtree(0), parameters, paramValues)); 303 // } 304 // if (node.Symbol is Tangent) { 305 // return AD.Tan(CreateDiffSharpFunc(node.GetSubtree(0), parameters, paramValues)); 306 // } 254 307 if (node.Symbol is StartSymbol) { 255 var alpha = new AutoDiff.Variable();256 var beta = new AutoDiff.Variable();257 variables.Add(beta); 258 variables.Add(alpha);259 return ConvertToAutoDiff(node.GetSubtree(0)) * alpha + beta;308 var alpha = Expression.Call(dv, DvIndexer, Expression.Constant(paramIdx++)); 309 var beta = Expression.Call(dv, DvIndexer, Expression.Constant(paramIdx++)); 310 311 return Expression.Call(d_Add_d, beta, 312 Expression.Call(d_Mul_d, alpha, MakeExpr(node.GetSubtree(0), parameters, dv))); 260 313 } 261 314 throw new ConversionException(); … … 265 318 // for each factor variable value we need a parameter which represents a binary indicator for that variable & value combination 266 319 // each binary indicator is only necessary once. So we only create a parameter if this combination is not yet available. 267 private static Term FindOrCreateParameter(Dictionary<DataForVariable, AutoDiff.Variable> parameters,320 private static int FindOrCreateParameter(Dictionary<DataForVariable, int> parameters, 268 321 string varName, string varValue = "", int lag = 0) { 269 322 var data = new DataForVariable(varName, varValue, lag); 270 271 AutoDiff.Variable par = null; 272 if (!parameters.TryGetValue(data, out par)) { 273 // not found -> create new parameter and entries in names and values lists 274 par = new AutoDiff.Variable(); 275 parameters.Add(data, par); 276 } 277 return par; 323 int idx = -1; 324 if (parameters.TryGetValue(data, out idx)) return idx; 325 else parameters[data] = parameters.Count; 326 return idx; 278 327 } 279 328 … … 282 331 from n in tree.Root.GetSubtree(0).IterateNodesPrefix() 283 332 where 284 !(n.Symbol is Variable) &&333 !(n.Symbol is HeuristicLab.Problems.DataAnalysis.Symbolic.Variable) && 285 334 !(n.Symbol is BinaryFactorVariable) && 286 335 !(n.Symbol is FactorVariable) && 287 336 !(n.Symbol is LaggedVariable) && 288 !(n.Symbol is Constant) &&337 !(n.Symbol is HeuristicLab.Problems.DataAnalysis.Symbolic.Constant) && 289 338 !(n.Symbol is Addition) && 290 339 !(n.Symbol is Subtraction) && … … 298 347 !(n.Symbol is Cosine) && 299 348 !(n.Symbol is Tangent) && 300 !(n.Symbol is Erf) &&301 !(n.Symbol is Norm) &&302 349 !(n.Symbol is StartSymbol) 303 350 select n).Any();
Note: See TracChangeset
for help on using the changeset viewer.