source: branches/2994-AutoDiffForIntervals/HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Extensions/ConstrainedNLSInternal.cs @ 17325

Last change on this file since 17325 was 17325, checked in by gkronber, 12 months ago

#2994: worked on ConstrainedNLS

File size: 31.0 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Runtime.InteropServices;
6using HeuristicLab.Common;
7using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
8using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression.Extensions;
9
10namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
11  internal class ConstrainedNLSInternal : IDisposable {
12    private readonly int maxIterations;
13    public int MaxIterations => maxIterations;
14
15    private readonly string solver;
16    public string Solver => solver;
17
18    private readonly ISymbolicExpressionTree expr;
19    public ISymbolicExpressionTree Expr => expr;
20
21    private readonly IRegressionProblemData problemData;
22
23    public IRegressionProblemData ProblemData => problemData;
24
25
26    public event Action FunctionEvaluated;
27    public event Action<int, double> ConstraintEvaluated;
28
29    private double bestError = double.MaxValue;
30    public double BestError => bestError;
31
32    private double curError = double.MaxValue;
33    public double CurError => curError;
34
35    private double[] bestSolution;
36    public double[] BestSolution => bestSolution;
37
38    private ISymbolicExpressionTree bestTree;
39    public ISymbolicExpressionTree BestTree => bestTree;
40
41    private double[] bestConstraintValues;
42    public double[] BestConstraintValues => bestConstraintValues;
43
44    private bool disposed = false;
45
46
47    // for debugging (must be in the same order as processed below)
48    public IEnumerable<string> ConstraintDescriptions {
49      get {
50        foreach (var elem in problemData.IntervalConstraints.Constraints) {
51          if (!elem.Enabled) continue;
52          if (elem.Interval.UpperBound < double.PositiveInfinity) {
53            yield return elem.Expression + " < " + elem.Interval.UpperBound;
54          }
55          if (elem.Interval.LowerBound > double.NegativeInfinity) {
56            yield return "-" + elem.Expression + " < " + (-1) * elem.Interval.LowerBound;
57          }
58        }
59      }
60    }
61
62    public bool CheckGradient { get; internal set; }
63
64    // begin internal state
65    private IntPtr nlopt;
66    private SymbolicDataAnalysisExpressionTreeLinearInterpreter interpreter;
67    private readonly NLOpt.nlopt_func calculateObjectiveDelegate; // must hold the delegate to prevent GC
68    // private readonly NLOpt.nlopt_precond preconditionDelegate;
69    private readonly IntPtr[] constraintDataPtr; // must hold the objects to prevent GC
70    private readonly NLOpt.nlopt_func[] calculateConstraintDelegates; // must hold the delegates to prevent GC
71    private readonly List<double> thetaValues;
72    private readonly IDictionary<string, Interval> dataIntervals;
73    private readonly int[] trainingRows;
74    private readonly double[] target;
75    private readonly ISymbolicExpressionTree preparedTree;
76    private readonly ISymbolicExpressionTreeNode[] preparedTreeParameterNodes;
77    private readonly List<ConstantTreeNode>[] allThetaNodes;
78    public List<ISymbolicExpressionTree> constraintTrees;    // TODO make local in ctor (public for debugging)
79
80    private readonly double[] fi_eval;
81    private readonly double[,] jac_eval;
82    private readonly ISymbolicExpressionTree scaledTree;
83    private readonly VectorAutoDiffEvaluator autoDiffEval;
84    private readonly VectorEvaluator eval;
85    private readonly bool invalidProblem = false;
86
87    // end internal state
88
89
90    // for data exchange to/from optimizer in native code
91    [StructLayout(LayoutKind.Sequential)]
92    private struct ConstraintData {
93      public int Idx;
94      public ISymbolicExpressionTree Tree;
95      public ISymbolicExpressionTreeNode[] ParameterNodes;
96    }
97
98    internal ConstrainedNLSInternal(string solver, ISymbolicExpressionTree expr, int maxIterations, IRegressionProblemData problemData, double ftol_rel = 0, double ftol_abs = 0, double maxTime = 0) {
99      this.solver = solver;
100      this.expr = expr;
101      this.maxIterations = maxIterations;
102      this.problemData = problemData;
103      this.interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
104      this.autoDiffEval = new VectorAutoDiffEvaluator();
105      this.eval = new VectorEvaluator();
106
107      CheckGradient = false;
108
109      var intervalConstraints = problemData.IntervalConstraints;
110      dataIntervals = problemData.VariableRanges.GetIntervals();
111      trainingRows = problemData.TrainingIndices.ToArray();
112      // buffers
113      target = problemData.TargetVariableTrainingValues.ToArray();
114      var targetStDev = target.StandardDeviationPop();
115      var targetVariance = targetStDev * targetStDev;
116      var targetMean = target.Average();
117      var pred = interpreter.GetSymbolicExpressionTreeValues(expr, problemData.Dataset, trainingRows).ToArray();
118
119      bestError = targetVariance;
120
121      if (pred.Any(pi => double.IsInfinity(pi) || double.IsNaN(pi))) {
122        invalidProblem = true;
123      }
124
125      // all trees are linearly scaled (to improve GP performance)
126      #region linear scaling
127      var predStDev = pred.StandardDeviationPop();
128      if (predStDev == 0) {
129        invalidProblem = true;
130      }
131      var predMean = pred.Average();
132
133      var scalingFactor = targetStDev / predStDev;
134      var offset = targetMean - predMean * scalingFactor;
135
136      scaledTree = CopyAndScaleTree(expr, scalingFactor, offset);
137      #endregion
138
139      // convert constants to variables named theta...
140      var treeForDerivation = ReplaceAndExtractParameters(scaledTree, out List<string> thetaNames, out thetaValues); // copies the tree
141
142      // create trees for relevant derivatives
143      Dictionary<string, ISymbolicExpressionTree> derivatives = new Dictionary<string, ISymbolicExpressionTree>();
144      allThetaNodes = thetaNames.Select(_ => new List<ConstantTreeNode>()).ToArray();
145      constraintTrees = new List<ISymbolicExpressionTree>();
146      foreach (var constraint in intervalConstraints.Constraints) {
147        if (!constraint.Enabled) continue;
148        if (constraint.IsDerivation) {
149          if (!problemData.AllowedInputVariables.Contains(constraint.Variable))
150            throw new ArgumentException($"Invalid constraint: the variable {constraint.Variable} does not exist in the dataset.");
151          var df = DerivativeCalculator.Derive(treeForDerivation, constraint.Variable);
152
153          // NLOpt requires constraint expressions of the form c(x) <= 0
154          // -> we make two expressions, one for the lower bound and one for the upper bound
155
156          if (constraint.Interval.UpperBound < double.PositiveInfinity) {
157            var df_smaller_upper = Subtract((ISymbolicExpressionTree)df.Clone(), CreateConstant(constraint.Interval.UpperBound));
158            // convert variables named theta back to constants
159            var df_prepared = ReplaceVarWithConst(df_smaller_upper, thetaNames, thetaValues, allThetaNodes);
160            constraintTrees.Add(df_prepared);
161          }
162          if (constraint.Interval.LowerBound > double.NegativeInfinity) {
163            var df_larger_lower = Subtract(CreateConstant(constraint.Interval.LowerBound), (ISymbolicExpressionTree)df.Clone());
164            // convert variables named theta back to constants
165            var df_prepared = ReplaceVarWithConst(df_larger_lower, thetaNames, thetaValues, allThetaNodes);
166            constraintTrees.Add(df_prepared);
167          }
168        } else {
169          if (constraint.Interval.UpperBound < double.PositiveInfinity) {
170            var f_smaller_upper = Subtract((ISymbolicExpressionTree)treeForDerivation.Clone(), CreateConstant(constraint.Interval.UpperBound));
171            // convert variables named theta back to constants
172            var df_prepared = ReplaceVarWithConst(f_smaller_upper, thetaNames, thetaValues, allThetaNodes);
173            constraintTrees.Add(df_prepared);
174          }
175          if (constraint.Interval.LowerBound > double.NegativeInfinity) {
176            var f_larger_lower = Subtract(CreateConstant(constraint.Interval.LowerBound), (ISymbolicExpressionTree)treeForDerivation.Clone());
177            // convert variables named theta back to constants
178            var df_prepared = ReplaceVarWithConst(f_larger_lower, thetaNames, thetaValues, allThetaNodes);
179            constraintTrees.Add(df_prepared);
180          }
181        }
182      }
183
184      preparedTree = ReplaceVarWithConst(treeForDerivation, thetaNames, thetaValues, allThetaNodes);
185      preparedTreeParameterNodes = GetParameterNodes(preparedTree, allThetaNodes);
186
187      var dim = thetaValues.Count;
188      fi_eval = new double[target.Length]; // init buffer;
189      jac_eval = new double[target.Length, dim]; // init buffer
190
191
192      var minVal = Math.Min(-1000.0, thetaValues.Min());
193      var maxVal = Math.Max(1000.0, thetaValues.Max());
194      var lb = Enumerable.Repeat(minVal, thetaValues.Count).ToArray();
195      var up = Enumerable.Repeat(maxVal, thetaValues.Count).ToArray();
196      nlopt = NLOpt.nlopt_create(GetSolver(solver), (uint)dim);
197
198      NLOpt.nlopt_set_lower_bounds(nlopt, lb);
199      NLOpt.nlopt_set_upper_bounds(nlopt, up);
200      calculateObjectiveDelegate = new NLOpt.nlopt_func(CalculateObjective); // keep a reference to the delegate (see below)
201      NLOpt.nlopt_set_min_objective(nlopt, calculateObjectiveDelegate, IntPtr.Zero); // --> without preconditioning
202
203      //preconditionDelegate = new NLOpt.nlopt_precond(PreconditionObjective);
204      //NLOpt.nlopt_set_precond_min_objective(nlopt, calculateObjectiveDelegate, preconditionDelegate, IntPtr.Zero);
205
206
207      constraintDataPtr = new IntPtr[constraintTrees.Count];
208      calculateConstraintDelegates = new NLOpt.nlopt_func[constraintTrees.Count]; // make sure we keep a reference to the delegates (otherwise GC will free delegate objects see https://stackoverflow.com/questions/7302045/callback-delegates-being-collected#7302258)
209      for (int i = 0; i < constraintTrees.Count; i++) {
210        var constraintData = new ConstraintData() { Idx = i, Tree = constraintTrees[i], ParameterNodes = GetParameterNodes(constraintTrees[i], allThetaNodes) };
211        constraintDataPtr[i] = Marshal.AllocHGlobal(Marshal.SizeOf<ConstraintData>());
212        Marshal.StructureToPtr(constraintData, constraintDataPtr[i], fDeleteOld: false);
213        calculateConstraintDelegates[i] = new NLOpt.nlopt_func(CalculateConstraint);
214        NLOpt.nlopt_add_inequality_constraint(nlopt, calculateConstraintDelegates[i], constraintDataPtr[i], 1e-8);
215        // NLOpt.nlopt_add_precond_inequality_constraint(nlopt, calculateConstraintDelegates[i], preconditionDelegate, constraintDataPtr[i], 1e-8);
216      }
217
218      NLOpt.nlopt_set_ftol_rel(nlopt, ftol_rel);
219      NLOpt.nlopt_set_ftol_abs(nlopt, ftol_abs);
220      NLOpt.nlopt_set_maxtime(nlopt, maxTime);
221      NLOpt.nlopt_set_maxeval(nlopt, maxIterations);
222    }
223
224
225    ~ConstrainedNLSInternal() {
226      Dispose(false);
227    }
228
229
230    public enum OptimizationMode { ReadOnly, UpdateParameters, UpdateParametersAndKeepLinearScaling };
231
232    internal void Optimize(OptimizationMode mode) {
233      if (invalidProblem) return;
234      var x = thetaValues.ToArray();  /* initial guess */
235      double minf = double.MaxValue; /* minimum objective value upon return */
236      var res = NLOpt.nlopt_optimize(nlopt, x, ref minf);
237
238      if (res < 0 && res != NLOpt.nlopt_result.NLOPT_FORCED_STOP) {
239        // throw new InvalidOperationException($"NLOpt failed {res} {NLOpt.nlopt_get_errmsg(nlopt)}");
240        return;
241      } else /*if ( minf <= bestError ) */{
242        bestSolution = x;
243        bestError = minf;
244
245        // calculate constraints of final solution
246        double[] _ = new double[x.Length];
247        bestConstraintValues = new double[calculateConstraintDelegates.Length];
248        for (int i = 0; i < calculateConstraintDelegates.Length; i++) {
249          bestConstraintValues[i] = calculateConstraintDelegates[i].Invoke((uint)x.Length, x, _, constraintDataPtr[i]);
250        }
251
252        // update parameters in tree
253        UpdateParametersInTree(scaledTree, x);
254
255        if (mode == OptimizationMode.UpdateParameters) {
256          // update original expression (when called from evaluator we want to write back optimized parameters)
257          expr.Root.GetSubtree(0).RemoveSubtree(0); // delete old tree
258          expr.Root.GetSubtree(0).InsertSubtree(0,
259            scaledTree.Root.GetSubtree(0).GetSubtree(0).GetSubtree(0).GetSubtree(0) // insert the optimized sub-tree (without scaling nodes)
260            );
261        } else if (mode == OptimizationMode.UpdateParametersAndKeepLinearScaling) {
262          expr.Root.GetSubtree(0).RemoveSubtree(0); // delete old tree
263          expr.Root.GetSubtree(0).InsertSubtree(0, scaledTree.Root.GetSubtree(0).GetSubtree(0)); // insert the optimized sub-tree (including scaling nodes)
264        }
265      }
266      bestTree = expr;
267    }
268
269
270    double CalculateObjective(uint dim, double[] curX, double[] grad, IntPtr data) {
271      UpdateThetaValues(curX);
272      var sse = 0.0;
273
274      if (grad != null) {
275        autoDiffEval.Evaluate(preparedTree, problemData.Dataset, trainingRows,
276          preparedTreeParameterNodes, fi_eval, jac_eval);
277
278        // calc sum of squared errors and gradient
279        for (int j = 0; j < grad.Length; j++) grad[j] = 0;
280        for (int i = 0; i < target.Length; i++) {
281          var r = target[i] - fi_eval[i];
282          sse += r * r;
283          for (int j = 0; j < grad.Length; j++) {
284            grad[j] -= 2 * r * jac_eval[i, j];
285          }
286        }
287        // average
288        for (int j = 0; j < grad.Length; j++) { grad[j] /= target.Length; }
289
290        #region check gradient
291        if (grad != null && CheckGradient) {
292          for (int i = 0; i < dim; i++) {
293            // make two additional evaluations
294            var xForNumDiff = (double[])curX.Clone();
295            double delta = Math.Abs(xForNumDiff[i] * 1e-5);
296            xForNumDiff[i] += delta;
297            UpdateThetaValues(xForNumDiff);
298            var evalHigh = eval.Evaluate(preparedTree, problemData.Dataset, trainingRows);
299            var mseHigh = MSE(target, evalHigh);
300            xForNumDiff[i] = curX[i] - delta;
301            UpdateThetaValues(xForNumDiff);
302            var evalLow = eval.Evaluate(preparedTree, problemData.Dataset, trainingRows);
303            var mseLow = MSE(target, evalLow);
304
305            var numericDiff = (mseHigh - mseLow) / (2 * delta);
306            var autoDiff = grad[i];
307            if ((Math.Abs(autoDiff) < 1e-10 && Math.Abs(numericDiff) > 1e-2)
308              || (Math.Abs(autoDiff) >= 1e-10 && Math.Abs((numericDiff - autoDiff) / numericDiff) > 1e-2))
309              throw new InvalidProgramException();
310          }
311        }
312        #endregion
313      } else {
314        var eval = new VectorEvaluator();
315        var prediction = eval.Evaluate(preparedTree, problemData.Dataset, trainingRows);
316
317        // calc sum of squared errors
318        sse = 0.0;
319        for (int i = 0; i < target.Length; i++) {
320          var r = target[i] - prediction[i];
321          sse += r * r;
322        }
323      }
324
325      UpdateBestSolution(sse / target.Length, curX);
326      RaiseFunctionEvaluated();
327
328      if (double.IsNaN(sse)) {
329        if (grad != null) Array.Clear(grad, 0, grad.Length);
330        return double.MaxValue;
331      }
332      return sse / target.Length;
333    }
334
335    // TODO
336    // private void PreconditionObjective(uint n, double[] x, double[] v, double[] vpre, IntPtr data) {
337    //   UpdateThetaValues(x); // calc H(x)
338    //   
339    //   autoDiffEval.Evaluate(preparedTree, problemData.Dataset, trainingRows,
340    //     preparedTreeParameterNodes, fi_eval, jac_eval);
341    //   var k = jac_eval.GetLength(0);
342    //   var h = new double[n, n];
343    //   
344    //   // calc residuals and scale jac_eval
345    //   var f = 2.0 / (k*k);
346    //
347    //   // approximate hessian H(x) = J(x)^T * J(x)
348    //   alglib.rmatrixgemm((int)n, (int)n, k,
349    //     f, jac_eval, 0, 0, 1,  // transposed
350    //     jac_eval, 0, 0, 0,
351    //     0.0, ref h, 0, 0,
352    //     null
353    //     );
354    //   
355    //
356    //   // scale v
357    //   alglib.rmatrixmv((int)n, (int)n, h, 0, 0, 0, v, 0, ref vpre, 0, alglib.serial);
358    //
359    //
360    //   alglib.spdmatrixcholesky(ref h, (int)n, true);
361    //
362    //   var det = alglib.matdet.spdmatrixcholeskydet(h, (int)n, alglib.serial);
363    // }
364
365
366    private double MSE(double[] a, double[] b) {
367      Trace.Assert(a.Length == b.Length);
368      var sse = 0.0;
369      for (int i = 0; i < a.Length; i++) sse += (a[i] - b[i]) * (a[i] - b[i]);
370      return sse / a.Length;
371    }
372
373    private void UpdateBestSolution(double curF, double[] curX) {
374      if (double.IsNaN(curF) || double.IsInfinity(curF)) return;
375      else if (curF < bestError) {
376        bestError = curF;
377        bestSolution = (double[])curX.Clone();
378      }
379      curError = curF;
380    }
381
382    private void UpdateConstraintViolations(int constraintIdx, double value) {
383      if (double.IsNaN(value) || double.IsInfinity(value)) return;
384      RaiseConstraintEvaluated(constraintIdx, value);
385      // else if (curF < bestError) {
386      //   bestError = curF;
387      //   bestSolution = (double[])curX.Clone();
388      // }
389    }
390
391    double CalculateConstraint(uint dim, double[] curX, double[] grad, IntPtr data) {
392      UpdateThetaValues(curX);
393      var intervalEvaluator = new IntervalEvaluator();
394      var refIntervalEvaluator = new IntervalInterpreter();
395
396      var constraintData = Marshal.PtrToStructure<ConstraintData>(data);
397
398      if (grad != null) Array.Clear(grad, 0, grad.Length);
399
400      var interval = intervalEvaluator.Evaluate(constraintData.Tree, dataIntervals, constraintData.ParameterNodes,
401        out double[] lowerGradient, out double[] upperGradient);
402
403      var refInterval = refIntervalEvaluator.GetSymbolicExpressionTreeInterval(constraintData.Tree, dataIntervals);
404      if (Math.Abs(interval.LowerBound - refInterval.LowerBound) > Math.Abs(interval.LowerBound) * 1e-4) throw new InvalidProgramException($"Intervals don't match. {interval.LowerBound} <> {refInterval.LowerBound}");
405      if (Math.Abs(interval.UpperBound - refInterval.UpperBound) > Math.Abs(interval.UpperBound) * 1e-4) throw new InvalidProgramException($"Intervals don't match. {interval.UpperBound} <> {refInterval.UpperBound}");
406
407      // we transformed this to a constraint c(x) <= 0, so only the upper bound is relevant for us
408      if (grad != null) for (int j = 0; j < grad.Length; j++) { grad[j] = upperGradient[j]; }
409
410      #region check gradient
411      if (grad != null && CheckGradient)
412        for (int i = 0; i < dim; i++) {
413          // make two additional evaluations
414          var xForNumDiff = (double[])curX.Clone();
415          double delta = Math.Abs(xForNumDiff[i] * 1e-5);
416          xForNumDiff[i] += delta;
417          UpdateThetaValues(xForNumDiff);
418          var evalHigh = intervalEvaluator.Evaluate(constraintData.Tree, dataIntervals, constraintData.ParameterNodes,
419            out double[] unusedLowerGradientHigh, out double[] unusedUpperGradientHigh);
420
421          xForNumDiff[i] = curX[i] - delta;
422          UpdateThetaValues(xForNumDiff);
423          var evalLow = intervalEvaluator.Evaluate(constraintData.Tree, dataIntervals, constraintData.ParameterNodes,
424            out double[] unusedLowerGradientLow, out double[] unusedUpperGradientLow);
425
426          var numericDiff = (evalHigh.UpperBound - evalLow.UpperBound) / (2 * delta);
427          var autoDiff = grad[i];
428
429          if ((Math.Abs(autoDiff) < 1e-10 && Math.Abs(numericDiff) > 1e-2)
430            || (Math.Abs(autoDiff) >= 1e-10 && Math.Abs((numericDiff - autoDiff) / numericDiff) > 1e-2))
431            throw new InvalidProgramException();
432        }
433      #endregion
434
435
436      UpdateConstraintViolations(constraintData.Idx, interval.UpperBound);
437      if (double.IsNaN(interval.UpperBound)) {
438        if (grad != null) Array.Clear(grad, 0, grad.Length);
439        return double.MaxValue;
440      } else return interval.UpperBound;
441    }
442
443
444    void UpdateThetaValues(double[] theta) {
445      for (int i = 0; i < theta.Length; ++i) {
446        foreach (var constNode in allThetaNodes[i]) constNode.Value = theta[i];
447      }
448    }
449
450    internal void RequestStop() {
451      NLOpt.nlopt_set_force_stop(nlopt, 1); // hopefully NLOpt is thread safe  , val must be <> 0 otherwise no effect
452    }
453
454    private void RaiseFunctionEvaluated() {
455      FunctionEvaluated?.Invoke();
456    }
457
458    private void RaiseConstraintEvaluated(int idx, double value) {
459      ConstraintEvaluated?.Invoke(idx, value);
460    }
461
462
463    #region helper
464
465    private static ISymbolicExpressionTree CopyAndScaleTree(ISymbolicExpressionTree tree, double scalingFactor, double offset) {
466      var m = (ISymbolicExpressionTree)tree.Clone();
467
468      var add = MakeNode<Addition>(MakeNode<Multiplication>(m.Root.GetSubtree(0).GetSubtree(0), CreateConstant(scalingFactor)), CreateConstant(offset));
469      m.Root.GetSubtree(0).RemoveSubtree(0);
470      m.Root.GetSubtree(0).AddSubtree(add);
471      return m;
472    }
473
474
475    private NLOpt.nlopt_algorithm GetSolver(string solver) {
476      if (solver.Contains("MMA")) return NLOpt.nlopt_algorithm.NLOPT_LD_MMA;
477      if (solver.Contains("COBYLA")) return NLOpt.nlopt_algorithm.NLOPT_LN_COBYLA;
478      if (solver.Contains("CCSAQ")) return NLOpt.nlopt_algorithm.NLOPT_LD_CCSAQ;
479      if (solver.Contains("ISRES")) return NLOpt.nlopt_algorithm.NLOPT_GN_ISRES;
480
481      if (solver.Contains("DIRECT_G")) return NLOpt.nlopt_algorithm.NLOPT_GN_DIRECT;
482      if (solver.Contains("NLOPT_GN_DIRECT_L")) return NLOpt.nlopt_algorithm.NLOPT_GN_DIRECT_L;
483      if (solver.Contains("NLOPT_GN_DIRECT_L_RAND")) return NLOpt.nlopt_algorithm.NLOPT_GN_DIRECT_L_RAND;
484      if (solver.Contains("NLOPT_GN_ORIG_DIRECT")) return NLOpt.nlopt_algorithm.NLOPT_GN_DIRECT;
485      if (solver.Contains("NLOPT_GN_ORIG_DIRECT_L")) return NLOpt.nlopt_algorithm.NLOPT_GN_ORIG_DIRECT_L;
486      if (solver.Contains("NLOPT_GD_STOGO")) return NLOpt.nlopt_algorithm.NLOPT_GD_STOGO;
487      if (solver.Contains("NLOPT_GD_STOGO_RAND")) return NLOpt.nlopt_algorithm.NLOPT_GD_STOGO_RAND;
488      if (solver.Contains("NLOPT_LD_LBFGS_NOCEDAL")) return NLOpt.nlopt_algorithm.NLOPT_LD_LBFGS_NOCEDAL;
489      if (solver.Contains("NLOPT_LD_LBFGS")) return NLOpt.nlopt_algorithm.NLOPT_LD_LBFGS;
490      if (solver.Contains("NLOPT_LN_PRAXIS")) return NLOpt.nlopt_algorithm.NLOPT_LN_PRAXIS;
491      if (solver.Contains("NLOPT_LD_VAR1")) return NLOpt.nlopt_algorithm.NLOPT_LD_VAR1;
492      if (solver.Contains("NLOPT_LD_VAR2")) return NLOpt.nlopt_algorithm.NLOPT_LD_VAR2;
493      if (solver.Contains("NLOPT_LD_TNEWTON")) return NLOpt.nlopt_algorithm.NLOPT_LD_TNEWTON;
494      if (solver.Contains("NLOPT_LD_TNEWTON_RESTART")) return NLOpt.nlopt_algorithm.NLOPT_LD_TNEWTON_RESTART;
495      if (solver.Contains("NLOPT_LD_TNEWTON_PRECOND")) return NLOpt.nlopt_algorithm.NLOPT_LD_TNEWTON_PRECOND;
496      if (solver.Contains("NLOPT_LD_TNEWTON_PRECOND_RESTART")) return NLOpt.nlopt_algorithm.NLOPT_LD_TNEWTON_PRECOND_RESTART;
497      if (solver.Contains("NLOPT_GN_CRS2_LM")) return NLOpt.nlopt_algorithm.NLOPT_GN_CRS2_LM;
498      if (solver.Contains("NLOPT_GN_MLSL")) return NLOpt.nlopt_algorithm.NLOPT_GN_MLSL;
499      if (solver.Contains("NLOPT_GD_MLSL")) return NLOpt.nlopt_algorithm.NLOPT_GD_MLSL;
500      if (solver.Contains("NLOPT_GN_MLSL_LDS")) return NLOpt.nlopt_algorithm.NLOPT_GN_MLSL_LDS;
501      if (solver.Contains("NLOPT_GD_MLSL_LDS")) return NLOpt.nlopt_algorithm.NLOPT_GD_MLSL_LDS;
502      if (solver.Contains("NLOPT_LN_NEWUOA")) return NLOpt.nlopt_algorithm.NLOPT_LN_NEWUOA;
503      if (solver.Contains("NLOPT_LN_NEWUOA_BOUND")) return NLOpt.nlopt_algorithm.NLOPT_LN_NEWUOA_BOUND;
504      if (solver.Contains("NLOPT_LN_NELDERMEAD")) return NLOpt.nlopt_algorithm.NLOPT_LN_NELDERMEAD;
505      if (solver.Contains("NLOPT_LN_SBPLX")) return NLOpt.nlopt_algorithm.NLOPT_LN_SBPLX;
506      if (solver.Contains("NLOPT_LN_AUGLAG")) return NLOpt.nlopt_algorithm.NLOPT_LN_AUGLAG;
507      if (solver.Contains("NLOPT_LD_AUGLAG")) return NLOpt.nlopt_algorithm.NLOPT_LD_AUGLAG;
508      if (solver.Contains("NLOPT_LN_BOBYQA")) return NLOpt.nlopt_algorithm.NLOPT_LN_BOBYQA;
509      if (solver.Contains("NLOPT_AUGLAG")) return NLOpt.nlopt_algorithm.NLOPT_AUGLAG;
510      if (solver.Contains("NLOPT_LD_SLSQP")) return NLOpt.nlopt_algorithm.NLOPT_LD_SLSQP;
511      if (solver.Contains("NLOPT_LD_CCSAQ))")) return NLOpt.nlopt_algorithm.NLOPT_LD_CCSAQ;
512      if (solver.Contains("NLOPT_GN_ESCH")) return NLOpt.nlopt_algorithm.NLOPT_GN_ESCH;
513      if (solver.Contains("NLOPT_GN_AGS")) return NLOpt.nlopt_algorithm.NLOPT_GN_AGS;
514
515      throw new ArgumentException($"Unknown solver {solver}");
516    }
517
518    // determines the nodes over which we can calculate the partial derivative
519    // this is different from the vector of all parameters because not every tree contains all parameters
520    private static ISymbolicExpressionTreeNode[] GetParameterNodes(ISymbolicExpressionTree tree, List<ConstantTreeNode>[] allNodes) {
521      // TODO better solution necessary
522      var treeConstNodes = tree.IterateNodesPostfix().OfType<ConstantTreeNode>().ToArray();
523      var paramNodes = new ISymbolicExpressionTreeNode[allNodes.Length];
524      for (int i = 0; i < paramNodes.Length; i++) {
525        paramNodes[i] = allNodes[i].SingleOrDefault(n => treeConstNodes.Contains(n));
526      }
527      return paramNodes;
528    }
529
530    private static ISymbolicExpressionTree ReplaceVarWithConst(ISymbolicExpressionTree tree, List<string> thetaNames, List<double> thetaValues, List<ConstantTreeNode>[] thetaNodes) {
531      var copy = (ISymbolicExpressionTree)tree.Clone();
532      var nodes = copy.IterateNodesPostfix().ToList();
533      for (int i = 0; i < nodes.Count; i++) {
534        var n = nodes[i] as VariableTreeNode;
535        if (n != null) {
536          var thetaIdx = thetaNames.IndexOf(n.VariableName);
537          if (thetaIdx >= 0) {
538            var parent = n.Parent;
539            if (thetaNodes[thetaIdx].Any()) {
540              // HACK: REUSE CONSTANT TREE NODE IN SEVERAL TREES
541              // we use this trick to allow autodiff over thetas when thetas occurr multiple times in the tree (e.g. in derived trees)
542              var constNode = thetaNodes[thetaIdx].First();
543              var childIdx = parent.IndexOfSubtree(n);
544              parent.RemoveSubtree(childIdx);
545              parent.InsertSubtree(childIdx, constNode);
546            } else {
547              var constNode = (ConstantTreeNode)CreateConstant(thetaValues[thetaIdx]);
548              var childIdx = parent.IndexOfSubtree(n);
549              parent.RemoveSubtree(childIdx);
550              parent.InsertSubtree(childIdx, constNode);
551              thetaNodes[thetaIdx].Add(constNode);
552            }
553          }
554        }
555      }
556      return copy;
557    }
558
559
560
561
562    private void UpdateParametersInTree(ISymbolicExpressionTree scaledTree, double[] x) {
563      var pIdx = 0;
564      // here we lose the two last parameters (for linear scaling)
565      foreach (var node in scaledTree.IterateNodesPostfix()) {
566        if (node is ConstantTreeNode constTreeNode) {
567          constTreeNode.Value = x[pIdx++];
568        } else if (node is VariableTreeNode varTreeNode) {
569          if (varTreeNode.Weight != 1.0) // see ReplaceAndExtractParameters
570            varTreeNode.Weight = x[pIdx++];
571        }
572      }
573      if (pIdx != x.Length) throw new InvalidProgramException();
574    }
575
576    private static ISymbolicExpressionTree ReplaceAndExtractParameters(ISymbolicExpressionTree tree, out List<string> thetaNames, out List<double> thetaValues) {
577      thetaNames = new List<string>();
578      thetaValues = new List<double>();
579      var copy = (ISymbolicExpressionTree)tree.Clone();
580      var nodes = copy.IterateNodesPostfix().ToList();
581
582      int n = 1;
583      for (int i = 0; i < nodes.Count; ++i) {
584        var node = nodes[i];
585        if (node is ConstantTreeNode constantTreeNode) {
586          var thetaVar = (VariableTreeNode)new Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();
587          thetaVar.Weight = 1;
588          thetaVar.VariableName = $"θ{n++}";
589
590          thetaNames.Add(thetaVar.VariableName);
591          thetaValues.Add(constantTreeNode.Value);
592
593          var parent = constantTreeNode.Parent;
594          if (parent != null) {
595            var index = constantTreeNode.Parent.IndexOfSubtree(constantTreeNode);
596            parent.RemoveSubtree(index);
597            parent.InsertSubtree(index, thetaVar);
598          }
599        }
600        if (node is VariableTreeNode varTreeNode) {
601          if (varTreeNode.Weight == 1) continue; // NOTE: here we assume that we do not tune variable weights when they are originally exactly 1 because we assume that the tree has been parsed and the tree explicitly has the structure w * var
602
603          var thetaVar = (VariableTreeNode)new Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();
604          thetaVar.Weight = 1;
605          thetaVar.VariableName = $"θ{n++}";
606
607          thetaNames.Add(thetaVar.VariableName);
608          thetaValues.Add(varTreeNode.Weight);
609
610          var parent = varTreeNode.Parent;
611          if (parent != null) {
612            var index = varTreeNode.Parent.IndexOfSubtree(varTreeNode);
613            parent.RemoveSubtree(index);
614            var prodNode = MakeNode<Multiplication>();
615            varTreeNode.Weight = 1.0;
616            prodNode.AddSubtree(varTreeNode);
617            prodNode.AddSubtree(thetaVar);
618            parent.InsertSubtree(index, prodNode);
619          }
620        }
621      }
622      return copy;
623    }
624
625    private static ISymbolicExpressionTreeNode CreateConstant(double value) {
626      var constantNode = (ConstantTreeNode)new Constant().CreateTreeNode();
627      constantNode.Value = value;
628      return constantNode;
629    }
630
631    private static ISymbolicExpressionTree Subtract(ISymbolicExpressionTree t, ISymbolicExpressionTreeNode b) {
632      var sub = MakeNode<Subtraction>(t.Root.GetSubtree(0).GetSubtree(0), b);
633      t.Root.GetSubtree(0).RemoveSubtree(0);
634      t.Root.GetSubtree(0).InsertSubtree(0, sub);
635      return t;
636    }
637    private static ISymbolicExpressionTree Subtract(ISymbolicExpressionTreeNode b, ISymbolicExpressionTree t) {
638      var sub = MakeNode<Subtraction>(b, t.Root.GetSubtree(0).GetSubtree(0));
639      t.Root.GetSubtree(0).RemoveSubtree(0);
640      t.Root.GetSubtree(0).InsertSubtree(0, sub);
641      return t;
642    }
643
644    private static ISymbolicExpressionTreeNode MakeNode<T>(params ISymbolicExpressionTreeNode[] fs) where T : ISymbol, new() {
645      var node = new T().CreateTreeNode();
646      foreach (var f in fs) node.AddSubtree(f);
647      return node;
648    }
649
650    public void Dispose() {
651      Dispose(true);
652      GC.SuppressFinalize(this);
653    }
654
655    protected virtual void Dispose(bool disposing) {
656      if (disposed)
657        return;
658
659      if (disposing) {
660        // Free any other managed objects here.
661      }
662
663      // Free any unmanaged objects here.
664      if (nlopt != IntPtr.Zero) {
665        NLOpt.nlopt_destroy(nlopt);
666        nlopt = IntPtr.Zero;
667      }
668      if (constraintDataPtr != null) {
669        for (int i = 0; i < constraintDataPtr.Length; i++)
670          if (constraintDataPtr[i] != IntPtr.Zero) {
671            Marshal.FreeHGlobal(constraintDataPtr[i]);
672            constraintDataPtr[i] = IntPtr.Zero;
673          }
674      }
675
676      disposed = true;
677    }
678    #endregion
679  }
680}
Note: See TracBrowser for help on using the repository browser.