source: branches/2994-AutoDiffForIntervals/HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Extensions/ConstrainedNLSInternal.cs @ 17311

Last change on this file since 17311 was 17311, checked in by gkronber, 12 months ago

#2994: add names for constraints and derivatives, fix a problem: variableRanges are reset when creating a new solution. fix smaller bugs

File size: 29.4 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Runtime.InteropServices;
6using HeuristicLab.Common;
7using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
8
9namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
10  internal class ConstrainedNLSInternal : IDisposable {
11    private readonly int maxIterations;
12    public int MaxIterations => maxIterations;
13
14    private readonly string solver;
15    public string Solver => solver;
16
17    private readonly ISymbolicExpressionTree expr;
18    public ISymbolicExpressionTree Expr => expr;
19
20    private readonly IRegressionProblemData problemData;
21
22    public IRegressionProblemData ProblemData => problemData;
23
24
25    public event Action FunctionEvaluated;
26    public event Action<int, double> ConstraintEvaluated;
27
28    private double bestError = double.MaxValue;
29    public double BestError => bestError;
30
31    private double curError = double.MaxValue;
32    public double CurError => curError;
33
34    private double[] bestSolution;
35    public double[] BestSolution => bestSolution;
36
37    private ISymbolicExpressionTree bestTree;
38    public ISymbolicExpressionTree BestTree => bestTree;
39
40    private double[] bestConstraintValues;
41    public double[] BestConstraintValues => bestConstraintValues;
42
43
44    // for debugging (must be in the same order as processed below)
45    public IEnumerable<string> ConstraintDescriptions {
46      get {
47        foreach (var elem in problemData.IntervalConstraints.Constraints) {
48          if (!elem.Enabled) continue;
49          if (elem.Interval.UpperBound < double.PositiveInfinity) {
50            yield return elem.Expression + " < " + elem.Interval.UpperBound;
51          }
52          if (elem.Interval.LowerBound > double.NegativeInfinity) {
53            yield return "-" + elem.Expression + " < " + (-1) * elem.Interval.LowerBound;
54          }
55        }
56      }
57    }
58
59    public bool CheckGradient { get; internal set; }
60
61    // begin internal state
62    private IntPtr nlopt;
63    private SymbolicDataAnalysisExpressionTreeLinearInterpreter interpreter;
64    private readonly NLOpt.nlopt_func calculateObjectiveDelegate; // must hold the delegate to prevent GC
65    // private readonly NLOpt.nlopt_precond preconditionDelegate;
66    private readonly IntPtr[] constraintDataPtr; // must hold the objects to prevent GC
67    private readonly NLOpt.nlopt_func[] calculateConstraintDelegates; // must hold the delegates to prevent GC
68    private readonly List<double> thetaValues;
69    private readonly IDictionary<string, Interval> dataIntervals;
70    private readonly int[] trainingRows;
71    private readonly double[] target;
72    private readonly ISymbolicExpressionTree preparedTree;
73    private readonly ISymbolicExpressionTreeNode[] preparedTreeParameterNodes;
74    private readonly List<ConstantTreeNode>[] allThetaNodes;
75    public List<ISymbolicExpressionTree> constraintTrees;    // TODO make local in ctor (public for debugging)
76
77    private readonly double[] fi_eval;
78    private readonly double[,] jac_eval;
79    private readonly ISymbolicExpressionTree scaledTree;
80    private readonly VectorAutoDiffEvaluator autoDiffEval;
81    private readonly VectorEvaluator eval;
82    private readonly bool invalidProblem = false;
83
84    // end internal state
85
86
87    // for data exchange to/from optimizer in native code
88    [StructLayout(LayoutKind.Sequential)]
89    private struct ConstraintData {
90      public int Idx;
91      public ISymbolicExpressionTree Tree;
92      public ISymbolicExpressionTreeNode[] ParameterNodes;
93    }
94
95    internal ConstrainedNLSInternal(string solver, ISymbolicExpressionTree expr, int maxIterations, IRegressionProblemData problemData, double ftol_rel = 0, double ftol_abs = 0, double maxTime = 0) {
96      this.solver = solver;
97      this.expr = expr;
98      this.maxIterations = maxIterations;
99      this.problemData = problemData;
100      this.interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
101      this.autoDiffEval = new VectorAutoDiffEvaluator();
102      this.eval = new VectorEvaluator();
103
104      CheckGradient = false;
105
106      var intervalConstraints = problemData.IntervalConstraints;
107      dataIntervals = problemData.VariableRanges.GetIntervals();
108      trainingRows = problemData.TrainingIndices.ToArray();
109      // buffers
110      target = problemData.TargetVariableTrainingValues.ToArray();
111      var targetStDev = target.StandardDeviationPop();
112      var targetVariance = targetStDev * targetStDev;
113      var targetMean = target.Average();
114      var pred = interpreter.GetSymbolicExpressionTreeValues(expr, problemData.Dataset, trainingRows).ToArray();
115
116      bestError = targetVariance;
117
118      if (pred.Any(pi => double.IsInfinity(pi) || double.IsNaN(pi))) {
119        invalidProblem = true;
120      }
121
122      #region linear scaling
123      var predStDev = pred.StandardDeviationPop();
124      if (predStDev == 0) {
125        invalidProblem = true;
126      }
127      var predMean = pred.Average();
128
129      var scalingFactor = targetStDev / predStDev;
130      var offset = targetMean - predMean * scalingFactor;
131
132      scaledTree = CopyAndScaleTree(expr, scalingFactor, offset);
133      #endregion
134
135      // convert constants to variables named theta...
136      var treeForDerivation = ReplaceConstWithVar(scaledTree, out List<string> thetaNames, out thetaValues); // copies the tree
137
138      // create trees for relevant derivatives
139      Dictionary<string, ISymbolicExpressionTree> derivatives = new Dictionary<string, ISymbolicExpressionTree>();
140      allThetaNodes = thetaNames.Select(_ => new List<ConstantTreeNode>()).ToArray();
141      constraintTrees = new List<ISymbolicExpressionTree>();
142      foreach (var constraint in intervalConstraints.Constraints) {
143        if (!constraint.Enabled) continue;
144        if (constraint.IsDerivation) {
145          if (!problemData.AllowedInputVariables.Contains(constraint.Variable))
146            throw new ArgumentException($"Invalid constraint: the variable {constraint.Variable} does not exist in the dataset.");
147          var df = DerivativeCalculator.Derive(treeForDerivation, constraint.Variable);
148
149          // NLOpt requires constraint expressions of the form c(x) <= 0
150          // -> we make two expressions, one for the lower bound and one for the upper bound
151
152          if (constraint.Interval.UpperBound < double.PositiveInfinity) {
153            var df_smaller_upper = Subtract((ISymbolicExpressionTree)df.Clone(), CreateConstant(constraint.Interval.UpperBound));
154            // convert variables named theta back to constants
155            var df_prepared = ReplaceVarWithConst(df_smaller_upper, thetaNames, thetaValues, allThetaNodes);
156            constraintTrees.Add(df_prepared);
157          }
158          if (constraint.Interval.LowerBound > double.NegativeInfinity) {
159            var df_larger_lower = Subtract(CreateConstant(constraint.Interval.LowerBound), (ISymbolicExpressionTree)df.Clone());
160            // convert variables named theta back to constants
161            var df_prepared = ReplaceVarWithConst(df_larger_lower, thetaNames, thetaValues, allThetaNodes);
162            constraintTrees.Add(df_prepared);
163          }
164        } else {
165          if (constraint.Interval.UpperBound < double.PositiveInfinity) {
166            var f_smaller_upper = Subtract((ISymbolicExpressionTree)treeForDerivation.Clone(), CreateConstant(constraint.Interval.UpperBound));
167            // convert variables named theta back to constants
168            var df_prepared = ReplaceVarWithConst(f_smaller_upper, thetaNames, thetaValues, allThetaNodes);
169            constraintTrees.Add(df_prepared);
170          }
171          if (constraint.Interval.LowerBound > double.NegativeInfinity) {
172            var f_larger_lower = Subtract(CreateConstant(constraint.Interval.LowerBound), (ISymbolicExpressionTree)treeForDerivation.Clone());
173            // convert variables named theta back to constants
174            var df_prepared = ReplaceVarWithConst(f_larger_lower, thetaNames, thetaValues, allThetaNodes);
175            constraintTrees.Add(df_prepared);
176          }
177        }
178      }
179
180      preparedTree = ReplaceVarWithConst(treeForDerivation, thetaNames, thetaValues, allThetaNodes);
181      preparedTreeParameterNodes = GetParameterNodes(preparedTree, allThetaNodes);
182
183      var dim = thetaValues.Count;
184      fi_eval = new double[target.Length]; // init buffer;
185      jac_eval = new double[target.Length, dim]; // init buffer
186
187
188      var minVal = Math.Min(-1000.0, thetaValues.Min());
189      var maxVal = Math.Max(1000.0, thetaValues.Max());
190      var lb = Enumerable.Repeat(minVal, thetaValues.Count).ToArray();
191      var up = Enumerable.Repeat(maxVal, thetaValues.Count).ToArray();
192      nlopt = NLOpt.nlopt_create(GetSolver(solver), (uint)dim);
193
194      NLOpt.nlopt_set_lower_bounds(nlopt, lb);
195      NLOpt.nlopt_set_upper_bounds(nlopt, up);
196      calculateObjectiveDelegate = new NLOpt.nlopt_func(CalculateObjective); // keep a reference to the delegate (see below)
197      NLOpt.nlopt_set_min_objective(nlopt, calculateObjectiveDelegate, IntPtr.Zero); // --> without preconditioning
198
199      //preconditionDelegate = new NLOpt.nlopt_precond(PreconditionObjective);
200      //NLOpt.nlopt_set_precond_min_objective(nlopt, calculateObjectiveDelegate, preconditionDelegate, IntPtr.Zero);
201
202
203      constraintDataPtr = new IntPtr[constraintTrees.Count];
204      calculateConstraintDelegates = new NLOpt.nlopt_func[constraintTrees.Count]; // make sure we keep a reference to the delegates (otherwise GC will free delegate objects see https://stackoverflow.com/questions/7302045/callback-delegates-being-collected#7302258)
205      for (int i = 0; i < constraintTrees.Count; i++) {
206        var constraintData = new ConstraintData() { Idx = i, Tree = constraintTrees[i], ParameterNodes = GetParameterNodes(constraintTrees[i], allThetaNodes) };
207        constraintDataPtr[i] = Marshal.AllocHGlobal(Marshal.SizeOf<ConstraintData>());
208        Marshal.StructureToPtr(constraintData, constraintDataPtr[i], fDeleteOld: false);
209        calculateConstraintDelegates[i] = new NLOpt.nlopt_func(CalculateConstraint);
210        NLOpt.nlopt_add_inequality_constraint(nlopt, calculateConstraintDelegates[i], constraintDataPtr[i], 1e-8);
211        // NLOpt.nlopt_add_precond_inequality_constraint(nlopt, calculateConstraintDelegates[i], preconditionDelegate, constraintDataPtr[i], 1e-8);
212      }
213
214      NLOpt.nlopt_set_ftol_rel(nlopt, ftol_rel);
215      NLOpt.nlopt_set_ftol_abs(nlopt, ftol_abs);
216      NLOpt.nlopt_set_maxtime(nlopt, maxTime);
217      NLOpt.nlopt_set_maxeval(nlopt, maxIterations);
218    }
219
220
221    ~ConstrainedNLSInternal() {
222      Dispose();
223    }
224
225
226    internal void Optimize() {
227      if (invalidProblem) return;
228      var x = thetaValues.ToArray();  /* initial guess */
229      double minf = double.MaxValue; /* minimum objective value upon return */
230      var res = NLOpt.nlopt_optimize(nlopt, x, ref minf);
231
232      if (res < 0 && res != NLOpt.nlopt_result.NLOPT_FORCED_STOP) {
233        // throw new InvalidOperationException($"NLOpt failed {res} {NLOpt.nlopt_get_errmsg(nlopt)}");
234        return;
235      } else if (minf <= bestError) {
236        bestSolution = x;
237        bestError = minf;
238
239        // calculate constraints of final solution
240        double[] _ = new double[x.Length];
241        bestConstraintValues = new double[calculateConstraintDelegates.Length];
242        for (int i = 0; i < calculateConstraintDelegates.Length; i++) {
243          bestConstraintValues[i] = calculateConstraintDelegates[i].Invoke((uint)x.Length, x, _, constraintDataPtr[i]);
244        }
245
246        // update parameters in tree
247        var pIdx = 0;
248        // here we lose the two last parameters (for linear scaling)
249        foreach (var node in scaledTree.IterateNodesPostfix()) {
250          if (node is ConstantTreeNode constTreeNode) {
251            constTreeNode.Value = x[pIdx++];
252          } else if (node is VariableTreeNode varTreeNode) {
253            varTreeNode.Weight = x[pIdx++];
254          }
255        }
256        if (pIdx != x.Length) throw new InvalidProgramException();
257      }
258      bestTree = scaledTree;
259    }
260
261    double CalculateObjective(uint dim, double[] curX, double[] grad, IntPtr data) {
262      UpdateThetaValues(curX);
263      var sse = 0.0;
264
265      if (grad != null) {
266        autoDiffEval.Evaluate(preparedTree, problemData.Dataset, trainingRows,
267          preparedTreeParameterNodes, fi_eval, jac_eval);
268
269        // calc sum of squared errors and gradient
270        for (int j = 0; j < grad.Length; j++) grad[j] = 0;
271        for (int i = 0; i < target.Length; i++) {
272          var r = target[i] - fi_eval[i];
273          sse += r * r;
274          for (int j = 0; j < grad.Length; j++) {
275            grad[j] -= 2 * r * jac_eval[i, j];
276          }
277        }
278        // average
279        for (int j = 0; j < grad.Length; j++) { grad[j] /= target.Length; }
280
281        #region check gradient
282        if (grad != null && CheckGradient) {
283          for (int i = 0; i < dim; i++) {
284            // make two additional evaluations
285            var xForNumDiff = (double[])curX.Clone();
286            double delta = Math.Abs(xForNumDiff[i] * 1e-5);
287            xForNumDiff[i] += delta;
288            UpdateThetaValues(xForNumDiff);
289            var evalHigh = eval.Evaluate(preparedTree, problemData.Dataset, trainingRows);
290            var mseHigh = MSE(target, evalHigh);
291            xForNumDiff[i] = curX[i] - delta;
292            UpdateThetaValues(xForNumDiff);
293            var evalLow = eval.Evaluate(preparedTree, problemData.Dataset, trainingRows);
294            var mseLow = MSE(target, evalLow);
295
296            var numericDiff = (mseHigh - mseLow) / (2 * delta);
297            var autoDiff = grad[i];
298            if ((Math.Abs(autoDiff) < 1e-10 && Math.Abs(numericDiff) > 1e-2)
299              || (Math.Abs(autoDiff) >= 1e-10 && Math.Abs((numericDiff - autoDiff) / numericDiff) > 1e-2))
300              throw new InvalidProgramException();
301          }
302        }
303        #endregion
304      } else {
305        var eval = new VectorEvaluator();
306        var prediction = eval.Evaluate(preparedTree, problemData.Dataset, trainingRows);
307
308        // calc sum of squared errors
309        sse = 0.0;
310        for (int i = 0; i < target.Length; i++) {
311          var r = target[i] - prediction[i];
312          sse += r * r;
313        }
314      }
315
316      // UpdateBestSolution(sse / target.Length, curX);
317      RaiseFunctionEvaluated();
318
319      if (double.IsNaN(sse)) {
320        if (grad != null) Array.Clear(grad, 0, grad.Length);
321        return double.MaxValue;
322      }
323      return sse / target.Length;
324    }
325
326    // TODO
327    // private void PreconditionObjective(uint n, double[] x, double[] v, double[] vpre, IntPtr data) {
328    //   UpdateThetaValues(x); // calc H(x)
329    //   
330    //   autoDiffEval.Evaluate(preparedTree, problemData.Dataset, trainingRows,
331    //     preparedTreeParameterNodes, fi_eval, jac_eval);
332    //   var k = jac_eval.GetLength(0);
333    //   var h = new double[n, n];
334    //   
335    //   // calc residuals and scale jac_eval
336    //   var f = 2.0 / (k*k);
337    //
338    //   // approximate hessian H(x) = J(x)^T * J(x)
339    //   alglib.rmatrixgemm((int)n, (int)n, k,
340    //     f, jac_eval, 0, 0, 1,  // transposed
341    //     jac_eval, 0, 0, 0,
342    //     0.0, ref h, 0, 0,
343    //     null
344    //     );
345    //   
346    //
347    //   // scale v
348    //   alglib.rmatrixmv((int)n, (int)n, h, 0, 0, 0, v, 0, ref vpre, 0, alglib.serial);
349    //
350    //
351    //   alglib.spdmatrixcholesky(ref h, (int)n, true);
352    //
353    //   var det = alglib.matdet.spdmatrixcholeskydet(h, (int)n, alglib.serial);
354    // }
355
356
357    private double MSE(double[] a, double[] b) {
358      Trace.Assert(a.Length == b.Length);
359      var sse = 0.0;
360      for (int i = 0; i < a.Length; i++) sse += (a[i] - b[i]) * (a[i] - b[i]);
361      return sse / a.Length;
362    }
363
364    private void UpdateBestSolution(double curF, double[] curX) {
365      if (double.IsNaN(curF) || double.IsInfinity(curF)) return;
366      else if (curF < bestError) {
367        bestError = curF;
368        bestSolution = (double[])curX.Clone();
369      }
370      curError = curF;
371    }
372
373    private void UpdateConstraintViolations(int constraintIdx, double value) {
374      if (double.IsNaN(value) || double.IsInfinity(value)) return;
375      RaiseConstraintEvaluated(constraintIdx, value);
376      // else if (curF < bestError) {
377      //   bestError = curF;
378      //   bestSolution = (double[])curX.Clone();
379      // }
380    }
381
382    double CalculateConstraint(uint dim, double[] curX, double[] grad, IntPtr data) {
383      UpdateThetaValues(curX);
384      var intervalEvaluator = new IntervalEvaluator();
385      var refIntervalEvaluator = new IntervalInterpreter();
386
387      var constraintData = Marshal.PtrToStructure<ConstraintData>(data);
388
389      if (grad != null) Array.Clear(grad, 0, grad.Length);
390
391      var interval = intervalEvaluator.Evaluate(constraintData.Tree, dataIntervals, constraintData.ParameterNodes,
392        out double[] lowerGradient, out double[] upperGradient);
393
394      var refInterval = refIntervalEvaluator.GetSymbolicExpressionTreeInterval(constraintData.Tree, dataIntervals);
395      if (Math.Abs(interval.LowerBound - refInterval.LowerBound) > Math.Abs(interval.LowerBound) * 1e-4) throw new InvalidProgramException($"Intervals don't match. {interval.LowerBound} <> {refInterval.LowerBound}");
396      if (Math.Abs(interval.UpperBound - refInterval.UpperBound) > Math.Abs(interval.UpperBound) * 1e-4) throw new InvalidProgramException($"Intervals don't match. {interval.UpperBound} <> {refInterval.UpperBound}");
397
398      // we transformed this to a constraint c(x) <= 0, so only the upper bound is relevant for us
399      if (grad != null) for (int j = 0; j < grad.Length; j++) { grad[j] = upperGradient[j]; }
400
401      #region check gradient
402      if (grad != null && CheckGradient)
403        for (int i = 0; i < dim; i++) {
404          // make two additional evaluations
405          var xForNumDiff = (double[])curX.Clone();
406          double delta = Math.Abs(xForNumDiff[i] * 1e-5);
407          xForNumDiff[i] += delta;
408          UpdateThetaValues(xForNumDiff);
409          var evalHigh = intervalEvaluator.Evaluate(constraintData.Tree, dataIntervals, constraintData.ParameterNodes,
410            out double[] unusedLowerGradientHigh, out double[] unusedUpperGradientHigh);
411
412          xForNumDiff[i] = curX[i] - delta;
413          UpdateThetaValues(xForNumDiff);
414          var evalLow = intervalEvaluator.Evaluate(constraintData.Tree, dataIntervals, constraintData.ParameterNodes,
415            out double[] unusedLowerGradientLow, out double[] unusedUpperGradientLow);
416
417          var numericDiff = (evalHigh.UpperBound - evalLow.UpperBound) / (2 * delta);
418          var autoDiff = grad[i];
419
420          if ((Math.Abs(autoDiff) < 1e-10 && Math.Abs(numericDiff) > 1e-2)
421            || (Math.Abs(autoDiff) >= 1e-10 && Math.Abs((numericDiff - autoDiff) / numericDiff) > 1e-2))
422            throw new InvalidProgramException();
423        }
424      #endregion
425
426
427      UpdateConstraintViolations(constraintData.Idx, interval.UpperBound);
428      if (double.IsNaN(interval.UpperBound)) {
429        if(grad!=null)Array.Clear(grad, 0, grad.Length);
430        return double.MaxValue;
431      } else return interval.UpperBound;
432    }
433
434
435    void UpdateThetaValues(double[] theta) {
436      for (int i = 0; i < theta.Length; ++i) {
437        foreach (var constNode in allThetaNodes[i]) constNode.Value = theta[i];
438      }
439    }
440
441    internal void RequestStop() {
442      NLOpt.nlopt_set_force_stop(nlopt, 1); // hopefully NLOpt is thread safe  , val must be <> 0 otherwise no effect
443    }
444
445    private void RaiseFunctionEvaluated() {
446      FunctionEvaluated?.Invoke();
447    }
448
449    private void RaiseConstraintEvaluated(int idx, double value) {
450      ConstraintEvaluated?.Invoke(idx, value);
451    }
452
453
454    #region helper
455
456    private static ISymbolicExpressionTree CopyAndScaleTree(ISymbolicExpressionTree tree, double scalingFactor, double offset) {
457      var m = (ISymbolicExpressionTree)tree.Clone();
458
459      var add = MakeNode<Addition>(MakeNode<Multiplication>(m.Root.GetSubtree(0).GetSubtree(0), CreateConstant(scalingFactor)), CreateConstant(offset));
460      m.Root.GetSubtree(0).RemoveSubtree(0);
461      m.Root.GetSubtree(0).AddSubtree(add);
462      return m;
463    }
464
465    private static void UpdateConstants(ISymbolicExpressionTreeNode[] nodes, double[] constants) {
466      if (nodes.Length != constants.Length) throw new InvalidOperationException();
467      for (int i = 0; i < nodes.Length; i++) {
468        if (nodes[i] is VariableTreeNode varNode) varNode.Weight = constants[i];
469        else if (nodes[i] is ConstantTreeNode constNode) constNode.Value = constants[i];
470      }
471    }
472
473    private NLOpt.nlopt_algorithm GetSolver(string solver) {
474      if (solver.Contains("MMA")) return NLOpt.nlopt_algorithm.NLOPT_LD_MMA;
475      if (solver.Contains("COBYLA")) return NLOpt.nlopt_algorithm.NLOPT_LN_COBYLA;
476      if (solver.Contains("CCSAQ")) return NLOpt.nlopt_algorithm.NLOPT_LD_CCSAQ;
477      if (solver.Contains("ISRES")) return NLOpt.nlopt_algorithm.NLOPT_GN_ISRES;
478
479      if (solver.Contains("DIRECT_G")) return NLOpt.nlopt_algorithm.NLOPT_GN_DIRECT;
480      if (solver.Contains("NLOPT_GN_DIRECT_L")) return NLOpt.nlopt_algorithm.NLOPT_GN_DIRECT_L;
481      if (solver.Contains("NLOPT_GN_DIRECT_L_RAND")) return NLOpt.nlopt_algorithm.NLOPT_GN_DIRECT_L_RAND;
482      if (solver.Contains("NLOPT_GN_ORIG_DIRECT")) return NLOpt.nlopt_algorithm.NLOPT_GN_DIRECT;
483      if (solver.Contains("NLOPT_GN_ORIG_DIRECT_L")) return NLOpt.nlopt_algorithm.NLOPT_GN_ORIG_DIRECT_L;
484      if (solver.Contains("NLOPT_GD_STOGO")) return NLOpt.nlopt_algorithm.NLOPT_GD_STOGO;
485      if (solver.Contains("NLOPT_GD_STOGO_RAND")) return NLOpt.nlopt_algorithm.NLOPT_GD_STOGO_RAND;
486      if (solver.Contains("NLOPT_LD_LBFGS_NOCEDAL")) return NLOpt.nlopt_algorithm.NLOPT_LD_LBFGS_NOCEDAL;
487      if (solver.Contains("NLOPT_LD_LBFGS")) return NLOpt.nlopt_algorithm.NLOPT_LD_LBFGS;
488      if (solver.Contains("NLOPT_LN_PRAXIS")) return NLOpt.nlopt_algorithm.NLOPT_LN_PRAXIS;
489      if (solver.Contains("NLOPT_LD_VAR1")) return NLOpt.nlopt_algorithm.NLOPT_LD_VAR1;
490      if (solver.Contains("NLOPT_LD_VAR2")) return NLOpt.nlopt_algorithm.NLOPT_LD_VAR2;
491      if (solver.Contains("NLOPT_LD_TNEWTON")) return NLOpt.nlopt_algorithm.NLOPT_LD_TNEWTON;
492      if (solver.Contains("NLOPT_LD_TNEWTON_RESTART")) return NLOpt.nlopt_algorithm.NLOPT_LD_TNEWTON_RESTART;
493      if (solver.Contains("NLOPT_LD_TNEWTON_PRECOND")) return NLOpt.nlopt_algorithm.NLOPT_LD_TNEWTON_PRECOND;
494      if (solver.Contains("NLOPT_LD_TNEWTON_PRECOND_RESTART")) return NLOpt.nlopt_algorithm.NLOPT_LD_TNEWTON_PRECOND_RESTART;
495      if (solver.Contains("NLOPT_GN_CRS2_LM")) return NLOpt.nlopt_algorithm.NLOPT_GN_CRS2_LM;
496      if (solver.Contains("NLOPT_GN_MLSL")) return NLOpt.nlopt_algorithm.NLOPT_GN_MLSL;
497      if (solver.Contains("NLOPT_GD_MLSL")) return NLOpt.nlopt_algorithm.NLOPT_GD_MLSL;
498      if (solver.Contains("NLOPT_GN_MLSL_LDS")) return NLOpt.nlopt_algorithm.NLOPT_GN_MLSL_LDS;
499      if (solver.Contains("NLOPT_GD_MLSL_LDS")) return NLOpt.nlopt_algorithm.NLOPT_GD_MLSL_LDS;
500      if (solver.Contains("NLOPT_LN_NEWUOA")) return NLOpt.nlopt_algorithm.NLOPT_LN_NEWUOA;
501      if (solver.Contains("NLOPT_LN_NEWUOA_BOUND")) return NLOpt.nlopt_algorithm.NLOPT_LN_NEWUOA_BOUND;
502      if (solver.Contains("NLOPT_LN_NELDERMEAD")) return NLOpt.nlopt_algorithm.NLOPT_LN_NELDERMEAD;
503      if (solver.Contains("NLOPT_LN_SBPLX")) return NLOpt.nlopt_algorithm.NLOPT_LN_SBPLX;
504      if (solver.Contains("NLOPT_LN_AUGLAG")) return NLOpt.nlopt_algorithm.NLOPT_LN_AUGLAG;
505      if (solver.Contains("NLOPT_LD_AUGLAG")) return NLOpt.nlopt_algorithm.NLOPT_LD_AUGLAG;
506      if (solver.Contains("NLOPT_LN_BOBYQA")) return NLOpt.nlopt_algorithm.NLOPT_LN_BOBYQA;
507      if (solver.Contains("NLOPT_AUGLAG")) return NLOpt.nlopt_algorithm.NLOPT_AUGLAG;
508      if (solver.Contains("NLOPT_LD_SLSQP")) return NLOpt.nlopt_algorithm.NLOPT_LD_SLSQP;
509      if (solver.Contains("NLOPT_LD_CCSAQ))")) return NLOpt.nlopt_algorithm.NLOPT_LD_CCSAQ;
510      if (solver.Contains("NLOPT_GN_ESCH")) return NLOpt.nlopt_algorithm.NLOPT_GN_ESCH;
511      if (solver.Contains("NLOPT_GN_AGS")) return NLOpt.nlopt_algorithm.NLOPT_GN_AGS;
512
513      throw new ArgumentException($"Unknown solver {solver}");
514    }
515
516    private static ISymbolicExpressionTreeNode[] GetParameterNodes(ISymbolicExpressionTree tree, List<ConstantTreeNode>[] allNodes) {
517      // TODO better solution necessary
518      var treeConstNodes = tree.IterateNodesPostfix().OfType<ConstantTreeNode>().ToArray();
519      var paramNodes = new ISymbolicExpressionTreeNode[allNodes.Length];
520      for (int i = 0; i < paramNodes.Length; i++) {
521        paramNodes[i] = allNodes[i].SingleOrDefault(n => treeConstNodes.Contains(n));
522      }
523      return paramNodes;
524    }
525
526    private static ISymbolicExpressionTree ReplaceVarWithConst(ISymbolicExpressionTree tree, List<string> thetaNames, List<double> thetaValues, List<ConstantTreeNode>[] thetaNodes) {
527      var copy = (ISymbolicExpressionTree)tree.Clone();
528      var nodes = copy.IterateNodesPostfix().ToList();
529      for (int i = 0; i < nodes.Count; i++) {
530        var n = nodes[i] as VariableTreeNode;
531        if (n != null) {
532          var thetaIdx = thetaNames.IndexOf(n.VariableName);
533          if (thetaIdx >= 0) {
534            var parent = n.Parent;
535            if (thetaNodes[thetaIdx].Any()) {
536              // HACK: REUSE CONSTANT TREE NODE IN SEVERAL TREES
537              // we use this trick to allow autodiff over thetas when thetas occurr multiple times in the tree (e.g. in derived trees)
538              var constNode = thetaNodes[thetaIdx].First();
539              var childIdx = parent.IndexOfSubtree(n);
540              parent.RemoveSubtree(childIdx);
541              parent.InsertSubtree(childIdx, constNode);
542            } else {
543              var constNode = (ConstantTreeNode)CreateConstant(thetaValues[thetaIdx]);
544              var childIdx = parent.IndexOfSubtree(n);
545              parent.RemoveSubtree(childIdx);
546              parent.InsertSubtree(childIdx, constNode);
547              thetaNodes[thetaIdx].Add(constNode);
548            }
549          }
550        }
551      }
552      return copy;
553    }
554
555    private static ISymbolicExpressionTree ReplaceConstWithVar(ISymbolicExpressionTree tree, out List<string> thetaNames, out List<double> thetaValues) {
556      thetaNames = new List<string>();
557      thetaValues = new List<double>();
558      var copy = (ISymbolicExpressionTree)tree.Clone();
559      var nodes = copy.IterateNodesPostfix().ToList();
560
561      int n = 1;
562      for (int i = 0; i < nodes.Count; ++i) {
563        var node = nodes[i];
564        if (node is ConstantTreeNode constantTreeNode) {
565          var thetaVar = (VariableTreeNode)new Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();
566          thetaVar.Weight = 1;
567          thetaVar.VariableName = $"θ{n++}";
568
569          thetaNames.Add(thetaVar.VariableName);
570          thetaValues.Add(constantTreeNode.Value);
571
572          var parent = constantTreeNode.Parent;
573          if (parent != null) {
574            var index = constantTreeNode.Parent.IndexOfSubtree(constantTreeNode);
575            parent.RemoveSubtree(index);
576            parent.InsertSubtree(index, thetaVar);
577          }
578        }
579        if (node is VariableTreeNode varTreeNode) {
580          var thetaVar = (VariableTreeNode)new Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();
581          thetaVar.Weight = 1;
582          thetaVar.VariableName = $"θ{n++}";
583
584          thetaNames.Add(thetaVar.VariableName);
585          thetaValues.Add(varTreeNode.Weight);
586
587          var parent = varTreeNode.Parent;
588          if (parent != null) {
589            var index = varTreeNode.Parent.IndexOfSubtree(varTreeNode);
590            parent.RemoveSubtree(index);
591            var prodNode = MakeNode<Multiplication>();
592            varTreeNode.Weight = 1.0;
593            prodNode.AddSubtree(varTreeNode);
594            prodNode.AddSubtree(thetaVar);
595            parent.InsertSubtree(index, prodNode);
596          }
597        }
598      }
599      return copy;
600    }
601
602    private static ISymbolicExpressionTreeNode CreateConstant(double value) {
603      var constantNode = (ConstantTreeNode)new Constant().CreateTreeNode();
604      constantNode.Value = value;
605      return constantNode;
606    }
607
608    private static ISymbolicExpressionTree Subtract(ISymbolicExpressionTree t, ISymbolicExpressionTreeNode b) {
609      var sub = MakeNode<Subtraction>(t.Root.GetSubtree(0).GetSubtree(0), b);
610      t.Root.GetSubtree(0).RemoveSubtree(0);
611      t.Root.GetSubtree(0).InsertSubtree(0, sub);
612      return t;
613    }
614    private static ISymbolicExpressionTree Subtract(ISymbolicExpressionTreeNode b, ISymbolicExpressionTree t) {
615      var sub = MakeNode<Subtraction>(b, t.Root.GetSubtree(0).GetSubtree(0));
616      t.Root.GetSubtree(0).RemoveSubtree(0);
617      t.Root.GetSubtree(0).InsertSubtree(0, sub);
618      return t;
619    }
620
621    private static ISymbolicExpressionTreeNode MakeNode<T>(params ISymbolicExpressionTreeNode[] fs) where T : ISymbol, new() {
622      var node = new T().CreateTreeNode();
623      foreach (var f in fs) node.AddSubtree(f);
624      return node;
625    }
626
627    public void Dispose() {
628      if (nlopt != IntPtr.Zero) {
629        NLOpt.nlopt_destroy(nlopt);
630        nlopt = IntPtr.Zero;
631      }
632      if (constraintDataPtr != null) {
633        for (int i = 0; i < constraintDataPtr.Length; i++)
634          if (constraintDataPtr[i] != IntPtr.Zero) {
635            Marshal.FreeHGlobal(constraintDataPtr[i]);
636            constraintDataPtr[i] = IntPtr.Zero;
637          }
638      }
639    }
640    #endregion
641  }
642}
Note: See TracBrowser for help on using the repository browser.