Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2994-AutoDiffForIntervals/HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Extensions/ConstrainedNLSInternal.cs @ 17197

Last change on this file since 17197 was 17197, checked in by gkronber, 5 years ago

#2994: made an algorithm to experiment with NLOpt solvers

File size: 19.7 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Runtime.InteropServices;
5using HeuristicLab.Common;
6using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
7
8namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression.Extensions {
9  internal class ConstrainedNLSInternal {
10    private readonly int maxIterations;
11    public int MaxIterations => maxIterations;
12
13    private readonly string solver;
14    public string Solver => solver;
15
16    private readonly ISymbolicExpressionTree expr;
17    public ISymbolicExpressionTree Expr => expr;
18
19    private readonly IRegressionProblemData problemData;
20
21    public IRegressionProblemData ProblemData => problemData;
22
23
24    public event Action FunctionEvaluated;
25    public event Action<int, double> ConstraintEvaluated;
26
27    private double bestError = double.MaxValue;
28    public double BestError => bestError;
29
30    private double curError = double.MaxValue;
31    public double CurError => curError;
32
33    private double[] bestSolution;
34    public double[] BestSolution => bestSolution;
35
36    private ISymbolicExpressionTree bestTree;
37    public ISymbolicExpressionTree BestTree => bestTree;
38
39    // begin internal state
40    private IntPtr nlopt;
41    private SymbolicDataAnalysisExpressionTreeLinearInterpreter interpreter;
42    private readonly NLOpt.nlopt_func calculateObjectiveDelegate; // must hold the delegate to prevent GC
43    private readonly IntPtr[] constraintDataPtr; // must hold the objects to prevent GC
44    private readonly NLOpt.nlopt_func[] calculateConstraintDelegates; // must hold the delegates to prevent GC
45    private readonly List<double> thetaValues;
46    private readonly IDictionary<string, Interval> dataIntervals;
47    private readonly int[] trainingRows;
48    private readonly double[] target;
49    private readonly ISymbolicExpressionTree preparedTree;
50    private readonly ISymbolicExpressionTreeNode[] preparedTreeParameterNodes;
51    private readonly List<ConstantTreeNode>[] allThetaNodes;
52    private readonly double[] fi_eval;
53    private readonly double[,] jac_eval;
54    private readonly ISymbolicExpressionTree scaledTree;
55
56    // end internal state
57
58
59    // for data exchange to/from optimizer in native code
60    [StructLayout(LayoutKind.Sequential)]
61    private struct ConstraintData {
62      public int Idx;
63      public ISymbolicExpressionTree Tree;
64      public ISymbolicExpressionTreeNode[] ParameterNodes;
65    }
66
67    internal ConstrainedNLSInternal(string solver, ISymbolicExpressionTree expr, int maxIterations, IRegressionProblemData problemData, double ftol_rel = 0, double ftol_abs = 0, double maxTime = 0) {
68      this.solver = solver;
69      this.expr = expr;
70      this.maxIterations = maxIterations;
71      this.problemData = problemData;
72      this.interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
73
74
75      var intervalConstraints = problemData.IntervalConstraints;
76      dataIntervals = problemData.VariableRanges.GetIntervals();
77      trainingRows = problemData.TrainingIndices.ToArray();
78      // buffers
79      target = problemData.TargetVariableTrainingValues.ToArray();
80      var targetStDev = target.StandardDeviationPop();
81      var targetVariance = targetStDev * targetStDev;
82      var targetMean = target.Average();
83      var pred = interpreter.GetSymbolicExpressionTreeValues(expr, problemData.Dataset, trainingRows).ToArray();
84
85      if (pred.Any(pi => double.IsInfinity(pi) || double.IsNaN(pi))) throw new ArgumentException("The expression produces NaN or infinite values.");
86
87      #region linear scaling
88      var predStDev = pred.StandardDeviationPop();
89      if (predStDev == 0) throw new ArgumentException("The expression is constant.");
90      var predMean = pred.Average();
91
92      var scalingFactor = targetStDev / predStDev;
93      var offset = targetMean - predMean * scalingFactor;
94
95      scaledTree = CopyAndScaleTree(expr, scalingFactor, offset);
96      #endregion
97
98      // convert constants to variables named theta...
99      var treeForDerivation = ReplaceConstWithVar(scaledTree, out List<string> thetaNames, out thetaValues); // copies the tree
100
101      // create trees for relevant derivatives
102      Dictionary<string, ISymbolicExpressionTree> derivatives = new Dictionary<string, ISymbolicExpressionTree>();
103      allThetaNodes = thetaNames.Select(_ => new List<ConstantTreeNode>()).ToArray();
104      var constraintTrees = new List<ISymbolicExpressionTree>();
105      foreach (var constraint in intervalConstraints.Constraints) {
106        if (constraint.IsDerivation) {
107          if (!problemData.AllowedInputVariables.Contains(constraint.Variable))
108            throw new ArgumentException($"Invalid constraint: the variable {constraint.Variable} does not exist in the dataset.");
109          var df = DerivativeCalculator.Derive(treeForDerivation, constraint.Variable);
110
111          // NLOpt requires constraint expressions of the form c(x) <= 0
112          // -> we make two expressions, one for the lower bound and one for the upper bound
113
114          if (constraint.Interval.UpperBound < double.PositiveInfinity) {
115            var df_smaller_upper = Subtract((ISymbolicExpressionTree)df.Clone(), CreateConstant(constraint.Interval.UpperBound));
116            // convert variables named theta back to constants
117            var df_prepared = ReplaceVarWithConst(df_smaller_upper, thetaNames, thetaValues, allThetaNodes);
118            constraintTrees.Add(df_prepared);
119          }
120          if (constraint.Interval.LowerBound > double.NegativeInfinity) {
121            var df_larger_lower = Subtract(CreateConstant(constraint.Interval.LowerBound), (ISymbolicExpressionTree)df.Clone());
122            // convert variables named theta back to constants
123            var df_prepared = ReplaceVarWithConst(df_larger_lower, thetaNames, thetaValues, allThetaNodes);
124            constraintTrees.Add(df_prepared);
125          }
126        } else {
127          if (constraint.Interval.UpperBound < double.PositiveInfinity) {
128            var f_smaller_upper = Subtract((ISymbolicExpressionTree)treeForDerivation.Clone(), CreateConstant(constraint.Interval.UpperBound));
129            // convert variables named theta back to constants
130            var df_prepared = ReplaceVarWithConst(f_smaller_upper, thetaNames, thetaValues, allThetaNodes);
131            constraintTrees.Add(df_prepared);
132          }
133          if (constraint.Interval.LowerBound > double.NegativeInfinity) {
134            var f_larger_lower = Subtract(CreateConstant(constraint.Interval.LowerBound), (ISymbolicExpressionTree)treeForDerivation.Clone());
135            // convert variables named theta back to constants
136            var df_prepared = ReplaceVarWithConst(f_larger_lower, thetaNames, thetaValues, allThetaNodes);
137            constraintTrees.Add(df_prepared);
138          }
139        }
140      }
141
142      preparedTree = ReplaceVarWithConst(treeForDerivation, thetaNames, thetaValues, allThetaNodes);
143      preparedTreeParameterNodes = GetParameterNodes(preparedTree, allThetaNodes);
144
145      var dim = thetaValues.Count;
146      fi_eval = new double[target.Length]; // init buffer;
147      jac_eval = new double[target.Length, dim]; // init buffer
148
149
150      var minVal = Math.Min(-1000.0, thetaValues.Min());
151      var maxVal = Math.Max(1000.0, thetaValues.Max());
152      var lb = Enumerable.Repeat(minVal, thetaValues.Count).ToArray();
153      var up = Enumerable.Repeat(maxVal, thetaValues.Count).ToArray();
154      nlopt = NLOpt.nlopt_create(GetSolver(solver), (uint)dim);
155
156      NLOpt.nlopt_set_lower_bounds(nlopt, lb);
157      NLOpt.nlopt_set_upper_bounds(nlopt, up);
158      calculateObjectiveDelegate = new NLOpt.nlopt_func(CalculateObjective); // keep a reference to the delegate (see below)
159      NLOpt.nlopt_set_min_objective(nlopt, calculateObjectiveDelegate, IntPtr.Zero);
160
161
162      constraintDataPtr = new IntPtr[constraintTrees.Count];
163      calculateConstraintDelegates = new NLOpt.nlopt_func[constraintTrees.Count]; // make sure we keep a reference to the delegates (otherwise GC will free delegate objects see https://stackoverflow.com/questions/7302045/callback-delegates-being-collected#7302258)
164      for (int i = 0; i < constraintTrees.Count; i++) {
165        var constraintData = new ConstraintData() { Idx = i, Tree = constraintTrees[i], ParameterNodes = GetParameterNodes(constraintTrees[i], allThetaNodes) };
166        constraintDataPtr[i] = Marshal.AllocHGlobal(Marshal.SizeOf<ConstraintData>());
167        Marshal.StructureToPtr(constraintData, constraintDataPtr[i], fDeleteOld: false);
168        calculateConstraintDelegates[i] = new NLOpt.nlopt_func(CalculateConstraint);
169        NLOpt.nlopt_add_inequality_constraint(nlopt, calculateConstraintDelegates[i], constraintDataPtr[i], 1e-8);
170      }
171
172      NLOpt.nlopt_set_ftol_rel(nlopt, ftol_rel);
173      NLOpt.nlopt_set_ftol_abs(nlopt, ftol_abs);
174      NLOpt.nlopt_set_maxtime(nlopt, maxTime);
175      NLOpt.nlopt_set_maxeval(nlopt, maxIterations);
176    }
177
178    ~ConstrainedNLSInternal() {
179      if (nlopt != IntPtr.Zero)
180        NLOpt.nlopt_destroy(nlopt);
181      if (constraintDataPtr != null) {
182        for (int i = 0; i < constraintDataPtr.Length; i++)
183          if (constraintDataPtr[i] != IntPtr.Zero)
184            Marshal.FreeHGlobal(constraintDataPtr[i]);
185      }
186    }
187
188
189    internal void Optimize() {
190      var x = thetaValues.ToArray();  /* initial guess */
191      double minf = double.MaxValue; /* minimum objective value upon return */
192      var res = NLOpt.nlopt_optimize(nlopt, x, ref minf);
193      bestSolution = x;
194      bestError = minf;
195     
196      if (res < 0) {
197        throw new InvalidOperationException($"NLOpt failed {res} {NLOpt.nlopt_get_errmsg(nlopt)}");
198      } else {
199        // update parameters in tree
200        var pIdx = 0;
201        // here we lose the two last parameters (for linear scaling)
202        foreach (var node in scaledTree.IterateNodesPostfix()) {
203          if (node is ConstantTreeNode constTreeNode) {
204            constTreeNode.Value = x[pIdx++];
205          } else if (node is VariableTreeNode varTreeNode) {
206            varTreeNode.Weight = x[pIdx++];
207          }
208        }
209        if (pIdx != x.Length) throw new InvalidProgramException();
210      }
211      bestTree = scaledTree;
212    }
213
214    double CalculateObjective(uint dim, double[] curX, double[] grad, IntPtr data) {
215      UpdateThetaValues(curX);
216      var sse = 0.0;
217
218      if (grad != null) {
219        var autoDiffEval = new VectorAutoDiffEvaluator();
220        autoDiffEval.Evaluate(preparedTree, problemData.Dataset, trainingRows,
221          preparedTreeParameterNodes, fi_eval, jac_eval);
222
223        // calc sum of squared errors and gradient
224        for (int j = 0; j < grad.Length; j++) grad[j] = 0;
225        for (int i = 0; i < target.Length; i++) {
226          var r = target[i] - fi_eval[i];
227          sse += r * r;
228          for (int j = 0; j < grad.Length; j++) {
229            grad[j] -= 2 * r * jac_eval[i, j];
230          }
231        }
232        // average
233        for (int j = 0; j < grad.Length; j++) { grad[j] /= target.Length; }
234      } else {
235        var eval = new VectorEvaluator();
236        var prediction = eval.Evaluate(preparedTree, problemData.Dataset, trainingRows);
237
238        // calc sum of squared errors and gradient
239        sse = 0.0;
240        for (int i = 0; i < target.Length; i++) {
241          var r = target[i] - prediction[i];
242          sse += r * r;
243        }
244      }
245
246      UpdateBestSolution(sse / target.Length, curX);
247      RaiseFunctionEvaluated();
248
249      if (double.IsNaN(sse)) return double.MaxValue;
250      return sse / target.Length;
251    }
252
253    private void UpdateBestSolution(double curF, double[] curX) {
254      if (double.IsNaN(curF) || double.IsInfinity(curF)) return;
255      else if (curF < bestError) {
256        bestError = curF;
257        bestSolution = (double[])curX.Clone();
258      }
259    }
260
261    private void UpdateConstraintViolations(int constraintIdx, double value) {
262      if (double.IsNaN(value) || double.IsInfinity(value)) return;
263      RaiseConstraintEvaluated(constraintIdx, value);
264      // else if (curF < bestError) {
265      //   bestError = curF;
266      //   bestSolution = (double[])curX.Clone();
267      // }
268    }
269
270    double CalculateConstraint(uint dim, double[] curX, double[] grad, IntPtr data) {
271      UpdateThetaValues(curX);
272      var intervalEvaluator = new IntervalEvaluator();
273      var constraintData = Marshal.PtrToStructure<ConstraintData>(data);
274
275      if (grad != null) for (int j = 0; j < grad.Length; j++) grad[j] = 0; // clear grad
276
277      var interval = intervalEvaluator.Evaluate(constraintData.Tree, dataIntervals, constraintData.ParameterNodes,
278        out double[] lowerGradient, out double[] upperGradient);
279
280      // we transformed this to a constraint c(x) <= 0, so only the upper bound is relevant for us
281      if (grad != null) for (int j = 0; j < grad.Length; j++) { grad[j] = upperGradient[j]; }
282      UpdateConstraintViolations(constraintData.Idx, interval.UpperBound);
283      if (double.IsNaN(interval.UpperBound)) return double.MaxValue;
284      else return interval.UpperBound;
285    }
286
287
288    void UpdateThetaValues(double[] theta) {
289      for (int i = 0; i < theta.Length; ++i) {
290        foreach (var constNode in allThetaNodes[i]) constNode.Value = theta[i];
291      }
292    }
293
294    internal void RequestStop() {
295      NLOpt.nlopt_set_force_stop(nlopt, 1); // hopefully NLOpt is thread safe  , val must be <> 0 otherwise no effect
296    }
297
298    private void RaiseFunctionEvaluated() {
299      FunctionEvaluated?.Invoke();
300    }
301
302    private void RaiseConstraintEvaluated(int idx, double value) {
303      ConstraintEvaluated?.Invoke(idx, value);
304    }
305
306
307    #region helper
308
309    private static ISymbolicExpressionTree CopyAndScaleTree(ISymbolicExpressionTree tree, double scalingFactor, double offset) {
310      var m = (ISymbolicExpressionTree)tree.Clone();
311
312      var add = MakeNode<Addition>(MakeNode<Multiplication>(m.Root.GetSubtree(0).GetSubtree(0), CreateConstant(scalingFactor)), CreateConstant(offset));
313      m.Root.GetSubtree(0).RemoveSubtree(0);
314      m.Root.GetSubtree(0).AddSubtree(add);
315      return m;
316    }
317
318    private static void UpdateConstants(ISymbolicExpressionTreeNode[] nodes, double[] constants) {
319      if (nodes.Length != constants.Length) throw new InvalidOperationException();
320      for (int i = 0; i < nodes.Length; i++) {
321        if (nodes[i] is VariableTreeNode varNode) varNode.Weight = constants[i];
322        else if (nodes[i] is ConstantTreeNode constNode) constNode.Value = constants[i];
323      }
324    }
325
326    private NLOpt.nlopt_algorithm GetSolver(string solver) {
327      if (solver.Contains("MMA")) return NLOpt.nlopt_algorithm.NLOPT_LD_MMA;
328      if (solver.Contains("COBYLA")) return NLOpt.nlopt_algorithm.NLOPT_LN_COBYLA;
329      if (solver.Contains("CCSAQ")) return NLOpt.nlopt_algorithm.NLOPT_LD_CCSAQ;
330      if (solver.Contains("ISRES")) return NLOpt.nlopt_algorithm.NLOPT_GN_ISRES;
331      throw new ArgumentException($"Unknown solver {solver}");
332    }
333
334    private static ISymbolicExpressionTreeNode[] GetParameterNodes(ISymbolicExpressionTree tree, List<ConstantTreeNode>[] allNodes) {
335      // TODO better solution necessary
336      var treeConstNodes = tree.IterateNodesPostfix().OfType<ConstantTreeNode>().ToArray();
337      var paramNodes = new ISymbolicExpressionTreeNode[allNodes.Length];
338      for (int i = 0; i < paramNodes.Length; i++) {
339        paramNodes[i] = allNodes[i].SingleOrDefault(n => treeConstNodes.Contains(n));
340      }
341      return paramNodes;
342    }
343
344    private static ISymbolicExpressionTree ReplaceVarWithConst(ISymbolicExpressionTree tree, List<string> thetaNames, List<double> thetaValues, List<ConstantTreeNode>[] thetaNodes) {
345      var copy = (ISymbolicExpressionTree)tree.Clone();
346      var nodes = copy.IterateNodesPostfix().ToList();
347      for (int i = 0; i < nodes.Count; i++) {
348        var n = nodes[i] as VariableTreeNode;
349        if (n != null) {
350          var thetaIdx = thetaNames.IndexOf(n.VariableName);
351          if (thetaIdx >= 0) {
352            var parent = n.Parent;
353            if (thetaNodes[thetaIdx].Any()) {
354              // HACK: REUSE CONSTANT TREE NODE IN SEVERAL TREES
355              // we use this trick to allow autodiff over thetas when thetas occurr multiple times in the tree (e.g. in derived trees)
356              var constNode = thetaNodes[thetaIdx].First();
357              var childIdx = parent.IndexOfSubtree(n);
358              parent.RemoveSubtree(childIdx);
359              parent.InsertSubtree(childIdx, constNode);
360            } else {
361              var constNode = (ConstantTreeNode)CreateConstant(thetaValues[thetaIdx]);
362              var childIdx = parent.IndexOfSubtree(n);
363              parent.RemoveSubtree(childIdx);
364              parent.InsertSubtree(childIdx, constNode);
365              thetaNodes[thetaIdx].Add(constNode);
366            }
367          }
368        }
369      }
370      return copy;
371    }
372
373    private static ISymbolicExpressionTree ReplaceConstWithVar(ISymbolicExpressionTree tree, out List<string> thetaNames, out List<double> thetaValues) {
374      thetaNames = new List<string>();
375      thetaValues = new List<double>();
376      var copy = (ISymbolicExpressionTree)tree.Clone();
377      var nodes = copy.IterateNodesPostfix().ToList();
378
379      int n = 1;
380      for (int i = 0; i < nodes.Count; ++i) {
381        var node = nodes[i];
382        if (node is ConstantTreeNode constantTreeNode) {
383          var thetaVar = (VariableTreeNode)new Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();
384          thetaVar.Weight = 1;
385          thetaVar.VariableName = $"θ{n++}";
386
387          thetaNames.Add(thetaVar.VariableName);
388          thetaValues.Add(constantTreeNode.Value);
389
390          var parent = constantTreeNode.Parent;
391          if (parent != null) {
392            var index = constantTreeNode.Parent.IndexOfSubtree(constantTreeNode);
393            parent.RemoveSubtree(index);
394            parent.InsertSubtree(index, thetaVar);
395          }
396        }
397        if (node is VariableTreeNode varTreeNode) {
398          var thetaVar = (VariableTreeNode)new Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();
399          thetaVar.Weight = 1;
400          thetaVar.VariableName = $"θ{n++}";
401
402          thetaNames.Add(thetaVar.VariableName);
403          thetaValues.Add(varTreeNode.Weight);
404
405          var parent = varTreeNode.Parent;
406          if (parent != null) {
407            var index = varTreeNode.Parent.IndexOfSubtree(varTreeNode);
408            parent.RemoveSubtree(index);
409            var prodNode = MakeNode<Multiplication>();
410            varTreeNode.Weight = 1.0;
411            prodNode.AddSubtree(varTreeNode);
412            prodNode.AddSubtree(thetaVar);
413            parent.InsertSubtree(index, prodNode);
414          }
415        }
416      }
417      return copy;
418    }
419
420    private static ISymbolicExpressionTreeNode CreateConstant(double value) {
421      var constantNode = (ConstantTreeNode)new Constant().CreateTreeNode();
422      constantNode.Value = value;
423      return constantNode;
424    }
425
426    private static ISymbolicExpressionTree Subtract(ISymbolicExpressionTree t, ISymbolicExpressionTreeNode b) {
427      var sub = MakeNode<Subtraction>(t.Root.GetSubtree(0).GetSubtree(0), b);
428      t.Root.GetSubtree(0).RemoveSubtree(0);
429      t.Root.GetSubtree(0).InsertSubtree(0, sub);
430      return t;
431    }
432    private static ISymbolicExpressionTree Subtract(ISymbolicExpressionTreeNode b, ISymbolicExpressionTree t) {
433      var sub = MakeNode<Subtraction>(b, t.Root.GetSubtree(0).GetSubtree(0));
434      t.Root.GetSubtree(0).RemoveSubtree(0);
435      t.Root.GetSubtree(0).InsertSubtree(0, sub);
436      return t;
437    }
438
439    private static ISymbolicExpressionTreeNode MakeNode<T>(params ISymbolicExpressionTreeNode[] fs) where T : ISymbol, new() {
440      var node = new T().CreateTreeNode();
441      foreach (var f in fs) node.AddSubtree(f);
442      return node;
443    }
444    #endregion
445  }
446}
Note: See TracBrowser for help on using the repository browser.