Changeset 17176
- Timestamp:
- 07/28/19 19:48:12 (5 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/2994-AutoDiffForIntervals/HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Extensions/ConstrainedConstantOptimizationEvaluator.cs
r17136 r17176 223 223 if (!updateVariableWeights) throw new NotSupportedException("not updating variable weights is not supported"); 224 224 if (!updateConstantsInTree) throw new NotSupportedException("not updating tree parameters is not supported"); 225 if ( applyLinearScaling) throw new NotSupportedException("linear scaling is not supported");225 if (!applyLinearScaling) throw new NotSupportedException("application without linear scaling is not supported"); 226 226 227 227 // we always update constants, so we don't need to calculate initial quality … … 234 234 var dataIntervals = problemData.VariableRanges.GetIntervals(); 235 235 236 // buffers 237 var target = problemData.TargetVariableTrainingValues.ToArray(); 238 var targetStDev = target.StandardDeviationPop(); 239 var targetVariance = targetStDev * targetStDev; 240 var targetMean = target.Average(); 241 var pred = interpreter.GetSymbolicExpressionTreeValues(tree, problemData.Dataset, problemData.TrainingIndices).ToArray(); 242 var predStDev = pred.StandardDeviationPop(); 243 var predMean = pred.Average(); 244 245 var scalingFactor = targetStDev / predStDev; 246 var offset = targetMean - predMean * scalingFactor; 247 248 ISymbolicExpressionTree scaledTree = null; 249 if (applyLinearScaling) scaledTree = CopyAndScaleTree(tree, scalingFactor, offset); 250 236 251 // convert constants to variables named theta... 237 var treeForDerivation = ReplaceConstWithVar( tree, out List<string> thetaNames, out List<double> thetaValues); // copies the tree252 var treeForDerivation = ReplaceConstWithVar(scaledTree, out List<string> thetaNames, out List<double> thetaValues); // copies the tree 238 253 239 254 // create trees for relevant derivatives … … 288 303 } 289 304 290 // buffers for calculate_jacobian291 var target = problemData.TargetVariableTrainingValues.ToArray();292 var targetVariance = target.VariancePop();293 305 var fi_eval = new double[target.Length]; 294 306 var jac_eval = new double[target.Length, thetaValues.Count]; … … 337 349 alglib.minnscreate(thetaValues.Count, thetaValues.ToArray(), out state); 338 350 alglib.minnssetbc(state, thetaValues.Select(_ => -10000.0).ToArray(), thetaValues.Select(_ => +10000.0).ToArray()); 339 alglib.minnssetcond(state, 1E-7, maxIterations);351 alglib.minnssetcond(state, 0, maxIterations); 340 352 var s = Enumerable.Repeat(1d, thetaValues.Count).ToArray(); // scale is set to unit scale 341 353 alglib.minnssetscale(state, s); … … 352 364 353 365 if (rep.terminationtype > 0) { 366 // update parameters in tree 367 var pIdx = 0; 368 // here we lose the two last parameters (for linear scaling) 369 foreach (var node in tree.IterateNodesPostfix()) { 370 if (node is ConstantTreeNode constTreeNode) { 371 constTreeNode.Value = xOpt[pIdx++]; 372 } else if (node is VariableTreeNode varTreeNode) { 373 varTreeNode.Weight = xOpt[pIdx++]; 374 } 375 } 376 // note: we keep the optimized constants even when the tree is worse. 377 // assert that we lose the last two parameters 378 if (pIdx != xOpt.Length - 2) throw new InvalidProgramException(); 379 } 380 if (Math.Abs(rep.nlcerr) > 0.01) return targetVariance; // constraints are violated 381 } catch (ArithmeticException) { 382 return targetVariance; 383 } catch (alglib.alglibexception) { 384 // eval MSE of original tree 385 return targetVariance; 386 } 387 } else if (solver.Contains("minnlc")) { 388 alglib.minnlcstate state; 389 alglib.minnlcreport rep; 390 alglib.optguardreport optGuardRep; 391 try { 392 alglib.minnlccreate(thetaValues.Count, thetaValues.ToArray(), out state); 393 alglib.minnlcsetalgoslp(state); // SLP is more robust but slower 394 alglib.minnlcsetbc(state, thetaValues.Select(_ => -10000.0).ToArray(), thetaValues.Select(_ => +10000.0).ToArray()); 395 alglib.minnlcsetcond(state, 0, maxIterations); 396 var s = Enumerable.Repeat(1d, thetaValues.Count).ToArray(); // scale is set to unit scale 397 alglib.minnlcsetscale(state, s); 398 399 // set non-linear constraints: 0 equality constraints, constraintTrees inequality constraints 400 alglib.minnlcsetnlc(state, 0, constraintTrees.Count); 401 alglib.minnlcoptguardsmoothness(state, 1); 402 403 alglib.minnlcoptimize(state, calculate_jacobian, null, null); 404 alglib.minnlcresults(state, out double[] xOpt, out rep); 405 alglib.minnlcoptguardresults(state, out optGuardRep); 406 if (optGuardRep.nonc0suspected) throw new InvalidProgramException("optGuardRep.nonc0suspected"); 407 if (optGuardRep.nonc1suspected) { 408 alglib.minnlcoptguardnonc1test1results(state, out alglib.optguardnonc1test1report strrep, out alglib.optguardnonc1test1report lngrep); 409 throw new InvalidProgramException("optGuardRep.nonc1suspected"); 410 } 411 412 // counter.FunctionEvaluations += rep.nfev; TODO 413 counter.GradientEvaluations += rep.nfev; 414 415 if (rep.terminationtype != -8) { 354 416 // update parameters in tree 355 417 var pIdx = 0; … … 362 424 } 363 425 // note: we keep the optimized constants even when the tree is worse. 364 } 365 if (Math.Abs(rep.nlcerr) > 0.01) return targetVariance; // constraints are violated 366 } catch (ArithmeticException) { 367 return targetVariance; 368 } catch (alglib.alglibexception) { 369 // eval MSE of original tree 370 return targetVariance; 371 } 372 } else if (solver.Contains("minnlc")) { 373 alglib.minnlcstate state; 374 alglib.minnlcreport rep; 375 alglib.optguardreport optGuardRep; 376 try { 377 alglib.minnlccreate(thetaValues.Count, thetaValues.ToArray(), out state); 378 alglib.minnlcsetalgoslp(state); // SLP is more robust but slower 379 alglib.minnlcsetbc(state, thetaValues.Select(_ => -10000.0).ToArray(), thetaValues.Select(_ => +10000.0).ToArray()); 380 alglib.minnlcsetcond(state, 1E-7, maxIterations); 381 var s = Enumerable.Repeat(1d, thetaValues.Count).ToArray(); // scale is set to unit scale 382 alglib.minnlcsetscale(state, s); 383 384 // set non-linear constraints: 0 equality constraints, constraintTrees inequality constraints 385 alglib.minnlcsetnlc(state, 0, constraintTrees.Count); 386 alglib.minnlcoptguardsmoothness(state, 1); 387 388 alglib.minnlcoptimize(state, calculate_jacobian, null, null); 389 alglib.minnlcresults(state, out double[] xOpt, out rep); 390 alglib.minnlcoptguardresults(state, out optGuardRep); 391 if (optGuardRep.nonc0suspected) throw new InvalidProgramException("optGuardRep.nonc0suspected"); 392 if (optGuardRep.nonc1suspected) throw new InvalidProgramException("optGuardRep.nonc1suspected"); 393 394 // counter.FunctionEvaluations += rep.nfev; TODO 395 counter.GradientEvaluations += rep.nfev; 396 397 if (rep.terminationtype != -8) { 398 // update parameters in tree 399 var pIdx = 0; 400 foreach (var node in tree.IterateNodesPostfix()) { 401 if (node is ConstantTreeNode constTreeNode) { 402 constTreeNode.Value = xOpt[pIdx++]; 403 } else if (node is VariableTreeNode varTreeNode) { 404 varTreeNode.Weight = xOpt[pIdx++]; 405 } 406 } 407 408 // note: we keep the optimized constants even when the tree is worse. 426 // assert that we lose the last two parameters 427 if (pIdx != xOpt.Length - 2) throw new InvalidProgramException(); 428 409 429 } 410 430 if (Math.Abs(rep.nlcerr) > 0.01) return targetVariance; // constraints are violated … … 421 441 422 442 // evaluate tree with updated constants 423 var residualVariance = SymbolicRegressionSingleObjectiveMeanSquaredErrorEvaluator.Calculate(interpreter, tree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling: false);443 var residualVariance = SymbolicRegressionSingleObjectiveMeanSquaredErrorEvaluator.Calculate(interpreter, scaledTree, lowerEstimationLimit, upperEstimationLimit, problemData, rows, applyLinearScaling: false); 424 444 return Math.Min(residualVariance, targetVariance); 445 } 446 447 private static ISymbolicExpressionTree CopyAndScaleTree(ISymbolicExpressionTree tree, double scalingFactor, double offset) { 448 var m = (ISymbolicExpressionTree)tree.Clone(); 449 450 var add = MakeNode<Addition>(MakeNode<Multiplication>(m.Root.GetSubtree(0).GetSubtree(0), CreateConstant(scalingFactor)), CreateConstant(offset)); 451 m.Root.GetSubtree(0).RemoveSubtree(0); 452 m.Root.GetSubtree(0).AddSubtree(add); 453 return m; 425 454 } 426 455
Note: See TracChangeset
for help on using the changeset viewer.