1 | using System;
2 | using HeuristicLab.Optimization;
3 | using HEAL.Attic;
4 | using HeuristicLab.Common;
5 | using System.Threading;
6 | using HeuristicLab.Core;
7 | using HeuristicLab.Data;
8 | using HeuristicLab.Parameters;
9 | using System.Linq;
10 | using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
11 | using HeuristicLab.Analysis;
12 |
13 | namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
14 | [StorableType("676B237C-DD9C-4F24-B64F-D44B0FA1F6A6")]
15 | [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 120)]
16 | [Item(Name = "ConstrainedNLS", Description = "Non-linear Regression with non-linear constraints")]
17 | public class ConstrainedNLS : BasicAlgorithm {
18 | public static readonly string IterationsParameterName = "Iterations";
19 | public static readonly string SolverParameterName = "Solver";
20 | public static readonly string ModelStructureParameterName = "Model structure";
21 |
22 |
23 | public IFixedValueParameter<IntValue> IterationsParameter {
24 | get { return (IFixedValueParameter<IntValue>)Parameters[IterationsParameterName]; }
25 | }
26 | public IFixedValueParameter<StringValue> ModelStructureParameter {
27 | get { return (IFixedValueParameter<StringValue>)Parameters[ModelStructureParameterName]; }
28 | }
29 | public IConstrainedValueParameter<StringValue> SolverParameter { get { return (IConstrainedValueParameter<StringValue>)Parameters[SolverParameterName]; } }
30 |
31 | public IFixedValueParameter<DoubleValue> FuncToleranceRelParameter {
32 | get { return (IFixedValueParameter<DoubleValue>)Parameters["FuncToleranceRel"]; }
33 | }
34 | public IFixedValueParameter<DoubleValue> FuncToleranceAbsParameter {
35 | get { return (IFixedValueParameter<DoubleValue>)Parameters["FuncToleranceAbs"]; }
36 | }
37 | public IFixedValueParameter<DoubleValue> MaxTimeParameter {
38 | get { return (IFixedValueParameter<DoubleValue>)Parameters["MaxTime"]; }
39 | }
40 | public IFixedValueParameter<BoolValue> CheckGradientParameter {
41 | get { return (IFixedValueParameter<BoolValue>)Parameters["CheckGradient"]; }
42 | }
43 | public int Iterations { get { return IterationsParameter.Value.Value; } set { IterationsParameter.Value.Value = value; } }
44 |
45 | public StringValue Solver {
46 | get { return SolverParameter.Value; }
47 | set { throw new NotImplementedException(); }
48 | }
49 |
50 | public string ModelStructure {
51 | get { return ModelStructureParameter.Value.Value; }
52 | set { ModelStructureParameter.Value.Value = value; }
53 | }
54 | public bool CheckGradient {
55 | get { return CheckGradientParameter.Value.Value; }
56 | set { CheckGradientParameter.Value.Value = value; }
57 | }
58 |
59 | public double FuncToleranceRel { get { return FuncToleranceRelParameter.Value.Value; } set { FuncToleranceRelParameter.Value.Value = value; } }
60 | public double FuncToleranceAbs { get { return FuncToleranceAbsParameter.Value.Value; } set { FuncToleranceAbsParameter.Value.Value = value; } }
61 | public double MaxTime { get { return MaxTimeParameter.Value.Value; } set { MaxTimeParameter.Value.Value = value; } }
62 |
63 | public ConstrainedNLS() : base() {
64 | Problem = new RegressionProblem();
65 |
66 | Parameters.Add(new FixedValueParameter<StringValue>(ModelStructureParameterName, "The function for which the parameters must be fit (only numeric constants are tuned).", new StringValue("1.0 * x*x + 0.0")));
67 | Parameters.Add(new FixedValueParameter<IntValue>(IterationsParameterName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree (0 indicates other or default stopping criterion).", new IntValue(10)));
68 | var validSolvers = new ItemSet<StringValue>(new[] {
69 | "MMA",
70 | "COBYLA",
71 | "CCSAQ",
72 | "ISRES",
73 | "DIRECT_G",
78 | "NLOPT_GD_STOGO", |
81 | "NLOPT_LD_LBFGS", |
83 | "NLOPT_LD_VAR1", |
84 | "NLOPT_LD_VAR2", |
89 | "NLOPT_GN_CRS2_LM", |
90 | "NLOPT_GN_MLSL", |
91 | "NLOPT_GD_MLSL", |
97 | "NLOPT_LN_SBPLX", |
100 | "NLOPT_LN_BOBYQA", |
101 | "NLOPT_AUGLAG", |
102 | "NLOPT_LD_SLSQP", |
103 | "NLOPT_LD_CCSAQ", |
104 | "NLOPT_GN_ESCH", |
105 | "NLOPT_GN_AGS", |
106 | }.Select(s => new StringValue(s).AsReadOnly()));
107 | Parameters.Add(new ConstrainedValueParameter<StringValue>(SolverParameterName, "The solver algorithm", validSolvers, validSolvers.First()));
108 | Parameters.Add(new FixedValueParameter<DoubleValue>("FuncToleranceRel", new DoubleValue(0)));
109 | Parameters.Add(new FixedValueParameter<DoubleValue>("FuncToleranceAbs", new DoubleValue(0)));
110 | Parameters.Add(new FixedValueParameter<DoubleValue>("MaxTime", new DoubleValue(10)));
111 | Parameters.Add(new FixedValueParameter<BoolValue>("CheckGradient", "Flag to indicate whether the gradient should be checked using numeric approximation", new BoolValue(false)));
112 |
113 | CheckGradientParameter.Hidden = true;
114 | }
115 |
116 | public ConstrainedNLS(ConstrainedNLS original, Cloner cloner) : base(original, cloner) {
117 | }
118 |
119 | [StorableHook(HookType.AfterDeserialization)]
120 | public void AfterDeserializationHook() {
121 | if (!Parameters.ContainsKey("CheckGradient")) {
122 | Parameters.Add(new FixedValueParameter<BoolValue>("CheckGradient", "Flag to indicate whether the gradient should be checked using numeric approximation", new BoolValue(false)));
123 |
124 | CheckGradientParameter.Hidden = true;
125 | }
126 | }
127 |
128 | [StorableConstructor]
129 | protected ConstrainedNLS(StorableConstructorFlag _) : base(_) {
130 | }
131 |
132 | public override bool SupportsPause => false;
133 |
134 | public override IDeepCloneable Clone(Cloner cloner) {
135 | return new ConstrainedNLS(this, cloner);
136 | }
137 |
138 | protected override void Run(CancellationToken cancellationToken) {
139 | var parser = new InfixExpressionParser();
140 | var tree = parser.Parse(ModelStructure);
141 | var problem = (IRegressionProblem)Problem;
142 |
143 |
144 | #region prepare results
145 | var functionEvaluations = new IntValue(0);
146 | Results.AddOrUpdateResult("Evaluations", functionEvaluations);
147 | var bestError = new DoubleValue(double.MaxValue);
148 | var curError = new DoubleValue(double.MaxValue);
149 | Results.AddOrUpdateResult("Best error", bestError);
150 | Results.AddOrUpdateResult("Current error", curError);
151 | Results.AddOrUpdateResult("Tree", tree);
152 | var qualitiesTable = new DataTable("Qualities");
153 | var curQualityRow = new DataRow("Current Quality");
154 | var bestQualityRow = new DataRow("Best Quality");
155 | qualitiesTable.Rows.Add(bestQualityRow);
156 | qualitiesTable.Rows.Add(curQualityRow);
157 | Results.AddOrUpdateResult("Qualities", qualitiesTable);
158 |
159 | var curConstraintValue = new DoubleValue(0);
160 | Results.AddOrUpdateResult("Current Constraint Value", curConstraintValue);
161 | var curConstraintIdx = new IntValue(0);
162 | Results.AddOrUpdateResult("Current Constraint Index", curConstraintIdx);
163 |
164 | var curConstraintRow = new DataRow("Constraint Value");
165 | var constraintsTable = new DataTable("Constraints");
166 |
167 | constraintsTable.Rows.Add(curConstraintRow);
168 | Results.AddOrUpdateResult("Constraints", constraintsTable);
169 |
170 | #endregion
171 |
172 | var state = new ConstrainedNLSInternal(Solver.Value, tree, Iterations, problem.ProblemData, FuncToleranceRel, FuncToleranceAbs, MaxTime);
173 | if (CheckGradient) state.CheckGradient = true;
174 | int idx = 0;
175 | var formatter = new InfixExpressionFormatter();
176 | foreach(var constraintTree in state.constraintTrees) {
177 | // HACK to remove parameter nodes which occurr multiple times
178 | var reparsedTree = parser.Parse(formatter.Format(constraintTree));
179 | Results.AddOrUpdateResult($"Constraint {idx++}", reparsedTree);
180 | }
181 |
182 | // we use a listener model here to get state from the solver
183 |
184 | state.FunctionEvaluated += State_FunctionEvaluated;
185 | state.ConstraintEvaluated += State_ConstraintEvaluated;
186 |
187 | state.Optimize();
188 | bestError.Value = state.BestError;
189 | curQualityRow.Values.Add(state.CurError);
190 | bestQualityRow.Values.Add(bestError.Value);
191 |
192 | Results.AddOrUpdateResult("Best solution", CreateSolution((ISymbolicExpressionTree)state.BestTree.Clone(), problem.ProblemData));
193 | Results.AddOrUpdateResult("Best solution constraint values", new DoubleArray(state.BestConstraintValues));
194 |
195 |
196 | // local function
197 | void State_FunctionEvaluated() {
198 | if (cancellationToken.IsCancellationRequested) state.RequestStop();
199 | functionEvaluations.Value++;
200 | bestError.Value = state.BestError;
201 | curError.Value = state.CurError;
202 | curQualityRow.Values.Add(state.CurError);
203 | bestQualityRow.Values.Add(bestError.Value);
204 | }
205 |
206 | // local function
207 | void State_ConstraintEvaluated(int constraintIdx, double value) {
208 | curConstraintIdx.Value = constraintIdx;
209 | curConstraintValue.Value = value;
210 | curConstraintRow.Values.Add(value);
211 | }
212 | }
213 |
214 | private static ISymbolicRegressionSolution CreateSolution(ISymbolicExpressionTree tree, IRegressionProblemData problemData) {
215 | var model = new SymbolicRegressionModel(problemData.TargetVariable, tree, new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
216 | // model.Scale(problemData);
217 | return model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());
218 | }
219 | }
220 | }