Changeset 17176


Ignore:
Timestamp:
07/28/19 19:48:12 (4 weeks ago)
Author:
gkronber
Message:

#2994 linear scaling for const opt with constraints

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/2994-AutoDiffForIntervals/HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Extensions/ConstrainedConstantOptimizationEvaluator.cs

    r17136 r17176  
    223223      if (!updateVariableWeights) throw new NotSupportedException("not updating variable weights is not supported");
    224224      if (!updateConstantsInTree) throw new NotSupportedException("not updating tree parameters is not supported");
    225       if (applyLinearScaling) throw new NotSupportedException("linear scaling is not supported");
     225      if (!applyLinearScaling) throw new NotSupportedException("application without linear scaling is not supported");
    226226
    227227      // we always update constants, so we don't need to calculate initial quality
     
    234234      var dataIntervals = problemData.VariableRanges.GetIntervals();
    235235
     236      // buffers
     237      var target = problemData.TargetVariableTrainingValues.ToArray();
     238      var targetStDev = target.StandardDeviationPop();
     239      var targetVariance = targetStDev * targetStDev;
     240      var targetMean = target.Average();
     241      var pred = interpreter.GetSymbolicExpressionTreeValues(tree, problemData.Dataset, problemData.TrainingIndices).ToArray();
     242      var predStDev = pred.StandardDeviationPop();
     243      var predMean = pred.Average();
     244
     245      var scalingFactor = targetStDev / predStDev;
     246      var offset = targetMean - predMean * scalingFactor;
     247
     248      ISymbolicExpressionTree scaledTree = null;
     249      if (applyLinearScaling) scaledTree = CopyAndScaleTree(tree, scalingFactor, offset);
     250
    236251      // convert constants to variables named theta...
    237       var treeForDerivation = ReplaceConstWithVar(tree, out List<string> thetaNames, out List<double> thetaValues); // copies the tree
     252      var treeForDerivation = ReplaceConstWithVar(scaledTree, out List<string> thetaNames, out List<double> thetaValues); // copies the tree
    238253
    239254      // create trees for relevant derivatives
     
    288303      }
    289304
    290       // buffers for calculate_jacobian
    291       var target = problemData.TargetVariableTrainingValues.ToArray();
    292       var targetVariance = target.VariancePop();
    293305      var fi_eval = new double[target.Length];
    294306      var jac_eval = new double[target.Length, thetaValues.Count];
     
    337349          alglib.minnscreate(thetaValues.Count, thetaValues.ToArray(), out state);
    338350          alglib.minnssetbc(state, thetaValues.Select(_ => -10000.0).ToArray(), thetaValues.Select(_ => +10000.0).ToArray());
    339           alglib.minnssetcond(state, 1E-7, maxIterations);
     351          alglib.minnssetcond(state, 0, maxIterations);
    340352          var s = Enumerable.Repeat(1d, thetaValues.Count).ToArray();  // scale is set to unit scale
    341353          alglib.minnssetscale(state, s);
     
    352364
    353365          if (rep.terminationtype > 0) {
     366            // update parameters in tree
     367            var pIdx = 0;
     368            // here we lose the two last parameters (for linear scaling)
     369            foreach (var node in tree.IterateNodesPostfix()) {
     370              if (node is ConstantTreeNode constTreeNode) {
     371                constTreeNode.Value = xOpt[pIdx++];
     372              } else if (node is VariableTreeNode varTreeNode) {
     373                varTreeNode.Weight = xOpt[pIdx++];
     374              }
     375            }
     376            // note: we keep the optimized constants even when the tree is worse.
     377            // assert that we lose the last two parameters
     378            if (pIdx != xOpt.Length - 2) throw new InvalidProgramException();
     379          }
     380          if (Math.Abs(rep.nlcerr) > 0.01) return targetVariance; // constraints are violated
     381        } catch (ArithmeticException) {
     382          return targetVariance;
     383        } catch (alglib.alglibexception) {
     384          // eval MSE of original tree
     385          return targetVariance;
     386        }
     387      } else if (solver.Contains("minnlc")) {
     388        alglib.minnlcstate state;
     389        alglib.minnlcreport rep;
     390        alglib.optguardreport optGuardRep;
     391        try {
     392          alglib.minnlccreate(thetaValues.Count, thetaValues.ToArray(), out state);
     393          alglib.minnlcsetalgoslp(state);        // SLP is more robust but slower
     394          alglib.minnlcsetbc(state, thetaValues.Select(_ => -10000.0).ToArray(), thetaValues.Select(_ => +10000.0).ToArray());
     395          alglib.minnlcsetcond(state, 0, maxIterations);
     396          var s = Enumerable.Repeat(1d, thetaValues.Count).ToArray();  // scale is set to unit scale
     397          alglib.minnlcsetscale(state, s);
     398
     399          // set non-linear constraints: 0 equality constraints, constraintTrees inequality constraints
     400          alglib.minnlcsetnlc(state, 0, constraintTrees.Count);
     401          alglib.minnlcoptguardsmoothness(state, 1);
     402
     403          alglib.minnlcoptimize(state, calculate_jacobian, null, null);
     404          alglib.minnlcresults(state, out double[] xOpt, out rep);
     405          alglib.minnlcoptguardresults(state, out optGuardRep);
     406          if (optGuardRep.nonc0suspected) throw new InvalidProgramException("optGuardRep.nonc0suspected");
     407          if (optGuardRep.nonc1suspected) {
     408            alglib.minnlcoptguardnonc1test1results(state, out alglib.optguardnonc1test1report strrep, out alglib.optguardnonc1test1report lngrep);
     409            throw new InvalidProgramException("optGuardRep.nonc1suspected");
     410          }
     411
     412          // counter.FunctionEvaluations += rep.nfev; TODO
     413          counter.GradientEvaluations += rep.nfev;
     414
     415          if (rep.terminationtype != -8) {
    354416            // update parameters in tree
    355417            var pIdx = 0;
     
    362424            }
    363425            // note: we keep the optimized constants even when the tree is worse.
    364           }
    365           if (Math.Abs(rep.nlcerr) > 0.01) return targetVariance; // constraints are violated
    366         } catch (ArithmeticException) {
    367           return targetVariance;
    368         } catch (alglib.alglibexception) {
    369           // eval MSE of original tree
    370           return targetVariance;
    371         }
    372       } else if (solver.Contains("minnlc")) {
    373         alglib.minnlcstate state;
    374         alglib.minnlcreport rep;
    375         alglib.optguardreport optGuardRep;
    376         try {
    377           alglib.minnlccreate(thetaValues.Count, thetaValues.ToArray(), out state);
    378           alglib.minnlcsetalgoslp(state);        // SLP is more robust but slower
    379           alglib.minnlcsetbc(state, thetaValues.Select(_ => -10000.0).ToArray(), thetaValues.Select(_ => +10000.0).ToArray());
    380           alglib.minnlcsetcond(state, 1E-7, maxIterations);
    381           var s = Enumerable.Repeat(1d, thetaValues.Count).ToArray();  // scale is set to unit scale
    382           alglib.minnlcsetscale(state, s);
    383 
    384           // set non-linear constraints: 0 equality constraints, constraintTrees inequality constraints
    385           alglib.minnlcsetnlc(state, 0, constraintTrees.Count);
    386           alglib.minnlcoptguardsmoothness(state, 1);
    387 
    388           alglib.minnlcoptimize(state, calculate_jacobian, null, null);
    389           alglib.minnlcresults(state, out double[] xOpt, out rep);
    390           alglib.minnlcoptguardresults(state, out optGuardRep);
    391           if (optGuardRep.nonc0suspected) throw new InvalidProgramException("optGuardRep.nonc0suspected");
    392           if (optGuardRep.nonc1suspected) throw new InvalidProgramException("optGuardRep.nonc1suspected");
    393 
    394           // counter.FunctionEvaluations += rep.nfev; TODO
    395           counter.GradientEvaluations += rep.nfev;
    396 
    397           if (rep.terminationtype != -8) {
    398             // update parameters in tree
    399             var pIdx = 0;
    400             foreach (var node in tree.IterateNodesPostfix()) {
    401               if (node is ConstantTreeNode constTreeNode) {
    402                 constTreeNode.Value = xOpt[pIdx++];
    403               } else if (node is VariableTreeNode varTreeNode) {
    404                 varTreeNode.Weight = xOpt[pIdx++];
    405               }
    406             }
    407 
    408             // note: we keep the optimized constants even when the tree is worse.
     426            // assert that we lose the last two parameters
     427            if (pIdx != xOpt.Length - 2) throw new InvalidProgramException();
     428
    409429          }
    410430          if (Math.Abs(rep.nlcerr) > 0.01) return targetVariance; // constraints are violated
     
    421441
    422442      // evaluate tree with updated constants
    423       var residualVariance = SymbolicRegressionSingleObjectiveMeanSquaredErrorEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling: false);
     443      var residualVariance = SymbolicRegressionSingleObjectiveMeanSquaredErrorEvaluator.Calculate(interpreter, scaledTree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling: false);
    424444      return Math.Min(residualVariance, targetVariance);
     445    }
     446
     447    private static ISymbolicExpressionTree CopyAndScaleTree(ISymbolicExpressionTree tree, double scalingFactor, double offset) {
     448      var m = (ISymbolicExpressionTree)tree.Clone();
     449
     450      var add = MakeNode<Addition>(MakeNode<Multiplication>(m.Root.GetSubtree(0).GetSubtree(0), CreateConstant(scalingFactor)), CreateConstant(offset));
     451      m.Root.GetSubtree(0).RemoveSubtree(0);
     452      m.Root.GetSubtree(0).AddSubtree(add);
     453      return m;
    425454    }
    426455
Note: See TracChangeset for help on using the changeset viewer.