Changeset 17213
- Timestamp:
- 08/13/19 19:01:54 (5 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/2994-AutoDiffForIntervals/HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Extensions/ConstrainedNLSInternal.cs
r17204 r17213 7 7 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; 8 8 9 namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression .Extensions{10 internal class ConstrainedNLSInternal {9 namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression { 10 internal class ConstrainedNLSInternal : IDisposable { 11 11 private readonly int maxIterations; 12 12 public int MaxIterations => maxIterations; … … 58 58 private readonly List<ConstantTreeNode>[] allThetaNodes; 59 59 public List<ISymbolicExpressionTree> constraintTrees; // TODO make local in ctor (public for debugging) 60 60 61 61 private readonly double[] fi_eval; 62 62 private readonly double[,] jac_eval; … … 64 64 private readonly VectorAutoDiffEvaluator autoDiffEval; 65 65 private readonly VectorEvaluator eval; 66 private readonly bool invalidProblem = false; 66 67 67 68 // end internal state … … 97 98 var pred = interpreter.GetSymbolicExpressionTreeValues(expr, problemData.Dataset, trainingRows).ToArray(); 98 99 99 if (pred.Any(pi => double.IsInfinity(pi) || double.IsNaN(pi))) throw new ArgumentException("The expression produces NaN or infinite values."); 100 if (pred.Any(pi => double.IsInfinity(pi) || double.IsNaN(pi))) { 101 bestError = targetVariance; 102 invalidProblem = true; 103 } 100 104 101 105 #region linear scaling 102 106 var predStDev = pred.StandardDeviationPop(); 103 if (predStDev == 0) throw new ArgumentException("The expression is constant."); 107 if (predStDev == 0) { 108 bestError = targetVariance; 109 invalidProblem = true; 110 } 104 111 var predMean = pred.Average(); 105 112 … … 118 125 constraintTrees = new List<ISymbolicExpressionTree>(); 119 126 foreach (var constraint in intervalConstraints.Constraints) { 127 if (!constraint.Enabled) continue; 120 128 if (constraint.IsDerivation) { 121 129 if (!problemData.AllowedInputVariables.Contains(constraint.Variable)) … … 130 138 // convert variables named theta back to constants 131 139 var df_prepared = ReplaceVarWithConst(df_smaller_upper, thetaNames, thetaValues, allThetaNodes); 132 constraintTrees.Add( (ISymbolicExpressionTree)df_prepared.Clone());140 constraintTrees.Add(df_prepared); 133 141 } 134 142 if (constraint.Interval.LowerBound > double.NegativeInfinity) { 135 143 var df_larger_lower = Subtract(CreateConstant(constraint.Interval.LowerBound), (ISymbolicExpressionTree)df.Clone()); 136 144 // convert variables named theta back to constants 137 var df_prepared = ReplaceVarWithConst(df_larger_lower, thetaNames, thetaValues, allThetaNodes); 138 constraintTrees.Add( (ISymbolicExpressionTree)df_prepared.Clone());145 var df_prepared = ReplaceVarWithConst(df_larger_lower, thetaNames, thetaValues, allThetaNodes); 146 constraintTrees.Add(df_prepared); 139 147 } 140 148 } else { … … 143 151 // convert variables named theta back to constants 144 152 var df_prepared = ReplaceVarWithConst(f_smaller_upper, thetaNames, thetaValues, allThetaNodes); 145 constraintTrees.Add( (ISymbolicExpressionTree)df_prepared.Clone());153 constraintTrees.Add(df_prepared); 146 154 } 147 155 if (constraint.Interval.LowerBound > double.NegativeInfinity) { … … 149 157 // convert variables named theta back to constants 150 158 var df_prepared = ReplaceVarWithConst(f_larger_lower, thetaNames, thetaValues, allThetaNodes); 151 constraintTrees.Add( (ISymbolicExpressionTree)df_prepared.Clone());159 constraintTrees.Add(df_prepared); 152 160 } 153 161 } … … 173 181 NLOpt.nlopt_set_min_objective(nlopt, calculateObjectiveDelegate, IntPtr.Zero); // --> without preconditioning 174 182 175 // 176 // 183 //preconditionDelegate = new NLOpt.nlopt_precond(PreconditionObjective); 184 //NLOpt.nlopt_set_precond_min_objective(nlopt, calculateObjectiveDelegate, preconditionDelegate, IntPtr.Zero); 177 185 178 186 … … 185 193 calculateConstraintDelegates[i] = new NLOpt.nlopt_func(CalculateConstraint); 186 194 NLOpt.nlopt_add_inequality_constraint(nlopt, calculateConstraintDelegates[i], constraintDataPtr[i], 1e-8); 187 // NLOpt.nlopt_add_precond_inequality_constraint(nlopt, calculateConstraintDelegates[i], preconditionDelegate, constraintDataPtr[i], 1e-8);195 // NLOpt.nlopt_add_precond_inequality_constraint(nlopt, calculateConstraintDelegates[i], preconditionDelegate, constraintDataPtr[i], 1e-8); 188 196 } 189 197 … … 196 204 197 205 ~ConstrainedNLSInternal() { 198 if (nlopt != IntPtr.Zero) 199 NLOpt.nlopt_destroy(nlopt); 200 if (constraintDataPtr != null) { 201 for (int i = 0; i < constraintDataPtr.Length; i++) 202 if (constraintDataPtr[i] != IntPtr.Zero) 203 Marshal.FreeHGlobal(constraintDataPtr[i]); 204 } 206 Dispose(); 205 207 } 206 208 207 209 208 210 internal void Optimize() { 211 if (invalidProblem) return; 209 212 var x = thetaValues.ToArray(); /* initial guess */ 210 213 double minf = double.MaxValue; /* minimum objective value upon return */ … … 219 222 double[] _ = new double[x.Length]; 220 223 bestConstraintValues = new double[calculateConstraintDelegates.Length]; 221 for (int i=0;i<calculateConstraintDelegates.Length;i++) {224 for (int i = 0; i < calculateConstraintDelegates.Length; i++) { 222 225 bestConstraintValues[i] = calculateConstraintDelegates[i].Invoke((uint)x.Length, x, _, constraintDataPtr[i]); 223 226 } … … 300 303 } 301 304 302 // // NOT WORKING YET? WHY is det(H) always zero?! 303 // // TODO 305 // TODO 304 306 // private void PreconditionObjective(uint n, double[] x, double[] v, double[] vpre, IntPtr data) { 305 307 // UpdateThetaValues(x); // calc H(x) … … 311 313 // 312 314 // // calc residuals and scale jac_eval 313 // var f = -2.0 / k; 314 // for (int i = 0;i<k;i++) { 315 // var r = target[i] - fi_eval[i]; 316 // for(int j = 0;j<n;j++) { 317 // jac_eval[i, j] *= f * r; 318 // } 319 // } 320 // 315 // var f = 2.0 / (k*k); 316 // 321 317 // // approximate hessian H(x) = J(x)^T * J(x) 322 318 // alglib.rmatrixgemm((int)n, (int)n, k, 323 // 1.0, jac_eval, 0, 0, 1, // transposed319 // f, jac_eval, 0, 0, 1, // transposed 324 320 // jac_eval, 0, 0, 0, 325 321 // 0.0, ref h, 0, 0, … … 327 323 // ); 328 324 // 325 // 329 326 // // scale v 330 327 // alglib.rmatrixmv((int)n, (int)n, h, 0, 0, 0, v, 0, ref vpre, 0, alglib.serial); 331 // var det = alglib.matdet.rmatrixdet(h, (int)n, alglib.serial); 328 // 329 // 330 // alglib.spdmatrixcholesky(ref h, (int)n, true); 331 // 332 // var det = alglib.matdet.spdmatrixcholeskydet(h, (int)n, alglib.serial); 332 333 // } 333 334 … … 446 447 if (solver.Contains("CCSAQ")) return NLOpt.nlopt_algorithm.NLOPT_LD_CCSAQ; 447 448 if (solver.Contains("ISRES")) return NLOpt.nlopt_algorithm.NLOPT_GN_ISRES; 449 450 if (solver.Contains("DIRECT_G")) return NLOpt.nlopt_algorithm.NLOPT_GN_DIRECT; 451 if (solver.Contains("NLOPT_GN_DIRECT_L")) return NLOpt.nlopt_algorithm.NLOPT_GN_DIRECT_L; 452 if (solver.Contains("NLOPT_GN_DIRECT_L_RAND")) return NLOpt.nlopt_algorithm.NLOPT_GN_DIRECT_L_RAND; 453 if (solver.Contains("NLOPT_GN_ORIG_DIRECT")) return NLOpt.nlopt_algorithm.NLOPT_GN_DIRECT; 454 if (solver.Contains("NLOPT_GN_ORIG_DIRECT_L")) return NLOpt.nlopt_algorithm.NLOPT_GN_ORIG_DIRECT_L; 455 if (solver.Contains("NLOPT_GD_STOGO")) return NLOpt.nlopt_algorithm.NLOPT_GD_STOGO; 456 if (solver.Contains("NLOPT_GD_STOGO_RAND")) return NLOpt.nlopt_algorithm.NLOPT_GD_STOGO_RAND; 457 if (solver.Contains("NLOPT_LD_LBFGS_NOCEDAL")) return NLOpt.nlopt_algorithm.NLOPT_LD_LBFGS_NOCEDAL; 458 if (solver.Contains("NLOPT_LD_LBFGS")) return NLOpt.nlopt_algorithm.NLOPT_LD_LBFGS; 459 if (solver.Contains("NLOPT_LN_PRAXIS")) return NLOpt.nlopt_algorithm.NLOPT_LN_PRAXIS; 460 if (solver.Contains("NLOPT_LD_VAR1")) return NLOpt.nlopt_algorithm.NLOPT_LD_VAR1; 461 if (solver.Contains("NLOPT_LD_VAR2")) return NLOpt.nlopt_algorithm.NLOPT_LD_VAR2; 462 if (solver.Contains("NLOPT_LD_TNEWTON")) return NLOpt.nlopt_algorithm.NLOPT_LD_TNEWTON; 463 if (solver.Contains("NLOPT_LD_TNEWTON_RESTART")) return NLOpt.nlopt_algorithm.NLOPT_LD_TNEWTON_RESTART; 464 if (solver.Contains("NLOPT_LD_TNEWTON_PRECOND")) return NLOpt.nlopt_algorithm.NLOPT_LD_TNEWTON_PRECOND; 465 if (solver.Contains("NLOPT_LD_TNEWTON_PRECOND_RESTART")) return NLOpt.nlopt_algorithm.NLOPT_LD_TNEWTON_PRECOND_RESTART; 466 if (solver.Contains("NLOPT_GN_CRS2_LM")) return NLOpt.nlopt_algorithm.NLOPT_GN_CRS2_LM; 467 if (solver.Contains("NLOPT_GN_MLSL")) return NLOpt.nlopt_algorithm.NLOPT_GN_MLSL; 468 if (solver.Contains("NLOPT_GD_MLSL")) return NLOpt.nlopt_algorithm.NLOPT_GD_MLSL; 469 if (solver.Contains("NLOPT_GN_MLSL_LDS")) return NLOpt.nlopt_algorithm.NLOPT_GN_MLSL_LDS; 470 if (solver.Contains("NLOPT_GD_MLSL_LDS")) return NLOpt.nlopt_algorithm.NLOPT_GD_MLSL_LDS; 471 if (solver.Contains("NLOPT_LN_NEWUOA")) return NLOpt.nlopt_algorithm.NLOPT_LN_NEWUOA; 472 if (solver.Contains("NLOPT_LN_NEWUOA_BOUND")) return NLOpt.nlopt_algorithm.NLOPT_LN_NEWUOA_BOUND; 473 if (solver.Contains("NLOPT_LN_NELDERMEAD")) return NLOpt.nlopt_algorithm.NLOPT_LN_NELDERMEAD; 474 if (solver.Contains("NLOPT_LN_SBPLX")) return NLOpt.nlopt_algorithm.NLOPT_LN_SBPLX; 475 if (solver.Contains("NLOPT_LN_AUGLAG")) return NLOpt.nlopt_algorithm.NLOPT_LN_AUGLAG; 476 if (solver.Contains("NLOPT_LD_AUGLAG")) return NLOpt.nlopt_algorithm.NLOPT_LD_AUGLAG; 477 if (solver.Contains("NLOPT_LN_BOBYQA")) return NLOpt.nlopt_algorithm.NLOPT_LN_BOBYQA; 478 if (solver.Contains("NLOPT_AUGLAG")) return NLOpt.nlopt_algorithm.NLOPT_AUGLAG; 479 if (solver.Contains("NLOPT_LD_SLSQP")) return NLOpt.nlopt_algorithm.NLOPT_LD_SLSQP; 480 if (solver.Contains("NLOPT_LD_CCSAQ))")) return NLOpt.nlopt_algorithm.NLOPT_LD_CCSAQ; 481 if (solver.Contains("NLOPT_GN_ESCH")) return NLOpt.nlopt_algorithm.NLOPT_GN_ESCH; 482 if (solver.Contains("NLOPT_GN_AGS")) return NLOpt.nlopt_algorithm.NLOPT_GN_AGS; 483 448 484 throw new ArgumentException($"Unknown solver {solver}"); 449 485 } … … 559 595 return node; 560 596 } 597 598 public void Dispose() { 599 if (nlopt != IntPtr.Zero) { 600 NLOpt.nlopt_destroy(nlopt); 601 nlopt = IntPtr.Zero; 602 } 603 if (constraintDataPtr != null) { 604 for (int i = 0; i < constraintDataPtr.Length; i++) 605 if (constraintDataPtr[i] != IntPtr.Zero) { 606 Marshal.FreeHGlobal(constraintDataPtr[i]); 607 constraintDataPtr[i] = IntPtr.Zero; 608 } 609 } 610 } 561 611 #endregion 562 612 }
Note: See TracChangeset
for help on using the changeset viewer.