Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2994-AutoDiffForIntervals/HeuristicLab.Algorithms.DataAnalysis/3.4/NonlinearRegression/NonlinearConstrainedRegression.cs @ 17300

Last change on this file since 17300 was 16697, checked in by gkronber, 6 years ago

#2994: missing files for r16696 (+ svn:ignore)

File size: 25.0 KB
RevLine 
[16697]1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Threading;
26using HeuristicLab.Analysis;
27using HeuristicLab.Common;
28using HeuristicLab.Core;
29using HeuristicLab.Data;
30using HeuristicLab.Optimization;
31using HeuristicLab.Parameters;
32using HEAL.Attic;
33using HeuristicLab.Problems.DataAnalysis;
34using HeuristicLab.Problems.DataAnalysis.Symbolic;
35using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression;
36using HeuristicLab.Random;
37using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
38
39namespace HeuristicLab.Algorithms.DataAnalysis {
40  /// <summary>
41  /// Nonlinear regression data analysis algorithm.
42  /// </summary>
43  [Item("Nonlinear Regression with Constraints (NLR)", "Nonlinear regression (curve fitting) data analysis algorithm that supports interval constraints.")]
44  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 120)]
45  [StorableType("B235DB6E-591F-4537-8D2F-C2D1232AAEFD")]
46  public sealed class NonlinearConstrainedRegression : FixedDataAnalysisAlgorithm<IRegressionProblem> {
47    private const string RegressionSolutionResultName = "Regression solution";
48    private const string ModelStructureParameterName = "Model structure";
49    private const string IterationsParameterName = "Iterations";
50    private const string RestartsParameterName = "Restarts";
51    private const string SetSeedRandomlyParameterName = "SetSeedRandomly";
52    private const string SeedParameterName = "Seed";
53    private const string InitParamsRandomlyParameterName = "InitializeParametersRandomly";
54    private const string ApplyLinearScalingParameterName = "Apply linear scaling";
55
56    public IFixedValueParameter<StringValue> ModelStructureParameter {
57      get { return (IFixedValueParameter<StringValue>)Parameters[ModelStructureParameterName]; }
58    }
59    public IFixedValueParameter<IntValue> IterationsParameter {
60      get { return (IFixedValueParameter<IntValue>)Parameters[IterationsParameterName]; }
61    }
62
63    public IFixedValueParameter<BoolValue> SetSeedRandomlyParameter {
64      get { return (IFixedValueParameter<BoolValue>)Parameters[SetSeedRandomlyParameterName]; }
65    }
66
67    public IFixedValueParameter<IntValue> SeedParameter {
68      get { return (IFixedValueParameter<IntValue>)Parameters[SeedParameterName]; }
69    }
70
71    public IFixedValueParameter<IntValue> RestartsParameter {
72      get { return (IFixedValueParameter<IntValue>)Parameters[RestartsParameterName]; }
73    }
74
75    public IFixedValueParameter<BoolValue> InitParametersRandomlyParameter {
76      get { return (IFixedValueParameter<BoolValue>)Parameters[InitParamsRandomlyParameterName]; }
77    }
78
79    public IFixedValueParameter<BoolValue> ApplyLinearScalingParameter {
80      get { return (IFixedValueParameter<BoolValue>)Parameters[ApplyLinearScalingParameterName]; }
81    }
82
83    public string ModelStructure {
84      get { return ModelStructureParameter.Value.Value; }
85      set { ModelStructureParameter.Value.Value = value; }
86    }
87
88    public int Iterations {
89      get { return IterationsParameter.Value.Value; }
90      set { IterationsParameter.Value.Value = value; }
91    }
92
93    public int Restarts {
94      get { return RestartsParameter.Value.Value; }
95      set { RestartsParameter.Value.Value = value; }
96    }
97
98    public int Seed {
99      get { return SeedParameter.Value.Value; }
100      set { SeedParameter.Value.Value = value; }
101    }
102
103    public bool SetSeedRandomly {
104      get { return SetSeedRandomlyParameter.Value.Value; }
105      set { SetSeedRandomlyParameter.Value.Value = value; }
106    }
107
108    public bool InitializeParametersRandomly {
109      get { return InitParametersRandomlyParameter.Value.Value; }
110      set { InitParametersRandomlyParameter.Value.Value = value; }
111    }
112
113    public bool ApplyLinearScaling {
114      get { return ApplyLinearScalingParameter.Value.Value; }
115      set { ApplyLinearScalingParameter.Value.Value = value; }
116    }
117
118    [StorableConstructor]
119    private NonlinearConstrainedRegression(StorableConstructorFlag _) : base(_) { }
120    private NonlinearConstrainedRegression(NonlinearConstrainedRegression original, Cloner cloner)
121      : base(original, cloner) {
122    }
123    public NonlinearConstrainedRegression()
124      : base() {
125      Problem = new RegressionProblem();
126      Parameters.Add(new FixedValueParameter<StringValue>(ModelStructureParameterName, "The function for which the parameters must be fit (only numeric constants are tuned).", new StringValue("1.0 * x*x + 0.0")));
127      Parameters.Add(new FixedValueParameter<IntValue>(IterationsParameterName, "The maximum number of iterations for constants optimization.", new IntValue(200)));
128      Parameters.Add(new FixedValueParameter<IntValue>(RestartsParameterName, "The number of independent random restarts (>0)", new IntValue(10)));
129      Parameters.Add(new FixedValueParameter<IntValue>(SeedParameterName, "The PRNG seed value.", new IntValue()));
130      Parameters.Add(new FixedValueParameter<BoolValue>(SetSeedRandomlyParameterName, "Switch to determine if the random number seed should be initialized randomly.", new BoolValue(true)));
131      Parameters.Add(new FixedValueParameter<BoolValue>(InitParamsRandomlyParameterName, "Switch to determine if the real-valued model parameters should be initialized randomly in each restart.", new BoolValue(false)));
132      Parameters.Add(new FixedValueParameter<BoolValue>(ApplyLinearScalingParameterName, "Switch to determine if linear scaling terms should be added to the model", new BoolValue(true)));
133
134      SetParameterHiddenState();
135
136      InitParametersRandomlyParameter.Value.ValueChanged += (sender, args) => {
137        SetParameterHiddenState();
138      };
139    }
140
141    private void SetParameterHiddenState() {
142      var hide = !InitializeParametersRandomly;
143      RestartsParameter.Hidden = hide;
144      SeedParameter.Hidden = hide;
145      SetSeedRandomlyParameter.Hidden = hide;
146    }
147
148    [StorableHook(HookType.AfterDeserialization)]
149    private void AfterDeserialization() {
150      SetParameterHiddenState();
151      InitParametersRandomlyParameter.Value.ValueChanged += (sender, args) => {
152        SetParameterHiddenState();
153      };
154    }
155
156    public override IDeepCloneable Clone(Cloner cloner) {
157      return new NonlinearConstrainedRegression(this, cloner);
158    }
159
160    #region nonlinear regression
161    protected override void Run(CancellationToken cancellationToken) {
162      IRegressionSolution bestSolution = null;
163      if (InitializeParametersRandomly) {
164        var qualityTable = new DataTable("RMSE table");
165        qualityTable.VisualProperties.YAxisLogScale = true;
166        var trainRMSERow = new DataRow("RMSE (train)");
167        trainRMSERow.VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;
168        var testRMSERow = new DataRow("RMSE test");
169        testRMSERow.VisualProperties.ChartType = DataRowVisualProperties.DataRowChartType.Points;
170
171        qualityTable.Rows.Add(trainRMSERow);
172        qualityTable.Rows.Add(testRMSERow);
173        Results.Add(new Result(qualityTable.Name, qualityTable.Name + " for all restarts", qualityTable));
174        if (SetSeedRandomly) Seed = RandomSeedGenerator.GetSeed();
175        var rand = new MersenneTwister((uint)Seed);
176        bestSolution = CreateRegressionSolution((RegressionProblemData)Problem.ProblemData, ModelStructure, Iterations, ApplyLinearScaling, rand);
177        trainRMSERow.Values.Add(bestSolution.TrainingRootMeanSquaredError);
178        testRMSERow.Values.Add(bestSolution.TestRootMeanSquaredError);
179        for (int r = 0; r < Restarts; r++) {
180          var solution = CreateRegressionSolution((RegressionProblemData)Problem.ProblemData, ModelStructure, Iterations, ApplyLinearScaling, rand);
181          trainRMSERow.Values.Add(solution.TrainingRootMeanSquaredError);
182          testRMSERow.Values.Add(solution.TestRootMeanSquaredError);
183          if (solution.TrainingRootMeanSquaredError < bestSolution.TrainingRootMeanSquaredError) {
184            bestSolution = solution;
185          }
186        }
187      } else {
188        bestSolution = CreateRegressionSolution((RegressionProblemData)Problem.ProblemData, ModelStructure, Iterations, ApplyLinearScaling);
189      }
190
191      Results.Add(new Result(RegressionSolutionResultName, "The nonlinear regression solution.", bestSolution));
192      Results.Add(new Result("Root mean square error (train)", "The root of the mean of squared errors of the regression solution on the training set.", new DoubleValue(bestSolution.TrainingRootMeanSquaredError)));
193      Results.Add(new Result("Root mean square error (test)", "The root of the mean of squared errors of the regression solution on the test set.", new DoubleValue(bestSolution.TestRootMeanSquaredError)));
194
195    }
196
197    /// <summary>
198    /// Fits a model to the data by optimizing the numeric constants.
199    /// Model is specified as infix expression containing variable names and numbers.
200    /// The starting point for the numeric constants is initialized randomly if a random number generator is specified (~N(0,1)). Otherwise the user specified constants are
201    /// used as a starting point.
202    /// </summary>-
203    /// <param name="problemData">Training and test data</param>
204    /// <param name="modelStructure">The function as infix expression</param>
205    /// <param name="maxIterations">Number of constant optimization iterations (using Levenberg-Marquardt algorithm)</param>
206    /// <param name="random">Optional random number generator for random initialization of numeric constants.</param>
207    /// <returns></returns>
208    public static ISymbolicRegressionSolution CreateRegressionSolution(RegressionProblemData problemData, string modelStructure, int maxIterations, bool applyLinearScaling, IRandom rand = null) {
209      var parser = new InfixExpressionParser();
210      var tree = parser.Parse(modelStructure);
211      // parser handles double and string variables equally by creating a VariableTreeNode
212      // post-process to replace VariableTreeNodes by FactorVariableTreeNodes for all string variables
213      var factorSymbol = new FactorVariable();
214      factorSymbol.VariableNames =
215        problemData.AllowedInputVariables.Where(name => problemData.Dataset.VariableHasType<string>(name));
216      factorSymbol.AllVariableNames = factorSymbol.VariableNames;
217      factorSymbol.VariableValues =
218        factorSymbol.VariableNames.Select(name =>
219        new KeyValuePair<string, Dictionary<string, int>>(name,
220        problemData.Dataset.GetReadOnlyStringValues(name).Distinct()
221        .Select((n, i) => Tuple.Create(n, i))
222        .ToDictionary(tup => tup.Item1, tup => tup.Item2)));
223
224      foreach (var parent in tree.IterateNodesPrefix().ToArray()) {
225        for (int i = 0; i < parent.SubtreeCount; i++) {
226          var varChild = parent.GetSubtree(i) as VariableTreeNode;
227          var factorVarChild = parent.GetSubtree(i) as FactorVariableTreeNode;
228          if (varChild != null && factorSymbol.VariableNames.Contains(varChild.VariableName)) {
229            parent.RemoveSubtree(i);
230            var factorTreeNode = (FactorVariableTreeNode)factorSymbol.CreateTreeNode();
231            factorTreeNode.VariableName = varChild.VariableName;
232            factorTreeNode.Weights =
233              factorTreeNode.Symbol.GetVariableValues(factorTreeNode.VariableName).Select(_ => 1.0).ToArray();
234            // weight = 1.0 for each value
235            parent.InsertSubtree(i, factorTreeNode);
236          } else if (factorVarChild != null && factorSymbol.VariableNames.Contains(factorVarChild.VariableName)) {
237            if (factorSymbol.GetVariableValues(factorVarChild.VariableName).Count() != factorVarChild.Weights.Length)
238              throw new ArgumentException(
239                string.Format("Factor variable {0} needs exactly {1} weights",
240                factorVarChild.VariableName,
241                factorSymbol.GetVariableValues(factorVarChild.VariableName).Count()));
242            parent.RemoveSubtree(i);
243            var factorTreeNode = (FactorVariableTreeNode)factorSymbol.CreateTreeNode();
244            factorTreeNode.VariableName = factorVarChild.VariableName;
245            factorTreeNode.Weights = factorVarChild.Weights;
246            parent.InsertSubtree(i, factorTreeNode);
247          }
248        }
249      }
250
251      // var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
252      //
253      // SymbolicRegressionConstantOptimizationEvaluator.OptimizeConstants(interpreter, tree, problemData, problemData.TrainingIndices,
254      //   applyLinearScaling: applyLinearScaling, maxIterations: maxIterations,
255      //   updateVariableWeights: false, updateConstantsInTree: true);
256
257
258      var intervals = problemData.IntervalConstraints;
259      var constraintsParser = new IntervalConstraintsParser();
260      var constraints = constraintsParser.Parse(intervals.Value);
261      var dataIntervals = problemData.VariableRanges.VariableIntervals;
262
263      // convert constants to variables named theta...
264      var treeForDerivation = ReplaceConstWithVar(tree, out List<string> thetaNames, out List<double> thetaValues);
265
266      // create trees for relevant derivatives
267      Dictionary<string, ISymbolicExpressionTree> derivatives = new Dictionary<string, ISymbolicExpressionTree>();
268      var allThetaNodes = thetaNames.Select(_ => new List<ConstantTreeNode>()).ToArray();
269      var constraintTrees = new List<ISymbolicExpressionTree>();
270      foreach (var constraint in constraints) {
271        if (constraint.IsDerivation) {
272          var df = DerivativeCalculator.Derive(treeForDerivation, constraint.Variable);
273
274          // alglib requires constraint expressions of the form c(x) <= 0
275          // -> we make two expressions, one for the lower bound and one for the upper bound
276
277          if (constraint.Interval.UpperBound < double.PositiveInfinity) {
278            var df_smaller_upper = Subtract((ISymbolicExpressionTree)df.Clone(), CreateConstant(constraint.Interval.UpperBound));
279            // convert variables named theta back to constants
280            var df_prepared = ReplaceVarWithConst(df_smaller_upper, thetaNames, thetaValues, allThetaNodes);
281            constraintTrees.Add(df_prepared);
282          }
283          if (constraint.Interval.LowerBound > double.NegativeInfinity) {
284            var df_larger_lower = Subtract(CreateConstant(constraint.Interval.LowerBound), (ISymbolicExpressionTree)df.Clone());
285            // convert variables named theta back to constants
286            var df_prepared = ReplaceVarWithConst(df_larger_lower, thetaNames, thetaValues, allThetaNodes);
287            constraintTrees.Add(df_prepared);
288          }
289        } else {
290          if (constraint.Interval.UpperBound < double.PositiveInfinity) {
291            var f_smaller_upper = Subtract((ISymbolicExpressionTree)treeForDerivation.Clone(), CreateConstant(constraint.Interval.UpperBound));
292            // convert variables named theta back to constants
293            var df_prepared = ReplaceVarWithConst(f_smaller_upper, thetaNames, thetaValues, allThetaNodes);
294            constraintTrees.Add(df_prepared);
295          }
296          if (constraint.Interval.LowerBound > double.NegativeInfinity) {
297            var f_larger_lower = Subtract(CreateConstant(constraint.Interval.LowerBound), (ISymbolicExpressionTree)treeForDerivation.Clone());
298            // convert variables named theta back to constants
299            var df_prepared = ReplaceVarWithConst(f_larger_lower, thetaNames, thetaValues, allThetaNodes);
300            constraintTrees.Add(df_prepared);
301          }
302        }
303      }
304
305      var preparedTree = ReplaceVarWithConst(treeForDerivation, thetaNames, thetaValues, allThetaNodes);
306
307      // initialize constants randomly
308      if (rand != null) {
309        for (int i = 0; i < allThetaNodes.Length; i++) {
310          double f = Math.Exp(NormalDistributedRandom.NextDouble(rand, 0, 1));
311          double scale = rand.NextDouble() < 0.5 ? -1 : 1;
312          thetaValues[i] = scale * thetaValues[i] * f;
313          foreach (var constNode in allThetaNodes[i]) constNode.Value = thetaValues[i];
314        }
315      }
316
317      void UpdateThetaValues(double[] theta) {
318        for (int i = 0; i < theta.Length; ++i) {
319          foreach (var constNode in allThetaNodes[i]) constNode.Value = theta[i];
320        }
321      }
322
323      // define the callback used by the alglib optimizer
324      // the x argument for this callback represents our theta
325      void calculate_jacobian(double[] x, double[] fi, double[,] jac, object obj) {
326        UpdateThetaValues(x);
327
328        var autoDiffEval = new VectorAutoDiffEvaluator();
329        autoDiffEval.Evaluate(preparedTree, problemData.Dataset, problemData.TrainingIndices.ToArray(),
330          GetParameterNodes(preparedTree, allThetaNodes), out double[] fi_eval, out double[,] jac_eval);
331        var target = problemData.TargetVariableTrainingValues.ToArray();
332
333        // calc sum of squared errors and gradient
334        var sse = 0.0;
335        var g = new double[x.Length];
336        for (int i = 0; i < target.Length; i++) {
337          var res = target[i] - fi_eval[i];
338          sse += res * res;
339          for (int j = 0; j < g.Length; j++) {
340            g[j] += -2.0 * res * jac_eval[i, j];
341          }
342        }
343
344        fi[0] = sse;
345        for (int j = 0; j < x.Length; j++) { jac[0, j] = g[j]; }
346
347        var intervalEvaluator = new IntervalEvaluator();
348        for (int i = 0; i < constraintTrees.Count; i++) {
349          var interval = intervalEvaluator.Evaluate(constraintTrees[i], dataIntervals, GetParameterNodes(constraintTrees[i], allThetaNodes),
350            out double[] lowerGradient, out double[] upperGradient);
351
352          // we transformed this to a constraint c(x) <= 0, so only the upper bound is relevant for us
353          fi[i + 1] = interval.UpperBound;
354          for (int j = 0; j < x.Length; j++) {
355            jac[i + 1, j] = upperGradient[j];
356          }
357        }
358      }
359
360      // prepare alglib
361      alglib.minnlcstate state;
362      alglib.minnlcreport rep;
363      var x0 = thetaValues.ToArray();
364
365      alglib.minnlccreate(x0.Length, x0, out state);
366      double epsx = 1e-6;
367      int maxits = 0;
368      alglib.minnlcsetalgoslp(state);
369      alglib.minnlcsetcond(state, 0, maxits);
370      var s = Enumerable.Repeat(1d, x0.Length).ToArray();  // scale is set to unit scale
371      alglib.minnlcsetscale(state, s);
372
373      // set boundary constraints
374      // var boundaryLower = Enumerable.Repeat(-10d, n).ToArray();
375      // var boundaryUpper = Enumerable.Repeat(10d, n).ToArray();
376      // alglib.minnlcsetbc(state, boundaryLower, boundaryUpper);
377      // set non-linear constraints: 0 equality constraints, 1 inequality constraint
378      alglib.minnlcsetnlc(state, 0, constraintTrees.Count);
379
380      alglib.minnlcoptimize(state, calculate_jacobian, null, null);
381      alglib.minnlcresults(state, out double[] xOpt, out rep);
382
383      var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
384      UpdateThetaValues(xOpt);
385      var model = new SymbolicRegressionModel(problemData.TargetVariable, (ISymbolicExpressionTree)preparedTree.Clone(), (ISymbolicDataAnalysisExpressionTreeInterpreter)interpreter.Clone());
386      if (applyLinearScaling)
387        model.Scale(problemData);
388
389      SymbolicRegressionSolution solution = new SymbolicRegressionSolution(model, (IRegressionProblemData)problemData.Clone());
390      solution.Model.Name = "Regression Model";
391      solution.Name = "Regression Solution";
392      return solution;
393    }
394
395    private static ISymbolicExpressionTreeNode[] GetParameterNodes(ISymbolicExpressionTree tree, List<ConstantTreeNode>[] allNodes) {
396      // TODO better solution necessary
397      var treeConstNodes = tree.IterateNodesPostfix().OfType<ConstantTreeNode>().ToArray();
398      var paramNodes = new ISymbolicExpressionTreeNode[allNodes.Length];
399      for (int i = 0; i < paramNodes.Length; i++) {
400        paramNodes[i] = allNodes[i].SingleOrDefault(n => treeConstNodes.Contains(n));
401      }
402      return paramNodes;
403    }
404
405    #endregion
406
407    #region helper
408    private static ISymbolicExpressionTree ReplaceVarWithConst(ISymbolicExpressionTree tree, List<string> thetaNames, List<double> thetaValues, List<ConstantTreeNode>[] thetaNodes) {
409      var copy = (ISymbolicExpressionTree)tree.Clone();
410      var nodes = copy.IterateNodesPostfix().ToList();
411      for (int i = 0; i < nodes.Count; i++) {
412        var n = nodes[i] as VariableTreeNode;
413        if (n != null) {
414          var thetaIdx = thetaNames.IndexOf(n.VariableName);
415          if (thetaIdx >= 0) {
416            var parent = n.Parent;
417            if(thetaNodes[thetaIdx].Any()) {
418              // HACKY: REUSE CONSTANT TREE NODE IN SEVERAL TREES
419              // we use this trick to allow autodiff over thetas when thetas occurr multiple times in the tree (e.g. in derived trees)
420              var constNode = thetaNodes[thetaIdx].First();
421              var childIdx = parent.IndexOfSubtree(n);
422              parent.RemoveSubtree(childIdx);
423              parent.InsertSubtree(childIdx, constNode);
424            } else {
425              var constNode = (ConstantTreeNode)CreateConstant(thetaValues[thetaIdx]);
426              var childIdx = parent.IndexOfSubtree(n);
427              parent.RemoveSubtree(childIdx);
428              parent.InsertSubtree(childIdx, constNode);
429              thetaNodes[thetaIdx].Add(constNode);
430            }
431          }
432        }
433      }
434      return copy;
435    }
436
437    private static ISymbolicExpressionTree ReplaceConstWithVar(ISymbolicExpressionTree tree, out List<string> thetaNames, out List<double> thetaValues) {
438      thetaNames = new List<string>();
439      thetaValues = new List<double>();
440      var copy = (ISymbolicExpressionTree)tree.Clone();
441      var nodes = copy.IterateNodesPostfix().ToList();
442
443      int n = 1;
444      for (int i = 0; i < nodes.Count; ++i) {
445        var node = nodes[i];
446        /*if (node is VariableTreeNode variableTreeNode) {
447          var thetaVar = (VariableTreeNode)new Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();
448          thetaVar.Weight = 1;
449          thetaVar.VariableName = $"θ{n++}";
450
451          thetaNames.Add(thetaVar.VariableName);
452          thetaValues.Add(variableTreeNode.Weight);
453          variableTreeNode.Weight = 1; // set to unit weight
454
455          var parent = variableTreeNode.Parent;
456          var prod = MakeNode<Multiplication>(thetaVar, variableTreeNode);
457          if (parent != null) {
458            var index = parent.IndexOfSubtree(variableTreeNode);
459            parent.RemoveSubtree(index);
460            parent.InsertSubtree(index, prod);
461          }
462        } else*/ if (node is ConstantTreeNode constantTreeNode) {
463          var thetaVar = (VariableTreeNode)new Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();
464          thetaVar.Weight = 1;
465          thetaVar.VariableName = $"θ{n++}";
466
467          thetaNames.Add(thetaVar.VariableName);
468          thetaValues.Add(constantTreeNode.Value);
469
470          var parent = constantTreeNode.Parent;
471          if (parent != null) {
472            var index = constantTreeNode.Parent.IndexOfSubtree(constantTreeNode);
473            parent.RemoveSubtree(index);
474            parent.InsertSubtree(index, thetaVar);
475          }
476        }
477      }
478      return copy;
479    }
480
481    private static ISymbolicExpressionTreeNode CreateConstant(double value) {
482      var constantNode = (ConstantTreeNode)new Constant().CreateTreeNode();
483      constantNode.Value = value;
484      return constantNode;
485    }
486
487    private static ISymbolicExpressionTree Subtract(ISymbolicExpressionTree t, ISymbolicExpressionTreeNode b) {
488      var sub = MakeNode<Subtraction>(t.Root.GetSubtree(0).GetSubtree(0), b);
489      t.Root.GetSubtree(0).RemoveSubtree(0);
490      t.Root.GetSubtree(0).InsertSubtree(0, sub);
491      return t;
492    }
493    private static ISymbolicExpressionTree Subtract(ISymbolicExpressionTreeNode b, ISymbolicExpressionTree t) {
494      var sub = MakeNode<Subtraction>(b, t.Root.GetSubtree(0).GetSubtree(0));
495      t.Root.GetSubtree(0).RemoveSubtree(0);
496      t.Root.GetSubtree(0).InsertSubtree(0, sub);
497      return t;
498    }
499
500    private static ISymbolicExpressionTreeNode MakeNode<T>(params ISymbolicExpressionTreeNode[] fs) where T : ISymbol, new() {
501      var node = new T().CreateTreeNode();
502      foreach (var f in fs) node.AddSubtree(f);
503      return node;
504    }
505    #endregion
506  }
507}
Note: See TracBrowser for help on using the repository browser.