Changeset 14950
- Timestamp:
- 05/09/17 20:08:11 (8 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/TreeToAutoDiffTermConverter.cs
r14851 r14950 23 23 using System.Collections.Generic; 24 24 using System.Linq; 25 using System.Runtime.Serialization; 25 26 using AutoDiff; 26 27 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; … … 29 30 public class TreeToAutoDiffTermConverter { 30 31 public delegate double ParametricFunction(double[] vars, double[] @params); 32 31 33 public delegate Tuple<double[], double> ParametricFunctionGradient(double[] vars, double[] @params); 32 34 … … 62 64 eval: Math.Atan, 63 65 diff: x => 1 / (1 + x * x)); 66 64 67 private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory( 65 68 eval: Math.Sin, 66 69 diff: Math.Cos); 70 67 71 private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory( 68 eval: Math.Cos, 69 diff: x => -Math.Sin(x)); 72 eval: Math.Cos, 73 diff: x => -Math.Sin(x)); 74 70 75 private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory( 71 76 eval: Math.Tan, 72 77 diff: x => 1 + Math.Tan(x) * Math.Tan(x)); 78 73 79 private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory( 74 80 eval: alglib.errorfunction, 75 81 diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI)); 82 76 83 private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory( 77 84 eval: alglib.normaldistribution, … … 88 95 var transformator = new TreeToAutoDiffTermConverter(makeVariableWeightsVariable); 89 96 AutoDiff.Term term; 90 var success = transformator.TryConvertToAutoDiff(tree.Root.GetSubtree(0), out term);91 if (success) {97 try { 98 term = transformator.ConvertToAutoDiff(tree.Root.GetSubtree(0)); 92 99 var parameterEntries = transformator.parameters.ToArray(); // guarantee same order for keys and values 93 var compiledTerm = term.Compile(transformator.variables.ToArray(), parameterEntries.Select(kvp => kvp.Value).ToArray()); 100 var compiledTerm = term.Compile(transformator.variables.ToArray(), 101 parameterEntries.Select(kvp => kvp.Value).ToArray()); 94 102 parameters = new List<DataForVariable>(parameterEntries.Select(kvp => kvp.Key)); 95 103 initialConstants = transformator.initialConstants.ToArray(); 96 104 func = (vars, @params) => compiledTerm.Evaluate(vars, @params); 97 105 func_grad = (vars, @params) => compiledTerm.Differentiate(vars, @params); 98 } else { 106 return true; 107 } catch (ConversionException) { 99 108 func = null; 100 109 func_grad = null; … … 102 111 initialConstants = null; 103 112 } 104 return success;113 return false; 105 114 } 106 115 107 116 // state for recursive transformation of trees 108 private readonly List<double> initialConstants; 117 private readonly 118 List<double> initialConstants; 109 119 private readonly Dictionary<DataForVariable, AutoDiff.Variable> parameters; 110 120 private readonly List<AutoDiff.Variable> variables; … … 118 128 } 119 129 120 private bool TryConvertToAutoDiff(ISymbolicExpressionTreeNode node, out AutoDiff.Term term) {130 private AutoDiff.Term ConvertToAutoDiff(ISymbolicExpressionTreeNode node) { 121 131 if (node.Symbol is Constant) { 122 132 initialConstants.Add(((ConstantTreeNode)node).Value); 123 133 var var = new AutoDiff.Variable(); 124 134 variables.Add(var); 125 term = var; 126 return true; 135 return var; 127 136 } 128 137 if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) { … … 137 146 var w = new AutoDiff.Variable(); 138 147 variables.Add(w); 139 term =AutoDiff.TermBuilder.Product(w, par);148 return AutoDiff.TermBuilder.Product(w, par); 140 149 } else { 141 term = varNode.Weight * par; 142 } 143 return true; 150 return varNode.Weight * par; 151 } 144 152 } 145 153 if (node.Symbol is FactorVariable) { … … 155 163 products.Add(AutoDiff.TermBuilder.Product(wVar, par)); 156 164 } 157 term = AutoDiff.TermBuilder.Sum(products); 158 return true; 165 return AutoDiff.TermBuilder.Sum(products); 159 166 } 160 167 if (node.Symbol is LaggedVariable) { … … 166 173 var w = new AutoDiff.Variable(); 167 174 variables.Add(w); 168 term =AutoDiff.TermBuilder.Product(w, par);175 return AutoDiff.TermBuilder.Product(w, par); 169 176 } else { 170 term = varNode.Weight * par; 171 } 172 return true; 177 return varNode.Weight * par; 178 } 173 179 } 174 180 if (node.Symbol is Addition) { 175 181 List<AutoDiff.Term> terms = new List<Term>(); 176 182 foreach (var subTree in node.Subtrees) { 177 AutoDiff.Term t; 178 if (!TryConvertToAutoDiff(subTree, out t)) { 179 term = null; 180 return false; 181 } 182 terms.Add(t); 183 } 184 term = AutoDiff.TermBuilder.Sum(terms); 185 return true; 183 terms.Add(ConvertToAutoDiff(subTree)); 184 } 185 return AutoDiff.TermBuilder.Sum(terms); 186 186 } 187 187 if (node.Symbol is Subtraction) { 188 188 List<AutoDiff.Term> terms = new List<Term>(); 189 189 for (int i = 0; i < node.SubtreeCount; i++) { 190 AutoDiff.Term t; 191 if (!TryConvertToAutoDiff(node.GetSubtree(i), out t)) { 192 term = null; 193 return false; 194 } 190 AutoDiff.Term t = ConvertToAutoDiff(node.GetSubtree(i)); 195 191 if (i > 0) t = -t; 196 192 terms.Add(t); 197 193 } 198 if (terms.Count == 1) term = -terms[0]; 199 else term = AutoDiff.TermBuilder.Sum(terms); 200 return true; 194 if (terms.Count == 1) return -terms[0]; 195 else return AutoDiff.TermBuilder.Sum(terms); 201 196 } 202 197 if (node.Symbol is Multiplication) { 203 198 List<AutoDiff.Term> terms = new List<Term>(); 204 199 foreach (var subTree in node.Subtrees) { 205 AutoDiff.Term t; 206 if (!TryConvertToAutoDiff(subTree, out t)) { 207 term = null; 208 return false; 209 } 210 terms.Add(t); 211 } 212 if (terms.Count == 1) term = terms[0]; 213 else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, b)); 214 return true; 215 200 terms.Add(ConvertToAutoDiff(subTree)); 201 } 202 if (terms.Count == 1) return terms[0]; 203 else return terms.Aggregate((a, b) => new AutoDiff.Product(a, b)); 216 204 } 217 205 if (node.Symbol is Division) { 218 206 List<AutoDiff.Term> terms = new List<Term>(); 219 207 foreach (var subTree in node.Subtrees) { 220 AutoDiff.Term t; 221 if (!TryConvertToAutoDiff(subTree, out t)) { 222 term = null; 223 return false; 224 } 225 terms.Add(t); 226 } 227 if (terms.Count == 1) term = 1.0 / terms[0]; 228 else term = terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b)); 229 return true; 208 terms.Add(ConvertToAutoDiff(subTree)); 209 } 210 if (terms.Count == 1) return 1.0 / terms[0]; 211 else return terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b)); 230 212 } 231 213 if (node.Symbol is Logarithm) { 232 AutoDiff.Term t; 233 if (!TryConvertToAutoDiff(node.GetSubtree(0), out t)) { 234 term = null; 235 return false; 236 } else { 237 term = AutoDiff.TermBuilder.Log(t); 238 return true; 239 } 214 return AutoDiff.TermBuilder.Log( 215 ConvertToAutoDiff(node.GetSubtree(0))); 240 216 } 241 217 if (node.Symbol is Exponential) { 242 AutoDiff.Term t; 243 if (!TryConvertToAutoDiff(node.GetSubtree(0), out t)) { 244 term = null; 245 return false; 246 } else { 247 term = AutoDiff.TermBuilder.Exp(t); 248 return true; 249 } 218 return AutoDiff.TermBuilder.Exp( 219 ConvertToAutoDiff(node.GetSubtree(0))); 250 220 } 251 221 if (node.Symbol is Square) { 252 AutoDiff.Term t; 253 if (!TryConvertToAutoDiff(node.GetSubtree(0), out t)) { 254 term = null; 255 return false; 256 } else { 257 term = AutoDiff.TermBuilder.Power(t, 2.0); 258 return true; 259 } 222 return AutoDiff.TermBuilder.Power( 223 ConvertToAutoDiff(node.GetSubtree(0)), 2.0); 260 224 } 261 225 if (node.Symbol is SquareRoot) { 262 AutoDiff.Term t; 263 if (!TryConvertToAutoDiff(node.GetSubtree(0), out t)) { 264 term = null; 265 return false; 266 } else { 267 term = AutoDiff.TermBuilder.Power(t, 0.5); 268 return true; 269 } 226 return AutoDiff.TermBuilder.Power( 227 ConvertToAutoDiff(node.GetSubtree(0)), 0.5); 270 228 } 271 229 if (node.Symbol is Sine) { 272 AutoDiff.Term t; 273 if (!TryConvertToAutoDiff(node.GetSubtree(0), out t)) { 274 term = null; 275 return false; 276 } else { 277 term = sin(t); 278 return true; 279 } 230 return sin( 231 ConvertToAutoDiff(node.GetSubtree(0))); 280 232 } 281 233 if (node.Symbol is Cosine) { 282 AutoDiff.Term t; 283 if (!TryConvertToAutoDiff(node.GetSubtree(0), out t)) { 284 term = null; 285 return false; 286 } else { 287 term = cos(t); 288 return true; 289 } 234 return cos( 235 ConvertToAutoDiff(node.GetSubtree(0))); 290 236 } 291 237 if (node.Symbol is Tangent) { 292 AutoDiff.Term t; 293 if (!TryConvertToAutoDiff(node.GetSubtree(0), out t)) { 294 term = null; 295 return false; 296 } else { 297 term = tan(t); 298 return true; 299 } 238 return tan( 239 ConvertToAutoDiff(node.GetSubtree(0))); 300 240 } 301 241 if (node.Symbol is Erf) { 302 AutoDiff.Term t; 303 if (!TryConvertToAutoDiff(node.GetSubtree(0), out t)) { 304 term = null; 305 return false; 306 } else { 307 term = erf(t); 308 return true; 309 } 242 return erf( 243 ConvertToAutoDiff(node.GetSubtree(0))); 310 244 } 311 245 if (node.Symbol is Norm) { 312 AutoDiff.Term t; 313 if (!TryConvertToAutoDiff(node.GetSubtree(0), out t)) { 314 term = null; 315 return false; 316 } else { 317 term = norm(t); 318 return true; 319 } 246 return norm( 247 ConvertToAutoDiff(node.GetSubtree(0))); 320 248 } 321 249 if (node.Symbol is StartSymbol) { … … 324 252 variables.Add(beta); 325 253 variables.Add(alpha); 326 AutoDiff.Term branchTerm; 327 if (TryConvertToAutoDiff(node.GetSubtree(0), out branchTerm)) { 328 term = branchTerm * alpha + beta; 329 return true; 330 } else { 331 term = null; 332 return false; 333 } 334 } 335 term = null; 336 return false; 254 return ConvertToAutoDiff(node.GetSubtree(0)) * alpha + beta; 255 } 256 throw new ConversionException(); 337 257 } 338 258 … … 357 277 from n in tree.Root.GetSubtree(0).IterateNodesPrefix() 358 278 where 359 !(n.Symbol is Variable) &&360 !(n.Symbol is BinaryFactorVariable) &&361 !(n.Symbol is FactorVariable) &&362 !(n.Symbol is LaggedVariable) &&363 !(n.Symbol is Constant) &&364 !(n.Symbol is Addition) &&365 !(n.Symbol is Subtraction) &&366 !(n.Symbol is Multiplication) &&367 !(n.Symbol is Division) &&368 !(n.Symbol is Logarithm) &&369 !(n.Symbol is Exponential) &&370 !(n.Symbol is SquareRoot) &&371 !(n.Symbol is Square) &&372 !(n.Symbol is Sine) &&373 !(n.Symbol is Cosine) &&374 !(n.Symbol is Tangent) &&375 !(n.Symbol is Erf) &&376 !(n.Symbol is Norm) &&377 !(n.Symbol is StartSymbol)279 !(n.Symbol is Variable) && 280 !(n.Symbol is BinaryFactorVariable) && 281 !(n.Symbol is FactorVariable) && 282 !(n.Symbol is LaggedVariable) && 283 !(n.Symbol is Constant) && 284 !(n.Symbol is Addition) && 285 !(n.Symbol is Subtraction) && 286 !(n.Symbol is Multiplication) && 287 !(n.Symbol is Division) && 288 !(n.Symbol is Logarithm) && 289 !(n.Symbol is Exponential) && 290 !(n.Symbol is SquareRoot) && 291 !(n.Symbol is Square) && 292 !(n.Symbol is Sine) && 293 !(n.Symbol is Cosine) && 294 !(n.Symbol is Tangent) && 295 !(n.Symbol is Erf) && 296 !(n.Symbol is Norm) && 297 !(n.Symbol is StartSymbol) 378 298 select n).Any(); 379 299 return !containsUnknownSymbol; 380 300 } 301 #region exception class 302 [Serializable] 303 public class ConversionException : Exception { 304 305 public ConversionException() { 306 } 307 308 public ConversionException(string message) : base(message) { 309 } 310 311 public ConversionException(string message, Exception inner) : base(message, inner) { 312 } 313 314 protected ConversionException( 315 SerializationInfo info, 316 StreamingContext context) : base(info, context) { 317 } 318 } 319 #endregion 381 320 } 382 321 }
Note: See TracChangeset
for help on using the changeset viewer.