Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/13/19 19:01:54 (5 years ago)
Author:
gkronber
Message:

#2994: fixed a bug caused by cloning of trees, support other NLOpt solvers, implement idisposable, experiment with preconditioning (still not working)

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/2994-AutoDiffForIntervals/HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Extensions/ConstrainedNLSInternal.cs

    r17204 r17213  
    77using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    88
    9 namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression.Extensions {
    10   internal class ConstrainedNLSInternal {
     9namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
     10  internal class ConstrainedNLSInternal : IDisposable {
    1111    private readonly int maxIterations;
    1212    public int MaxIterations => maxIterations;
     
    5858    private readonly List<ConstantTreeNode>[] allThetaNodes;
    5959    public List<ISymbolicExpressionTree> constraintTrees;    // TODO make local in ctor (public for debugging)
    60    
     60
    6161    private readonly double[] fi_eval;
    6262    private readonly double[,] jac_eval;
     
    6464    private readonly VectorAutoDiffEvaluator autoDiffEval;
    6565    private readonly VectorEvaluator eval;
     66    private readonly bool invalidProblem = false;
    6667
    6768    // end internal state
     
    9798      var pred = interpreter.GetSymbolicExpressionTreeValues(expr, problemData.Dataset, trainingRows).ToArray();
    9899
    99       if (pred.Any(pi => double.IsInfinity(pi) || double.IsNaN(pi))) throw new ArgumentException("The expression produces NaN or infinite values.");
     100      if (pred.Any(pi => double.IsInfinity(pi) || double.IsNaN(pi))) {
     101        bestError = targetVariance;
     102        invalidProblem = true;
     103      }
    100104
    101105      #region linear scaling
    102106      var predStDev = pred.StandardDeviationPop();
    103       if (predStDev == 0) throw new ArgumentException("The expression is constant.");
     107      if (predStDev == 0) {
     108        bestError = targetVariance;
     109        invalidProblem = true;
     110      }
    104111      var predMean = pred.Average();
    105112
     
    118125      constraintTrees = new List<ISymbolicExpressionTree>();
    119126      foreach (var constraint in intervalConstraints.Constraints) {
     127        if (!constraint.Enabled) continue;
    120128        if (constraint.IsDerivation) {
    121129          if (!problemData.AllowedInputVariables.Contains(constraint.Variable))
     
    130138            // convert variables named theta back to constants
    131139            var df_prepared = ReplaceVarWithConst(df_smaller_upper, thetaNames, thetaValues, allThetaNodes);
    132             constraintTrees.Add((ISymbolicExpressionTree)df_prepared.Clone());
     140            constraintTrees.Add(df_prepared);
    133141          }
    134142          if (constraint.Interval.LowerBound > double.NegativeInfinity) {
    135143            var df_larger_lower = Subtract(CreateConstant(constraint.Interval.LowerBound), (ISymbolicExpressionTree)df.Clone());
    136144            // convert variables named theta back to constants
    137             var df_prepared = ReplaceVarWithConst(df_larger_lower, thetaNames, thetaValues, allThetaNodes);           
    138             constraintTrees.Add((ISymbolicExpressionTree)df_prepared.Clone());
     145            var df_prepared = ReplaceVarWithConst(df_larger_lower, thetaNames, thetaValues, allThetaNodes);
     146            constraintTrees.Add(df_prepared);
    139147          }
    140148        } else {
     
    143151            // convert variables named theta back to constants
    144152            var df_prepared = ReplaceVarWithConst(f_smaller_upper, thetaNames, thetaValues, allThetaNodes);
    145             constraintTrees.Add((ISymbolicExpressionTree)df_prepared.Clone());
     153            constraintTrees.Add(df_prepared);
    146154          }
    147155          if (constraint.Interval.LowerBound > double.NegativeInfinity) {
     
    149157            // convert variables named theta back to constants
    150158            var df_prepared = ReplaceVarWithConst(f_larger_lower, thetaNames, thetaValues, allThetaNodes);
    151             constraintTrees.Add((ISymbolicExpressionTree)df_prepared.Clone());
     159            constraintTrees.Add(df_prepared);
    152160          }
    153161        }
     
    173181      NLOpt.nlopt_set_min_objective(nlopt, calculateObjectiveDelegate, IntPtr.Zero); // --> without preconditioning
    174182
    175       // preconditionDelegate = new NLOpt.nlopt_precond(PreconditionObjective);
    176       // NLOpt.nlopt_set_precond_min_objective(nlopt, calculateObjectiveDelegate, preconditionDelegate, IntPtr.Zero);
     183      //preconditionDelegate = new NLOpt.nlopt_precond(PreconditionObjective);
     184      //NLOpt.nlopt_set_precond_min_objective(nlopt, calculateObjectiveDelegate, preconditionDelegate, IntPtr.Zero);
    177185
    178186
     
    185193        calculateConstraintDelegates[i] = new NLOpt.nlopt_func(CalculateConstraint);
    186194        NLOpt.nlopt_add_inequality_constraint(nlopt, calculateConstraintDelegates[i], constraintDataPtr[i], 1e-8);
    187         //NLOpt.nlopt_add_precond_inequality_constraint(nlopt, calculateConstraintDelegates[i], preconditionDelegate, constraintDataPtr[i], 1e-8);
     195        // NLOpt.nlopt_add_precond_inequality_constraint(nlopt, calculateConstraintDelegates[i], preconditionDelegate, constraintDataPtr[i], 1e-8);
    188196      }
    189197
     
    196204
    197205    ~ConstrainedNLSInternal() {
    198       if (nlopt != IntPtr.Zero)
    199         NLOpt.nlopt_destroy(nlopt);
    200       if (constraintDataPtr != null) {
    201         for (int i = 0; i < constraintDataPtr.Length; i++)
    202           if (constraintDataPtr[i] != IntPtr.Zero)
    203             Marshal.FreeHGlobal(constraintDataPtr[i]);
    204       }
     206      Dispose();
    205207    }
    206208
    207209
    208210    internal void Optimize() {
     211      if (invalidProblem) return;
    209212      var x = thetaValues.ToArray();  /* initial guess */
    210213      double minf = double.MaxValue; /* minimum objective value upon return */
     
    219222        double[] _ = new double[x.Length];
    220223        bestConstraintValues = new double[calculateConstraintDelegates.Length];
    221         for(int i=0;i<calculateConstraintDelegates.Length;i++) {
     224        for (int i = 0; i < calculateConstraintDelegates.Length; i++) {
    222225          bestConstraintValues[i] = calculateConstraintDelegates[i].Invoke((uint)x.Length, x, _, constraintDataPtr[i]);
    223226        }
     
    300303    }
    301304
    302     // // NOT WORKING YET? WHY is det(H) always zero?!
    303     // // TODO
     305    // TODO
    304306    // private void PreconditionObjective(uint n, double[] x, double[] v, double[] vpre, IntPtr data) {
    305307    //   UpdateThetaValues(x); // calc H(x)
     
    311313    //   
    312314    //   // calc residuals and scale jac_eval
    313     //   var f = -2.0 / k;
    314     //   for (int i = 0;i<k;i++) {
    315     //     var r = target[i] - fi_eval[i];
    316     //     for(int j = 0;j<n;j++) {
    317     //       jac_eval[i, j] *= f * r;
    318     //     }
    319     //   }
    320     //   
     315    //   var f = 2.0 / (k*k);
     316    //
    321317    //   // approximate hessian H(x) = J(x)^T * J(x)
    322318    //   alglib.rmatrixgemm((int)n, (int)n, k,
    323     //     1.0, jac_eval, 0, 0, 1,  // transposed
     319    //     f, jac_eval, 0, 0, 1,  // transposed
    324320    //     jac_eval, 0, 0, 0,
    325321    //     0.0, ref h, 0, 0,
     
    327323    //     );
    328324    //   
     325    //
    329326    //   // scale v
    330327    //   alglib.rmatrixmv((int)n, (int)n, h, 0, 0, 0, v, 0, ref vpre, 0, alglib.serial);
    331     //   var det = alglib.matdet.rmatrixdet(h, (int)n, alglib.serial);
     328    //
     329    //
     330    //   alglib.spdmatrixcholesky(ref h, (int)n, true);
     331    //
     332    //   var det = alglib.matdet.spdmatrixcholeskydet(h, (int)n, alglib.serial);
    332333    // }
    333334
     
    446447      if (solver.Contains("CCSAQ")) return NLOpt.nlopt_algorithm.NLOPT_LD_CCSAQ;
    447448      if (solver.Contains("ISRES")) return NLOpt.nlopt_algorithm.NLOPT_GN_ISRES;
     449
     450      if (solver.Contains("DIRECT_G")) return NLOpt.nlopt_algorithm.NLOPT_GN_DIRECT;
     451      if (solver.Contains("NLOPT_GN_DIRECT_L")) return NLOpt.nlopt_algorithm.NLOPT_GN_DIRECT_L;
     452      if (solver.Contains("NLOPT_GN_DIRECT_L_RAND")) return NLOpt.nlopt_algorithm.NLOPT_GN_DIRECT_L_RAND;
     453      if (solver.Contains("NLOPT_GN_ORIG_DIRECT")) return NLOpt.nlopt_algorithm.NLOPT_GN_DIRECT;
     454      if (solver.Contains("NLOPT_GN_ORIG_DIRECT_L")) return NLOpt.nlopt_algorithm.NLOPT_GN_ORIG_DIRECT_L;
     455      if (solver.Contains("NLOPT_GD_STOGO")) return NLOpt.nlopt_algorithm.NLOPT_GD_STOGO;
     456      if (solver.Contains("NLOPT_GD_STOGO_RAND")) return NLOpt.nlopt_algorithm.NLOPT_GD_STOGO_RAND;
     457      if (solver.Contains("NLOPT_LD_LBFGS_NOCEDAL")) return NLOpt.nlopt_algorithm.NLOPT_LD_LBFGS_NOCEDAL;
     458      if (solver.Contains("NLOPT_LD_LBFGS")) return NLOpt.nlopt_algorithm.NLOPT_LD_LBFGS;
     459      if (solver.Contains("NLOPT_LN_PRAXIS")) return NLOpt.nlopt_algorithm.NLOPT_LN_PRAXIS;
     460      if (solver.Contains("NLOPT_LD_VAR1")) return NLOpt.nlopt_algorithm.NLOPT_LD_VAR1;
     461      if (solver.Contains("NLOPT_LD_VAR2")) return NLOpt.nlopt_algorithm.NLOPT_LD_VAR2;
     462      if (solver.Contains("NLOPT_LD_TNEWTON")) return NLOpt.nlopt_algorithm.NLOPT_LD_TNEWTON;
     463      if (solver.Contains("NLOPT_LD_TNEWTON_RESTART")) return NLOpt.nlopt_algorithm.NLOPT_LD_TNEWTON_RESTART;
     464      if (solver.Contains("NLOPT_LD_TNEWTON_PRECOND")) return NLOpt.nlopt_algorithm.NLOPT_LD_TNEWTON_PRECOND;
     465      if (solver.Contains("NLOPT_LD_TNEWTON_PRECOND_RESTART")) return NLOpt.nlopt_algorithm.NLOPT_LD_TNEWTON_PRECOND_RESTART;
     466      if (solver.Contains("NLOPT_GN_CRS2_LM")) return NLOpt.nlopt_algorithm.NLOPT_GN_CRS2_LM;
     467      if (solver.Contains("NLOPT_GN_MLSL")) return NLOpt.nlopt_algorithm.NLOPT_GN_MLSL;
     468      if (solver.Contains("NLOPT_GD_MLSL")) return NLOpt.nlopt_algorithm.NLOPT_GD_MLSL;
     469      if (solver.Contains("NLOPT_GN_MLSL_LDS")) return NLOpt.nlopt_algorithm.NLOPT_GN_MLSL_LDS;
     470      if (solver.Contains("NLOPT_GD_MLSL_LDS")) return NLOpt.nlopt_algorithm.NLOPT_GD_MLSL_LDS;
     471      if (solver.Contains("NLOPT_LN_NEWUOA")) return NLOpt.nlopt_algorithm.NLOPT_LN_NEWUOA;
     472      if (solver.Contains("NLOPT_LN_NEWUOA_BOUND")) return NLOpt.nlopt_algorithm.NLOPT_LN_NEWUOA_BOUND;
     473      if (solver.Contains("NLOPT_LN_NELDERMEAD")) return NLOpt.nlopt_algorithm.NLOPT_LN_NELDERMEAD;
     474      if (solver.Contains("NLOPT_LN_SBPLX")) return NLOpt.nlopt_algorithm.NLOPT_LN_SBPLX;
     475      if (solver.Contains("NLOPT_LN_AUGLAG")) return NLOpt.nlopt_algorithm.NLOPT_LN_AUGLAG;
     476      if (solver.Contains("NLOPT_LD_AUGLAG")) return NLOpt.nlopt_algorithm.NLOPT_LD_AUGLAG;
     477      if (solver.Contains("NLOPT_LN_BOBYQA")) return NLOpt.nlopt_algorithm.NLOPT_LN_BOBYQA;
     478      if (solver.Contains("NLOPT_AUGLAG")) return NLOpt.nlopt_algorithm.NLOPT_AUGLAG;
     479      if (solver.Contains("NLOPT_LD_SLSQP")) return NLOpt.nlopt_algorithm.NLOPT_LD_SLSQP;
     480      if (solver.Contains("NLOPT_LD_CCSAQ))")) return NLOpt.nlopt_algorithm.NLOPT_LD_CCSAQ;
     481      if (solver.Contains("NLOPT_GN_ESCH")) return NLOpt.nlopt_algorithm.NLOPT_GN_ESCH;
     482      if (solver.Contains("NLOPT_GN_AGS")) return NLOpt.nlopt_algorithm.NLOPT_GN_AGS;
     483
    448484      throw new ArgumentException($"Unknown solver {solver}");
    449485    }
     
    559595      return node;
    560596    }
     597
     598    public void Dispose() {
     599      if (nlopt != IntPtr.Zero) {
     600        NLOpt.nlopt_destroy(nlopt);
     601        nlopt = IntPtr.Zero;
     602      }
     603      if (constraintDataPtr != null) {
     604        for (int i = 0; i < constraintDataPtr.Length; i++)
     605          if (constraintDataPtr[i] != IntPtr.Zero) {
     606            Marshal.FreeHGlobal(constraintDataPtr[i]);
     607            constraintDataPtr[i] = IntPtr.Zero;
     608          }
     609      }
     610    }
    561611    #endregion
    562612  }
Note: See TracChangeset for help on using the changeset viewer.