[17197] | 1 | using System;
|
---|
| 2 | using HeuristicLab.Optimization;
|
---|
| 3 | using HEAL.Attic;
|
---|
| 4 | using HeuristicLab.Common;
|
---|
| 5 | using System.Threading;
|
---|
| 6 | using HeuristicLab.Core;
|
---|
| 7 | using HeuristicLab.Data;
|
---|
| 8 | using HeuristicLab.Parameters;
|
---|
| 9 | using System.Linq;
|
---|
| 10 | using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
|
---|
| 11 | using HeuristicLab.Analysis;
|
---|
[17325] | 12 | using System.Collections.Generic;
|
---|
[17197] | 13 |
|
---|
[17214] | 14 | namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
|
---|
[17197] | 15 | [StorableType("676B237C-DD9C-4F24-B64F-D44B0FA1F6A6")]
|
---|
| 16 | [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 120)]
|
---|
| 17 | [Item(Name = "ConstrainedNLS", Description = "Non-linear Regression with non-linear constraints")]
|
---|
| 18 | public class ConstrainedNLS : BasicAlgorithm {
|
---|
| 19 | public static readonly string IterationsParameterName = "Iterations";
|
---|
| 20 | public static readonly string SolverParameterName = "Solver";
|
---|
| 21 | public static readonly string ModelStructureParameterName = "Model structure";
|
---|
| 22 |
|
---|
| 23 |
|
---|
| 24 | public IFixedValueParameter<IntValue> IterationsParameter {
|
---|
| 25 | get { return (IFixedValueParameter<IntValue>)Parameters[IterationsParameterName]; }
|
---|
| 26 | }
|
---|
| 27 | public IFixedValueParameter<StringValue> ModelStructureParameter {
|
---|
| 28 | get { return (IFixedValueParameter<StringValue>)Parameters[ModelStructureParameterName]; }
|
---|
| 29 | }
|
---|
| 30 | public IConstrainedValueParameter<StringValue> SolverParameter { get { return (IConstrainedValueParameter<StringValue>)Parameters[SolverParameterName]; } }
|
---|
| 31 |
|
---|
| 32 | public IFixedValueParameter<DoubleValue> FuncToleranceRelParameter {
|
---|
| 33 | get { return (IFixedValueParameter<DoubleValue>)Parameters["FuncToleranceRel"]; }
|
---|
| 34 | }
|
---|
| 35 | public IFixedValueParameter<DoubleValue> FuncToleranceAbsParameter {
|
---|
| 36 | get { return (IFixedValueParameter<DoubleValue>)Parameters["FuncToleranceAbs"]; }
|
---|
| 37 | }
|
---|
| 38 | public IFixedValueParameter<DoubleValue> MaxTimeParameter {
|
---|
| 39 | get { return (IFixedValueParameter<DoubleValue>)Parameters["MaxTime"]; }
|
---|
| 40 | }
|
---|
[17204] | 41 | public IFixedValueParameter<BoolValue> CheckGradientParameter {
|
---|
| 42 | get { return (IFixedValueParameter<BoolValue>)Parameters["CheckGradient"]; }
|
---|
| 43 | }
|
---|
[17197] | 44 | public int Iterations { get { return IterationsParameter.Value.Value; } set { IterationsParameter.Value.Value = value; } }
|
---|
| 45 |
|
---|
| 46 | public StringValue Solver {
|
---|
| 47 | get { return SolverParameter.Value; }
|
---|
| 48 | set { throw new NotImplementedException(); }
|
---|
| 49 | }
|
---|
| 50 |
|
---|
| 51 | public string ModelStructure {
|
---|
| 52 | get { return ModelStructureParameter.Value.Value; }
|
---|
| 53 | set { ModelStructureParameter.Value.Value = value; }
|
---|
| 54 | }
|
---|
[17204] | 55 | public bool CheckGradient {
|
---|
| 56 | get { return CheckGradientParameter.Value.Value; }
|
---|
| 57 | set { CheckGradientParameter.Value.Value = value; }
|
---|
| 58 | }
|
---|
[17197] | 59 |
|
---|
| 60 | public double FuncToleranceRel { get { return FuncToleranceRelParameter.Value.Value; } set { FuncToleranceRelParameter.Value.Value = value; } }
|
---|
| 61 | public double FuncToleranceAbs { get { return FuncToleranceAbsParameter.Value.Value; } set { FuncToleranceAbsParameter.Value.Value = value; } }
|
---|
| 62 | public double MaxTime { get { return MaxTimeParameter.Value.Value; } set { MaxTimeParameter.Value.Value = value; } }
|
---|
| 63 |
|
---|
| 64 | public ConstrainedNLS() : base() {
|
---|
| 65 | Problem = new RegressionProblem();
|
---|
| 66 |
|
---|
| 67 | Parameters.Add(new FixedValueParameter<StringValue>(ModelStructureParameterName, "The function for which the parameters must be fit (only numeric constants are tuned).", new StringValue("1.0 * x*x + 0.0")));
|
---|
| 68 | Parameters.Add(new FixedValueParameter<IntValue>(IterationsParameterName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree (0 indicates other or default stopping criterion).", new IntValue(10)));
|
---|
[17214] | 69 | var validSolvers = new ItemSet<StringValue>(new[] {
|
---|
| 70 | "MMA",
|
---|
| 71 | "COBYLA",
|
---|
| 72 | "CCSAQ",
|
---|
| 73 | "ISRES",
|
---|
| 74 | "DIRECT_G",
|
---|
| 75 | "NLOPT_GN_DIRECT_L",
|
---|
| 76 | "NLOPT_GN_DIRECT_L_RAND",
|
---|
| 77 | "NLOPT_GN_ORIG_DIRECT",
|
---|
| 78 | "NLOPT_GN_ORIG_DIRECT_L",
|
---|
[17311] | 79 | "NLOPT_GD_STOGO",
|
---|
| 80 | "NLOPT_GD_STOGO_RAND",
|
---|
| 81 | "NLOPT_LD_LBFGS_NOCEDAL",
|
---|
| 82 | "NLOPT_LD_LBFGS",
|
---|
| 83 | "NLOPT_LN_PRAXIS",
|
---|
| 84 | "NLOPT_LD_VAR1",
|
---|
| 85 | "NLOPT_LD_VAR2",
|
---|
| 86 | "NLOPT_LD_TNEWTON",
|
---|
| 87 | "NLOPT_LD_TNEWTON_RESTART",
|
---|
| 88 | "NLOPT_LD_TNEWTON_PRECOND",
|
---|
| 89 | "NLOPT_LD_TNEWTON_PRECOND_RESTART",
|
---|
| 90 | "NLOPT_GN_CRS2_LM",
|
---|
| 91 | "NLOPT_GN_MLSL",
|
---|
| 92 | "NLOPT_GD_MLSL",
|
---|
| 93 | "NLOPT_GN_MLSL_LDS",
|
---|
| 94 | "NLOPT_GD_MLSL_LDS",
|
---|
| 95 | "NLOPT_LN_NEWUOA",
|
---|
| 96 | "NLOPT_LN_NEWUOA_BOUND",
|
---|
| 97 | "NLOPT_LN_NELDERMEAD",
|
---|
| 98 | "NLOPT_LN_SBPLX",
|
---|
| 99 | "NLOPT_LN_AUGLAG",
|
---|
| 100 | "NLOPT_LD_AUGLAG",
|
---|
| 101 | "NLOPT_LN_BOBYQA",
|
---|
| 102 | "NLOPT_AUGLAG",
|
---|
| 103 | "NLOPT_LD_SLSQP",
|
---|
| 104 | "NLOPT_LD_CCSAQ",
|
---|
| 105 | "NLOPT_GN_ESCH",
|
---|
| 106 | "NLOPT_GN_AGS",
|
---|
[17214] | 107 | }.Select(s => new StringValue(s).AsReadOnly()));
|
---|
[17197] | 108 | Parameters.Add(new ConstrainedValueParameter<StringValue>(SolverParameterName, "The solver algorithm", validSolvers, validSolvers.First()));
|
---|
| 109 | Parameters.Add(new FixedValueParameter<DoubleValue>("FuncToleranceRel", new DoubleValue(0)));
|
---|
| 110 | Parameters.Add(new FixedValueParameter<DoubleValue>("FuncToleranceAbs", new DoubleValue(0)));
|
---|
| 111 | Parameters.Add(new FixedValueParameter<DoubleValue>("MaxTime", new DoubleValue(10)));
|
---|
[17204] | 112 | Parameters.Add(new FixedValueParameter<BoolValue>("CheckGradient", "Flag to indicate whether the gradient should be checked using numeric approximation", new BoolValue(false)));
|
---|
| 113 |
|
---|
| 114 | CheckGradientParameter.Hidden = true;
|
---|
[17197] | 115 | }
|
---|
| 116 |
|
---|
| 117 | public ConstrainedNLS(ConstrainedNLS original, Cloner cloner) : base(original, cloner) {
|
---|
| 118 | }
|
---|
| 119 |
|
---|
[17204] | 120 | [StorableHook(HookType.AfterDeserialization)]
|
---|
| 121 | public void AfterDeserializationHook() {
|
---|
| 122 | if (!Parameters.ContainsKey("CheckGradient")) {
|
---|
| 123 | Parameters.Add(new FixedValueParameter<BoolValue>("CheckGradient", "Flag to indicate whether the gradient should be checked using numeric approximation", new BoolValue(false)));
|
---|
| 124 |
|
---|
| 125 | CheckGradientParameter.Hidden = true;
|
---|
| 126 | }
|
---|
| 127 | }
|
---|
| 128 |
|
---|
[17197] | 129 | [StorableConstructor]
|
---|
| 130 | protected ConstrainedNLS(StorableConstructorFlag _) : base(_) {
|
---|
| 131 | }
|
---|
| 132 |
|
---|
| 133 | public override bool SupportsPause => false;
|
---|
| 134 |
|
---|
| 135 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
| 136 | return new ConstrainedNLS(this, cloner);
|
---|
| 137 | }
|
---|
| 138 |
|
---|
| 139 | protected override void Run(CancellationToken cancellationToken) {
|
---|
| 140 | var parser = new InfixExpressionParser();
|
---|
| 141 | var tree = parser.Parse(ModelStructure);
|
---|
| 142 | var problem = (IRegressionProblem)Problem;
|
---|
| 143 |
|
---|
| 144 |
|
---|
| 145 | #region prepare results
|
---|
| 146 | var functionEvaluations = new IntValue(0);
|
---|
| 147 | Results.AddOrUpdateResult("Evaluations", functionEvaluations);
|
---|
| 148 | var bestError = new DoubleValue(double.MaxValue);
|
---|
[17204] | 149 | var curError = new DoubleValue(double.MaxValue);
|
---|
[17197] | 150 | Results.AddOrUpdateResult("Best error", bestError);
|
---|
[17204] | 151 | Results.AddOrUpdateResult("Current error", curError);
|
---|
[17197] | 152 | Results.AddOrUpdateResult("Tree", tree);
|
---|
| 153 | var qualitiesTable = new DataTable("Qualities");
|
---|
| 154 | var curQualityRow = new DataRow("Current Quality");
|
---|
| 155 | var bestQualityRow = new DataRow("Best Quality");
|
---|
| 156 | qualitiesTable.Rows.Add(bestQualityRow);
|
---|
| 157 | qualitiesTable.Rows.Add(curQualityRow);
|
---|
| 158 | Results.AddOrUpdateResult("Qualities", qualitiesTable);
|
---|
| 159 |
|
---|
[17325] | 160 | var constraintRows = new List<IndexedDataRow<int>>(); // for access via index
|
---|
| 161 | var constraintsTable = new IndexedDataTable<int>("Constraints");
|
---|
| 162 | Results.AddOrUpdateResult("Constraints", constraintsTable);
|
---|
| 163 | foreach (var constraint in problem.ProblemData.IntervalConstraints.Constraints.Where(c => c.Enabled)) {
|
---|
| 164 | if (constraint.Interval.LowerBound > double.NegativeInfinity) {
|
---|
| 165 | var constraintRow = new IndexedDataRow<int>("-" + constraint.Expression + " < " + (-constraint.Interval.LowerBound));
|
---|
| 166 | constraintRows.Add(constraintRow);
|
---|
| 167 | constraintsTable.Rows.Add(constraintRow);
|
---|
| 168 | }
|
---|
| 169 | if (constraint.Interval.UpperBound < double.PositiveInfinity) {
|
---|
| 170 | var constraintRow = new IndexedDataRow<int>(constraint.Expression + " < " + (constraint.Interval.UpperBound));
|
---|
| 171 | constraintRows.Add(constraintRow);
|
---|
| 172 | constraintsTable.Rows.Add(constraintRow);
|
---|
| 173 | }
|
---|
| 174 | }
|
---|
[17197] | 175 |
|
---|
[17325] | 176 | var parametersTable = new IndexedDataTable<int>("Parameters");
|
---|
[17197] | 177 |
|
---|
| 178 | #endregion
|
---|
| 179 |
|
---|
| 180 | var state = new ConstrainedNLSInternal(Solver.Value, tree, Iterations, problem.ProblemData, FuncToleranceRel, FuncToleranceAbs, MaxTime);
|
---|
[17204] | 181 | if (CheckGradient) state.CheckGradient = true;
|
---|
| 182 | int idx = 0;
|
---|
[17214] | 183 | var formatter = new InfixExpressionFormatter();
|
---|
[17311] | 184 | var constraintDescriptions = state.ConstraintDescriptions.ToArray();
|
---|
[17325] | 185 | foreach (var constraintTree in state.constraintTrees) {
|
---|
[17214] | 186 | // HACK to remove parameter nodes which occurr multiple times
|
---|
| 187 | var reparsedTree = parser.Parse(formatter.Format(constraintTree));
|
---|
[17311] | 188 | Results.AddOrUpdateResult($"{constraintDescriptions[idx++]}", reparsedTree);
|
---|
[17204] | 189 | }
|
---|
[17197] | 190 |
|
---|
| 191 | // we use a listener model here to get state from the solver
|
---|
| 192 |
|
---|
| 193 | state.FunctionEvaluated += State_FunctionEvaluated;
|
---|
| 194 | state.ConstraintEvaluated += State_ConstraintEvaluated;
|
---|
| 195 |
|
---|
[17325] | 196 | state.Optimize(ConstrainedNLSInternal.OptimizationMode.UpdateParametersAndKeepLinearScaling);
|
---|
[17197] | 197 | bestError.Value = state.BestError;
|
---|
| 198 | curQualityRow.Values.Add(state.CurError);
|
---|
| 199 | bestQualityRow.Values.Add(bestError.Value);
|
---|
| 200 |
|
---|
[17204] | 201 | Results.AddOrUpdateResult("Best solution", CreateSolution((ISymbolicExpressionTree)state.BestTree.Clone(), problem.ProblemData));
|
---|
[17311] | 202 | var bestConstraintValues = new DoubleArray(state.BestConstraintValues);
|
---|
| 203 | bestConstraintValues.ElementNames = constraintDescriptions;
|
---|
| 204 | Results.AddOrUpdateResult("Best solution constraint values", bestConstraintValues);
|
---|
[17197] | 205 |
|
---|
| 206 |
|
---|
| 207 | // local function
|
---|
| 208 | void State_FunctionEvaluated() {
|
---|
| 209 | if (cancellationToken.IsCancellationRequested) state.RequestStop();
|
---|
| 210 | functionEvaluations.Value++;
|
---|
| 211 | bestError.Value = state.BestError;
|
---|
[17204] | 212 | curError.Value = state.CurError;
|
---|
[17197] | 213 | curQualityRow.Values.Add(state.CurError);
|
---|
| 214 | bestQualityRow.Values.Add(bestError.Value);
|
---|
[17325] | 215 |
|
---|
| 216 | // on the first call create the data rows
|
---|
| 217 | if(!parametersTable.Rows.Any()) {
|
---|
| 218 | for(int i=0;i<state.BestSolution.Length;i++) {
|
---|
| 219 | parametersTable.Rows.Add(new IndexedDataRow<int>("p" + i));
|
---|
| 220 | }
|
---|
| 221 | }
|
---|
| 222 | for (int i = 0; i < state.BestSolution.Length; i++) {
|
---|
| 223 | parametersTable.Rows["p" + i].Values.Add(Tuple.Create(functionEvaluations.Value, state.BestSolution[i])); // TODO: remove access via string
|
---|
| 224 | }
|
---|
[17197] | 225 | }
|
---|
| 226 |
|
---|
| 227 | // local function
|
---|
| 228 | void State_ConstraintEvaluated(int constraintIdx, double value) {
|
---|
[17325] | 229 | constraintRows[constraintIdx].Values.Add(Tuple.Create(functionEvaluations.Value, value));
|
---|
[17197] | 230 | }
|
---|
| 231 | }
|
---|
| 232 |
|
---|
| 233 | private static ISymbolicRegressionSolution CreateSolution(ISymbolicExpressionTree tree, IRegressionProblemData problemData) {
|
---|
| 234 | var model = new SymbolicRegressionModel(problemData.TargetVariable, tree, new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
|
---|
[17325] | 235 | // model.CreateRegressionSolution produces a new ProblemData and recalculates intervals ==> use SymbolicRegressionSolution.ctor instead
|
---|
| 236 | var sol = new SymbolicRegressionSolution(model, (IRegressionProblemData)problemData.Clone());
|
---|
| 237 | // NOTE: the solution has slightly different derivative values because simplification of derivatives can be done differently when parameter values are fixed.
|
---|
| 238 |
|
---|
[17311] | 239 | // var sol = model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());
|
---|
| 240 |
|
---|
| 241 | return sol;
|
---|
[17197] | 242 | }
|
---|
| 243 | }
|
---|
| 244 | }
|
---|