Changeset 17726 for branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/VectorUnrollingTreeToAutoDiffTermConverter.cs
- Timestamp:
- 08/26/20 16:43:25 (4 years ago)
- File:
-
- 1 copied
Legend:
- Unmodified
- Added
- Removed
-
branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/VectorUnrollingTreeToAutoDiffTermConverter.cs
r17725 r17726 25 25 using System.Runtime.Serialization; 26 26 using AutoDiff; 27 using HeuristicLab.Common; 27 28 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; 28 29 29 30 namespace HeuristicLab.Problems.DataAnalysis.Symbolic { 30 public class TreeToAutoDiffTermConverter {31 public class VectorUnrollingTreeToAutoDiffTermConverter { 31 32 public delegate double ParametricFunction(double[] vars, double[] @params); 32 33 … … 38 39 public readonly string variableValue; // for factor vars 39 40 public readonly int lag; 40 41 public DataForVariable(string varName, string varValue, int lag) { 41 public readonly int index; // for vectors 42 43 public DataForVariable(string varName, string varValue, int lag, int index) { 42 44 this.variableName = varName; 43 45 this.variableValue = varValue; 44 46 this.lag = lag; 47 this.index = index; 45 48 } 46 49 … … 50 53 return other.variableName.Equals(this.variableName) && 51 54 other.variableValue.Equals(this.variableValue) && 52 other.lag == this.lag; 55 other.lag == this.lag && 56 other.index == this.index; 53 57 } 54 58 55 59 public override int GetHashCode() { 56 return variableName.GetHashCode() ^ variableValue.GetHashCode() ^ lag ;60 return variableName.GetHashCode() ^ variableValue.GetHashCode() ^ lag ^ index; 57 61 } 58 62 } … … 101 105 #endregion 102 106 103 public static bool TryConvertToAutoDiff(ISymbolicExpressionTree tree, bool makeVariableWeightsVariable, bool addLinearScalingTerms, 107 public static bool TryConvertToAutoDiff(ISymbolicExpressionTree tree, 108 IDictionary<ISymbolicExpressionTreeNode, SymbolicDataAnalysisExpressionTreeVectorInterpreter.EvaluationResult> evaluationTrace, 109 bool makeVariableWeightsVariable, bool addLinearScalingTerms, 104 110 out List<DataForVariable> parameters, out double[] initialConstants, 105 111 out ParametricFunction func, … … 107 113 108 114 // use a transformator object which holds the state (variable list, parameter list, ...) for recursive transformation of the tree 109 var transformator = new TreeToAutoDiffTermConverter(makeVariableWeightsVariable, addLinearScalingTerms); 110 AutoDiff.Term term; 115 var transformator = new VectorUnrollingTreeToAutoDiffTermConverter(evaluationTrace, 116 makeVariableWeightsVariable, addLinearScalingTerms); 117 Term term; 111 118 try { 112 term = transformator.ConvertToAutoDiff(tree.Root.GetSubtree(0)) ;119 term = transformator.ConvertToAutoDiff(tree.Root.GetSubtree(0)).Single(); 113 120 var parameterEntries = transformator.parameters.ToArray(); // guarantee same order for keys and values 114 121 var compiledTerm = term.Compile(transformator.variables.ToArray(), … … 128 135 } 129 136 137 private readonly IDictionary<ISymbolicExpressionTreeNode, SymbolicDataAnalysisExpressionTreeVectorInterpreter.EvaluationResult> evaluationTrace; 130 138 // state for recursive transformation of trees 131 private readonly 132 List<double> initialConstants; 139 private readonly List<double> initialConstants; 133 140 private readonly Dictionary<DataForVariable, AutoDiff.Variable> parameters; 134 141 private readonly List<AutoDiff.Variable> variables; … … 136 143 private readonly bool addLinearScalingTerms; 137 144 138 private TreeToAutoDiffTermConverter(bool makeVariableWeightsVariable, bool addLinearScalingTerms) { 145 private VectorUnrollingTreeToAutoDiffTermConverter(IDictionary<ISymbolicExpressionTreeNode, SymbolicDataAnalysisExpressionTreeVectorInterpreter.EvaluationResult> evaluationTrace, 146 bool makeVariableWeightsVariable, bool addLinearScalingTerms) { 147 this.evaluationTrace = evaluationTrace; 139 148 this.makeVariableWeightsVariable = makeVariableWeightsVariable; 140 149 this.addLinearScalingTerms = addLinearScalingTerms; … … 144 153 } 145 154 146 private AutoDiff.Term ConvertToAutoDiff(ISymbolicExpressionTreeNode node) { 147 if (node.Symbol is Constant) { 155 private IList<AutoDiff.Term> ConvertToAutoDiff(ISymbolicExpressionTreeNode node) { 156 IList<Term> BinaryOp(Func<Term, Term, Term> binaryOp, Func<Term, Term> singleElementOp, params IList<Term>[] terms) { 157 if (terms.Length == 1) return terms[0].Select(singleElementOp).ToList(); 158 return terms.Aggregate((acc, vectorizedTerm) => acc.Zip(vectorizedTerm, binaryOp).ToList()); 159 } 160 IList<Term> BinaryOp2(Func<Term, Term, Term> binaryOp, params IList<Term>[] terms) { 161 return terms.Aggregate((acc, vectorizedTerm) => acc.Zip(vectorizedTerm, binaryOp).ToList()); 162 } 163 IList<Term> UnaryOp(Func<Term, Term> unaryOp, IList<Term> term) { 164 return term.Select(unaryOp).ToList(); 165 } 166 167 var evaluationResult = evaluationTrace[node]; 168 169 if (node.Symbol is Constant) { // assume scalar constant 148 170 initialConstants.Add(((ConstantTreeNode)node).Value); 149 171 var var = new AutoDiff.Variable(); 150 172 variables.Add(var); 151 return var;173 return new Term[] { var }; 152 174 } 153 175 if (node.Symbol is Variable || node.Symbol is BinaryFactorVariable) { … … 156 178 // factor variable values are only 0 or 1 and set in x accordingly 157 179 var varValue = factorVarNode != null ? factorVarNode.VariableValue : string.Empty; 158 var par = FindOrCreateParameter(parameters, varNode.VariableName, varValue); 180 var pars = evaluationResult.IsVector 181 ? Enumerable.Range(0, evaluationResult.Vector.Count).Select(i => FindOrCreateParameter(parameters, varNode.VariableName, varValue, index: i)) 182 : FindOrCreateParameter(parameters, varNode.VariableName, varValue).ToEnumerable(); 159 183 160 184 if (makeVariableWeightsVariable) { … … 162 186 var w = new AutoDiff.Variable(); 163 187 variables.Add(w); 164 return AutoDiff.TermBuilder.Product(w, par);188 return pars.Select(par => AutoDiff.TermBuilder.Product(w, par)).ToList(); 165 189 } else { 166 return varNode.Weight * par;190 return pars.Select(par => varNode.Weight * par).ToList(); 167 191 } 168 192 } … … 179 203 products.Add(AutoDiff.TermBuilder.Product(wVar, par)); 180 204 } 181 return AutoDiff.TermBuilder.Sum(products);182 } 183 if (node.Symbol is LaggedVariable) {184 var varNode = node as LaggedVariableTreeNode;185 var par = FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag);186 187 if (makeVariableWeightsVariable) {188 initialConstants.Add(varNode.Weight);189 var w = new AutoDiff.Variable();190 variables.Add(w);191 return AutoDiff.TermBuilder.Product(w, par);192 } else {193 return varNode.Weight * par;194 }195 }205 return new[] { AutoDiff.TermBuilder.Sum(products) }; 206 } 207 //if (node.Symbol is LaggedVariable) { 208 // var varNode = node as LaggedVariableTreeNode; 209 // var par = FindOrCreateParameter(parameters, varNode.VariableName, string.Empty, varNode.Lag); 210 211 // if (makeVariableWeightsVariable) { 212 // initialConstants.Add(varNode.Weight); 213 // var w = new AutoDiff.Variable(); 214 // variables.Add(w); 215 // return AutoDiff.TermBuilder.Product(w, par); 216 // } else { 217 // return varNode.Weight * par; 218 // } 219 //} 196 220 if (node.Symbol is Addition) { 197 List<AutoDiff.Term> terms = new List<Term>(); 198 foreach (var subTree in node.Subtrees) { 199 terms.Add(ConvertToAutoDiff(subTree)); 200 } 201 return AutoDiff.TermBuilder.Sum(terms); 221 var terms = node.Subtrees.Select(ConvertToAutoDiff).ToArray(); 222 return BinaryOp((a, b) => a + b, a => a, terms); 202 223 } 203 224 if (node.Symbol is Subtraction) { 204 List<AutoDiff.Term> terms = new List<Term>(); 205 for (int i = 0; i < node.SubtreeCount; i++) { 206 AutoDiff.Term t = ConvertToAutoDiff(node.GetSubtree(i)); 207 if (i > 0) t = -t; 208 terms.Add(t); 209 } 210 if (terms.Count == 1) return -terms[0]; 211 else return AutoDiff.TermBuilder.Sum(terms); 225 var terms = node.Subtrees.Select(ConvertToAutoDiff).ToArray(); 226 return BinaryOp((a, b) => a - b, a => -a, terms); 212 227 } 213 228 if (node.Symbol is Multiplication) { 214 List<AutoDiff.Term> terms = new List<Term>(); 215 foreach (var subTree in node.Subtrees) { 216 terms.Add(ConvertToAutoDiff(subTree)); 217 } 218 if (terms.Count == 1) return terms[0]; 219 else return terms.Aggregate((a, b) => new AutoDiff.Product(a, b)); 229 var terms = node.Subtrees.Select(ConvertToAutoDiff).ToArray(); 230 return BinaryOp((a, b) => a * b, a => a, terms); 220 231 } 221 232 if (node.Symbol is Division) { 222 List<AutoDiff.Term> terms = new List<Term>(); 223 foreach (var subTree in node.Subtrees) { 224 terms.Add(ConvertToAutoDiff(subTree)); 225 } 226 if (terms.Count == 1) return 1.0 / terms[0]; 227 else return terms.Aggregate((a, b) => new AutoDiff.Product(a, 1.0 / b)); 233 var terms = node.Subtrees.Select(ConvertToAutoDiff).ToArray(); 234 return BinaryOp((a, b) => a / b, a => 1.0 / a, terms); 228 235 } 229 236 if (node.Symbol is Absolute) { 230 var x1 = ConvertToAutoDiff(node.GetSubtree(0));231 return abs(x1);232 } 233 if (node.Symbol is AnalyticQuotient) {234 var x1 = ConvertToAutoDiff(node.GetSubtree(0));235 var x2 = ConvertToAutoDiff(node.GetSubtree(1));236 return x1 / (TermBuilder.Power(1 + x2 * x2, 0.5));237 }237 var term = node.Subtrees.Select(ConvertToAutoDiff).Single(); 238 return UnaryOp(abs, term); 239 } 240 //if (node.Symbol is AnalyticQuotient) { 241 // var x1 = ConvertToAutoDiff(node.GetSubtree(0)); 242 // var x2 = ConvertToAutoDiff(node.GetSubtree(1)); 243 // return x1 / (TermBuilder.Power(1 + x2 * x2, 0.5)); 244 //} 238 245 if (node.Symbol is Logarithm) { 239 return AutoDiff.TermBuilder.Log(240 ConvertToAutoDiff(node.GetSubtree(0)));246 var term = node.Subtrees.Select(ConvertToAutoDiff).Single(); 247 return UnaryOp(TermBuilder.Log, term); 241 248 } 242 249 if (node.Symbol is Exponential) { 243 return AutoDiff.TermBuilder.Exp(244 ConvertToAutoDiff(node.GetSubtree(0)));250 var term = node.Subtrees.Select(ConvertToAutoDiff).Single(); 251 return UnaryOp(TermBuilder.Exp, term); 245 252 } 246 253 if (node.Symbol is Square) { 247 return AutoDiff.TermBuilder.Power(248 ConvertToAutoDiff(node.GetSubtree(0)), 2.0);254 var term = node.Subtrees.Select(ConvertToAutoDiff).Single(); 255 return UnaryOp(t => TermBuilder.Power(t, 2.0), term); 249 256 } 250 257 if (node.Symbol is SquareRoot) { 251 return AutoDiff.TermBuilder.Power(252 ConvertToAutoDiff(node.GetSubtree(0)), 0.5);258 var term = node.Subtrees.Select(ConvertToAutoDiff).Single(); 259 return UnaryOp(t => TermBuilder.Power(t, 0.5), term); 253 260 } 254 261 if (node.Symbol is Cube) { 255 return AutoDiff.TermBuilder.Power(256 ConvertToAutoDiff(node.GetSubtree(0)), 3.0);262 var term = node.Subtrees.Select(ConvertToAutoDiff).Single(); 263 return UnaryOp(t => TermBuilder.Power(t, 3.0), term); 257 264 } 258 265 if (node.Symbol is CubeRoot) { 259 return cbrt(ConvertToAutoDiff(node.GetSubtree(0))); 266 var term = node.Subtrees.Select(ConvertToAutoDiff).Single(); 267 return UnaryOp(cbrt, term); 260 268 } 261 269 if (node.Symbol is Sine) { 262 return sin(263 ConvertToAutoDiff(node.GetSubtree(0)));270 var term = node.Subtrees.Select(ConvertToAutoDiff).Single(); 271 return UnaryOp(sin, term); 264 272 } 265 273 if (node.Symbol is Cosine) { 266 return cos(267 ConvertToAutoDiff(node.GetSubtree(0)));274 var term = node.Subtrees.Select(ConvertToAutoDiff).Single(); 275 return UnaryOp(cos, term); 268 276 } 269 277 if (node.Symbol is Tangent) { 270 return tan(271 ConvertToAutoDiff(node.GetSubtree(0)));278 var term = node.Subtrees.Select(ConvertToAutoDiff).Single(); 279 return UnaryOp(tan, term); 272 280 } 273 281 if (node.Symbol is HyperbolicTangent) { 274 return tanh(275 ConvertToAutoDiff(node.GetSubtree(0)));282 var term = node.Subtrees.Select(ConvertToAutoDiff).Single(); 283 return UnaryOp(tanh, term); 276 284 } 277 285 if (node.Symbol is Erf) { 278 return erf(279 ConvertToAutoDiff(node.GetSubtree(0)));286 var term = node.Subtrees.Select(ConvertToAutoDiff).Single(); 287 return UnaryOp(erf, term); 280 288 } 281 289 if (node.Symbol is Norm) { 282 return norm(283 ConvertToAutoDiff(node.GetSubtree(0)));290 var term = node.Subtrees.Select(ConvertToAutoDiff).Single(); 291 return UnaryOp(norm, term); 284 292 } 285 293 if (node.Symbol is StartSymbol) { … … 291 299 variables.Add(alpha); 292 300 var t = ConvertToAutoDiff(node.GetSubtree(0)); 293 return t * alpha + beta; 301 if (t.Count > 1) throw new InvalidOperationException("Tree Result must be scalar value"); 302 return new[] { t[0] * alpha + beta }; 294 303 } else return ConvertToAutoDiff(node.GetSubtree(0)); 295 304 } 305 if (node.Symbol is Sum) { 306 var term = node.Subtrees.Select(ConvertToAutoDiff).Single(); 307 return new[] { TermBuilder.Sum(term) }; 308 } 309 if (node.Symbol is Mean) { 310 var term = node.Subtrees.Select(ConvertToAutoDiff).Single(); 311 return new[] { TermBuilder.Sum(term) / term.Count }; 312 } 313 if (node.Symbol is StandardDeviation) { 314 var term = node.Subtrees.Select(ConvertToAutoDiff).Single(); 315 var mean = TermBuilder.Sum(term) / term.Count; 316 var ssd = TermBuilder.Sum(term.Select(t => TermBuilder.Power(t - mean, 2.0))); 317 return new[] { TermBuilder.Power(ssd / term.Count, 0.5) }; 318 } 319 if (node.Symbol is Length) { 320 var term = node.Subtrees.Select(ConvertToAutoDiff).Single(); 321 return new[] { TermBuilder.Constant(term.Count) }; 322 } 323 //if (node.Symbol is Min) { 324 //} 325 //if (node.Symbol is Max) { 326 //} 327 if (node.Symbol is Variance) { 328 var term = node.Subtrees.Select(ConvertToAutoDiff).Single(); 329 var mean = TermBuilder.Sum(term) / term.Count; 330 var ssd = TermBuilder.Sum(term.Select(t => TermBuilder.Power(t - mean, 2.0))); 331 return new[] { ssd / term.Count }; 332 } 333 //if (node.Symbol is Skewness) { 334 //} 335 //if (node.Symbol is Kurtosis) { 336 //} 337 //if (node.Symbol is EuclideanDistance) { 338 //} 339 //if (node.Symbol is Covariance) { 340 //} 341 342 296 343 throw new ConversionException(); 297 344 } … … 301 348 // each binary indicator is only necessary once. So we only create a parameter if this combination is not yet available 302 349 private static Term FindOrCreateParameter(Dictionary<DataForVariable, AutoDiff.Variable> parameters, 303 string varName, string varValue = "", int lag = 0 ) {304 var data = new DataForVariable(varName, varValue, lag );350 string varName, string varValue = "", int lag = 0, int index = -1) { 351 var data = new DataForVariable(varName, varValue, lag, index); 305 352 306 353 AutoDiff.Variable par = null; … … 319 366 !(n.Symbol is Variable) && 320 367 !(n.Symbol is BinaryFactorVariable) && 321 !(n.Symbol is FactorVariable) &&322 !(n.Symbol is LaggedVariable) &&368 //!(n.Symbol is FactorVariable) && 369 //!(n.Symbol is LaggedVariable) && 323 370 !(n.Symbol is Constant) && 324 371 !(n.Symbol is Addition) && … … 338 385 !(n.Symbol is StartSymbol) && 339 386 !(n.Symbol is Absolute) && 340 !(n.Symbol is AnalyticQuotient) &&387 //!(n.Symbol is AnalyticQuotient) && 341 388 !(n.Symbol is Cube) && 342 !(n.Symbol is CubeRoot) 389 !(n.Symbol is CubeRoot) && 390 !(n.Symbol is Sum) && 391 !(n.Symbol is Mean) && 392 !(n.Symbol is StandardDeviation) && 393 !(n.Symbol is Length) && 394 //!(n.Symbol is Min) && 395 //!(n.Symbol is Max) && 396 !(n.Symbol is Variance) 397 //!(n.Symbol is Skewness) && 398 //!(n.Symbol is Kurtosis) && 399 //!(n.Symbol is EuclideanDistance) && 400 //!(n.Symbol is Covariance) 343 401 select n).Any(); 344 402 return !containsUnknownSymbol;
Note: See TracChangeset
for help on using the changeset viewer.