Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2994-AutoDiffForIntervals/HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Extensions/NLOptEvaluator.cs @ 17200

Last change on this file since 17200 was 17196, checked in by gkronber, 5 years ago

#2994 added NLOpt wrapper and evaluator

File size: 30.9 KB
RevLine 
[17196]1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Optimization;
30using HeuristicLab.Parameters;
31using HEAL.Attic;
32using System.Runtime.InteropServices;
33
34namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
35  [Item("NLOpt Evaluator (with constraints)", "")]
36  [StorableType("5FADAE55-3516-4539-8A36-BC9B0D00880D")]
37  public class NLOptEvaluator : SymbolicRegressionSingleObjectiveEvaluator {
38    private const string ConstantOptimizationIterationsParameterName = "ConstantOptimizationIterations";
39    private const string ConstantOptimizationImprovementParameterName = "ConstantOptimizationImprovement";
40    private const string ConstantOptimizationProbabilityParameterName = "ConstantOptimizationProbability";
41    private const string ConstantOptimizationRowsPercentageParameterName = "ConstantOptimizationRowsPercentage";
42    private const string UpdateConstantsInTreeParameterName = "UpdateConstantsInSymbolicExpressionTree";
43    private const string UpdateVariableWeightsParameterName = "Update Variable Weights";
44
45    private const string FunctionEvaluationsResultParameterName = "Constants Optimization Function Evaluations";
46    private const string GradientEvaluationsResultParameterName = "Constants Optimization Gradient Evaluations";
47    private const string CountEvaluationsParameterName = "Count Function and Gradient Evaluations";
48
49    public IFixedValueParameter<IntValue> ConstantOptimizationIterationsParameter {
50      get { return (IFixedValueParameter<IntValue>)Parameters[ConstantOptimizationIterationsParameterName]; }
51    }
52    public IFixedValueParameter<DoubleValue> ConstantOptimizationImprovementParameter {
53      get { return (IFixedValueParameter<DoubleValue>)Parameters[ConstantOptimizationImprovementParameterName]; }
54    }
55    public IFixedValueParameter<PercentValue> ConstantOptimizationProbabilityParameter {
56      get { return (IFixedValueParameter<PercentValue>)Parameters[ConstantOptimizationProbabilityParameterName]; }
57    }
58    public IFixedValueParameter<PercentValue> ConstantOptimizationRowsPercentageParameter {
59      get { return (IFixedValueParameter<PercentValue>)Parameters[ConstantOptimizationRowsPercentageParameterName]; }
60    }
61    public IFixedValueParameter<BoolValue> UpdateConstantsInTreeParameter {
62      get { return (IFixedValueParameter<BoolValue>)Parameters[UpdateConstantsInTreeParameterName]; }
63    }
64    public IFixedValueParameter<BoolValue> UpdateVariableWeightsParameter {
65      get { return (IFixedValueParameter<BoolValue>)Parameters[UpdateVariableWeightsParameterName]; }
66    }
67
68    public IResultParameter<IntValue> FunctionEvaluationsResultParameter {
69      get { return (IResultParameter<IntValue>)Parameters[FunctionEvaluationsResultParameterName]; }
70    }
71    public IResultParameter<IntValue> GradientEvaluationsResultParameter {
72      get { return (IResultParameter<IntValue>)Parameters[GradientEvaluationsResultParameterName]; }
73    }
74    public IFixedValueParameter<BoolValue> CountEvaluationsParameter {
75      get { return (IFixedValueParameter<BoolValue>)Parameters[CountEvaluationsParameterName]; }
76    }
77    public IConstrainedValueParameter<StringValue> SolverParameter {
78      get { return (IConstrainedValueParameter<StringValue>)Parameters["Solver"]; }
79    }
80
81
82    public IntValue ConstantOptimizationIterations {
83      get { return ConstantOptimizationIterationsParameter.Value; }
84    }
85    public DoubleValue ConstantOptimizationImprovement {
86      get { return ConstantOptimizationImprovementParameter.Value; }
87    }
88    public PercentValue ConstantOptimizationProbability {
89      get { return ConstantOptimizationProbabilityParameter.Value; }
90    }
91    public PercentValue ConstantOptimizationRowsPercentage {
92      get { return ConstantOptimizationRowsPercentageParameter.Value; }
93    }
94    public bool UpdateConstantsInTree {
95      get { return UpdateConstantsInTreeParameter.Value.Value; }
96      set { UpdateConstantsInTreeParameter.Value.Value = value; }
97    }
98
99    public bool UpdateVariableWeights {
100      get { return UpdateVariableWeightsParameter.Value.Value; }
101      set { UpdateVariableWeightsParameter.Value.Value = value; }
102    }
103
104    public bool CountEvaluations {
105      get { return CountEvaluationsParameter.Value.Value; }
106      set { CountEvaluationsParameter.Value.Value = value; }
107    }
108
109    public string Solver {
110      get { return SolverParameter.Value.Value; }
111    }
112    public override bool Maximization {
113      get { return false; }
114    }
115
116    [StorableConstructor]
117    protected NLOptEvaluator(StorableConstructorFlag _) : base(_) { }
118    protected NLOptEvaluator(NLOptEvaluator original, Cloner cloner)
119      : base(original, cloner) {
120    }
121    public NLOptEvaluator()
122      : base() {
123      Parameters.Add(new FixedValueParameter<IntValue>(ConstantOptimizationIterationsParameterName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree (0 indicates other or default stopping criterion).", new IntValue(10)));
124      Parameters.Add(new FixedValueParameter<DoubleValue>(ConstantOptimizationImprovementParameterName, "Determines the relative improvement which must be achieved in the constant optimization to continue with it (0 indicates other or default stopping criterion).", new DoubleValue(0)) { Hidden = true });
125      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationProbabilityParameterName, "Determines the probability that the constants are optimized", new PercentValue(1)));
126      Parameters.Add(new FixedValueParameter<PercentValue>(ConstantOptimizationRowsPercentageParameterName, "Determines the percentage of the rows which should be used for constant optimization", new PercentValue(1)));
127      Parameters.Add(new FixedValueParameter<BoolValue>(UpdateConstantsInTreeParameterName, "Determines if the constants in the tree should be overwritten by the optimized constants.", new BoolValue(true)) { Hidden = true });
128      Parameters.Add(new FixedValueParameter<BoolValue>(UpdateVariableWeightsParameterName, "Determines if the variable weights in the tree should be  optimized.", new BoolValue(true)) { Hidden = true });
129
130      Parameters.Add(new FixedValueParameter<BoolValue>(CountEvaluationsParameterName, "Determines if function and gradient evaluation should be counted.", new BoolValue(false)));
131
132      var validSolvers = new ItemSet<StringValue>(new[] { "MMA", "COBYLA", "CCSAQ", "ISRES" }.Select(s => new StringValue(s).AsReadOnly()));
133      Parameters.Add(new ConstrainedValueParameter<StringValue>("Solver", "The solver algorithm", validSolvers, validSolvers.First()));
134      Parameters.Add(new ResultParameter<IntValue>(FunctionEvaluationsResultParameterName, "The number of function evaluations performed by the constants optimization evaluator", "Results", new IntValue()));
135      Parameters.Add(new ResultParameter<IntValue>(GradientEvaluationsResultParameterName, "The number of gradient evaluations performed by the constants optimization evaluator", "Results", new IntValue()));
136    }
137
138    public override IDeepCloneable Clone(Cloner cloner) {
139      return new NLOptEvaluator(this, cloner);
140    }
141
142    [StorableHook(HookType.AfterDeserialization)]
143    private void AfterDeserialization() { }
144
145    private static readonly object locker = new object();
146
147    public override IOperation InstrumentedApply() {
148      var solution = SymbolicExpressionTreeParameter.ActualValue;
149      double quality;
150      if (RandomParameter.ActualValue.NextDouble() < ConstantOptimizationProbability.Value) {
151        IEnumerable<int> constantOptimizationRows = GenerateRowsToEvaluate(ConstantOptimizationRowsPercentage.Value);
152        var counter = new EvaluationsCounter();
153        quality = OptimizeConstants(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, ProblemDataParameter.ActualValue,
154           constantOptimizationRows, ApplyLinearScalingParameter.ActualValue.Value, Solver, ConstantOptimizationIterations.Value, updateVariableWeights: UpdateVariableWeights, lowerEstimationLimit: EstimationLimitsParameter.ActualValue.Lower, upperEstimationLimit: EstimationLimitsParameter.ActualValue.Upper, updateConstantsInTree: UpdateConstantsInTree, counter: counter);
155
156        if (ConstantOptimizationRowsPercentage.Value != RelativeNumberOfEvaluatedSamplesParameter.ActualValue.Value) {
157          throw new NotSupportedException();
158        }
159
160        if (CountEvaluations) {
161          lock (locker) {
162            FunctionEvaluationsResultParameter.ActualValue.Value += counter.FunctionEvaluations;
163            GradientEvaluationsResultParameter.ActualValue.Value += counter.GradientEvaluations;
164          }
165        }
166
167      } else {
168        throw new NotSupportedException();
169      }
170      QualityParameter.ActualValue = new DoubleValue(quality);
171
172      return base.InstrumentedApply();
173    }
174
175    public override double Evaluate(IExecutionContext context, ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows) {
176      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = context;
177      EstimationLimitsParameter.ExecutionContext = context;
178      ApplyLinearScalingParameter.ExecutionContext = context;
179      FunctionEvaluationsResultParameter.ExecutionContext = context;
180      GradientEvaluationsResultParameter.ExecutionContext = context;
181
182      // MSE evaluator is used on purpose instead of the const-opt evaluator,
183      // because Evaluate() is used to get the quality of evolved models on
184      // different partitions of the dataset (e.g., best validation model)
185      double mse = SymbolicRegressionSingleObjectiveMeanSquaredErrorEvaluator.Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, tree, double.MinValue, double.MaxValue, problemData, rows, applyLinearScaling: false);
186
187      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = null;
188      EstimationLimitsParameter.ExecutionContext = null;
189      ApplyLinearScalingParameter.ExecutionContext = null;
190      FunctionEvaluationsResultParameter.ExecutionContext = null;
191      GradientEvaluationsResultParameter.ExecutionContext = null;
192
193      return mse;
194    }
195
196    public class EvaluationsCounter {
197      public int FunctionEvaluations = 0;
198      public int GradientEvaluations = 0;
199    }
200
201    private static void GetParameterNodes(ISymbolicExpressionTree tree, out List<ISymbolicExpressionTreeNode> thetaNodes, out List<double> thetaValues) {
202      thetaNodes = new List<ISymbolicExpressionTreeNode>();
203      thetaValues = new List<double>();
204
205      var nodes = tree.IterateNodesPrefix().ToArray();
206      for (int i = 0; i < nodes.Length; ++i) {
207        var node = nodes[i];
208        if (node is VariableTreeNode variableTreeNode) {
209          thetaValues.Add(variableTreeNode.Weight);
210          thetaNodes.Add(node);
211        } else if (node is ConstantTreeNode constantTreeNode) {
212          thetaNodes.Add(node);
213          thetaValues.Add(constantTreeNode.Value);
214        }
215      }
216    }
217
218    // for data exchange to/from optimizer in native code
219    [StructLayout(LayoutKind.Sequential)]
220    private struct ConstraintData {
221      public ISymbolicExpressionTree Tree;
222      public ISymbolicExpressionTreeNode[] ParameterNodes;
223    }
224
225    public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
226      ISymbolicExpressionTree tree, IRegressionProblemData problemData, IEnumerable<int> rows, bool applyLinearScaling,
227      string solver,
228      int maxIterations, bool updateVariableWeights = true,
229      double lowerEstimationLimit = double.MinValue, double upperEstimationLimit = double.MaxValue,
230      bool updateConstantsInTree = true, Action<double[], double, object> iterationCallback = null, EvaluationsCounter counter = null) {
231
232      if (!updateVariableWeights) throw new NotSupportedException("not updating variable weights is not supported");
233      if (!updateConstantsInTree) throw new NotSupportedException("not updating tree parameters is not supported");
234      if (!applyLinearScaling) throw new NotSupportedException("application without linear scaling is not supported");
235
236      // we always update constants, so we don't need to calculate initial quality
237      // double originalQuality = SymbolicRegressionSingleObjectiveMeanSquaredErrorEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling: false);
238
239      if (counter == null) counter = new EvaluationsCounter();
240      var rowEvaluationsCounter = new EvaluationsCounter();
241
242      var intervalConstraints = problemData.IntervalConstraints;
243      var dataIntervals = problemData.VariableRanges.GetIntervals();
244      var trainingRows = problemData.TrainingIndices.ToArray();
245      // buffers
246      var target = problemData.TargetVariableTrainingValues.ToArray();
247      var targetStDev = target.StandardDeviationPop();
248      var targetVariance = targetStDev * targetStDev;
249      var targetMean = target.Average();
250      var pred = interpreter.GetSymbolicExpressionTreeValues(tree, problemData.Dataset, trainingRows).ToArray();
251      if (pred.Any(pi => double.IsInfinity(pi) || double.IsNaN(pi))) return targetVariance;
252
253      #region linear scaling
254      var predStDev = pred.StandardDeviationPop();
255      if (predStDev == 0) return targetVariance; // constant expression
256      var predMean = pred.Average();
257
258      var scalingFactor = targetStDev / predStDev;
259      var offset = targetMean - predMean * scalingFactor;
260
261      ISymbolicExpressionTree scaledTree = null;
262      if (applyLinearScaling) scaledTree = CopyAndScaleTree(tree, scalingFactor, offset);
263      #endregion
264
265      // convert constants to variables named theta...
266      var treeForDerivation = ReplaceConstWithVar(scaledTree, out List<string> thetaNames, out List<double> thetaValues); // copies the tree
267
268      // create trees for relevant derivatives
269      Dictionary<string, ISymbolicExpressionTree> derivatives = new Dictionary<string, ISymbolicExpressionTree>();
270      var allThetaNodes = thetaNames.Select(_ => new List<ConstantTreeNode>()).ToArray();
271      var constraintTrees = new List<ISymbolicExpressionTree>();
272      foreach (var constraint in intervalConstraints.Constraints) {
273        if (constraint.IsDerivation) {
274          if (!problemData.AllowedInputVariables.Contains(constraint.Variable))
275            throw new ArgumentException($"Invalid constraint: the variable {constraint.Variable} does not exist in the dataset.");
276          var df = DerivativeCalculator.Derive(treeForDerivation, constraint.Variable);
277
278          // alglib requires constraint expressions of the form c(x) <= 0
279          // -> we make two expressions, one for the lower bound and one for the upper bound
280
281          if (constraint.Interval.UpperBound < double.PositiveInfinity) {
282            var df_smaller_upper = Subtract((ISymbolicExpressionTree)df.Clone(), CreateConstant(constraint.Interval.UpperBound));
283            // convert variables named theta back to constants
284            var df_prepared = ReplaceVarWithConst(df_smaller_upper, thetaNames, thetaValues, allThetaNodes);
285            constraintTrees.Add(df_prepared);
286          }
287          if (constraint.Interval.LowerBound > double.NegativeInfinity) {
288            var df_larger_lower = Subtract(CreateConstant(constraint.Interval.LowerBound), (ISymbolicExpressionTree)df.Clone());
289            // convert variables named theta back to constants
290            var df_prepared = ReplaceVarWithConst(df_larger_lower, thetaNames, thetaValues, allThetaNodes);
291            constraintTrees.Add(df_prepared);
292          }
293        } else {
294          if (constraint.Interval.UpperBound < double.PositiveInfinity) {
295            var f_smaller_upper = Subtract((ISymbolicExpressionTree)treeForDerivation.Clone(), CreateConstant(constraint.Interval.UpperBound));
296            // convert variables named theta back to constants
297            var df_prepared = ReplaceVarWithConst(f_smaller_upper, thetaNames, thetaValues, allThetaNodes);
298            constraintTrees.Add(df_prepared);
299          }
300          if (constraint.Interval.LowerBound > double.NegativeInfinity) {
301            var f_larger_lower = Subtract(CreateConstant(constraint.Interval.LowerBound), (ISymbolicExpressionTree)treeForDerivation.Clone());
302            // convert variables named theta back to constants
303            var df_prepared = ReplaceVarWithConst(f_larger_lower, thetaNames, thetaValues, allThetaNodes);
304            constraintTrees.Add(df_prepared);
305          }
306        }
307      }
308
309      var preparedTree = ReplaceVarWithConst(treeForDerivation, thetaNames, thetaValues, allThetaNodes);
310      var preparedTreeParameterNodes = GetParameterNodes(preparedTree, allThetaNodes);
311
312      // local function
313      void UpdateThetaValues(double[] theta) {
314        for (int i = 0; i < theta.Length; ++i) {
315          foreach (var constNode in allThetaNodes[i]) constNode.Value = theta[i];
316        }
317      }
318
319      var fi_eval = new double[target.Length];
320      var jac_eval = new double[target.Length, thetaValues.Count];
321
322      double calculate_obj(uint dim, double[] curX, double[] grad, IntPtr data) {
323        UpdateThetaValues(curX);
324
325        if (grad != null) {
326          var autoDiffEval = new VectorAutoDiffEvaluator();
327          autoDiffEval.Evaluate(preparedTree, problemData.Dataset, trainingRows,
328            preparedTreeParameterNodes, fi_eval, jac_eval);
329
330          // calc sum of squared errors and gradient
331          var sse = 0.0;
332          for (int j = 0; j < grad.Length; j++) grad[j] = 0;
333          for (int i = 0; i < target.Length; i++) {
334            var r = target[i] - fi_eval[i];
335            sse += 0.5 * r * r;
336            for (int j = 0; j < grad.Length; j++) {
337              grad[j] -= r * jac_eval[i, j];
338            }
339          }
340          if (double.IsNaN(sse)) return double.MaxValue;
341          // average
342          for (int j = 0; j < grad.Length; j++) { grad[j] /= target.Length; }
343          return sse / target.Length;
344        } else {
345          var eval = new VectorEvaluator();
346          var prediction = eval.Evaluate(preparedTree, problemData.Dataset, trainingRows);
347
348          // calc sum of squared errors and gradient
349          var sse = 0.0;
350          for (int i = 0; i < target.Length; i++) {
351            var r = target[i] - prediction[i];
352            sse += 0.5 * r * r;
353          }
354          if (double.IsNaN(sse)) return double.MaxValue;
355          // average
356          return sse / target.Length;
357        }
358
359      }
360
361      double calculate_constraint(uint dim, double[] curX, double[] grad, IntPtr data) {
362        UpdateThetaValues(curX);
363        var intervalEvaluator = new IntervalEvaluator();
364        var constraintData = Marshal.PtrToStructure<ConstraintData>(data);
365
366        if (grad != null) for (int j = 0; j < grad.Length; j++) grad[j] = 0; // clear grad
367
368        var interval = intervalEvaluator.Evaluate(constraintData.Tree, dataIntervals, constraintData.ParameterNodes,
369          out double[] lowerGradient, out double[] upperGradient);
370
371        // we transformed this to a constraint c(x) <= 0, so only the upper bound is relevant for us
372        if (grad != null) for (int j = 0; j < grad.Length; j++) { grad[j] = upperGradient[j]; }
373        if (double.IsNaN(interval.UpperBound)) return double.MaxValue;
374        else return interval.UpperBound;
375      }
376
377      var minVal = Math.Min(-1000.0, thetaValues.Min());
378      var maxVal = Math.Max(1000.0, thetaValues.Max());
379      var lb = Enumerable.Repeat(minVal, thetaValues.Count).ToArray();
380      var up = Enumerable.Repeat(maxVal, thetaValues.Count).ToArray();
381      IntPtr nlopt_opt = NLOpt.nlopt_create(GetAlgFromIdentifier(solver), (uint)thetaValues.Count); /* algorithm and dimensionality */
382
383      NLOpt.nlopt_set_lower_bounds(nlopt_opt, lb);
384      NLOpt.nlopt_set_upper_bounds(nlopt_opt, up);
385      var calculateObjectiveDelegate = new NLOpt.nlopt_func(calculate_obj); // keep a reference to the delegate (see below)
386      NLOpt.nlopt_set_min_objective(nlopt_opt, calculateObjectiveDelegate, IntPtr.Zero);
387
388
389      var constraintDataPtr = new IntPtr[constraintTrees.Count];
390      var calculateConstraintDelegates = new NLOpt.nlopt_func[constraintTrees.Count]; // make sure we keep a reference to the delegates (otherwise GC will free delegate objects see https://stackoverflow.com/questions/7302045/callback-delegates-being-collected#7302258)
391      for (int i = 0; i < constraintTrees.Count; i++) {
392        var constraintData = new ConstraintData() { Tree = constraintTrees[i], ParameterNodes = GetParameterNodes(constraintTrees[i], allThetaNodes) };
393        constraintDataPtr[i] = Marshal.AllocHGlobal(Marshal.SizeOf<ConstraintData>());
394        Marshal.StructureToPtr(constraintData, constraintDataPtr[i], fDeleteOld: false);
395        calculateConstraintDelegates[i] = new NLOpt.nlopt_func(calculate_constraint);
396        NLOpt.nlopt_add_inequality_constraint(nlopt_opt, calculateConstraintDelegates[i], constraintDataPtr[i], 1e-8);
397      }
398
399      NLOpt.nlopt_set_xtol_rel(nlopt_opt, 1e-4);
400      NLOpt.nlopt_set_maxtime(nlopt_opt, 10.0); // 10 secs
401      NLOpt.nlopt_set_maxeval(nlopt_opt, maxIterations);
402
403      var x = thetaValues.ToArray();  /* initial guess */
404      double minf = double.MaxValue; /* minimum objective value upon return */
405      var res = NLOpt.nlopt_optimize(nlopt_opt, x, ref minf);
406      if (res < 0) {
407        throw new InvalidOperationException($"NLOpt failed {res} {NLOpt.nlopt_get_errmsg(nlopt_opt)}");
408      } else {
409        // update parameters in tree
410        var pIdx = 0;
411        // here we lose the two last parameters (for linear scaling)
412        foreach (var node in tree.IterateNodesPostfix()) {
413          if (node is ConstantTreeNode constTreeNode) {
414            constTreeNode.Value = x[pIdx++];
415          } else if (node is VariableTreeNode varTreeNode) {
416            varTreeNode.Weight = x[pIdx++];
417          }
418        }
419        // note: we keep the optimized constants even when the tree is worse.
420        // assert that we lose the last two parameters
421        if (pIdx != x.Length - 2) throw new InvalidProgramException();
422      }
423
424      NLOpt.nlopt_destroy(nlopt_opt);
425      for (int i = 0; i < constraintDataPtr.Length; i++)
426        Marshal.FreeHGlobal(constraintDataPtr[i]);
427
428      counter.FunctionEvaluations += NLOpt.nlopt_get_numevals(nlopt_opt);
429      // counter.GradientEvaluations += NLOpt.nlopt_get; // TODO
430
431
432
433      return Math.Min(minf, targetVariance);
434    }
435
436    private static NLOpt.nlopt_algorithm GetAlgFromIdentifier(string solver) {
437      if (solver.Contains("MMA")) return NLOpt.nlopt_algorithm.NLOPT_LD_MMA;
438      if (solver.Contains("COBYLA")) return NLOpt.nlopt_algorithm.NLOPT_LN_COBYLA;
439      if (solver.Contains("CCSAQ")) return NLOpt.nlopt_algorithm.NLOPT_LD_CCSAQ;
440      if (solver.Contains("ISRES")) return NLOpt.nlopt_algorithm.NLOPT_GN_ISRES;
441      throw new ArgumentException($"Unknown solver {solver}");
442    }
443
444    private static ISymbolicExpressionTree CopyAndScaleTree(ISymbolicExpressionTree tree, double scalingFactor, double offset) {
445      var m = (ISymbolicExpressionTree)tree.Clone();
446
447      var add = MakeNode<Addition>(MakeNode<Multiplication>(m.Root.GetSubtree(0).GetSubtree(0), CreateConstant(scalingFactor)), CreateConstant(offset));
448      m.Root.GetSubtree(0).RemoveSubtree(0);
449      m.Root.GetSubtree(0).AddSubtree(add);
450      return m;
451    }
452
453    #region helper
454    private static ISymbolicExpressionTreeNode[] GetParameterNodes(ISymbolicExpressionTree tree, List<ConstantTreeNode>[] allNodes) {
455      // TODO better solution necessary
456      var treeConstNodes = tree.IterateNodesPostfix().OfType<ConstantTreeNode>().ToArray();
457      var paramNodes = new ISymbolicExpressionTreeNode[allNodes.Length];
458      for (int i = 0; i < paramNodes.Length; i++) {
459        paramNodes[i] = allNodes[i].SingleOrDefault(n => treeConstNodes.Contains(n));
460      }
461      return paramNodes;
462    }
463
464    private static ISymbolicExpressionTree ReplaceVarWithConst(ISymbolicExpressionTree tree, List<string> thetaNames, List<double> thetaValues, List<ConstantTreeNode>[] thetaNodes) {
465      var copy = (ISymbolicExpressionTree)tree.Clone();
466      var nodes = copy.IterateNodesPostfix().ToList();
467      for (int i = 0; i < nodes.Count; i++) {
468        var n = nodes[i] as VariableTreeNode;
469        if (n != null) {
470          var thetaIdx = thetaNames.IndexOf(n.VariableName);
471          if (thetaIdx >= 0) {
472            var parent = n.Parent;
473            if (thetaNodes[thetaIdx].Any()) {
474              // HACK: REUSE CONSTANT TREE NODE IN SEVERAL TREES
475              // we use this trick to allow autodiff over thetas when thetas occurr multiple times in the tree (e.g. in derived trees)
476              var constNode = thetaNodes[thetaIdx].First();
477              var childIdx = parent.IndexOfSubtree(n);
478              parent.RemoveSubtree(childIdx);
479              parent.InsertSubtree(childIdx, constNode);
480            } else {
481              var constNode = (ConstantTreeNode)CreateConstant(thetaValues[thetaIdx]);
482              var childIdx = parent.IndexOfSubtree(n);
483              parent.RemoveSubtree(childIdx);
484              parent.InsertSubtree(childIdx, constNode);
485              thetaNodes[thetaIdx].Add(constNode);
486            }
487          }
488        }
489      }
490      return copy;
491    }
492
493    private static ISymbolicExpressionTree ReplaceConstWithVar(ISymbolicExpressionTree tree, out List<string> thetaNames, out List<double> thetaValues) {
494      thetaNames = new List<string>();
495      thetaValues = new List<double>();
496      var copy = (ISymbolicExpressionTree)tree.Clone();
497      var nodes = copy.IterateNodesPostfix().ToList();
498
499      int n = 1;
500      for (int i = 0; i < nodes.Count; ++i) {
501        var node = nodes[i];
502        if (node is ConstantTreeNode constantTreeNode) {
503          var thetaVar = (VariableTreeNode)new Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();
504          thetaVar.Weight = 1;
505          thetaVar.VariableName = $"θ{n++}";
506
507          thetaNames.Add(thetaVar.VariableName);
508          thetaValues.Add(constantTreeNode.Value);
509
510          var parent = constantTreeNode.Parent;
511          if (parent != null) {
512            var index = constantTreeNode.Parent.IndexOfSubtree(constantTreeNode);
513            parent.RemoveSubtree(index);
514            parent.InsertSubtree(index, thetaVar);
515          }
516        }
517        if (node is VariableTreeNode varTreeNode) {
518          var thetaVar = (VariableTreeNode)new Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();
519          thetaVar.Weight = 1;
520          thetaVar.VariableName = $"θ{n++}";
521
522          thetaNames.Add(thetaVar.VariableName);
523          thetaValues.Add(varTreeNode.Weight);
524
525          var parent = varTreeNode.Parent;
526          if (parent != null) {
527            var index = varTreeNode.Parent.IndexOfSubtree(varTreeNode);
528            parent.RemoveSubtree(index);
529            var prodNode = MakeNode<Multiplication>();
530            varTreeNode.Weight = 1.0;
531            prodNode.AddSubtree(varTreeNode);
532            prodNode.AddSubtree(thetaVar);
533            parent.InsertSubtree(index, prodNode);
534          }
535        }
536      }
537      return copy;
538    }
539
540    private static ISymbolicExpressionTreeNode CreateConstant(double value) {
541      var constantNode = (ConstantTreeNode)new Constant().CreateTreeNode();
542      constantNode.Value = value;
543      return constantNode;
544    }
545
546    private static ISymbolicExpressionTree Subtract(ISymbolicExpressionTree t, ISymbolicExpressionTreeNode b) {
547      var sub = MakeNode<Subtraction>(t.Root.GetSubtree(0).GetSubtree(0), b);
548      t.Root.GetSubtree(0).RemoveSubtree(0);
549      t.Root.GetSubtree(0).InsertSubtree(0, sub);
550      return t;
551    }
552    private static ISymbolicExpressionTree Subtract(ISymbolicExpressionTreeNode b, ISymbolicExpressionTree t) {
553      var sub = MakeNode<Subtraction>(b, t.Root.GetSubtree(0).GetSubtree(0));
554      t.Root.GetSubtree(0).RemoveSubtree(0);
555      t.Root.GetSubtree(0).InsertSubtree(0, sub);
556      return t;
557    }
558
559    private static ISymbolicExpressionTreeNode MakeNode<T>(params ISymbolicExpressionTreeNode[] fs) where T : ISymbol, new() {
560      var node = new T().CreateTreeNode();
561      foreach (var f in fs) node.AddSubtree(f);
562      return node;
563    }
564    #endregion
565
566    private static void UpdateConstants(ISymbolicExpressionTreeNode[] nodes, double[] constants) {
567      if (nodes.Length != constants.Length) throw new InvalidOperationException();
568      for (int i = 0; i < nodes.Length; i++) {
569        if (nodes[i] is VariableTreeNode varNode) varNode.Weight = constants[i];
570        else if (nodes[i] is ConstantTreeNode constNode) constNode.Value = constants[i];
571      }
572    }
573
574    private static alglib.ndimensional_fvec CreateFunc(ISymbolicExpressionTree tree, VectorEvaluator eval, ISymbolicExpressionTreeNode[] parameterNodes, IDataset ds, string targetVar, int[] rows) {
575      var y = ds.GetDoubleValues(targetVar, rows).ToArray();
576      return (double[] c, double[] fi, object o) => {
577        UpdateConstants(parameterNodes, c);
578        var pred = eval.Evaluate(tree, ds, rows);
579        for (int i = 0; i < fi.Length; i++)
580          fi[i] = pred[i] - y[i];
581
582        var counter = (EvaluationsCounter)o;
583        counter.FunctionEvaluations++;
584      };
585    }
586
587    private static alglib.ndimensional_jac CreateJac(ISymbolicExpressionTree tree, VectorAutoDiffEvaluator eval, ISymbolicExpressionTreeNode[] parameterNodes, IDataset ds, string targetVar, int[] rows) {
588      var y = ds.GetDoubleValues(targetVar, rows).ToArray();
589      return (double[] c, double[] fi, double[,] jac, object o) => {
590        UpdateConstants(parameterNodes, c);
591        eval.Evaluate(tree, ds, rows, parameterNodes, fi, jac);
592
593        for (int i = 0; i < fi.Length; i++)
594          fi[i] -= y[i];
595
596        var counter = (EvaluationsCounter)o;
597        counter.GradientEvaluations++;
598      };
599    }
600    public static bool CanOptimizeConstants(ISymbolicExpressionTree tree) {
601      return TreeToAutoDiffTermConverter.IsCompatible(tree);
602    }
603  }
604}
Note: See TracBrowser for help on using the repository browser.