Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2994-AutoDiffForIntervals/HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Extensions/ConstrainedNLSInternal.cs @ 17204

Last change on this file since 17204 was 17204, checked in by gkronber, 5 years ago

#2994: added parameter for gradient checks and experimented with preconditioning

File size: 24.9 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Runtime.InteropServices;
6using HeuristicLab.Common;
7using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
8
9namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression.Extensions {
10  internal class ConstrainedNLSInternal {
11    private readonly int maxIterations;
12    public int MaxIterations => maxIterations;
13
14    private readonly string solver;
15    public string Solver => solver;
16
17    private readonly ISymbolicExpressionTree expr;
18    public ISymbolicExpressionTree Expr => expr;
19
20    private readonly IRegressionProblemData problemData;
21
22    public IRegressionProblemData ProblemData => problemData;
23
24
25    public event Action FunctionEvaluated;
26    public event Action<int, double> ConstraintEvaluated;
27
28    private double bestError = double.MaxValue;
29    public double BestError => bestError;
30
31    private double curError = double.MaxValue;
32    public double CurError => curError;
33
34    private double[] bestSolution;
35    public double[] BestSolution => bestSolution;
36
37    private ISymbolicExpressionTree bestTree;
38    public ISymbolicExpressionTree BestTree => bestTree;
39
40    private double[] bestConstraintValues;
41    public double[] BestConstraintValues => bestConstraintValues;
42
43    public bool CheckGradient { get; internal set; }
44
45    // begin internal state
46    private IntPtr nlopt;
47    private SymbolicDataAnalysisExpressionTreeLinearInterpreter interpreter;
48    private readonly NLOpt.nlopt_func calculateObjectiveDelegate; // must hold the delegate to prevent GC
49    private readonly NLOpt.nlopt_precond preconditionDelegate;
50    private readonly IntPtr[] constraintDataPtr; // must hold the objects to prevent GC
51    private readonly NLOpt.nlopt_func[] calculateConstraintDelegates; // must hold the delegates to prevent GC
52    private readonly List<double> thetaValues;
53    private readonly IDictionary<string, Interval> dataIntervals;
54    private readonly int[] trainingRows;
55    private readonly double[] target;
56    private readonly ISymbolicExpressionTree preparedTree;
57    private readonly ISymbolicExpressionTreeNode[] preparedTreeParameterNodes;
58    private readonly List<ConstantTreeNode>[] allThetaNodes;
59    public List<ISymbolicExpressionTree> constraintTrees;    // TODO make local in ctor (public for debugging)
60   
61    private readonly double[] fi_eval;
62    private readonly double[,] jac_eval;
63    private readonly ISymbolicExpressionTree scaledTree;
64    private readonly VectorAutoDiffEvaluator autoDiffEval;
65    private readonly VectorEvaluator eval;
66
67    // end internal state
68
69
70    // for data exchange to/from optimizer in native code
71    [StructLayout(LayoutKind.Sequential)]
72    private struct ConstraintData {
73      public int Idx;
74      public ISymbolicExpressionTree Tree;
75      public ISymbolicExpressionTreeNode[] ParameterNodes;
76    }
77
78    internal ConstrainedNLSInternal(string solver, ISymbolicExpressionTree expr, int maxIterations, IRegressionProblemData problemData, double ftol_rel = 0, double ftol_abs = 0, double maxTime = 0) {
79      this.solver = solver;
80      this.expr = expr;
81      this.maxIterations = maxIterations;
82      this.problemData = problemData;
83      this.interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
84      this.autoDiffEval = new VectorAutoDiffEvaluator();
85      this.eval = new VectorEvaluator();
86
87      CheckGradient = false;
88
89      var intervalConstraints = problemData.IntervalConstraints;
90      dataIntervals = problemData.VariableRanges.GetIntervals();
91      trainingRows = problemData.TrainingIndices.ToArray();
92      // buffers
93      target = problemData.TargetVariableTrainingValues.ToArray();
94      var targetStDev = target.StandardDeviationPop();
95      var targetVariance = targetStDev * targetStDev;
96      var targetMean = target.Average();
97      var pred = interpreter.GetSymbolicExpressionTreeValues(expr, problemData.Dataset, trainingRows).ToArray();
98
99      if (pred.Any(pi => double.IsInfinity(pi) || double.IsNaN(pi))) throw new ArgumentException("The expression produces NaN or infinite values.");
100
101      #region linear scaling
102      var predStDev = pred.StandardDeviationPop();
103      if (predStDev == 0) throw new ArgumentException("The expression is constant.");
104      var predMean = pred.Average();
105
106      var scalingFactor = targetStDev / predStDev;
107      var offset = targetMean - predMean * scalingFactor;
108
109      scaledTree = CopyAndScaleTree(expr, scalingFactor, offset);
110      #endregion
111
112      // convert constants to variables named theta...
113      var treeForDerivation = ReplaceConstWithVar(scaledTree, out List<string> thetaNames, out thetaValues); // copies the tree
114
115      // create trees for relevant derivatives
116      Dictionary<string, ISymbolicExpressionTree> derivatives = new Dictionary<string, ISymbolicExpressionTree>();
117      allThetaNodes = thetaNames.Select(_ => new List<ConstantTreeNode>()).ToArray();
118      constraintTrees = new List<ISymbolicExpressionTree>();
119      foreach (var constraint in intervalConstraints.Constraints) {
120        if (constraint.IsDerivation) {
121          if (!problemData.AllowedInputVariables.Contains(constraint.Variable))
122            throw new ArgumentException($"Invalid constraint: the variable {constraint.Variable} does not exist in the dataset.");
123          var df = DerivativeCalculator.Derive(treeForDerivation, constraint.Variable);
124
125          // NLOpt requires constraint expressions of the form c(x) <= 0
126          // -> we make two expressions, one for the lower bound and one for the upper bound
127
128          if (constraint.Interval.UpperBound < double.PositiveInfinity) {
129            var df_smaller_upper = Subtract((ISymbolicExpressionTree)df.Clone(), CreateConstant(constraint.Interval.UpperBound));
130            // convert variables named theta back to constants
131            var df_prepared = ReplaceVarWithConst(df_smaller_upper, thetaNames, thetaValues, allThetaNodes);
132            constraintTrees.Add((ISymbolicExpressionTree)df_prepared.Clone());
133          }
134          if (constraint.Interval.LowerBound > double.NegativeInfinity) {
135            var df_larger_lower = Subtract(CreateConstant(constraint.Interval.LowerBound), (ISymbolicExpressionTree)df.Clone());
136            // convert variables named theta back to constants
137            var df_prepared = ReplaceVarWithConst(df_larger_lower, thetaNames, thetaValues, allThetaNodes);           
138            constraintTrees.Add((ISymbolicExpressionTree)df_prepared.Clone());
139          }
140        } else {
141          if (constraint.Interval.UpperBound < double.PositiveInfinity) {
142            var f_smaller_upper = Subtract((ISymbolicExpressionTree)treeForDerivation.Clone(), CreateConstant(constraint.Interval.UpperBound));
143            // convert variables named theta back to constants
144            var df_prepared = ReplaceVarWithConst(f_smaller_upper, thetaNames, thetaValues, allThetaNodes);
145            constraintTrees.Add((ISymbolicExpressionTree)df_prepared.Clone());
146          }
147          if (constraint.Interval.LowerBound > double.NegativeInfinity) {
148            var f_larger_lower = Subtract(CreateConstant(constraint.Interval.LowerBound), (ISymbolicExpressionTree)treeForDerivation.Clone());
149            // convert variables named theta back to constants
150            var df_prepared = ReplaceVarWithConst(f_larger_lower, thetaNames, thetaValues, allThetaNodes);
151            constraintTrees.Add((ISymbolicExpressionTree)df_prepared.Clone());
152          }
153        }
154      }
155
156      preparedTree = ReplaceVarWithConst(treeForDerivation, thetaNames, thetaValues, allThetaNodes);
157      preparedTreeParameterNodes = GetParameterNodes(preparedTree, allThetaNodes);
158
159      var dim = thetaValues.Count;
160      fi_eval = new double[target.Length]; // init buffer;
161      jac_eval = new double[target.Length, dim]; // init buffer
162
163
164      var minVal = Math.Min(-100.0, thetaValues.Min());
165      var maxVal = Math.Max(100.0, thetaValues.Max());
166      var lb = Enumerable.Repeat(minVal, thetaValues.Count).ToArray();
167      var up = Enumerable.Repeat(maxVal, thetaValues.Count).ToArray();
168      nlopt = NLOpt.nlopt_create(GetSolver(solver), (uint)dim);
169
170      NLOpt.nlopt_set_lower_bounds(nlopt, lb);
171      NLOpt.nlopt_set_upper_bounds(nlopt, up);
172      calculateObjectiveDelegate = new NLOpt.nlopt_func(CalculateObjective); // keep a reference to the delegate (see below)
173      NLOpt.nlopt_set_min_objective(nlopt, calculateObjectiveDelegate, IntPtr.Zero); // --> without preconditioning
174
175      // preconditionDelegate = new NLOpt.nlopt_precond(PreconditionObjective);
176      // NLOpt.nlopt_set_precond_min_objective(nlopt, calculateObjectiveDelegate, preconditionDelegate, IntPtr.Zero);
177
178
179      constraintDataPtr = new IntPtr[constraintTrees.Count];
180      calculateConstraintDelegates = new NLOpt.nlopt_func[constraintTrees.Count]; // make sure we keep a reference to the delegates (otherwise GC will free delegate objects see https://stackoverflow.com/questions/7302045/callback-delegates-being-collected#7302258)
181      for (int i = 0; i < constraintTrees.Count; i++) {
182        var constraintData = new ConstraintData() { Idx = i, Tree = constraintTrees[i], ParameterNodes = GetParameterNodes(constraintTrees[i], allThetaNodes) };
183        constraintDataPtr[i] = Marshal.AllocHGlobal(Marshal.SizeOf<ConstraintData>());
184        Marshal.StructureToPtr(constraintData, constraintDataPtr[i], fDeleteOld: false);
185        calculateConstraintDelegates[i] = new NLOpt.nlopt_func(CalculateConstraint);
186        NLOpt.nlopt_add_inequality_constraint(nlopt, calculateConstraintDelegates[i], constraintDataPtr[i], 1e-8);
187        //NLOpt.nlopt_add_precond_inequality_constraint(nlopt, calculateConstraintDelegates[i], preconditionDelegate, constraintDataPtr[i], 1e-8);
188      }
189
190      NLOpt.nlopt_set_ftol_rel(nlopt, ftol_rel);
191      NLOpt.nlopt_set_ftol_abs(nlopt, ftol_abs);
192      NLOpt.nlopt_set_maxtime(nlopt, maxTime);
193      NLOpt.nlopt_set_maxeval(nlopt, maxIterations);
194    }
195
196
197    ~ConstrainedNLSInternal() {
198      if (nlopt != IntPtr.Zero)
199        NLOpt.nlopt_destroy(nlopt);
200      if (constraintDataPtr != null) {
201        for (int i = 0; i < constraintDataPtr.Length; i++)
202          if (constraintDataPtr[i] != IntPtr.Zero)
203            Marshal.FreeHGlobal(constraintDataPtr[i]);
204      }
205    }
206
207
208    internal void Optimize() {
209      var x = thetaValues.ToArray();  /* initial guess */
210      double minf = double.MaxValue; /* minimum objective value upon return */
211      var res = NLOpt.nlopt_optimize(nlopt, x, ref minf);
212      bestSolution = x;
213      bestError = minf;
214
215      if (res < 0 && res != NLOpt.nlopt_result.NLOPT_FORCED_STOP) {
216        throw new InvalidOperationException($"NLOpt failed {res} {NLOpt.nlopt_get_errmsg(nlopt)}");
217      } else {
218        // calculate constraints of final solution
219        double[] _ = new double[x.Length];
220        bestConstraintValues = new double[calculateConstraintDelegates.Length];
221        for(int i=0;i<calculateConstraintDelegates.Length;i++) {
222          bestConstraintValues[i] = calculateConstraintDelegates[i].Invoke((uint)x.Length, x, _, constraintDataPtr[i]);
223        }
224
225        // update parameters in tree
226        var pIdx = 0;
227        // here we lose the two last parameters (for linear scaling)
228        foreach (var node in scaledTree.IterateNodesPostfix()) {
229          if (node is ConstantTreeNode constTreeNode) {
230            constTreeNode.Value = x[pIdx++];
231          } else if (node is VariableTreeNode varTreeNode) {
232            varTreeNode.Weight = x[pIdx++];
233          }
234        }
235        if (pIdx != x.Length) throw new InvalidProgramException();
236      }
237      bestTree = scaledTree;
238    }
239
240    double CalculateObjective(uint dim, double[] curX, double[] grad, IntPtr data) {
241      UpdateThetaValues(curX);
242      var sse = 0.0;
243
244      if (grad != null) {
245        autoDiffEval.Evaluate(preparedTree, problemData.Dataset, trainingRows,
246          preparedTreeParameterNodes, fi_eval, jac_eval);
247
248        // calc sum of squared errors and gradient
249        for (int j = 0; j < grad.Length; j++) grad[j] = 0;
250        for (int i = 0; i < target.Length; i++) {
251          var r = target[i] - fi_eval[i];
252          sse += r * r;
253          for (int j = 0; j < grad.Length; j++) {
254            grad[j] -= 2 * r * jac_eval[i, j];
255          }
256        }
257        // average
258        for (int j = 0; j < grad.Length; j++) { grad[j] /= target.Length; }
259
260        #region check gradient
261        if (grad != null && CheckGradient) {
262          for (int i = 0; i < dim; i++) {
263            // make two additional evaluations
264            var xForNumDiff = (double[])curX.Clone();
265            double delta = Math.Abs(xForNumDiff[i] * 1e-5);
266            xForNumDiff[i] += delta;
267            UpdateThetaValues(xForNumDiff);
268            var evalHigh = eval.Evaluate(preparedTree, problemData.Dataset, trainingRows);
269            var mseHigh = MSE(target, evalHigh);
270            xForNumDiff[i] = curX[i] - delta;
271            UpdateThetaValues(xForNumDiff);
272            var evalLow = eval.Evaluate(preparedTree, problemData.Dataset, trainingRows);
273            var mseLow = MSE(target, evalLow);
274
275            var numericDiff = (mseHigh - mseLow) / (2 * delta);
276            var autoDiff = grad[i];
277            if ((Math.Abs(autoDiff) < 1e-10 && Math.Abs(numericDiff) > 1e-2)
278              || (Math.Abs(autoDiff) >= 1e-10 && Math.Abs((numericDiff - autoDiff) / numericDiff) > 1e-2))
279              throw new InvalidProgramException();
280          }
281        }
282        #endregion
283      } else {
284        var eval = new VectorEvaluator();
285        var prediction = eval.Evaluate(preparedTree, problemData.Dataset, trainingRows);
286
287        // calc sum of squared errors
288        sse = 0.0;
289        for (int i = 0; i < target.Length; i++) {
290          var r = target[i] - prediction[i];
291          sse += r * r;
292        }
293      }
294
295      UpdateBestSolution(sse / target.Length, curX);
296      RaiseFunctionEvaluated();
297
298      if (double.IsNaN(sse)) return double.MaxValue;
299      return sse / target.Length;
300    }
301
302    // // NOT WORKING YET? WHY is det(H) always zero?!
303    // // TODO
304    // private void PreconditionObjective(uint n, double[] x, double[] v, double[] vpre, IntPtr data) {
305    //   UpdateThetaValues(x); // calc H(x)
306    //   
307    //   autoDiffEval.Evaluate(preparedTree, problemData.Dataset, trainingRows,
308    //     preparedTreeParameterNodes, fi_eval, jac_eval);
309    //   var k = jac_eval.GetLength(0);
310    //   var h = new double[n, n];
311    //   
312    //   // calc residuals and scale jac_eval
313    //   var f = -2.0 / k;
314    //   for (int i = 0;i<k;i++) {
315    //     var r = target[i] - fi_eval[i];
316    //     for(int j = 0;j<n;j++) {
317    //       jac_eval[i, j] *= f * r;
318    //     }
319    //   }
320    //   
321    //   // approximate hessian H(x) = J(x)^T * J(x)
322    //   alglib.rmatrixgemm((int)n, (int)n, k,
323    //     1.0, jac_eval, 0, 0, 1,  // transposed
324    //     jac_eval, 0, 0, 0,
325    //     0.0, ref h, 0, 0,
326    //     null
327    //     );
328    //   
329    //   // scale v
330    //   alglib.rmatrixmv((int)n, (int)n, h, 0, 0, 0, v, 0, ref vpre, 0, alglib.serial);
331    //   var det = alglib.matdet.rmatrixdet(h, (int)n, alglib.serial);
332    // }
333
334
335    private double MSE(double[] a, double[] b) {
336      Trace.Assert(a.Length == b.Length);
337      var sse = 0.0;
338      for (int i = 0; i < a.Length; i++) sse += (a[i] - b[i]) * (a[i] - b[i]);
339      return sse / a.Length;
340    }
341
342    private void UpdateBestSolution(double curF, double[] curX) {
343      if (double.IsNaN(curF) || double.IsInfinity(curF)) return;
344      else if (curF < bestError) {
345        bestError = curF;
346        bestSolution = (double[])curX.Clone();
347      }
348      curError = curF;
349    }
350
351    private void UpdateConstraintViolations(int constraintIdx, double value) {
352      if (double.IsNaN(value) || double.IsInfinity(value)) return;
353      RaiseConstraintEvaluated(constraintIdx, value);
354      // else if (curF < bestError) {
355      //   bestError = curF;
356      //   bestSolution = (double[])curX.Clone();
357      // }
358    }
359
360    double CalculateConstraint(uint dim, double[] curX, double[] grad, IntPtr data) {
361      UpdateThetaValues(curX);
362      var intervalEvaluator = new IntervalEvaluator();
363      var constraintData = Marshal.PtrToStructure<ConstraintData>(data);
364
365      if (grad != null) for (int j = 0; j < grad.Length; j++) grad[j] = 0; // clear grad
366
367      var interval = intervalEvaluator.Evaluate(constraintData.Tree, dataIntervals, constraintData.ParameterNodes,
368        out double[] lowerGradient, out double[] upperGradient);
369
370      // we transformed this to a constraint c(x) <= 0, so only the upper bound is relevant for us
371      if (grad != null) for (int j = 0; j < grad.Length; j++) { grad[j] = upperGradient[j]; }
372
373      #region check gradient
374      if (grad != null && CheckGradient)
375        for (int i = 0; i < dim; i++) {
376          // make two additional evaluations
377          var xForNumDiff = (double[])curX.Clone();
378          double delta = Math.Abs(xForNumDiff[i] * 1e-5);
379          xForNumDiff[i] += delta;
380          UpdateThetaValues(xForNumDiff);
381          var evalHigh = intervalEvaluator.Evaluate(constraintData.Tree, dataIntervals, constraintData.ParameterNodes,
382            out double[] unusedLowerGradientHigh, out double[] unusedUpperGradientHigh);
383
384          xForNumDiff[i] = curX[i] - delta;
385          UpdateThetaValues(xForNumDiff);
386          var evalLow = intervalEvaluator.Evaluate(constraintData.Tree, dataIntervals, constraintData.ParameterNodes,
387            out double[] unusedLowerGradientLow, out double[] unusedUpperGradientLow);
388
389          var numericDiff = (evalHigh.UpperBound - evalLow.UpperBound) / (2 * delta);
390          var autoDiff = grad[i];
391
392          if ((Math.Abs(autoDiff) < 1e-10 && Math.Abs(numericDiff) > 1e-2)
393            || (Math.Abs(autoDiff) >= 1e-10 && Math.Abs((numericDiff - autoDiff) / numericDiff) > 1e-2))
394            throw new InvalidProgramException();
395        }
396      #endregion
397
398
399      UpdateConstraintViolations(constraintData.Idx, interval.UpperBound);
400      if (double.IsNaN(interval.UpperBound)) return double.MaxValue;
401      else return interval.UpperBound;
402    }
403
404
405    void UpdateThetaValues(double[] theta) {
406      for (int i = 0; i < theta.Length; ++i) {
407        foreach (var constNode in allThetaNodes[i]) constNode.Value = theta[i];
408      }
409    }
410
411    internal void RequestStop() {
412      NLOpt.nlopt_set_force_stop(nlopt, 1); // hopefully NLOpt is thread safe  , val must be <> 0 otherwise no effect
413    }
414
415    private void RaiseFunctionEvaluated() {
416      FunctionEvaluated?.Invoke();
417    }
418
419    private void RaiseConstraintEvaluated(int idx, double value) {
420      ConstraintEvaluated?.Invoke(idx, value);
421    }
422
423
424    #region helper
425
426    private static ISymbolicExpressionTree CopyAndScaleTree(ISymbolicExpressionTree tree, double scalingFactor, double offset) {
427      var m = (ISymbolicExpressionTree)tree.Clone();
428
429      var add = MakeNode<Addition>(MakeNode<Multiplication>(m.Root.GetSubtree(0).GetSubtree(0), CreateConstant(scalingFactor)), CreateConstant(offset));
430      m.Root.GetSubtree(0).RemoveSubtree(0);
431      m.Root.GetSubtree(0).AddSubtree(add);
432      return m;
433    }
434
435    private static void UpdateConstants(ISymbolicExpressionTreeNode[] nodes, double[] constants) {
436      if (nodes.Length != constants.Length) throw new InvalidOperationException();
437      for (int i = 0; i < nodes.Length; i++) {
438        if (nodes[i] is VariableTreeNode varNode) varNode.Weight = constants[i];
439        else if (nodes[i] is ConstantTreeNode constNode) constNode.Value = constants[i];
440      }
441    }
442
443    private NLOpt.nlopt_algorithm GetSolver(string solver) {
444      if (solver.Contains("MMA")) return NLOpt.nlopt_algorithm.NLOPT_LD_MMA;
445      if (solver.Contains("COBYLA")) return NLOpt.nlopt_algorithm.NLOPT_LN_COBYLA;
446      if (solver.Contains("CCSAQ")) return NLOpt.nlopt_algorithm.NLOPT_LD_CCSAQ;
447      if (solver.Contains("ISRES")) return NLOpt.nlopt_algorithm.NLOPT_GN_ISRES;
448      throw new ArgumentException($"Unknown solver {solver}");
449    }
450
451    private static ISymbolicExpressionTreeNode[] GetParameterNodes(ISymbolicExpressionTree tree, List<ConstantTreeNode>[] allNodes) {
452      // TODO better solution necessary
453      var treeConstNodes = tree.IterateNodesPostfix().OfType<ConstantTreeNode>().ToArray();
454      var paramNodes = new ISymbolicExpressionTreeNode[allNodes.Length];
455      for (int i = 0; i < paramNodes.Length; i++) {
456        paramNodes[i] = allNodes[i].SingleOrDefault(n => treeConstNodes.Contains(n));
457      }
458      return paramNodes;
459    }
460
461    private static ISymbolicExpressionTree ReplaceVarWithConst(ISymbolicExpressionTree tree, List<string> thetaNames, List<double> thetaValues, List<ConstantTreeNode>[] thetaNodes) {
462      var copy = (ISymbolicExpressionTree)tree.Clone();
463      var nodes = copy.IterateNodesPostfix().ToList();
464      for (int i = 0; i < nodes.Count; i++) {
465        var n = nodes[i] as VariableTreeNode;
466        if (n != null) {
467          var thetaIdx = thetaNames.IndexOf(n.VariableName);
468          if (thetaIdx >= 0) {
469            var parent = n.Parent;
470            if (thetaNodes[thetaIdx].Any()) {
471              // HACK: REUSE CONSTANT TREE NODE IN SEVERAL TREES
472              // we use this trick to allow autodiff over thetas when thetas occurr multiple times in the tree (e.g. in derived trees)
473              var constNode = thetaNodes[thetaIdx].First();
474              var childIdx = parent.IndexOfSubtree(n);
475              parent.RemoveSubtree(childIdx);
476              parent.InsertSubtree(childIdx, constNode);
477            } else {
478              var constNode = (ConstantTreeNode)CreateConstant(thetaValues[thetaIdx]);
479              var childIdx = parent.IndexOfSubtree(n);
480              parent.RemoveSubtree(childIdx);
481              parent.InsertSubtree(childIdx, constNode);
482              thetaNodes[thetaIdx].Add(constNode);
483            }
484          }
485        }
486      }
487      return copy;
488    }
489
490    private static ISymbolicExpressionTree ReplaceConstWithVar(ISymbolicExpressionTree tree, out List<string> thetaNames, out List<double> thetaValues) {
491      thetaNames = new List<string>();
492      thetaValues = new List<double>();
493      var copy = (ISymbolicExpressionTree)tree.Clone();
494      var nodes = copy.IterateNodesPostfix().ToList();
495
496      int n = 1;
497      for (int i = 0; i < nodes.Count; ++i) {
498        var node = nodes[i];
499        if (node is ConstantTreeNode constantTreeNode) {
500          var thetaVar = (VariableTreeNode)new Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();
501          thetaVar.Weight = 1;
502          thetaVar.VariableName = $"θ{n++}";
503
504          thetaNames.Add(thetaVar.VariableName);
505          thetaValues.Add(constantTreeNode.Value);
506
507          var parent = constantTreeNode.Parent;
508          if (parent != null) {
509            var index = constantTreeNode.Parent.IndexOfSubtree(constantTreeNode);
510            parent.RemoveSubtree(index);
511            parent.InsertSubtree(index, thetaVar);
512          }
513        }
514        if (node is VariableTreeNode varTreeNode) {
515          var thetaVar = (VariableTreeNode)new Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();
516          thetaVar.Weight = 1;
517          thetaVar.VariableName = $"θ{n++}";
518
519          thetaNames.Add(thetaVar.VariableName);
520          thetaValues.Add(varTreeNode.Weight);
521
522          var parent = varTreeNode.Parent;
523          if (parent != null) {
524            var index = varTreeNode.Parent.IndexOfSubtree(varTreeNode);
525            parent.RemoveSubtree(index);
526            var prodNode = MakeNode<Multiplication>();
527            varTreeNode.Weight = 1.0;
528            prodNode.AddSubtree(varTreeNode);
529            prodNode.AddSubtree(thetaVar);
530            parent.InsertSubtree(index, prodNode);
531          }
532        }
533      }
534      return copy;
535    }
536
537    private static ISymbolicExpressionTreeNode CreateConstant(double value) {
538      var constantNode = (ConstantTreeNode)new Constant().CreateTreeNode();
539      constantNode.Value = value;
540      return constantNode;
541    }
542
543    private static ISymbolicExpressionTree Subtract(ISymbolicExpressionTree t, ISymbolicExpressionTreeNode b) {
544      var sub = MakeNode<Subtraction>(t.Root.GetSubtree(0).GetSubtree(0), b);
545      t.Root.GetSubtree(0).RemoveSubtree(0);
546      t.Root.GetSubtree(0).InsertSubtree(0, sub);
547      return t;
548    }
549    private static ISymbolicExpressionTree Subtract(ISymbolicExpressionTreeNode b, ISymbolicExpressionTree t) {
550      var sub = MakeNode<Subtraction>(b, t.Root.GetSubtree(0).GetSubtree(0));
551      t.Root.GetSubtree(0).RemoveSubtree(0);
552      t.Root.GetSubtree(0).InsertSubtree(0, sub);
553      return t;
554    }
555
556    private static ISymbolicExpressionTreeNode MakeNode<T>(params ISymbolicExpressionTreeNode[] fs) where T : ISymbol, new() {
557      var node = new T().CreateTreeNode();
558      foreach (var f in fs) node.AddSubtree(f);
559      return node;
560    }
561    #endregion
562  }
563}
Note: See TracBrowser for help on using the repository browser.