Changeset 16692 for branches/2521_ProblemRefactoring/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/SymbolicRegressionConstantOptimizationEvaluator.cs
- Timestamp:
- 03/18/19 17:24:30 (6 years ago)
- Location:
- branches/2521_ProblemRefactoring
- Files:
-
- 4 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/2521_ProblemRefactoring
- Property svn:ignore
-
old new 24 24 protoc.exe 25 25 obj 26 .vs
-
- Property svn:mergeinfo changed
- Property svn:ignore
-
branches/2521_ProblemRefactoring/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression
- Property svn:mergeinfo changed
-
branches/2521_ProblemRefactoring/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4
- Property svn:mergeinfo changed
-
branches/2521_ProblemRefactoring/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SingleObjective/Evaluators/SymbolicRegressionConstantOptimizationEvaluator.cs
r13300 r16692 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-201 5Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) 2002-2018 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * 5 5 * This file is part of HeuristicLab. … … 23 23 using System.Collections.Generic; 24 24 using System.Linq; 25 using AutoDiff;26 25 using HeuristicLab.Common; 27 26 using HeuristicLab.Core; 28 27 using HeuristicLab.Data; 29 28 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; 29 using HeuristicLab.Optimization; 30 30 using HeuristicLab.Parameters; 31 31 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; … … 40 40 private const string ConstantOptimizationRowsPercentageParameterName = "ConstantOptimizationRowsPercentage"; 41 41 private const string UpdateConstantsInTreeParameterName = "UpdateConstantsInSymbolicExpressionTree"; 42 private const string UpdateVariableWeightsParameterName = "Update Variable Weights"; 43 44 private const string FunctionEvaluationsResultParameterName = "Constants Optimization Function Evaluations"; 45 private const string GradientEvaluationsResultParameterName = "Constants Optimization Gradient Evaluations"; 46 private const string CountEvaluationsParameterName = "Count Function and Gradient Evaluations"; 42 47 43 48 public IFixedValueParameter<IntValue> ConstantOptimizationIterationsParameter { … … 56 61 get { return (IFixedValueParameter<BoolValue>)Parameters[UpdateConstantsInTreeParameterName]; } 57 62 } 63 public IFixedValueParameter<BoolValue> UpdateVariableWeightsParameter { 64 get { return (IFixedValueParameter<BoolValue>)Parameters[UpdateVariableWeightsParameterName]; } 65 } 66 67 public IResultParameter<IntValue> FunctionEvaluationsResultParameter { 68 get { return (IResultParameter<IntValue>)Parameters[FunctionEvaluationsResultParameterName]; } 69 } 70 public IResultParameter<IntValue> GradientEvaluationsResultParameter { 71 get { return (IResultParameter<IntValue>)Parameters[GradientEvaluationsResultParameterName]; } 72 } 73 public IFixedValueParameter<BoolValue> CountEvaluationsParameter { 74 get { return (IFixedValueParameter<BoolValue>)Parameters[CountEvaluationsParameterName]; } 75 } 76 58 77 59 78 public IntValue ConstantOptimizationIterations { … … 72 91 get { return UpdateConstantsInTreeParameter.Value.Value; } 73 92 set { UpdateConstantsInTreeParameter.Value.Value = value; } 93 } 94 95 public bool UpdateVariableWeights { 96 get { return UpdateVariableWeightsParameter.Value.Value; } 97 set { UpdateVariableWeightsParameter.Value.Value = value; } 98 } 99 100 public bool CountEvaluations { 101 get { return CountEvaluationsParameter.Value.Value; } 102 set { CountEvaluationsParameter.Value.Value = value; } 74 103 } 75 104 … … 86 115 : base() { 87 116 Parameters.Add(new FixedValueParameter<IntValue>(ConstantOptimizationIterationsParameterName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree (0 indicates other or default stopping criterion).", new IntValue(10), true)); 88 Parameters.Add(new FixedValueParameter<DoubleValue>(ConstantOptimizationImprovementParameterName, "Determines the relative improvement which must be achieved in the constant optimization to continue with it (0 indicates other or default stopping criterion).", new DoubleValue(0), true) );117 Parameters.Add(new FixedValueParameter<DoubleValue>(ConstantOptimizationImprovementParameterName, "Determines the relative improvement which must be achieved in the constant optimization to continue with it (0 indicates other or default stopping criterion).", new DoubleValue(0), true) { Hidden = true }); 89 118 Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationProbabilityParameterName, "Determines the probability that the constants are optimized", new PercentValue(1), true)); 90 119 Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationRowsPercentageParameterName, "Determines the percentage of the rows which should be used for constant optimization", new PercentValue(1), true)); 91 Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true))); 120 Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)) { Hidden = true }); 121 Parameters.Add(new FixedValueParameter<BoolValue>(UpdateVariableWeightsParameterName, "Determines if the variable weights in the tree should be optimized.", new BoolValue(true)) { Hidden = true }); 122 123 Parameters.Add(new FixedValueParameter<BoolValue>(CountEvaluationsParameterName, "Determines if function and gradient evaluation should be counted.", new BoolValue(false))); 124 Parameters.Add(new ResultParameter<IntValue>(FunctionEvaluationsResultParameterName, "The number of function evaluations performed by the constants optimization evaluator", "Results", new IntValue())); 125 Parameters.Add(new ResultParameter<IntValue>(GradientEvaluationsResultParameterName, "The number of gradient evaluations performed by the constants optimization evaluator", "Results", new IntValue())); 92 126 } 93 127 … … 100 134 if (!Parameters.ContainsKey(UpdateConstantsInTreeParameterName)) 101 135 Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true))); 102 } 103 136 if (!Parameters.ContainsKey(UpdateVariableWeightsParameterName)) 137 Parameters.Add(new FixedValueParameter<BoolValue>(UpdateVariableWeightsParameterName, "Determines if the variable weights in the tree should be optimized.", new BoolValue(true))); 138 139 if (!Parameters.ContainsKey(CountEvaluationsParameterName)) 140 Parameters.Add(new FixedValueParameter<BoolValue>(CountEvaluationsParameterName, "Determines if function and gradient evaluation should be counted.", new BoolValue(false))); 141 142 if (!Parameters.ContainsKey(FunctionEvaluationsResultParameterName)) 143 Parameters.Add(new ResultParameter<IntValue>(FunctionEvaluationsResultParameterName, "The number of function evaluations performed by the constants optimization evaluator", "Results", new IntValue())); 144 if (!Parameters.ContainsKey(GradientEvaluationsResultParameterName)) 145 Parameters.Add(new ResultParameter<IntValue>(GradientEvaluationsResultParameterName, "The number of gradient evaluations performed by the constants optimization evaluator", "Results", new IntValue())); 146 } 147 148 private static readonly object locker = new object(); 104 149 public override IOperation InstrumentedApply() { 105 150 var solution = SymbolicExpressionTreeParameter.ActualValue; … … 107 152 if (RandomParameter.ActualValue.NextDouble() < ConstantOptimizationProbability.Value) { 108 153 IEnumerable<int> constantOptimizationRows = GenerateRowsToEvaluate(ConstantOptimizationRowsPercentage.Value); 154 var counter = new EvaluationsCounter(); 109 155 quality = OptimizeConstants(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, ProblemDataParameter.ActualValue, 110 constantOptimizationRows, ApplyLinearScalingParameter.ActualValue.Value, ConstantOptimizationIterations.Value, 111 EstimationLimitsParameter.ActualValue.Upper, EstimationLimitsParameter.ActualValue.Lower, UpdateConstantsInTree); 156 constantOptimizationRows, ApplyLinearScalingParameter.ActualValue.Value, ConstantOptimizationIterations.Value, updateVariableWeights: UpdateVariableWeights, lowerEstimationLimit: EstimationLimitsParameter.ActualValue.Lower, upperEstimationLimit: EstimationLimitsParameter.ActualValue.Upper, updateConstantsInTree: UpdateConstantsInTree, counter: counter); 112 157 113 158 if (ConstantOptimizationRowsPercentage.Value != RelativeNumberOfEvaluatedSamplesParameter.ActualValue.Value) { … … 115 160 quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, evaluationRows, ApplyLinearScalingParameter.ActualValue.Value); 116 161 } 162 163 if (CountEvaluations) { 164 lock (locker) { 165 FunctionEvaluationsResultParameter.ActualValue.Value += counter.FunctionEvaluations; 166 GradientEvaluationsResultParameter.ActualValue.Value += counter.GradientEvaluations; 167 } 168 } 169 117 170 } else { 118 171 var evaluationRows = GenerateRowsToEvaluate(); … … 128 181 EstimationLimitsParameter.ExecutionContext = context; 129 182 ApplyLinearScalingParameter.ExecutionContext = context; 183 FunctionEvaluationsResultParameter.ExecutionContext = context; 184 GradientEvaluationsResultParameter.ExecutionContext = context; 130 185 131 186 // Pearson R² evaluator is used on purpose instead of the const-opt evaluator, … … 137 192 EstimationLimitsParameter.ExecutionContext = null; 138 193 ApplyLinearScalingParameter.ExecutionContext = null; 194 FunctionEvaluationsResultParameter.ExecutionContext = null; 195 GradientEvaluationsResultParameter.ExecutionContext = null; 139 196 140 197 return r2; 141 198 } 142 199 143 #region derivations of functions 144 // create function factory for arctangent 145 private readonly Func<Term, UnaryFunc> arctan = UnaryFunc.Factory( 146 eval: Math.Atan, 147 diff: x => 1 / (1 + x * x)); 148 private static readonly Func<Term, UnaryFunc> sin = UnaryFunc.Factory( 149 eval: Math.Sin, 150 diff: Math.Cos); 151 private static readonly Func<Term, UnaryFunc> cos = UnaryFunc.Factory( 152 eval: Math.Cos, 153 diff: x => -Math.Sin(x)); 154 private static readonly Func<Term, UnaryFunc> tan = UnaryFunc.Factory( 155 eval: Math.Tan, 156 diff: x => 1 + Math.Tan(x) * Math.Tan(x)); 157 private static readonly Func<Term, UnaryFunc> erf = UnaryFunc.Factory( 158 eval: alglib.errorfunction, 159 diff: x => 2.0 * Math.Exp(-(x * x)) / Math.Sqrt(Math.PI)); 160 private static readonly Func<Term, UnaryFunc> norm = UnaryFunc.Factory( 161 eval: alglib.normaldistribution, 162 diff: x => -(Math.Exp(-(x * x)) * Math.Sqrt(Math.Exp(x * x)) * x) / Math.Sqrt(2 * Math.PI)); 163 #endregion 164 165 166 // TODO: swap positions of lowerEstimationLimit and upperEstimationLimit parameters 167 public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, ISymbolicExpressionTree tree, IRegressionProblemData problemData, 168 IEnumerable<int> rows, bool applyLinearScaling, int maxIterations, double upperEstimationLimit = double.MaxValue, double lowerEstimationLimit = double.MinValue, bool updateConstantsInTree = true) { 169 170 List<AutoDiff.Variable> variables = new List<AutoDiff.Variable>(); 171 List<AutoDiff.Variable> parameters = new List<AutoDiff.Variable>(); 172 List<string> variableNames = new List<string>(); 173 174 AutoDiff.Term func; 175 if (!TryTransformToAutoDiff(tree.Root.GetSubtree(0), variables, parameters, variableNames, out func)) 200 public class EvaluationsCounter { 201 public int FunctionEvaluations = 0; 202 public int GradientEvaluations = 0; 203 } 204 205 public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, 206 ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows, bool applyLinearScaling, 207 int maxIterations, bool updateVariableWeights = true, 208 double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue, 209 bool updateConstantsInTree = true, Action<double[], double, object> iterationCallback = null, EvaluationsCounter counter = null) { 210 211 // numeric constants in the tree become variables for constant opt 212 // variables in the tree become parameters (fixed values) for constant opt 213 // for each parameter (variable in the original tree) we store the 214 // variable name, variable value (for factor vars) and lag as a DataForVariable object. 215 // A dictionary is used to find parameters 216 double[] initialConstants; 217 var parameters = new List<TreeToAutoDiffTermConverter.DataForVariable>(); 218 219 TreeToAutoDiffTermConverter.ParametricFunction func; 220 TreeToAutoDiffTermConverter.ParametricFunctionGradient func_grad; 221 if (!TreeToAutoDiffTermConverter.TryConvertToAutoDiff(tree, updateVariableWeights, applyLinearScaling, out parameters, out initialConstants, out func, out func_grad)) 176 222 throw new NotSupportedException("Could not optimize constants of symbolic expression tree due to not supported symbols used in the tree."); 177 if (variableNames.Count == 0) return 0.0; 178 179 AutoDiff.IParametricCompiledTerm compiledFunc = AutoDiff.TermUtils.Compile(func, variables.ToArray(), parameters.ToArray()); 180 181 List<SymbolicExpressionTreeTerminalNode> terminalNodes = tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>().ToList(); 182 double[] c = new double[variables.Count]; 183 184 { 223 if (parameters.Count == 0) return 0.0; // gkronber: constant expressions always have a R² of 0.0 224 var parameterEntries = parameters.ToArray(); // order of entries must be the same for x 225 226 //extract inital constants 227 double[] c; 228 if (applyLinearScaling) { 229 c = new double[initialConstants.Length + 2]; 185 230 c[0] = 0.0; 186 231 c[1] = 1.0; 187 //extract inital constants 188 int i = 2; 189 foreach (var node in terminalNodes) { 190 ConstantTreeNode constantTreeNode = node as ConstantTreeNode; 191 VariableTreeNode variableTreeNode = node as VariableTreeNode; 192 if (constantTreeNode != null) 193 c[i++] = constantTreeNode.Value; 194 else if (variableTreeNode != null) 195 c[i++] = variableTreeNode.Weight; 196 } 197 } 198 double[] originalConstants = (double[])c.Clone(); 232 Array.Copy(initialConstants, 0, c, 2, initialConstants.Length); 233 } else { 234 c = (double[])initialConstants.Clone(); 235 } 236 199 237 double originalQuality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling); 238 239 if (counter == null) counter = new EvaluationsCounter(); 240 var rowEvaluationsCounter = new EvaluationsCounter(); 200 241 201 242 alglib.lsfitstate state; 202 243 alglib.lsfitreport rep; 203 int info;244 int retVal; 204 245 205 246 IDataset ds = problemData.Dataset; 206 double[,] x = new double[rows.Count(), variableNames.Count];247 double[,] x = new double[rows.Count(), parameters.Count]; 207 248 int row = 0; 208 249 foreach (var r in rows) { 209 for (int col = 0; col < variableNames.Count; col++) { 210 x[row, col] = ds.GetDoubleValue(variableNames[col], r); 250 int col = 0; 251 foreach (var info in parameterEntries) { 252 if (ds.VariableHasType<double>(info.variableName)) { 253 x[row, col] = ds.GetDoubleValue(info.variableName, r + info.lag); 254 } else if (ds.VariableHasType<string>(info.variableName)) { 255 x[row, col] = ds.GetStringValue(info.variableName, r) == info.variableValue ? 1 : 0; 256 } else throw new InvalidProgramException("found a variable of unknown type"); 257 col++; 211 258 } 212 259 row++; … … 217 264 int k = c.Length; 218 265 219 alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(compiledFunc); 220 alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(compiledFunc); 266 alglib.ndimensional_pfunc function_cx_1_func = CreatePFunc(func); 267 alglib.ndimensional_pgrad function_cx_1_grad = CreatePGrad(func_grad); 268 alglib.ndimensional_rep xrep = (p, f, obj) => iterationCallback(p, f, obj); 221 269 222 270 try { 223 271 alglib.lsfitcreatefg(x, y, c, n, m, k, false, out state); 224 272 alglib.lsfitsetcond(state, 0.0, 0.0, maxIterations); 273 alglib.lsfitsetxrep(state, iterationCallback != null); 225 274 //alglib.lsfitsetgradientcheck(state, 0.001); 226 alglib.lsfitfit(state, function_cx_1_func, function_cx_1_grad, null, null); 227 alglib.lsfitresults(state, out info, out c, out rep); 228 } 229 catch (ArithmeticException) { 275 alglib.lsfitfit(state, function_cx_1_func, function_cx_1_grad, xrep, rowEvaluationsCounter); 276 alglib.lsfitresults(state, out retVal, out c, out rep); 277 } catch (ArithmeticException) { 230 278 return originalQuality; 231 } 232 catch (alglib.alglibexception) { 279 } catch (alglib.alglibexception) { 233 280 return originalQuality; 234 281 } 235 282 236 //info == -7 => constant optimization failed due to wrong gradient 237 if (info != -7) UpdateConstants(tree, c.Skip(2).ToArray()); 283 counter.FunctionEvaluations += rowEvaluationsCounter.FunctionEvaluations / n; 284 counter.GradientEvaluations += rowEvaluationsCounter.GradientEvaluations / n; 285 286 //retVal == -7 => constant optimization failed due to wrong gradient 287 if (retVal != -7) { 288 if (applyLinearScaling) { 289 var tmp = new double[c.Length - 2]; 290 Array.Copy(c, 2, tmp, 0, tmp.Length); 291 UpdateConstants(tree, tmp, updateVariableWeights); 292 } else UpdateConstants(tree, c, updateVariableWeights); 293 } 238 294 var quality = SymbolicRegressionSingleObjectivePearsonRSquaredEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling); 239 295 240 if (!updateConstantsInTree) UpdateConstants(tree, originalConstants.Skip(2).ToArray()); 296 if (!updateConstantsInTree) UpdateConstants(tree, initialConstants, updateVariableWeights); 297 241 298 if (originalQuality - quality > 0.001 || double.IsNaN(quality)) { 242 UpdateConstants(tree, originalConstants.Skip(2).ToArray());299 UpdateConstants(tree, initialConstants, updateVariableWeights); 243 300 return originalQuality; 244 301 } … … 246 303 } 247 304 248 private static void UpdateConstants(ISymbolicExpressionTree tree, double[] constants ) {305 private static void UpdateConstants(ISymbolicExpressionTree tree, double[] constants, bool updateVariableWeights) { 249 306 int i = 0; 250 307 foreach (var node in tree.Root.IterateNodesPrefix().OfType<SymbolicExpressionTreeTerminalNode>()) { 251 308 ConstantTreeNode constantTreeNode = node as ConstantTreeNode; 252 VariableTreeNode variableTreeNode = node as VariableTreeNode; 309 VariableTreeNodeBase variableTreeNodeBase = node as VariableTreeNodeBase; 310 FactorVariableTreeNode factorVarTreeNode = node as FactorVariableTreeNode; 253 311 if (constantTreeNode != null) 254 312 constantTreeNode.Value = constants[i++]; 255 else if (variableTreeNode != null) 256 variableTreeNode.Weight = constants[i++]; 257 } 258 } 259 260 private static alglib.ndimensional_pfunc CreatePFunc(AutoDiff.IParametricCompiledTerm compiledFunc) { 261 return (double[] c, double[] x, ref double func, object o) => { 262 func = compiledFunc.Evaluate(c, x); 313 else if (updateVariableWeights && variableTreeNodeBase != null) 314 variableTreeNodeBase.Weight = constants[i++]; 315 else if (factorVarTreeNode != null) { 316 for (int j = 0; j < factorVarTreeNode.Weights.Length; j++) 317 factorVarTreeNode.Weights[j] = constants[i++]; 318 } 319 } 320 } 321 322 private static alglib.ndimensional_pfunc CreatePFunc(TreeToAutoDiffTermConverter.ParametricFunction func) { 323 return (double[] c, double[] x, ref double fx, object o) => { 324 fx = func(c, x); 325 var counter = (EvaluationsCounter)o; 326 counter.FunctionEvaluations++; 263 327 }; 264 328 } 265 329 266 private static alglib.ndimensional_pgrad CreatePGrad(AutoDiff.IParametricCompiledTerm compiledFunc) { 267 return (double[] c, double[] x, ref double func, double[] grad, object o) => { 268 var tupel = compiledFunc.Differentiate(c, x); 269 func = tupel.Item2; 270 Array.Copy(tupel.Item1, grad, grad.Length); 330 private static alglib.ndimensional_pgrad CreatePGrad(TreeToAutoDiffTermConverter.ParametricFunctionGradient func_grad) { 331 return (double[] c, double[] x, ref double fx, double[] grad, object o) => { 332 var tuple = func_grad(c, x); 333 fx = tuple.Item2; 334 Array.Copy(tuple.Item1, grad, grad.Length); 335 var counter = (EvaluationsCounter)o; 336 counter.GradientEvaluations++; 271 337 }; 272 338 } 273 274 private static bool TryTransformToAutoDiff(ISymbolicExpressionTreeNode node, List<AutoDiff.Variable> variables, List<AutoDiff.Variable> parameters, List<string> variableNames, out AutoDiff.Term term) {275 if (node.Symbol is Constant) {276 var var = new AutoDiff.Variable();277 variables.Add(var);278 term = var;279 return true;280 }281 if (node.Symbol is Variable) {282 var varNode = node as VariableTreeNode;283 var par = new AutoDiff.Variable();284 parameters.Add(par);285 variableNames.Add(varNode.VariableName);286 var w = new AutoDiff.Variable();287 variables.Add(w);288 term = AutoDiff.TermBuilder.Product(w, par);289 return true;290 }291 if (node.Symbol is Addition) {292 List<AutoDiff.Term> terms = new List<Term>();293 foreach (var subTree in node.Subtrees) {294 AutoDiff.Term t;295 if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out t)) {296 term = null;297 return false;298 }299 terms.Add(t);300 }301 term = AutoDiff.TermBuilder.Sum(terms);302 return true;303 }304 if (node.Symbol is Subtraction) {305 List<AutoDiff.Term> terms = new List<Term>();306 for (int i = 0; i < node.SubtreeCount; i++) {307 AutoDiff.Term t;308 if (!TryTransformToAutoDiff(node.GetSubtree(i), variables, parameters, variableNames, out t)) {309 term = null;310 return false;311 }312 if (i > 0) t = -t;313 terms.Add(t);314 }315 term = AutoDiff.TermBuilder.Sum(terms);316 return true;317 }318 if (node.Symbol is Multiplication) {319 AutoDiff.Term a, b;320 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out a) ||321 !TryTransformToAutoDiff(node.GetSubtree(1), variables, parameters, variableNames, out b)) {322 term = null;323 return false;324 } else {325 List<AutoDiff.Term> factors = new List<Term>();326 foreach (var subTree in node.Subtrees.Skip(2)) {327 AutoDiff.Term f;328 if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out f)) {329 term = null;330 return false;331 }332 factors.Add(f);333 }334 term = AutoDiff.TermBuilder.Product(a, b, factors.ToArray());335 return true;336 }337 }338 if (node.Symbol is Division) {339 // only works for at least two subtrees340 AutoDiff.Term a, b;341 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out a) ||342 !TryTransformToAutoDiff(node.GetSubtree(1), variables, parameters, variableNames, out b)) {343 term = null;344 return false;345 } else {346 List<AutoDiff.Term> factors = new List<Term>();347 foreach (var subTree in node.Subtrees.Skip(2)) {348 AutoDiff.Term f;349 if (!TryTransformToAutoDiff(subTree, variables, parameters, variableNames, out f)) {350 term = null;351 return false;352 }353 factors.Add(1.0 / f);354 }355 term = AutoDiff.TermBuilder.Product(a, 1.0 / b, factors.ToArray());356 return true;357 }358 }359 if (node.Symbol is Logarithm) {360 AutoDiff.Term t;361 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {362 term = null;363 return false;364 } else {365 term = AutoDiff.TermBuilder.Log(t);366 return true;367 }368 }369 if (node.Symbol is Exponential) {370 AutoDiff.Term t;371 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {372 term = null;373 return false;374 } else {375 term = AutoDiff.TermBuilder.Exp(t);376 return true;377 }378 }379 if (node.Symbol is Square) {380 AutoDiff.Term t;381 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {382 term = null;383 return false;384 } else {385 term = AutoDiff.TermBuilder.Power(t, 2.0);386 return true;387 }388 } if (node.Symbol is SquareRoot) {389 AutoDiff.Term t;390 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {391 term = null;392 return false;393 } else {394 term = AutoDiff.TermBuilder.Power(t, 0.5);395 return true;396 }397 } if (node.Symbol is Sine) {398 AutoDiff.Term t;399 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {400 term = null;401 return false;402 } else {403 term = sin(t);404 return true;405 }406 } if (node.Symbol is Cosine) {407 AutoDiff.Term t;408 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {409 term = null;410 return false;411 } else {412 term = cos(t);413 return true;414 }415 } if (node.Symbol is Tangent) {416 AutoDiff.Term t;417 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {418 term = null;419 return false;420 } else {421 term = tan(t);422 return true;423 }424 } if (node.Symbol is Erf) {425 AutoDiff.Term t;426 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {427 term = null;428 return false;429 } else {430 term = erf(t);431 return true;432 }433 } if (node.Symbol is Norm) {434 AutoDiff.Term t;435 if (!TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out t)) {436 term = null;437 return false;438 } else {439 term = norm(t);440 return true;441 }442 }443 if (node.Symbol is StartSymbol) {444 var alpha = new AutoDiff.Variable();445 var beta = new AutoDiff.Variable();446 variables.Add(beta);447 variables.Add(alpha);448 AutoDiff.Term branchTerm;449 if (TryTransformToAutoDiff(node.GetSubtree(0), variables, parameters, variableNames, out branchTerm)) {450 term = branchTerm * alpha + beta;451 return true;452 } else {453 term = null;454 return false;455 }456 }457 term = null;458 return false;459 }460 461 339 public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) { 462 var containsUnknownSymbol = ( 463 from n in tree.Root.GetSubtree(0).IterateNodesPrefix() 464 where 465 !(n.Symbol is Variable) && 466 !(n.Symbol is Constant) && 467 !(n.Symbol is Addition) && 468 !(n.Symbol is Subtraction) && 469 !(n.Symbol is Multiplication) && 470 !(n.Symbol is Division) && 471 !(n.Symbol is Logarithm) && 472 !(n.Symbol is Exponential) && 473 !(n.Symbol is SquareRoot) && 474 !(n.Symbol is Square) && 475 !(n.Symbol is Sine) && 476 !(n.Symbol is Cosine) && 477 !(n.Symbol is Tangent) && 478 !(n.Symbol is Erf) && 479 !(n.Symbol is Norm) && 480 !(n.Symbol is StartSymbol) 481 select n). 482 Any(); 483 return !containsUnknownSymbol; 340 return TreeToAutoDiffTermConverter.IsCompatible(tree); 484 341 } 485 342 }
Note: See TracChangeset
for help on using the changeset viewer.