Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2994-AutoDiffForIntervals/HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Extensions/ConstrainedNLSInternal.cs @ 17240

Last change on this file since 17240 was 17240, checked in by gkronber, 5 years ago

#2994: some minor improvements

File size: 28.2 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Runtime.InteropServices;
6using HeuristicLab.Common;
7using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
8
9namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
10  internal class ConstrainedNLSInternal : IDisposable {
11    private readonly int maxIterations;
12    public int MaxIterations => maxIterations;
13
14    private readonly string solver;
15    public string Solver => solver;
16
17    private readonly ISymbolicExpressionTree expr;
18    public ISymbolicExpressionTree Expr => expr;
19
20    private readonly IRegressionProblemData problemData;
21
22    public IRegressionProblemData ProblemData => problemData;
23
24
25    public event Action FunctionEvaluated;
26    public event Action<int, double> ConstraintEvaluated;
27
28    private double bestError = double.MaxValue;
29    public double BestError => bestError;
30
31    private double curError = double.MaxValue;
32    public double CurError => curError;
33
34    private double[] bestSolution;
35    public double[] BestSolution => bestSolution;
36
37    private ISymbolicExpressionTree bestTree;
38    public ISymbolicExpressionTree BestTree => bestTree;
39
40    private double[] bestConstraintValues;
41    public double[] BestConstraintValues => bestConstraintValues;
42
43    public bool CheckGradient { get; internal set; }
44
45    // begin internal state
46    private IntPtr nlopt;
47    private SymbolicDataAnalysisExpressionTreeLinearInterpreter interpreter;
48    private readonly NLOpt.nlopt_func calculateObjectiveDelegate; // must hold the delegate to prevent GC
49    private readonly NLOpt.nlopt_precond preconditionDelegate;
50    private readonly IntPtr[] constraintDataPtr; // must hold the objects to prevent GC
51    private readonly NLOpt.nlopt_func[] calculateConstraintDelegates; // must hold the delegates to prevent GC
52    private readonly List<double> thetaValues;
53    private readonly IDictionary<string, Interval> dataIntervals;
54    private readonly int[] trainingRows;
55    private readonly double[] target;
56    private readonly ISymbolicExpressionTree preparedTree;
57    private readonly ISymbolicExpressionTreeNode[] preparedTreeParameterNodes;
58    private readonly List<ConstantTreeNode>[] allThetaNodes;
59    public List<ISymbolicExpressionTree> constraintTrees;    // TODO make local in ctor (public for debugging)
60
61    private readonly double[] fi_eval;
62    private readonly double[,] jac_eval;
63    private readonly ISymbolicExpressionTree scaledTree;
64    private readonly VectorAutoDiffEvaluator autoDiffEval;
65    private readonly VectorEvaluator eval;
66    private readonly bool invalidProblem = false;
67
68    // end internal state
69
70
71    // for data exchange to/from optimizer in native code
72    [StructLayout(LayoutKind.Sequential)]
73    private struct ConstraintData {
74      public int Idx;
75      public ISymbolicExpressionTree Tree;
76      public ISymbolicExpressionTreeNode[] ParameterNodes;
77    }
78
79    internal ConstrainedNLSInternal(string solver, ISymbolicExpressionTree expr, int maxIterations, IRegressionProblemData problemData, double ftol_rel = 0, double ftol_abs = 0, double maxTime = 0) {
80      this.solver = solver;
81      this.expr = expr;
82      this.maxIterations = maxIterations;
83      this.problemData = problemData;
84      this.interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
85      this.autoDiffEval = new VectorAutoDiffEvaluator();
86      this.eval = new VectorEvaluator();
87
88      CheckGradient = false;
89
90      var intervalConstraints = problemData.IntervalConstraints;
91      dataIntervals = problemData.VariableRanges.GetIntervals();
92      trainingRows = problemData.TrainingIndices.ToArray();
93      // buffers
94      target = problemData.TargetVariableTrainingValues.ToArray();
95      var targetStDev = target.StandardDeviationPop();
96      var targetVariance = targetStDev * targetStDev;
97      var targetMean = target.Average();
98      var pred = interpreter.GetSymbolicExpressionTreeValues(expr, problemData.Dataset, trainingRows).ToArray();
99
100      bestError = targetVariance;
101
102      if (pred.Any(pi => double.IsInfinity(pi) || double.IsNaN(pi))) {
103        invalidProblem = true;
104      }
105
106      #region linear scaling
107      var predStDev = pred.StandardDeviationPop();
108      if (predStDev == 0) {
109        invalidProblem = true;
110      }
111      var predMean = pred.Average();
112
113      var scalingFactor = targetStDev / predStDev;
114      var offset = targetMean - predMean * scalingFactor;
115
116      scaledTree = CopyAndScaleTree(expr, scalingFactor, offset);
117      #endregion
118
119      // convert constants to variables named theta...
120      var treeForDerivation = ReplaceConstWithVar(scaledTree, out List<string> thetaNames, out thetaValues); // copies the tree
121
122      // create trees for relevant derivatives
123      Dictionary<string, ISymbolicExpressionTree> derivatives = new Dictionary<string, ISymbolicExpressionTree>();
124      allThetaNodes = thetaNames.Select(_ => new List<ConstantTreeNode>()).ToArray();
125      constraintTrees = new List<ISymbolicExpressionTree>();
126      foreach (var constraint in intervalConstraints.Constraints) {
127        if (!constraint.Enabled) continue;
128        if (constraint.IsDerivation) {
129          if (!problemData.AllowedInputVariables.Contains(constraint.Variable))
130            throw new ArgumentException($"Invalid constraint: the variable {constraint.Variable} does not exist in the dataset.");
131          var df = DerivativeCalculator.Derive(treeForDerivation, constraint.Variable);
132
133          // NLOpt requires constraint expressions of the form c(x) <= 0
134          // -> we make two expressions, one for the lower bound and one for the upper bound
135
136          if (constraint.Interval.UpperBound < double.PositiveInfinity) {
137            var df_smaller_upper = Subtract((ISymbolicExpressionTree)df.Clone(), CreateConstant(constraint.Interval.UpperBound));
138            // convert variables named theta back to constants
139            var df_prepared = ReplaceVarWithConst(df_smaller_upper, thetaNames, thetaValues, allThetaNodes);
140            constraintTrees.Add(df_prepared);
141          }
142          if (constraint.Interval.LowerBound > double.NegativeInfinity) {
143            var df_larger_lower = Subtract(CreateConstant(constraint.Interval.LowerBound), (ISymbolicExpressionTree)df.Clone());
144            // convert variables named theta back to constants
145            var df_prepared = ReplaceVarWithConst(df_larger_lower, thetaNames, thetaValues, allThetaNodes);
146            constraintTrees.Add(df_prepared);
147          }
148        } else {
149          if (constraint.Interval.UpperBound < double.PositiveInfinity) {
150            var f_smaller_upper = Subtract((ISymbolicExpressionTree)treeForDerivation.Clone(), CreateConstant(constraint.Interval.UpperBound));
151            // convert variables named theta back to constants
152            var df_prepared = ReplaceVarWithConst(f_smaller_upper, thetaNames, thetaValues, allThetaNodes);
153            constraintTrees.Add(df_prepared);
154          }
155          if (constraint.Interval.LowerBound > double.NegativeInfinity) {
156            var f_larger_lower = Subtract(CreateConstant(constraint.Interval.LowerBound), (ISymbolicExpressionTree)treeForDerivation.Clone());
157            // convert variables named theta back to constants
158            var df_prepared = ReplaceVarWithConst(f_larger_lower, thetaNames, thetaValues, allThetaNodes);
159            constraintTrees.Add(df_prepared);
160          }
161        }
162      }
163
164      preparedTree = ReplaceVarWithConst(treeForDerivation, thetaNames, thetaValues, allThetaNodes);
165      preparedTreeParameterNodes = GetParameterNodes(preparedTree, allThetaNodes);
166
167      var dim = thetaValues.Count;
168      fi_eval = new double[target.Length]; // init buffer;
169      jac_eval = new double[target.Length, dim]; // init buffer
170
171
172      var minVal = Math.Min(-1000.0, thetaValues.Min());
173      var maxVal = Math.Max(1000.0, thetaValues.Max());
174      var lb = Enumerable.Repeat(minVal, thetaValues.Count).ToArray();
175      var up = Enumerable.Repeat(maxVal, thetaValues.Count).ToArray();
176      nlopt = NLOpt.nlopt_create(GetSolver(solver), (uint)dim);
177
178      NLOpt.nlopt_set_lower_bounds(nlopt, lb);
179      NLOpt.nlopt_set_upper_bounds(nlopt, up);
180      calculateObjectiveDelegate = new NLOpt.nlopt_func(CalculateObjective); // keep a reference to the delegate (see below)
181      NLOpt.nlopt_set_min_objective(nlopt, calculateObjectiveDelegate, IntPtr.Zero); // --> without preconditioning
182
183      //preconditionDelegate = new NLOpt.nlopt_precond(PreconditionObjective);
184      //NLOpt.nlopt_set_precond_min_objective(nlopt, calculateObjectiveDelegate, preconditionDelegate, IntPtr.Zero);
185
186
187      constraintDataPtr = new IntPtr[constraintTrees.Count];
188      calculateConstraintDelegates = new NLOpt.nlopt_func[constraintTrees.Count]; // make sure we keep a reference to the delegates (otherwise GC will free delegate objects see https://stackoverflow.com/questions/7302045/callback-delegates-being-collected#7302258)
189      for (int i = 0; i < constraintTrees.Count; i++) {
190        var constraintData = new ConstraintData() { Idx = i, Tree = constraintTrees[i], ParameterNodes = GetParameterNodes(constraintTrees[i], allThetaNodes) };
191        constraintDataPtr[i] = Marshal.AllocHGlobal(Marshal.SizeOf<ConstraintData>());
192        Marshal.StructureToPtr(constraintData, constraintDataPtr[i], fDeleteOld: false);
193        calculateConstraintDelegates[i] = new NLOpt.nlopt_func(CalculateConstraint);
194        NLOpt.nlopt_add_inequality_constraint(nlopt, calculateConstraintDelegates[i], constraintDataPtr[i], 1e-8);
195        // NLOpt.nlopt_add_precond_inequality_constraint(nlopt, calculateConstraintDelegates[i], preconditionDelegate, constraintDataPtr[i], 1e-8);
196      }
197
198      NLOpt.nlopt_set_ftol_rel(nlopt, ftol_rel);
199      NLOpt.nlopt_set_ftol_abs(nlopt, ftol_abs);
200      NLOpt.nlopt_set_maxtime(nlopt, maxTime);
201      NLOpt.nlopt_set_maxeval(nlopt, maxIterations);
202    }
203
204
205    ~ConstrainedNLSInternal() {
206      Dispose();
207    }
208
209
210    internal void Optimize() {
211      if (invalidProblem) return;
212      var x = thetaValues.ToArray();  /* initial guess */
213      double minf = double.MaxValue; /* minimum objective value upon return */
214      var res = NLOpt.nlopt_optimize(nlopt, x, ref minf);
215
216      if (res < 0 && res != NLOpt.nlopt_result.NLOPT_FORCED_STOP) {
217        // throw new InvalidOperationException($"NLOpt failed {res} {NLOpt.nlopt_get_errmsg(nlopt)}");
218        return;
219      } else if(minf <= bestError) {
220        bestSolution = x;
221        bestError = minf;
222
223        // calculate constraints of final solution
224        double[] _ = new double[x.Length];
225        bestConstraintValues = new double[calculateConstraintDelegates.Length];
226        for (int i = 0; i < calculateConstraintDelegates.Length; i++) {
227          bestConstraintValues[i] = calculateConstraintDelegates[i].Invoke((uint)x.Length, x, _, constraintDataPtr[i]);
228        }
229
230        // update parameters in tree
231        var pIdx = 0;
232        // here we lose the two last parameters (for linear scaling)
233        foreach (var node in scaledTree.IterateNodesPostfix()) {
234          if (node is ConstantTreeNode constTreeNode) {
235            constTreeNode.Value = x[pIdx++];
236          } else if (node is VariableTreeNode varTreeNode) {
237            varTreeNode.Weight = x[pIdx++];
238          }
239        }
240        if (pIdx != x.Length) throw new InvalidProgramException();
241      }
242      bestTree = scaledTree;
243    }
244
245    double CalculateObjective(uint dim, double[] curX, double[] grad, IntPtr data) {
246      UpdateThetaValues(curX);
247      var sse = 0.0;
248
249      if (grad != null) {
250        autoDiffEval.Evaluate(preparedTree, problemData.Dataset, trainingRows,
251          preparedTreeParameterNodes, fi_eval, jac_eval);
252
253        // calc sum of squared errors and gradient
254        for (int j = 0; j < grad.Length; j++) grad[j] = 0;
255        for (int i = 0; i < target.Length; i++) {
256          var r = target[i] - fi_eval[i];
257          sse += r * r;
258          for (int j = 0; j < grad.Length; j++) {
259            grad[j] -= 2 * r * jac_eval[i, j];
260          }
261        }
262        // average
263        for (int j = 0; j < grad.Length; j++) { grad[j] /= target.Length; }
264
265        #region check gradient
266        if (grad != null && CheckGradient) {
267          for (int i = 0; i < dim; i++) {
268            // make two additional evaluations
269            var xForNumDiff = (double[])curX.Clone();
270            double delta = Math.Abs(xForNumDiff[i] * 1e-5);
271            xForNumDiff[i] += delta;
272            UpdateThetaValues(xForNumDiff);
273            var evalHigh = eval.Evaluate(preparedTree, problemData.Dataset, trainingRows);
274            var mseHigh = MSE(target, evalHigh);
275            xForNumDiff[i] = curX[i] - delta;
276            UpdateThetaValues(xForNumDiff);
277            var evalLow = eval.Evaluate(preparedTree, problemData.Dataset, trainingRows);
278            var mseLow = MSE(target, evalLow);
279
280            var numericDiff = (mseHigh - mseLow) / (2 * delta);
281            var autoDiff = grad[i];
282            if ((Math.Abs(autoDiff) < 1e-10 && Math.Abs(numericDiff) > 1e-2)
283              || (Math.Abs(autoDiff) >= 1e-10 && Math.Abs((numericDiff - autoDiff) / numericDiff) > 1e-2))
284              throw new InvalidProgramException();
285          }
286        }
287        #endregion
288      } else {
289        var eval = new VectorEvaluator();
290        var prediction = eval.Evaluate(preparedTree, problemData.Dataset, trainingRows);
291
292        // calc sum of squared errors
293        sse = 0.0;
294        for (int i = 0; i < target.Length; i++) {
295          var r = target[i] - prediction[i];
296          sse += r * r;
297        }
298      }
299
300      // UpdateBestSolution(sse / target.Length, curX);
301      RaiseFunctionEvaluated();
302
303      if (double.IsNaN(sse)) {
304        if(grad!=null) Array.Clear(grad, 0, grad.Length);
305        return double.MaxValue;
306      }
307      return sse / target.Length;
308    }
309
310    // TODO
311    // private void PreconditionObjective(uint n, double[] x, double[] v, double[] vpre, IntPtr data) {
312    //   UpdateThetaValues(x); // calc H(x)
313    //   
314    //   autoDiffEval.Evaluate(preparedTree, problemData.Dataset, trainingRows,
315    //     preparedTreeParameterNodes, fi_eval, jac_eval);
316    //   var k = jac_eval.GetLength(0);
317    //   var h = new double[n, n];
318    //   
319    //   // calc residuals and scale jac_eval
320    //   var f = 2.0 / (k*k);
321    //
322    //   // approximate hessian H(x) = J(x)^T * J(x)
323    //   alglib.rmatrixgemm((int)n, (int)n, k,
324    //     f, jac_eval, 0, 0, 1,  // transposed
325    //     jac_eval, 0, 0, 0,
326    //     0.0, ref h, 0, 0,
327    //     null
328    //     );
329    //   
330    //
331    //   // scale v
332    //   alglib.rmatrixmv((int)n, (int)n, h, 0, 0, 0, v, 0, ref vpre, 0, alglib.serial);
333    //
334    //
335    //   alglib.spdmatrixcholesky(ref h, (int)n, true);
336    //
337    //   var det = alglib.matdet.spdmatrixcholeskydet(h, (int)n, alglib.serial);
338    // }
339
340
341    private double MSE(double[] a, double[] b) {
342      Trace.Assert(a.Length == b.Length);
343      var sse = 0.0;
344      for (int i = 0; i < a.Length; i++) sse += (a[i] - b[i]) * (a[i] - b[i]);
345      return sse / a.Length;
346    }
347
348    private void UpdateBestSolution(double curF, double[] curX) {
349      if (double.IsNaN(curF) || double.IsInfinity(curF)) return;
350      else if (curF < bestError) {
351        bestError = curF;
352        bestSolution = (double[])curX.Clone();
353      }
354      curError = curF;
355    }
356
357    private void UpdateConstraintViolations(int constraintIdx, double value) {
358      if (double.IsNaN(value) || double.IsInfinity(value)) return;
359      RaiseConstraintEvaluated(constraintIdx, value);
360      // else if (curF < bestError) {
361      //   bestError = curF;
362      //   bestSolution = (double[])curX.Clone();
363      // }
364    }
365
366    double CalculateConstraint(uint dim, double[] curX, double[] grad, IntPtr data) {
367      UpdateThetaValues(curX);
368      var intervalEvaluator = new IntervalEvaluator();
369      var constraintData = Marshal.PtrToStructure<ConstraintData>(data);
370
371      if (grad != null) Array.Clear(grad, 0, grad.Length);
372
373      var interval = intervalEvaluator.Evaluate(constraintData.Tree, dataIntervals, constraintData.ParameterNodes,
374        out double[] lowerGradient, out double[] upperGradient);
375
376      // we transformed this to a constraint c(x) <= 0, so only the upper bound is relevant for us
377      if (grad != null) for (int j = 0; j < grad.Length; j++) { grad[j] = upperGradient[j]; }
378
379      #region check gradient
380      if (grad != null && CheckGradient)
381        for (int i = 0; i < dim; i++) {
382          // make two additional evaluations
383          var xForNumDiff = (double[])curX.Clone();
384          double delta = Math.Abs(xForNumDiff[i] * 1e-5);
385          xForNumDiff[i] += delta;
386          UpdateThetaValues(xForNumDiff);
387          var evalHigh = intervalEvaluator.Evaluate(constraintData.Tree, dataIntervals, constraintData.ParameterNodes,
388            out double[] unusedLowerGradientHigh, out double[] unusedUpperGradientHigh);
389
390          xForNumDiff[i] = curX[i] - delta;
391          UpdateThetaValues(xForNumDiff);
392          var evalLow = intervalEvaluator.Evaluate(constraintData.Tree, dataIntervals, constraintData.ParameterNodes,
393            out double[] unusedLowerGradientLow, out double[] unusedUpperGradientLow);
394
395          var numericDiff = (evalHigh.UpperBound - evalLow.UpperBound) / (2 * delta);
396          var autoDiff = grad[i];
397
398          if ((Math.Abs(autoDiff) < 1e-10 && Math.Abs(numericDiff) > 1e-2)
399            || (Math.Abs(autoDiff) >= 1e-10 && Math.Abs((numericDiff - autoDiff) / numericDiff) > 1e-2))
400            throw new InvalidProgramException();
401        }
402      #endregion
403
404
405      UpdateConstraintViolations(constraintData.Idx, interval.UpperBound);
406      if (double.IsNaN(interval.UpperBound)) {
407        Array.Clear(grad, 0, grad.Length);
408        return double.MaxValue;
409      } else return interval.UpperBound;
410    }
411
412
413    void UpdateThetaValues(double[] theta) {
414      for (int i = 0; i < theta.Length; ++i) {
415        foreach (var constNode in allThetaNodes[i]) constNode.Value = theta[i];
416      }
417    }
418
419    internal void RequestStop() {
420      NLOpt.nlopt_set_force_stop(nlopt, 1); // hopefully NLOpt is thread safe  , val must be <> 0 otherwise no effect
421    }
422
423    private void RaiseFunctionEvaluated() {
424      FunctionEvaluated?.Invoke();
425    }
426
427    private void RaiseConstraintEvaluated(int idx, double value) {
428      ConstraintEvaluated?.Invoke(idx, value);
429    }
430
431
432    #region helper
433
434    private static ISymbolicExpressionTree CopyAndScaleTree(ISymbolicExpressionTree tree, double scalingFactor, double offset) {
435      var m = (ISymbolicExpressionTree)tree.Clone();
436
437      var add = MakeNode<Addition>(MakeNode<Multiplication>(m.Root.GetSubtree(0).GetSubtree(0), CreateConstant(scalingFactor)), CreateConstant(offset));
438      m.Root.GetSubtree(0).RemoveSubtree(0);
439      m.Root.GetSubtree(0).AddSubtree(add);
440      return m;
441    }
442
443    private static void UpdateConstants(ISymbolicExpressionTreeNode[] nodes, double[] constants) {
444      if (nodes.Length != constants.Length) throw new InvalidOperationException();
445      for (int i = 0; i < nodes.Length; i++) {
446        if (nodes[i] is VariableTreeNode varNode) varNode.Weight = constants[i];
447        else if (nodes[i] is ConstantTreeNode constNode) constNode.Value = constants[i];
448      }
449    }
450
451    private NLOpt.nlopt_algorithm GetSolver(string solver) {
452      if (solver.Contains("MMA")) return NLOpt.nlopt_algorithm.NLOPT_LD_MMA;
453      if (solver.Contains("COBYLA")) return NLOpt.nlopt_algorithm.NLOPT_LN_COBYLA;
454      if (solver.Contains("CCSAQ")) return NLOpt.nlopt_algorithm.NLOPT_LD_CCSAQ;
455      if (solver.Contains("ISRES")) return NLOpt.nlopt_algorithm.NLOPT_GN_ISRES;
456
457      if (solver.Contains("DIRECT_G")) return NLOpt.nlopt_algorithm.NLOPT_GN_DIRECT;
458      if (solver.Contains("NLOPT_GN_DIRECT_L")) return NLOpt.nlopt_algorithm.NLOPT_GN_DIRECT_L;
459      if (solver.Contains("NLOPT_GN_DIRECT_L_RAND")) return NLOpt.nlopt_algorithm.NLOPT_GN_DIRECT_L_RAND;
460      if (solver.Contains("NLOPT_GN_ORIG_DIRECT")) return NLOpt.nlopt_algorithm.NLOPT_GN_DIRECT;
461      if (solver.Contains("NLOPT_GN_ORIG_DIRECT_L")) return NLOpt.nlopt_algorithm.NLOPT_GN_ORIG_DIRECT_L;
462      if (solver.Contains("NLOPT_GD_STOGO")) return NLOpt.nlopt_algorithm.NLOPT_GD_STOGO;
463      if (solver.Contains("NLOPT_GD_STOGO_RAND")) return NLOpt.nlopt_algorithm.NLOPT_GD_STOGO_RAND;
464      if (solver.Contains("NLOPT_LD_LBFGS_NOCEDAL")) return NLOpt.nlopt_algorithm.NLOPT_LD_LBFGS_NOCEDAL;
465      if (solver.Contains("NLOPT_LD_LBFGS")) return NLOpt.nlopt_algorithm.NLOPT_LD_LBFGS;
466      if (solver.Contains("NLOPT_LN_PRAXIS")) return NLOpt.nlopt_algorithm.NLOPT_LN_PRAXIS;
467      if (solver.Contains("NLOPT_LD_VAR1")) return NLOpt.nlopt_algorithm.NLOPT_LD_VAR1;
468      if (solver.Contains("NLOPT_LD_VAR2")) return NLOpt.nlopt_algorithm.NLOPT_LD_VAR2;
469      if (solver.Contains("NLOPT_LD_TNEWTON")) return NLOpt.nlopt_algorithm.NLOPT_LD_TNEWTON;
470      if (solver.Contains("NLOPT_LD_TNEWTON_RESTART")) return NLOpt.nlopt_algorithm.NLOPT_LD_TNEWTON_RESTART;
471      if (solver.Contains("NLOPT_LD_TNEWTON_PRECOND")) return NLOpt.nlopt_algorithm.NLOPT_LD_TNEWTON_PRECOND;
472      if (solver.Contains("NLOPT_LD_TNEWTON_PRECOND_RESTART")) return NLOpt.nlopt_algorithm.NLOPT_LD_TNEWTON_PRECOND_RESTART;
473      if (solver.Contains("NLOPT_GN_CRS2_LM")) return NLOpt.nlopt_algorithm.NLOPT_GN_CRS2_LM;
474      if (solver.Contains("NLOPT_GN_MLSL")) return NLOpt.nlopt_algorithm.NLOPT_GN_MLSL;
475      if (solver.Contains("NLOPT_GD_MLSL")) return NLOpt.nlopt_algorithm.NLOPT_GD_MLSL;
476      if (solver.Contains("NLOPT_GN_MLSL_LDS")) return NLOpt.nlopt_algorithm.NLOPT_GN_MLSL_LDS;
477      if (solver.Contains("NLOPT_GD_MLSL_LDS")) return NLOpt.nlopt_algorithm.NLOPT_GD_MLSL_LDS;
478      if (solver.Contains("NLOPT_LN_NEWUOA")) return NLOpt.nlopt_algorithm.NLOPT_LN_NEWUOA;
479      if (solver.Contains("NLOPT_LN_NEWUOA_BOUND")) return NLOpt.nlopt_algorithm.NLOPT_LN_NEWUOA_BOUND;
480      if (solver.Contains("NLOPT_LN_NELDERMEAD")) return NLOpt.nlopt_algorithm.NLOPT_LN_NELDERMEAD;
481      if (solver.Contains("NLOPT_LN_SBPLX")) return NLOpt.nlopt_algorithm.NLOPT_LN_SBPLX;
482      if (solver.Contains("NLOPT_LN_AUGLAG")) return NLOpt.nlopt_algorithm.NLOPT_LN_AUGLAG;
483      if (solver.Contains("NLOPT_LD_AUGLAG")) return NLOpt.nlopt_algorithm.NLOPT_LD_AUGLAG;
484      if (solver.Contains("NLOPT_LN_BOBYQA")) return NLOpt.nlopt_algorithm.NLOPT_LN_BOBYQA;
485      if (solver.Contains("NLOPT_AUGLAG")) return NLOpt.nlopt_algorithm.NLOPT_AUGLAG;
486      if (solver.Contains("NLOPT_LD_SLSQP")) return NLOpt.nlopt_algorithm.NLOPT_LD_SLSQP;
487      if (solver.Contains("NLOPT_LD_CCSAQ))")) return NLOpt.nlopt_algorithm.NLOPT_LD_CCSAQ;
488      if (solver.Contains("NLOPT_GN_ESCH")) return NLOpt.nlopt_algorithm.NLOPT_GN_ESCH;
489      if (solver.Contains("NLOPT_GN_AGS")) return NLOpt.nlopt_algorithm.NLOPT_GN_AGS;
490
491      throw new ArgumentException($"Unknown solver {solver}");
492    }
493
494    private static ISymbolicExpressionTreeNode[] GetParameterNodes(ISymbolicExpressionTree tree, List<ConstantTreeNode>[] allNodes) {
495      // TODO better solution necessary
496      var treeConstNodes = tree.IterateNodesPostfix().OfType<ConstantTreeNode>().ToArray();
497      var paramNodes = new ISymbolicExpressionTreeNode[allNodes.Length];
498      for (int i = 0; i < paramNodes.Length; i++) {
499        paramNodes[i] = allNodes[i].SingleOrDefault(n => treeConstNodes.Contains(n));
500      }
501      return paramNodes;
502    }
503
504    private static ISymbolicExpressionTree ReplaceVarWithConst(ISymbolicExpressionTree tree, List<string> thetaNames, List<double> thetaValues, List<ConstantTreeNode>[] thetaNodes) {
505      var copy = (ISymbolicExpressionTree)tree.Clone();
506      var nodes = copy.IterateNodesPostfix().ToList();
507      for (int i = 0; i < nodes.Count; i++) {
508        var n = nodes[i] as VariableTreeNode;
509        if (n != null) {
510          var thetaIdx = thetaNames.IndexOf(n.VariableName);
511          if (thetaIdx >= 0) {
512            var parent = n.Parent;
513            if (thetaNodes[thetaIdx].Any()) {
514              // HACK: REUSE CONSTANT TREE NODE IN SEVERAL TREES
515              // we use this trick to allow autodiff over thetas when thetas occurr multiple times in the tree (e.g. in derived trees)
516              var constNode = thetaNodes[thetaIdx].First();
517              var childIdx = parent.IndexOfSubtree(n);
518              parent.RemoveSubtree(childIdx);
519              parent.InsertSubtree(childIdx, constNode);
520            } else {
521              var constNode = (ConstantTreeNode)CreateConstant(thetaValues[thetaIdx]);
522              var childIdx = parent.IndexOfSubtree(n);
523              parent.RemoveSubtree(childIdx);
524              parent.InsertSubtree(childIdx, constNode);
525              thetaNodes[thetaIdx].Add(constNode);
526            }
527          }
528        }
529      }
530      return copy;
531    }
532
533    private static ISymbolicExpressionTree ReplaceConstWithVar(ISymbolicExpressionTree tree, out List<string> thetaNames, out List<double> thetaValues) {
534      thetaNames = new List<string>();
535      thetaValues = new List<double>();
536      var copy = (ISymbolicExpressionTree)tree.Clone();
537      var nodes = copy.IterateNodesPostfix().ToList();
538
539      int n = 1;
540      for (int i = 0; i < nodes.Count; ++i) {
541        var node = nodes[i];
542        if (node is ConstantTreeNode constantTreeNode) {
543          var thetaVar = (VariableTreeNode)new Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();
544          thetaVar.Weight = 1;
545          thetaVar.VariableName = $"θ{n++}";
546
547          thetaNames.Add(thetaVar.VariableName);
548          thetaValues.Add(constantTreeNode.Value);
549
550          var parent = constantTreeNode.Parent;
551          if (parent != null) {
552            var index = constantTreeNode.Parent.IndexOfSubtree(constantTreeNode);
553            parent.RemoveSubtree(index);
554            parent.InsertSubtree(index, thetaVar);
555          }
556        }
557        if (node is VariableTreeNode varTreeNode) {
558          var thetaVar = (VariableTreeNode)new Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();
559          thetaVar.Weight = 1;
560          thetaVar.VariableName = $"θ{n++}";
561
562          thetaNames.Add(thetaVar.VariableName);
563          thetaValues.Add(varTreeNode.Weight);
564
565          var parent = varTreeNode.Parent;
566          if (parent != null) {
567            var index = varTreeNode.Parent.IndexOfSubtree(varTreeNode);
568            parent.RemoveSubtree(index);
569            var prodNode = MakeNode<Multiplication>();
570            varTreeNode.Weight = 1.0;
571            prodNode.AddSubtree(varTreeNode);
572            prodNode.AddSubtree(thetaVar);
573            parent.InsertSubtree(index, prodNode);
574          }
575        }
576      }
577      return copy;
578    }
579
580    private static ISymbolicExpressionTreeNode CreateConstant(double value) {
581      var constantNode = (ConstantTreeNode)new Constant().CreateTreeNode();
582      constantNode.Value = value;
583      return constantNode;
584    }
585
586    private static ISymbolicExpressionTree Subtract(ISymbolicExpressionTree t, ISymbolicExpressionTreeNode b) {
587      var sub = MakeNode<Subtraction>(t.Root.GetSubtree(0).GetSubtree(0), b);
588      t.Root.GetSubtree(0).RemoveSubtree(0);
589      t.Root.GetSubtree(0).InsertSubtree(0, sub);
590      return t;
591    }
592    private static ISymbolicExpressionTree Subtract(ISymbolicExpressionTreeNode b, ISymbolicExpressionTree t) {
593      var sub = MakeNode<Subtraction>(b, t.Root.GetSubtree(0).GetSubtree(0));
594      t.Root.GetSubtree(0).RemoveSubtree(0);
595      t.Root.GetSubtree(0).InsertSubtree(0, sub);
596      return t;
597    }
598
599    private static ISymbolicExpressionTreeNode MakeNode<T>(params ISymbolicExpressionTreeNode[] fs) where T : ISymbol, new() {
600      var node = new T().CreateTreeNode();
601      foreach (var f in fs) node.AddSubtree(f);
602      return node;
603    }
604
605    public void Dispose() {
606      if (nlopt != IntPtr.Zero) {
607        NLOpt.nlopt_destroy(nlopt);
608        nlopt = IntPtr.Zero;
609      }
610      if (constraintDataPtr != null) {
611        for (int i = 0; i < constraintDataPtr.Length; i++)
612          if (constraintDataPtr[i] != IntPtr.Zero) {
613            Marshal.FreeHGlobal(constraintDataPtr[i]);
614            constraintDataPtr[i] = IntPtr.Zero;
615          }
616      }
617    }
618    #endregion
619  }
620}
Note: See TracBrowser for help on using the repository browser.