Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2994-AutoDiffForIntervals/HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Extensions/ConstrainedNLSInternal.cs @ 17213

Last change on this file since 17213 was 17213, checked in by gkronber, 5 years ago

#2994: fixed a bug caused by cloning of trees, support other NLOpt solvers, implement idisposable, experiment with preconditioning (still not working)

File size: 28.1 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Runtime.InteropServices;
6using HeuristicLab.Common;
7using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
8
9namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
10  internal class ConstrainedNLSInternal : IDisposable {
11    private readonly int maxIterations;
12    public int MaxIterations => maxIterations;
13
14    private readonly string solver;
15    public string Solver => solver;
16
17    private readonly ISymbolicExpressionTree expr;
18    public ISymbolicExpressionTree Expr => expr;
19
20    private readonly IRegressionProblemData problemData;
21
22    public IRegressionProblemData ProblemData => problemData;
23
24
25    public event Action FunctionEvaluated;
26    public event Action<int, double> ConstraintEvaluated;
27
28    private double bestError = double.MaxValue;
29    public double BestError => bestError;
30
31    private double curError = double.MaxValue;
32    public double CurError => curError;
33
34    private double[] bestSolution;
35    public double[] BestSolution => bestSolution;
36
37    private ISymbolicExpressionTree bestTree;
38    public ISymbolicExpressionTree BestTree => bestTree;
39
40    private double[] bestConstraintValues;
41    public double[] BestConstraintValues => bestConstraintValues;
42
43    public bool CheckGradient { get; internal set; }
44
45    // begin internal state
46    private IntPtr nlopt;
47    private SymbolicDataAnalysisExpressionTreeLinearInterpreter interpreter;
48    private readonly NLOpt.nlopt_func calculateObjectiveDelegate; // must hold the delegate to prevent GC
49    private readonly NLOpt.nlopt_precond preconditionDelegate;
50    private readonly IntPtr[] constraintDataPtr; // must hold the objects to prevent GC
51    private readonly NLOpt.nlopt_func[] calculateConstraintDelegates; // must hold the delegates to prevent GC
52    private readonly List<double> thetaValues;
53    private readonly IDictionary<string, Interval> dataIntervals;
54    private readonly int[] trainingRows;
55    private readonly double[] target;
56    private readonly ISymbolicExpressionTree preparedTree;
57    private readonly ISymbolicExpressionTreeNode[] preparedTreeParameterNodes;
58    private readonly List<ConstantTreeNode>[] allThetaNodes;
59    public List<ISymbolicExpressionTree> constraintTrees;    // TODO make local in ctor (public for debugging)
60
61    private readonly double[] fi_eval;
62    private readonly double[,] jac_eval;
63    private readonly ISymbolicExpressionTree scaledTree;
64    private readonly VectorAutoDiffEvaluator autoDiffEval;
65    private readonly VectorEvaluator eval;
66    private readonly bool invalidProblem = false;
67
68    // end internal state
69
70
71    // for data exchange to/from optimizer in native code
72    [StructLayout(LayoutKind.Sequential)]
73    private struct ConstraintData {
74      public int Idx;
75      public ISymbolicExpressionTree Tree;
76      public ISymbolicExpressionTreeNode[] ParameterNodes;
77    }
78
79    internal ConstrainedNLSInternal(string solver, ISymbolicExpressionTree expr, int maxIterations, IRegressionProblemData problemData, double ftol_rel = 0, double ftol_abs = 0, double maxTime = 0) {
80      this.solver = solver;
81      this.expr = expr;
82      this.maxIterations = maxIterations;
83      this.problemData = problemData;
84      this.interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
85      this.autoDiffEval = new VectorAutoDiffEvaluator();
86      this.eval = new VectorEvaluator();
87
88      CheckGradient = false;
89
90      var intervalConstraints = problemData.IntervalConstraints;
91      dataIntervals = problemData.VariableRanges.GetIntervals();
92      trainingRows = problemData.TrainingIndices.ToArray();
93      // buffers
94      target = problemData.TargetVariableTrainingValues.ToArray();
95      var targetStDev = target.StandardDeviationPop();
96      var targetVariance = targetStDev * targetStDev;
97      var targetMean = target.Average();
98      var pred = interpreter.GetSymbolicExpressionTreeValues(expr, problemData.Dataset, trainingRows).ToArray();
99
100      if (pred.Any(pi => double.IsInfinity(pi) || double.IsNaN(pi))) {
101        bestError = targetVariance;
102        invalidProblem = true;
103      }
104
105      #region linear scaling
106      var predStDev = pred.StandardDeviationPop();
107      if (predStDev == 0) {
108        bestError = targetVariance;
109        invalidProblem = true;
110      }
111      var predMean = pred.Average();
112
113      var scalingFactor = targetStDev / predStDev;
114      var offset = targetMean - predMean * scalingFactor;
115
116      scaledTree = CopyAndScaleTree(expr, scalingFactor, offset);
117      #endregion
118
119      // convert constants to variables named theta...
120      var treeForDerivation = ReplaceConstWithVar(scaledTree, out List<string> thetaNames, out thetaValues); // copies the tree
121
122      // create trees for relevant derivatives
123      Dictionary<string, ISymbolicExpressionTree> derivatives = new Dictionary<string, ISymbolicExpressionTree>();
124      allThetaNodes = thetaNames.Select(_ => new List<ConstantTreeNode>()).ToArray();
125      constraintTrees = new List<ISymbolicExpressionTree>();
126      foreach (var constraint in intervalConstraints.Constraints) {
127        if (!constraint.Enabled) continue;
128        if (constraint.IsDerivation) {
129          if (!problemData.AllowedInputVariables.Contains(constraint.Variable))
130            throw new ArgumentException($"Invalid constraint: the variable {constraint.Variable} does not exist in the dataset.");
131          var df = DerivativeCalculator.Derive(treeForDerivation, constraint.Variable);
132
133          // NLOpt requires constraint expressions of the form c(x) <= 0
134          // -> we make two expressions, one for the lower bound and one for the upper bound
135
136          if (constraint.Interval.UpperBound < double.PositiveInfinity) {
137            var df_smaller_upper = Subtract((ISymbolicExpressionTree)df.Clone(), CreateConstant(constraint.Interval.UpperBound));
138            // convert variables named theta back to constants
139            var df_prepared = ReplaceVarWithConst(df_smaller_upper, thetaNames, thetaValues, allThetaNodes);
140            constraintTrees.Add(df_prepared);
141          }
142          if (constraint.Interval.LowerBound > double.NegativeInfinity) {
143            var df_larger_lower = Subtract(CreateConstant(constraint.Interval.LowerBound), (ISymbolicExpressionTree)df.Clone());
144            // convert variables named theta back to constants
145            var df_prepared = ReplaceVarWithConst(df_larger_lower, thetaNames, thetaValues, allThetaNodes);
146            constraintTrees.Add(df_prepared);
147          }
148        } else {
149          if (constraint.Interval.UpperBound < double.PositiveInfinity) {
150            var f_smaller_upper = Subtract((ISymbolicExpressionTree)treeForDerivation.Clone(), CreateConstant(constraint.Interval.UpperBound));
151            // convert variables named theta back to constants
152            var df_prepared = ReplaceVarWithConst(f_smaller_upper, thetaNames, thetaValues, allThetaNodes);
153            constraintTrees.Add(df_prepared);
154          }
155          if (constraint.Interval.LowerBound > double.NegativeInfinity) {
156            var f_larger_lower = Subtract(CreateConstant(constraint.Interval.LowerBound), (ISymbolicExpressionTree)treeForDerivation.Clone());
157            // convert variables named theta back to constants
158            var df_prepared = ReplaceVarWithConst(f_larger_lower, thetaNames, thetaValues, allThetaNodes);
159            constraintTrees.Add(df_prepared);
160          }
161        }
162      }
163
164      preparedTree = ReplaceVarWithConst(treeForDerivation, thetaNames, thetaValues, allThetaNodes);
165      preparedTreeParameterNodes = GetParameterNodes(preparedTree, allThetaNodes);
166
167      var dim = thetaValues.Count;
168      fi_eval = new double[target.Length]; // init buffer;
169      jac_eval = new double[target.Length, dim]; // init buffer
170
171
172      var minVal = Math.Min(-100.0, thetaValues.Min());
173      var maxVal = Math.Max(100.0, thetaValues.Max());
174      var lb = Enumerable.Repeat(minVal, thetaValues.Count).ToArray();
175      var up = Enumerable.Repeat(maxVal, thetaValues.Count).ToArray();
176      nlopt = NLOpt.nlopt_create(GetSolver(solver), (uint)dim);
177
178      NLOpt.nlopt_set_lower_bounds(nlopt, lb);
179      NLOpt.nlopt_set_upper_bounds(nlopt, up);
180      calculateObjectiveDelegate = new NLOpt.nlopt_func(CalculateObjective); // keep a reference to the delegate (see below)
181      NLOpt.nlopt_set_min_objective(nlopt, calculateObjectiveDelegate, IntPtr.Zero); // --> without preconditioning
182
183      //preconditionDelegate = new NLOpt.nlopt_precond(PreconditionObjective);
184      //NLOpt.nlopt_set_precond_min_objective(nlopt, calculateObjectiveDelegate, preconditionDelegate, IntPtr.Zero);
185
186
187      constraintDataPtr = new IntPtr[constraintTrees.Count];
188      calculateConstraintDelegates = new NLOpt.nlopt_func[constraintTrees.Count]; // make sure we keep a reference to the delegates (otherwise GC will free delegate objects see https://stackoverflow.com/questions/7302045/callback-delegates-being-collected#7302258)
189      for (int i = 0; i < constraintTrees.Count; i++) {
190        var constraintData = new ConstraintData() { Idx = i, Tree = constraintTrees[i], ParameterNodes = GetParameterNodes(constraintTrees[i], allThetaNodes) };
191        constraintDataPtr[i] = Marshal.AllocHGlobal(Marshal.SizeOf<ConstraintData>());
192        Marshal.StructureToPtr(constraintData, constraintDataPtr[i], fDeleteOld: false);
193        calculateConstraintDelegates[i] = new NLOpt.nlopt_func(CalculateConstraint);
194        NLOpt.nlopt_add_inequality_constraint(nlopt, calculateConstraintDelegates[i], constraintDataPtr[i], 1e-8);
195        // NLOpt.nlopt_add_precond_inequality_constraint(nlopt, calculateConstraintDelegates[i], preconditionDelegate, constraintDataPtr[i], 1e-8);
196      }
197
198      NLOpt.nlopt_set_ftol_rel(nlopt, ftol_rel);
199      NLOpt.nlopt_set_ftol_abs(nlopt, ftol_abs);
200      NLOpt.nlopt_set_maxtime(nlopt, maxTime);
201      NLOpt.nlopt_set_maxeval(nlopt, maxIterations);
202    }
203
204
205    ~ConstrainedNLSInternal() {
206      Dispose();
207    }
208
209
210    internal void Optimize() {
211      if (invalidProblem) return;
212      var x = thetaValues.ToArray();  /* initial guess */
213      double minf = double.MaxValue; /* minimum objective value upon return */
214      var res = NLOpt.nlopt_optimize(nlopt, x, ref minf);
215      bestSolution = x;
216      bestError = minf;
217
218      if (res < 0 && res != NLOpt.nlopt_result.NLOPT_FORCED_STOP) {
219        throw new InvalidOperationException($"NLOpt failed {res} {NLOpt.nlopt_get_errmsg(nlopt)}");
220      } else {
221        // calculate constraints of final solution
222        double[] _ = new double[x.Length];
223        bestConstraintValues = new double[calculateConstraintDelegates.Length];
224        for (int i = 0; i < calculateConstraintDelegates.Length; i++) {
225          bestConstraintValues[i] = calculateConstraintDelegates[i].Invoke((uint)x.Length, x, _, constraintDataPtr[i]);
226        }
227
228        // update parameters in tree
229        var pIdx = 0;
230        // here we lose the two last parameters (for linear scaling)
231        foreach (var node in scaledTree.IterateNodesPostfix()) {
232          if (node is ConstantTreeNode constTreeNode) {
233            constTreeNode.Value = x[pIdx++];
234          } else if (node is VariableTreeNode varTreeNode) {
235            varTreeNode.Weight = x[pIdx++];
236          }
237        }
238        if (pIdx != x.Length) throw new InvalidProgramException();
239      }
240      bestTree = scaledTree;
241    }
242
243    double CalculateObjective(uint dim, double[] curX, double[] grad, IntPtr data) {
244      UpdateThetaValues(curX);
245      var sse = 0.0;
246
247      if (grad != null) {
248        autoDiffEval.Evaluate(preparedTree, problemData.Dataset, trainingRows,
249          preparedTreeParameterNodes, fi_eval, jac_eval);
250
251        // calc sum of squared errors and gradient
252        for (int j = 0; j < grad.Length; j++) grad[j] = 0;
253        for (int i = 0; i < target.Length; i++) {
254          var r = target[i] - fi_eval[i];
255          sse += r * r;
256          for (int j = 0; j < grad.Length; j++) {
257            grad[j] -= 2 * r * jac_eval[i, j];
258          }
259        }
260        // average
261        for (int j = 0; j < grad.Length; j++) { grad[j] /= target.Length; }
262
263        #region check gradient
264        if (grad != null && CheckGradient) {
265          for (int i = 0; i < dim; i++) {
266            // make two additional evaluations
267            var xForNumDiff = (double[])curX.Clone();
268            double delta = Math.Abs(xForNumDiff[i] * 1e-5);
269            xForNumDiff[i] += delta;
270            UpdateThetaValues(xForNumDiff);
271            var evalHigh = eval.Evaluate(preparedTree, problemData.Dataset, trainingRows);
272            var mseHigh = MSE(target, evalHigh);
273            xForNumDiff[i] = curX[i] - delta;
274            UpdateThetaValues(xForNumDiff);
275            var evalLow = eval.Evaluate(preparedTree, problemData.Dataset, trainingRows);
276            var mseLow = MSE(target, evalLow);
277
278            var numericDiff = (mseHigh - mseLow) / (2 * delta);
279            var autoDiff = grad[i];
280            if ((Math.Abs(autoDiff) < 1e-10 && Math.Abs(numericDiff) > 1e-2)
281              || (Math.Abs(autoDiff) >= 1e-10 && Math.Abs((numericDiff - autoDiff) / numericDiff) > 1e-2))
282              throw new InvalidProgramException();
283          }
284        }
285        #endregion
286      } else {
287        var eval = new VectorEvaluator();
288        var prediction = eval.Evaluate(preparedTree, problemData.Dataset, trainingRows);
289
290        // calc sum of squared errors
291        sse = 0.0;
292        for (int i = 0; i < target.Length; i++) {
293          var r = target[i] - prediction[i];
294          sse += r * r;
295        }
296      }
297
298      UpdateBestSolution(sse / target.Length, curX);
299      RaiseFunctionEvaluated();
300
301      if (double.IsNaN(sse)) return double.MaxValue;
302      return sse / target.Length;
303    }
304
305    // TODO
306    // private void PreconditionObjective(uint n, double[] x, double[] v, double[] vpre, IntPtr data) {
307    //   UpdateThetaValues(x); // calc H(x)
308    //   
309    //   autoDiffEval.Evaluate(preparedTree, problemData.Dataset, trainingRows,
310    //     preparedTreeParameterNodes, fi_eval, jac_eval);
311    //   var k = jac_eval.GetLength(0);
312    //   var h = new double[n, n];
313    //   
314    //   // calc residuals and scale jac_eval
315    //   var f = 2.0 / (k*k);
316    //
317    //   // approximate hessian H(x) = J(x)^T * J(x)
318    //   alglib.rmatrixgemm((int)n, (int)n, k,
319    //     f, jac_eval, 0, 0, 1,  // transposed
320    //     jac_eval, 0, 0, 0,
321    //     0.0, ref h, 0, 0,
322    //     null
323    //     );
324    //   
325    //
326    //   // scale v
327    //   alglib.rmatrixmv((int)n, (int)n, h, 0, 0, 0, v, 0, ref vpre, 0, alglib.serial);
328    //
329    //
330    //   alglib.spdmatrixcholesky(ref h, (int)n, true);
331    //
332    //   var det = alglib.matdet.spdmatrixcholeskydet(h, (int)n, alglib.serial);
333    // }
334
335
336    private double MSE(double[] a, double[] b) {
337      Trace.Assert(a.Length == b.Length);
338      var sse = 0.0;
339      for (int i = 0; i < a.Length; i++) sse += (a[i] - b[i]) * (a[i] - b[i]);
340      return sse / a.Length;
341    }
342
343    private void UpdateBestSolution(double curF, double[] curX) {
344      if (double.IsNaN(curF) || double.IsInfinity(curF)) return;
345      else if (curF < bestError) {
346        bestError = curF;
347        bestSolution = (double[])curX.Clone();
348      }
349      curError = curF;
350    }
351
352    private void UpdateConstraintViolations(int constraintIdx, double value) {
353      if (double.IsNaN(value) || double.IsInfinity(value)) return;
354      RaiseConstraintEvaluated(constraintIdx, value);
355      // else if (curF < bestError) {
356      //   bestError = curF;
357      //   bestSolution = (double[])curX.Clone();
358      // }
359    }
360
361    double CalculateConstraint(uint dim, double[] curX, double[] grad, IntPtr data) {
362      UpdateThetaValues(curX);
363      var intervalEvaluator = new IntervalEvaluator();
364      var constraintData = Marshal.PtrToStructure<ConstraintData>(data);
365
366      if (grad != null) for (int j = 0; j < grad.Length; j++) grad[j] = 0; // clear grad
367
368      var interval = intervalEvaluator.Evaluate(constraintData.Tree, dataIntervals, constraintData.ParameterNodes,
369        out double[] lowerGradient, out double[] upperGradient);
370
371      // we transformed this to a constraint c(x) <= 0, so only the upper bound is relevant for us
372      if (grad != null) for (int j = 0; j < grad.Length; j++) { grad[j] = upperGradient[j]; }
373
374      #region check gradient
375      if (grad != null && CheckGradient)
376        for (int i = 0; i < dim; i++) {
377          // make two additional evaluations
378          var xForNumDiff = (double[])curX.Clone();
379          double delta = Math.Abs(xForNumDiff[i] * 1e-5);
380          xForNumDiff[i] += delta;
381          UpdateThetaValues(xForNumDiff);
382          var evalHigh = intervalEvaluator.Evaluate(constraintData.Tree, dataIntervals, constraintData.ParameterNodes,
383            out double[] unusedLowerGradientHigh, out double[] unusedUpperGradientHigh);
384
385          xForNumDiff[i] = curX[i] - delta;
386          UpdateThetaValues(xForNumDiff);
387          var evalLow = intervalEvaluator.Evaluate(constraintData.Tree, dataIntervals, constraintData.ParameterNodes,
388            out double[] unusedLowerGradientLow, out double[] unusedUpperGradientLow);
389
390          var numericDiff = (evalHigh.UpperBound - evalLow.UpperBound) / (2 * delta);
391          var autoDiff = grad[i];
392
393          if ((Math.Abs(autoDiff) < 1e-10 && Math.Abs(numericDiff) > 1e-2)
394            || (Math.Abs(autoDiff) >= 1e-10 && Math.Abs((numericDiff - autoDiff) / numericDiff) > 1e-2))
395            throw new InvalidProgramException();
396        }
397      #endregion
398
399
400      UpdateConstraintViolations(constraintData.Idx, interval.UpperBound);
401      if (double.IsNaN(interval.UpperBound)) return double.MaxValue;
402      else return interval.UpperBound;
403    }
404
405
406    void UpdateThetaValues(double[] theta) {
407      for (int i = 0; i < theta.Length; ++i) {
408        foreach (var constNode in allThetaNodes[i]) constNode.Value = theta[i];
409      }
410    }
411
412    internal void RequestStop() {
413      NLOpt.nlopt_set_force_stop(nlopt, 1); // hopefully NLOpt is thread safe  , val must be <> 0 otherwise no effect
414    }
415
416    private void RaiseFunctionEvaluated() {
417      FunctionEvaluated?.Invoke();
418    }
419
420    private void RaiseConstraintEvaluated(int idx, double value) {
421      ConstraintEvaluated?.Invoke(idx, value);
422    }
423
424
425    #region helper
426
427    private static ISymbolicExpressionTree CopyAndScaleTree(ISymbolicExpressionTree tree, double scalingFactor, double offset) {
428      var m = (ISymbolicExpressionTree)tree.Clone();
429
430      var add = MakeNode<Addition>(MakeNode<Multiplication>(m.Root.GetSubtree(0).GetSubtree(0), CreateConstant(scalingFactor)), CreateConstant(offset));
431      m.Root.GetSubtree(0).RemoveSubtree(0);
432      m.Root.GetSubtree(0).AddSubtree(add);
433      return m;
434    }
435
436    private static void UpdateConstants(ISymbolicExpressionTreeNode[] nodes, double[] constants) {
437      if (nodes.Length != constants.Length) throw new InvalidOperationException();
438      for (int i = 0; i < nodes.Length; i++) {
439        if (nodes[i] is VariableTreeNode varNode) varNode.Weight = constants[i];
440        else if (nodes[i] is ConstantTreeNode constNode) constNode.Value = constants[i];
441      }
442    }
443
444    private NLOpt.nlopt_algorithm GetSolver(string solver) {
445      if (solver.Contains("MMA")) return NLOpt.nlopt_algorithm.NLOPT_LD_MMA;
446      if (solver.Contains("COBYLA")) return NLOpt.nlopt_algorithm.NLOPT_LN_COBYLA;
447      if (solver.Contains("CCSAQ")) return NLOpt.nlopt_algorithm.NLOPT_LD_CCSAQ;
448      if (solver.Contains("ISRES")) return NLOpt.nlopt_algorithm.NLOPT_GN_ISRES;
449
450      if (solver.Contains("DIRECT_G")) return NLOpt.nlopt_algorithm.NLOPT_GN_DIRECT;
451      if (solver.Contains("NLOPT_GN_DIRECT_L")) return NLOpt.nlopt_algorithm.NLOPT_GN_DIRECT_L;
452      if (solver.Contains("NLOPT_GN_DIRECT_L_RAND")) return NLOpt.nlopt_algorithm.NLOPT_GN_DIRECT_L_RAND;
453      if (solver.Contains("NLOPT_GN_ORIG_DIRECT")) return NLOpt.nlopt_algorithm.NLOPT_GN_DIRECT;
454      if (solver.Contains("NLOPT_GN_ORIG_DIRECT_L")) return NLOpt.nlopt_algorithm.NLOPT_GN_ORIG_DIRECT_L;
455      if (solver.Contains("NLOPT_GD_STOGO")) return NLOpt.nlopt_algorithm.NLOPT_GD_STOGO;
456      if (solver.Contains("NLOPT_GD_STOGO_RAND")) return NLOpt.nlopt_algorithm.NLOPT_GD_STOGO_RAND;
457      if (solver.Contains("NLOPT_LD_LBFGS_NOCEDAL")) return NLOpt.nlopt_algorithm.NLOPT_LD_LBFGS_NOCEDAL;
458      if (solver.Contains("NLOPT_LD_LBFGS")) return NLOpt.nlopt_algorithm.NLOPT_LD_LBFGS;
459      if (solver.Contains("NLOPT_LN_PRAXIS")) return NLOpt.nlopt_algorithm.NLOPT_LN_PRAXIS;
460      if (solver.Contains("NLOPT_LD_VAR1")) return NLOpt.nlopt_algorithm.NLOPT_LD_VAR1;
461      if (solver.Contains("NLOPT_LD_VAR2")) return NLOpt.nlopt_algorithm.NLOPT_LD_VAR2;
462      if (solver.Contains("NLOPT_LD_TNEWTON")) return NLOpt.nlopt_algorithm.NLOPT_LD_TNEWTON;
463      if (solver.Contains("NLOPT_LD_TNEWTON_RESTART")) return NLOpt.nlopt_algorithm.NLOPT_LD_TNEWTON_RESTART;
464      if (solver.Contains("NLOPT_LD_TNEWTON_PRECOND")) return NLOpt.nlopt_algorithm.NLOPT_LD_TNEWTON_PRECOND;
465      if (solver.Contains("NLOPT_LD_TNEWTON_PRECOND_RESTART")) return NLOpt.nlopt_algorithm.NLOPT_LD_TNEWTON_PRECOND_RESTART;
466      if (solver.Contains("NLOPT_GN_CRS2_LM")) return NLOpt.nlopt_algorithm.NLOPT_GN_CRS2_LM;
467      if (solver.Contains("NLOPT_GN_MLSL")) return NLOpt.nlopt_algorithm.NLOPT_GN_MLSL;
468      if (solver.Contains("NLOPT_GD_MLSL")) return NLOpt.nlopt_algorithm.NLOPT_GD_MLSL;
469      if (solver.Contains("NLOPT_GN_MLSL_LDS")) return NLOpt.nlopt_algorithm.NLOPT_GN_MLSL_LDS;
470      if (solver.Contains("NLOPT_GD_MLSL_LDS")) return NLOpt.nlopt_algorithm.NLOPT_GD_MLSL_LDS;
471      if (solver.Contains("NLOPT_LN_NEWUOA")) return NLOpt.nlopt_algorithm.NLOPT_LN_NEWUOA;
472      if (solver.Contains("NLOPT_LN_NEWUOA_BOUND")) return NLOpt.nlopt_algorithm.NLOPT_LN_NEWUOA_BOUND;
473      if (solver.Contains("NLOPT_LN_NELDERMEAD")) return NLOpt.nlopt_algorithm.NLOPT_LN_NELDERMEAD;
474      if (solver.Contains("NLOPT_LN_SBPLX")) return NLOpt.nlopt_algorithm.NLOPT_LN_SBPLX;
475      if (solver.Contains("NLOPT_LN_AUGLAG")) return NLOpt.nlopt_algorithm.NLOPT_LN_AUGLAG;
476      if (solver.Contains("NLOPT_LD_AUGLAG")) return NLOpt.nlopt_algorithm.NLOPT_LD_AUGLAG;
477      if (solver.Contains("NLOPT_LN_BOBYQA")) return NLOpt.nlopt_algorithm.NLOPT_LN_BOBYQA;
478      if (solver.Contains("NLOPT_AUGLAG")) return NLOpt.nlopt_algorithm.NLOPT_AUGLAG;
479      if (solver.Contains("NLOPT_LD_SLSQP")) return NLOpt.nlopt_algorithm.NLOPT_LD_SLSQP;
480      if (solver.Contains("NLOPT_LD_CCSAQ))")) return NLOpt.nlopt_algorithm.NLOPT_LD_CCSAQ;
481      if (solver.Contains("NLOPT_GN_ESCH")) return NLOpt.nlopt_algorithm.NLOPT_GN_ESCH;
482      if (solver.Contains("NLOPT_GN_AGS")) return NLOpt.nlopt_algorithm.NLOPT_GN_AGS;
483
484      throw new ArgumentException($"Unknown solver {solver}");
485    }
486
487    private static ISymbolicExpressionTreeNode[] GetParameterNodes(ISymbolicExpressionTree tree, List<ConstantTreeNode>[] allNodes) {
488      // TODO better solution necessary
489      var treeConstNodes = tree.IterateNodesPostfix().OfType<ConstantTreeNode>().ToArray();
490      var paramNodes = new ISymbolicExpressionTreeNode[allNodes.Length];
491      for (int i = 0; i < paramNodes.Length; i++) {
492        paramNodes[i] = allNodes[i].SingleOrDefault(n => treeConstNodes.Contains(n));
493      }
494      return paramNodes;
495    }
496
497    private static ISymbolicExpressionTree ReplaceVarWithConst(ISymbolicExpressionTree tree, List<string> thetaNames, List<double> thetaValues, List<ConstantTreeNode>[] thetaNodes) {
498      var copy = (ISymbolicExpressionTree)tree.Clone();
499      var nodes = copy.IterateNodesPostfix().ToList();
500      for (int i = 0; i < nodes.Count; i++) {
501        var n = nodes[i] as VariableTreeNode;
502        if (n != null) {
503          var thetaIdx = thetaNames.IndexOf(n.VariableName);
504          if (thetaIdx >= 0) {
505            var parent = n.Parent;
506            if (thetaNodes[thetaIdx].Any()) {
507              // HACK: REUSE CONSTANT TREE NODE IN SEVERAL TREES
508              // we use this trick to allow autodiff over thetas when thetas occurr multiple times in the tree (e.g. in derived trees)
509              var constNode = thetaNodes[thetaIdx].First();
510              var childIdx = parent.IndexOfSubtree(n);
511              parent.RemoveSubtree(childIdx);
512              parent.InsertSubtree(childIdx, constNode);
513            } else {
514              var constNode = (ConstantTreeNode)CreateConstant(thetaValues[thetaIdx]);
515              var childIdx = parent.IndexOfSubtree(n);
516              parent.RemoveSubtree(childIdx);
517              parent.InsertSubtree(childIdx, constNode);
518              thetaNodes[thetaIdx].Add(constNode);
519            }
520          }
521        }
522      }
523      return copy;
524    }
525
526    private static ISymbolicExpressionTree ReplaceConstWithVar(ISymbolicExpressionTree tree, out List<string> thetaNames, out List<double> thetaValues) {
527      thetaNames = new List<string>();
528      thetaValues = new List<double>();
529      var copy = (ISymbolicExpressionTree)tree.Clone();
530      var nodes = copy.IterateNodesPostfix().ToList();
531
532      int n = 1;
533      for (int i = 0; i < nodes.Count; ++i) {
534        var node = nodes[i];
535        if (node is ConstantTreeNode constantTreeNode) {
536          var thetaVar = (VariableTreeNode)new Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();
537          thetaVar.Weight = 1;
538          thetaVar.VariableName = $"θ{n++}";
539
540          thetaNames.Add(thetaVar.VariableName);
541          thetaValues.Add(constantTreeNode.Value);
542
543          var parent = constantTreeNode.Parent;
544          if (parent != null) {
545            var index = constantTreeNode.Parent.IndexOfSubtree(constantTreeNode);
546            parent.RemoveSubtree(index);
547            parent.InsertSubtree(index, thetaVar);
548          }
549        }
550        if (node is VariableTreeNode varTreeNode) {
551          var thetaVar = (VariableTreeNode)new Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();
552          thetaVar.Weight = 1;
553          thetaVar.VariableName = $"θ{n++}";
554
555          thetaNames.Add(thetaVar.VariableName);
556          thetaValues.Add(varTreeNode.Weight);
557
558          var parent = varTreeNode.Parent;
559          if (parent != null) {
560            var index = varTreeNode.Parent.IndexOfSubtree(varTreeNode);
561            parent.RemoveSubtree(index);
562            var prodNode = MakeNode<Multiplication>();
563            varTreeNode.Weight = 1.0;
564            prodNode.AddSubtree(varTreeNode);
565            prodNode.AddSubtree(thetaVar);
566            parent.InsertSubtree(index, prodNode);
567          }
568        }
569      }
570      return copy;
571    }
572
573    private static ISymbolicExpressionTreeNode CreateConstant(double value) {
574      var constantNode = (ConstantTreeNode)new Constant().CreateTreeNode();
575      constantNode.Value = value;
576      return constantNode;
577    }
578
579    private static ISymbolicExpressionTree Subtract(ISymbolicExpressionTree t, ISymbolicExpressionTreeNode b) {
580      var sub = MakeNode<Subtraction>(t.Root.GetSubtree(0).GetSubtree(0), b);
581      t.Root.GetSubtree(0).RemoveSubtree(0);
582      t.Root.GetSubtree(0).InsertSubtree(0, sub);
583      return t;
584    }
585    private static ISymbolicExpressionTree Subtract(ISymbolicExpressionTreeNode b, ISymbolicExpressionTree t) {
586      var sub = MakeNode<Subtraction>(b, t.Root.GetSubtree(0).GetSubtree(0));
587      t.Root.GetSubtree(0).RemoveSubtree(0);
588      t.Root.GetSubtree(0).InsertSubtree(0, sub);
589      return t;
590    }
591
592    private static ISymbolicExpressionTreeNode MakeNode<T>(params ISymbolicExpressionTreeNode[] fs) where T : ISymbol, new() {
593      var node = new T().CreateTreeNode();
594      foreach (var f in fs) node.AddSubtree(f);
595      return node;
596    }
597
598    public void Dispose() {
599      if (nlopt != IntPtr.Zero) {
600        NLOpt.nlopt_destroy(nlopt);
601        nlopt = IntPtr.Zero;
602      }
603      if (constraintDataPtr != null) {
604        for (int i = 0; i < constraintDataPtr.Length; i++)
605          if (constraintDataPtr[i] != IntPtr.Zero) {
606            Marshal.FreeHGlobal(constraintDataPtr[i]);
607            constraintDataPtr[i] = IntPtr.Zero;
608          }
609      }
610    }
611    #endregion
612  }
613}
Note: See TracBrowser for help on using the repository browser.