Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2994-AutoDiffForIntervals/HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Extensions/ConstrainedNLS.cs @ 17714

Last change on this file since 17714 was 17325, checked in by gkronber, 5 years ago

#2994: worked on ConstrainedNLS

File size: 11.4 KB
Line 
1using System;
2using HeuristicLab.Optimization;
3using HEAL.Attic;
4using HeuristicLab.Common;
5using System.Threading;
6using HeuristicLab.Core;
7using HeuristicLab.Data;
8using HeuristicLab.Parameters;
9using System.Linq;
10using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
11using HeuristicLab.Analysis;
12using System.Collections.Generic;
13
14namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
15  [StorableType("676B237C-DD9C-4F24-B64F-D44B0FA1F6A6")]
16  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 120)]
17  [Item(Name = "ConstrainedNLS", Description = "Non-linear Regression with non-linear constraints")]
18  public class ConstrainedNLS : BasicAlgorithm {
19    public static readonly string IterationsParameterName = "Iterations";
20    public static readonly string SolverParameterName = "Solver";
21    public static readonly string ModelStructureParameterName = "Model structure";
22
23
24    public IFixedValueParameter<IntValue> IterationsParameter {
25      get { return (IFixedValueParameter<IntValue>)Parameters[IterationsParameterName]; }
26    }
27    public IFixedValueParameter<StringValue> ModelStructureParameter {
28      get { return (IFixedValueParameter<StringValue>)Parameters[ModelStructureParameterName]; }
29    }
30    public IConstrainedValueParameter<StringValue> SolverParameter { get { return (IConstrainedValueParameter<StringValue>)Parameters[SolverParameterName]; } }
31
32    public IFixedValueParameter<DoubleValue> FuncToleranceRelParameter {
33      get { return (IFixedValueParameter<DoubleValue>)Parameters["FuncToleranceRel"]; }
34    }
35    public IFixedValueParameter<DoubleValue> FuncToleranceAbsParameter {
36      get { return (IFixedValueParameter<DoubleValue>)Parameters["FuncToleranceAbs"]; }
37    }
38    public IFixedValueParameter<DoubleValue> MaxTimeParameter {
39      get { return (IFixedValueParameter<DoubleValue>)Parameters["MaxTime"]; }
40    }
41    public IFixedValueParameter<BoolValue> CheckGradientParameter {
42      get { return (IFixedValueParameter<BoolValue>)Parameters["CheckGradient"]; }
43    }
44    public int Iterations { get { return IterationsParameter.Value.Value; } set { IterationsParameter.Value.Value = value; } }
45
46    public StringValue Solver {
47      get { return SolverParameter.Value; }
48      set { throw new NotImplementedException(); }
49    }
50
51    public string ModelStructure {
52      get { return ModelStructureParameter.Value.Value; }
53      set { ModelStructureParameter.Value.Value = value; }
54    }
55    public bool CheckGradient {
56      get { return CheckGradientParameter.Value.Value; }
57      set { CheckGradientParameter.Value.Value = value; }
58    }
59
60    public double FuncToleranceRel { get { return FuncToleranceRelParameter.Value.Value; } set { FuncToleranceRelParameter.Value.Value = value; } }
61    public double FuncToleranceAbs { get { return FuncToleranceAbsParameter.Value.Value; } set { FuncToleranceAbsParameter.Value.Value = value; } }
62    public double MaxTime { get { return MaxTimeParameter.Value.Value; } set { MaxTimeParameter.Value.Value = value; } }
63
64    public ConstrainedNLS() : base() {
65      Problem = new RegressionProblem();
66
67      Parameters.Add(new FixedValueParameter<StringValue>(ModelStructureParameterName, "The function for which the parameters must be fit (only numeric constants are tuned).", new StringValue("1.0 * x*x + 0.0")));
68      Parameters.Add(new FixedValueParameter<IntValue>(IterationsParameterName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree (0 indicates other or default stopping criterion).", new IntValue(10)));
69      var validSolvers = new ItemSet<StringValue>(new[] {
70        "MMA",
71        "COBYLA",
72        "CCSAQ",
73        "ISRES",
74        "DIRECT_G",
75        "NLOPT_GN_DIRECT_L",
76        "NLOPT_GN_DIRECT_L_RAND",
77        "NLOPT_GN_ORIG_DIRECT",
78        "NLOPT_GN_ORIG_DIRECT_L",
79        "NLOPT_GD_STOGO",
80        "NLOPT_GD_STOGO_RAND",
81        "NLOPT_LD_LBFGS_NOCEDAL",
82        "NLOPT_LD_LBFGS",
83        "NLOPT_LN_PRAXIS",
84        "NLOPT_LD_VAR1",
85        "NLOPT_LD_VAR2",
86        "NLOPT_LD_TNEWTON",
87        "NLOPT_LD_TNEWTON_RESTART",
88        "NLOPT_LD_TNEWTON_PRECOND",
89        "NLOPT_LD_TNEWTON_PRECOND_RESTART",
90        "NLOPT_GN_CRS2_LM",
91        "NLOPT_GN_MLSL",
92        "NLOPT_GD_MLSL",
93        "NLOPT_GN_MLSL_LDS",
94        "NLOPT_GD_MLSL_LDS",
95        "NLOPT_LN_NEWUOA",
96        "NLOPT_LN_NEWUOA_BOUND",
97        "NLOPT_LN_NELDERMEAD",
98        "NLOPT_LN_SBPLX",
99        "NLOPT_LN_AUGLAG",
100        "NLOPT_LD_AUGLAG",
101        "NLOPT_LN_BOBYQA",
102        "NLOPT_AUGLAG",
103        "NLOPT_LD_SLSQP",
104        "NLOPT_LD_CCSAQ",
105        "NLOPT_GN_ESCH",
106        "NLOPT_GN_AGS",
107      }.Select(s => new StringValue(s).AsReadOnly()));
108      Parameters.Add(new ConstrainedValueParameter<StringValue>(SolverParameterName, "The solver algorithm", validSolvers, validSolvers.First()));
109      Parameters.Add(new FixedValueParameter<DoubleValue>("FuncToleranceRel", new DoubleValue(0)));
110      Parameters.Add(new FixedValueParameter<DoubleValue>("FuncToleranceAbs", new DoubleValue(0)));
111      Parameters.Add(new FixedValueParameter<DoubleValue>("MaxTime", new DoubleValue(10)));
112      Parameters.Add(new FixedValueParameter<BoolValue>("CheckGradient", "Flag to indicate whether the gradient should be checked using numeric approximation", new BoolValue(false)));
113
114      CheckGradientParameter.Hidden = true;
115    }
116
117    public ConstrainedNLS(ConstrainedNLS original, Cloner cloner) : base(original, cloner) {
118    }
119
120    [StorableHook(HookType.AfterDeserialization)]
121    public void AfterDeserializationHook() {
122      if (!Parameters.ContainsKey("CheckGradient")) {
123        Parameters.Add(new FixedValueParameter<BoolValue>("CheckGradient", "Flag to indicate whether the gradient should be checked using numeric approximation", new BoolValue(false)));
124
125        CheckGradientParameter.Hidden = true;
126      }
127    }
128
129    [StorableConstructor]
130    protected ConstrainedNLS(StorableConstructorFlag _) : base(_) {
131    }
132
133    public override bool SupportsPause => false;
134
135    public override IDeepCloneable Clone(Cloner cloner) {
136      return new ConstrainedNLS(this, cloner);
137    }
138
139    protected override void Run(CancellationToken cancellationToken) {
140      var parser = new InfixExpressionParser();
141      var tree = parser.Parse(ModelStructure);
142      var problem = (IRegressionProblem)Problem;
143
144
145      #region prepare results
146      var functionEvaluations = new IntValue(0);
147      Results.AddOrUpdateResult("Evaluations", functionEvaluations);
148      var bestError = new DoubleValue(double.MaxValue);
149      var curError = new DoubleValue(double.MaxValue);
150      Results.AddOrUpdateResult("Best error", bestError);
151      Results.AddOrUpdateResult("Current error", curError);
152      Results.AddOrUpdateResult("Tree", tree);
153      var qualitiesTable = new DataTable("Qualities");
154      var curQualityRow = new DataRow("Current Quality");
155      var bestQualityRow = new DataRow("Best Quality");
156      qualitiesTable.Rows.Add(bestQualityRow);
157      qualitiesTable.Rows.Add(curQualityRow);
158      Results.AddOrUpdateResult("Qualities", qualitiesTable);
159
160      var constraintRows = new List<IndexedDataRow<int>>(); // for access via index
161      var constraintsTable = new IndexedDataTable<int>("Constraints");
162      Results.AddOrUpdateResult("Constraints", constraintsTable);
163      foreach (var constraint in problem.ProblemData.IntervalConstraints.Constraints.Where(c => c.Enabled)) {
164        if (constraint.Interval.LowerBound > double.NegativeInfinity) {
165          var constraintRow = new IndexedDataRow<int>("-" + constraint.Expression + " < " + (-constraint.Interval.LowerBound));
166          constraintRows.Add(constraintRow);
167          constraintsTable.Rows.Add(constraintRow);
168        }
169        if (constraint.Interval.UpperBound < double.PositiveInfinity) {
170          var constraintRow = new IndexedDataRow<int>(constraint.Expression + " < " + (constraint.Interval.UpperBound));
171          constraintRows.Add(constraintRow);
172          constraintsTable.Rows.Add(constraintRow);
173        }
174      }
175
176      var parametersTable = new IndexedDataTable<int>("Parameters");
177
178      #endregion
179
180      var state = new ConstrainedNLSInternal(Solver.Value, tree, Iterations, problem.ProblemData, FuncToleranceRel, FuncToleranceAbs, MaxTime);
181      if (CheckGradient) state.CheckGradient = true;
182      int idx = 0;
183      var formatter = new InfixExpressionFormatter();
184      var constraintDescriptions = state.ConstraintDescriptions.ToArray();
185      foreach (var constraintTree in state.constraintTrees) {
186        // HACK to remove parameter nodes which occurr multiple times
187        var reparsedTree = parser.Parse(formatter.Format(constraintTree));
188        Results.AddOrUpdateResult($"{constraintDescriptions[idx++]}", reparsedTree);
189      }
190
191      // we use a listener model here to get state from the solver     
192
193      state.FunctionEvaluated += State_FunctionEvaluated;
194      state.ConstraintEvaluated += State_ConstraintEvaluated;
195
196      state.Optimize(ConstrainedNLSInternal.OptimizationMode.UpdateParametersAndKeepLinearScaling);
197      bestError.Value = state.BestError;
198      curQualityRow.Values.Add(state.CurError);
199      bestQualityRow.Values.Add(bestError.Value);
200
201      Results.AddOrUpdateResult("Best solution", CreateSolution((ISymbolicExpressionTree)state.BestTree.Clone(), problem.ProblemData));
202      var bestConstraintValues = new DoubleArray(state.BestConstraintValues);
203      bestConstraintValues.ElementNames = constraintDescriptions;
204      Results.AddOrUpdateResult("Best solution constraint values", bestConstraintValues);
205
206
207      // local function
208      void State_FunctionEvaluated() {
209        if (cancellationToken.IsCancellationRequested) state.RequestStop();
210        functionEvaluations.Value++;
211        bestError.Value = state.BestError;
212        curError.Value = state.CurError;
213        curQualityRow.Values.Add(state.CurError);
214        bestQualityRow.Values.Add(bestError.Value);
215
216        // on the first call create the data rows
217        if(!parametersTable.Rows.Any()) {
218          for(int i=0;i<state.BestSolution.Length;i++) {
219            parametersTable.Rows.Add(new IndexedDataRow<int>("p" + i));
220          }
221        }
222        for (int i = 0; i < state.BestSolution.Length; i++) {
223          parametersTable.Rows["p" + i].Values.Add(Tuple.Create(functionEvaluations.Value, state.BestSolution[i])); // TODO: remove access via string
224        }
225      }
226
227      // local function
228      void State_ConstraintEvaluated(int constraintIdx, double value) {
229        constraintRows[constraintIdx].Values.Add(Tuple.Create(functionEvaluations.Value, value));
230      }
231    }
232
233    private static ISymbolicRegressionSolution CreateSolution(ISymbolicExpressionTree tree, IRegressionProblemData problemData) {
234      var model = new SymbolicRegressionModel(problemData.TargetVariable, tree, new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
235      // model.CreateRegressionSolution produces a new ProblemData and recalculates intervals ==> use SymbolicRegressionSolution.ctor instead
236      var sol = new SymbolicRegressionSolution(model, (IRegressionProblemData)problemData.Clone());
237      // NOTE: the solution has slightly different derivative values because simplification of derivatives can be done differently when parameter values are fixed.
238
239      // var sol = model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());
240
241      return sol;
242    }
243  }
244}
Note: See TracBrowser for help on using the repository browser.