Changeset 17325 for branches/2994-AutoDiffForIntervals/HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Extensions/ConstrainedNLSInternal.cs
- Timestamp:
- 10/09/19 11:13:11 (5 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/2994-AutoDiffForIntervals/HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Extensions/ConstrainedNLSInternal.cs
r17311 r17325 6 6 using HeuristicLab.Common; 7 7 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; 8 using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression.Extensions; 8 9 9 10 namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression { … … 40 41 private double[] bestConstraintValues; 41 42 public double[] BestConstraintValues => bestConstraintValues; 43 44 private bool disposed = false; 42 45 43 46 … … 120 123 } 121 124 125 // all trees are linearly scaled (to improve GP performance) 122 126 #region linear scaling 123 127 var predStDev = pred.StandardDeviationPop(); … … 134 138 135 139 // convert constants to variables named theta... 136 var treeForDerivation = Replace ConstWithVar(scaledTree, out List<string> thetaNames, out thetaValues); // copies the tree140 var treeForDerivation = ReplaceAndExtractParameters(scaledTree, out List<string> thetaNames, out thetaValues); // copies the tree 137 141 138 142 // create trees for relevant derivatives … … 220 224 221 225 ~ConstrainedNLSInternal() { 222 Dispose(); 223 } 224 225 226 internal void Optimize() { 226 Dispose(false); 227 } 228 229 230 public enum OptimizationMode { ReadOnly, UpdateParameters, UpdateParametersAndKeepLinearScaling }; 231 232 internal void Optimize(OptimizationMode mode) { 227 233 if (invalidProblem) return; 228 234 var x = thetaValues.ToArray(); /* initial guess */ … … 233 239 // throw new InvalidOperationException($"NLOpt failed {res} {NLOpt.nlopt_get_errmsg(nlopt)}"); 234 240 return; 235 } else if (minf <= bestError){241 } else /*if ( minf <= bestError ) */{ 236 242 bestSolution = x; 237 243 bestError = minf; … … 245 251 246 252 // update parameters in tree 247 var pIdx = 0; 248 // here we lose the two last parameters (for linear scaling) 249 foreach (var node in scaledTree.IterateNodesPostfix()) { 250 if (node is ConstantTreeNode constTreeNode) { 251 constTreeNode.Value = x[pIdx++]; 252 } else if (node is VariableTreeNode varTreeNode) { 253 varTreeNode.Weight = x[pIdx++]; 254 } 255 } 256 if (pIdx != x.Length) throw new InvalidProgramException(); 257 } 258 bestTree = scaledTree; 259 } 253 UpdateParametersInTree(scaledTree, x); 254 255 if (mode == OptimizationMode.UpdateParameters) { 256 // update original expression (when called from evaluator we want to write back optimized parameters) 257 expr.Root.GetSubtree(0).RemoveSubtree(0); // delete old tree 258 expr.Root.GetSubtree(0).InsertSubtree(0, 259 scaledTree.Root.GetSubtree(0).GetSubtree(0).GetSubtree(0).GetSubtree(0) // insert the optimized sub-tree (without scaling nodes) 260 ); 261 } else if (mode == OptimizationMode.UpdateParametersAndKeepLinearScaling) { 262 expr.Root.GetSubtree(0).RemoveSubtree(0); // delete old tree 263 expr.Root.GetSubtree(0).InsertSubtree(0, scaledTree.Root.GetSubtree(0).GetSubtree(0)); // insert the optimized sub-tree (including scaling nodes) 264 } 265 } 266 bestTree = expr; 267 } 268 260 269 261 270 double CalculateObjective(uint dim, double[] curX, double[] grad, IntPtr data) { … … 314 323 } 315 324 316 //UpdateBestSolution(sse / target.Length, curX);325 UpdateBestSolution(sse / target.Length, curX); 317 326 RaiseFunctionEvaluated(); 318 327 … … 427 436 UpdateConstraintViolations(constraintData.Idx, interval.UpperBound); 428 437 if (double.IsNaN(interval.UpperBound)) { 429 if (grad!=null)Array.Clear(grad, 0, grad.Length);438 if (grad != null) Array.Clear(grad, 0, grad.Length); 430 439 return double.MaxValue; 431 440 } else return interval.UpperBound; … … 463 472 } 464 473 465 private static void UpdateConstants(ISymbolicExpressionTreeNode[] nodes, double[] constants) {466 if (nodes.Length != constants.Length) throw new InvalidOperationException();467 for (int i = 0; i < nodes.Length; i++) {468 if (nodes[i] is VariableTreeNode varNode) varNode.Weight = constants[i];469 else if (nodes[i] is ConstantTreeNode constNode) constNode.Value = constants[i];470 }471 }472 474 473 475 private NLOpt.nlopt_algorithm GetSolver(string solver) { … … 514 516 } 515 517 518 // determines the nodes over which we can calculate the partial derivative 519 // this is different from the vector of all parameters because not every tree contains all parameters 516 520 private static ISymbolicExpressionTreeNode[] GetParameterNodes(ISymbolicExpressionTree tree, List<ConstantTreeNode>[] allNodes) { 517 521 // TODO better solution necessary … … 553 557 } 554 558 555 private static ISymbolicExpressionTree ReplaceConstWithVar(ISymbolicExpressionTree tree, out List<string> thetaNames, out List<double> thetaValues) { 559 560 561 562 private void UpdateParametersInTree(ISymbolicExpressionTree scaledTree, double[] x) { 563 var pIdx = 0; 564 // here we lose the two last parameters (for linear scaling) 565 foreach (var node in scaledTree.IterateNodesPostfix()) { 566 if (node is ConstantTreeNode constTreeNode) { 567 constTreeNode.Value = x[pIdx++]; 568 } else if (node is VariableTreeNode varTreeNode) { 569 if (varTreeNode.Weight != 1.0) // see ReplaceAndExtractParameters 570 varTreeNode.Weight = x[pIdx++]; 571 } 572 } 573 if (pIdx != x.Length) throw new InvalidProgramException(); 574 } 575 576 private static ISymbolicExpressionTree ReplaceAndExtractParameters(ISymbolicExpressionTree tree, out List<string> thetaNames, out List<double> thetaValues) { 556 577 thetaNames = new List<string>(); 557 578 thetaValues = new List<double>(); … … 578 599 } 579 600 if (node is VariableTreeNode varTreeNode) { 601 if (varTreeNode.Weight == 1) continue; // NOTE: here we assume that we do not tune variable weights when they are originally exactly 1 because we assume that the tree has been parsed and the tree explicitly has the structure w * var 602 580 603 var thetaVar = (VariableTreeNode)new Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode(); 581 604 thetaVar.Weight = 1; … … 626 649 627 650 public void Dispose() { 651 Dispose(true); 652 GC.SuppressFinalize(this); 653 } 654 655 protected virtual void Dispose(bool disposing) { 656 if (disposed) 657 return; 658 659 if (disposing) { 660 // Free any other managed objects here. 661 } 662 663 // Free any unmanaged objects here. 628 664 if (nlopt != IntPtr.Zero) { 629 665 NLOpt.nlopt_destroy(nlopt); … … 637 673 } 638 674 } 675 676 disposed = true; 639 677 } 640 678 #endregion
Note: See TracChangeset
for help on using the changeset viewer.