Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2994-AutoDiffForIntervals/HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Extensions/ConstrainedNLS.cs @ 17285

Last change on this file since 17285 was 17214, checked in by gkronber, 5 years ago

#2994: support for other NLOpt analysers, display trees for constraints for debugging

File size: 9.8 KB
Line 
1using System;
2using HeuristicLab.Optimization;
3using HEAL.Attic;
4using HeuristicLab.Common;
5using System.Threading;
6using HeuristicLab.Core;
7using HeuristicLab.Data;
8using HeuristicLab.Parameters;
9using System.Linq;
10using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
11using HeuristicLab.Analysis;
12
13namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
14  [StorableType("676B237C-DD9C-4F24-B64F-D44B0FA1F6A6")]
15  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 120)]
16  [Item(Name = "ConstrainedNLS", Description = "Non-linear Regression with non-linear constraints")]
17  public class ConstrainedNLS : BasicAlgorithm {
18    public static readonly string IterationsParameterName = "Iterations";
19    public static readonly string SolverParameterName = "Solver";
20    public static readonly string ModelStructureParameterName = "Model structure";
21
22
23    public IFixedValueParameter<IntValue> IterationsParameter {
24      get { return (IFixedValueParameter<IntValue>)Parameters[IterationsParameterName]; }
25    }
26    public IFixedValueParameter<StringValue> ModelStructureParameter {
27      get { return (IFixedValueParameter<StringValue>)Parameters[ModelStructureParameterName]; }
28    }
29    public IConstrainedValueParameter<StringValue> SolverParameter { get { return (IConstrainedValueParameter<StringValue>)Parameters[SolverParameterName]; } }
30
31    public IFixedValueParameter<DoubleValue> FuncToleranceRelParameter {
32      get { return (IFixedValueParameter<DoubleValue>)Parameters["FuncToleranceRel"]; }
33    }
34    public IFixedValueParameter<DoubleValue> FuncToleranceAbsParameter {
35      get { return (IFixedValueParameter<DoubleValue>)Parameters["FuncToleranceAbs"]; }
36    }
37    public IFixedValueParameter<DoubleValue> MaxTimeParameter {
38      get { return (IFixedValueParameter<DoubleValue>)Parameters["MaxTime"]; }
39    }
40    public IFixedValueParameter<BoolValue> CheckGradientParameter {
41      get { return (IFixedValueParameter<BoolValue>)Parameters["CheckGradient"]; }
42    }
43    public int Iterations { get { return IterationsParameter.Value.Value; } set { IterationsParameter.Value.Value = value; } }
44
45    public StringValue Solver {
46      get { return SolverParameter.Value; }
47      set { throw new NotImplementedException(); }
48    }
49
50    public string ModelStructure {
51      get { return ModelStructureParameter.Value.Value; }
52      set { ModelStructureParameter.Value.Value = value; }
53    }
54    public bool CheckGradient {
55      get { return CheckGradientParameter.Value.Value; }
56      set { CheckGradientParameter.Value.Value = value; }
57    }
58
59    public double FuncToleranceRel { get { return FuncToleranceRelParameter.Value.Value; } set { FuncToleranceRelParameter.Value.Value = value; } }
60    public double FuncToleranceAbs { get { return FuncToleranceAbsParameter.Value.Value; } set { FuncToleranceAbsParameter.Value.Value = value; } }
61    public double MaxTime { get { return MaxTimeParameter.Value.Value; } set { MaxTimeParameter.Value.Value = value; } }
62
63    public ConstrainedNLS() : base() {
64      Problem = new RegressionProblem();
65
66      Parameters.Add(new FixedValueParameter<StringValue>(ModelStructureParameterName, "The function for which the parameters must be fit (only numeric constants are tuned).", new StringValue("1.0 * x*x + 0.0")));
67      Parameters.Add(new FixedValueParameter<IntValue>(IterationsParameterName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree (0 indicates other or default stopping criterion).", new IntValue(10)));
68      var validSolvers = new ItemSet<StringValue>(new[] {
69        "MMA",
70        "COBYLA",
71        "CCSAQ",
72        "ISRES",
73        "DIRECT_G",
74        "NLOPT_GN_DIRECT_L",
75        "NLOPT_GN_DIRECT_L_RAND",
76        "NLOPT_GN_ORIG_DIRECT",
77        "NLOPT_GN_ORIG_DIRECT_L",
78        "NLOPT_GD_STOGO",
79        "NLOPT_GD_STOGO_RAND",
80        "NLOPT_LD_LBFGS_NOCEDAL",
81        "NLOPT_LD_LBFGS",
82        "NLOPT_LN_PRAXIS",
83        "NLOPT_LD_VAR1",
84        "NLOPT_LD_VAR2",
85        "NLOPT_LD_TNEWTON",
86        "NLOPT_LD_TNEWTON_RESTART",
87        "NLOPT_LD_TNEWTON_PRECOND",
88        "NLOPT_LD_TNEWTON_PRECOND_RESTART",
89        "NLOPT_GN_CRS2_LM",
90        "NLOPT_GN_MLSL",
91        "NLOPT_GD_MLSL",
92        "NLOPT_GN_MLSL_LDS",
93        "NLOPT_GD_MLSL_LDS",
94        "NLOPT_LN_NEWUOA",
95        "NLOPT_LN_NEWUOA_BOUND",
96        "NLOPT_LN_NELDERMEAD",
97        "NLOPT_LN_SBPLX",
98        "NLOPT_LN_AUGLAG",
99        "NLOPT_LD_AUGLAG",
100        "NLOPT_LN_BOBYQA",
101        "NLOPT_AUGLAG",
102        "NLOPT_LD_SLSQP",
103        "NLOPT_LD_CCSAQ",
104        "NLOPT_GN_ESCH",
105        "NLOPT_GN_AGS",
106      }.Select(s => new StringValue(s).AsReadOnly()));
107      Parameters.Add(new ConstrainedValueParameter<StringValue>(SolverParameterName, "The solver algorithm", validSolvers, validSolvers.First()));
108      Parameters.Add(new FixedValueParameter<DoubleValue>("FuncToleranceRel", new DoubleValue(0)));
109      Parameters.Add(new FixedValueParameter<DoubleValue>("FuncToleranceAbs", new DoubleValue(0)));
110      Parameters.Add(new FixedValueParameter<DoubleValue>("MaxTime", new DoubleValue(10)));
111      Parameters.Add(new FixedValueParameter<BoolValue>("CheckGradient", "Flag to indicate whether the gradient should be checked using numeric approximation", new BoolValue(false)));
112
113      CheckGradientParameter.Hidden = true;
114    }
115
116    public ConstrainedNLS(ConstrainedNLS original, Cloner cloner) : base(original, cloner) {
117    }
118
119    [StorableHook(HookType.AfterDeserialization)]
120    public void AfterDeserializationHook() {
121      if (!Parameters.ContainsKey("CheckGradient")) {
122        Parameters.Add(new FixedValueParameter<BoolValue>("CheckGradient", "Flag to indicate whether the gradient should be checked using numeric approximation", new BoolValue(false)));
123
124        CheckGradientParameter.Hidden = true;
125      }
126    }
127
128    [StorableConstructor]
129    protected ConstrainedNLS(StorableConstructorFlag _) : base(_) {
130    }
131
132    public override bool SupportsPause => false;
133
134    public override IDeepCloneable Clone(Cloner cloner) {
135      return new ConstrainedNLS(this, cloner);
136    }
137
138    protected override void Run(CancellationToken cancellationToken) {
139      var parser = new InfixExpressionParser();
140      var tree = parser.Parse(ModelStructure);
141      var problem = (IRegressionProblem)Problem;
142
143
144      #region prepare results
145      var functionEvaluations = new IntValue(0);
146      Results.AddOrUpdateResult("Evaluations", functionEvaluations);
147      var bestError = new DoubleValue(double.MaxValue);
148      var curError = new DoubleValue(double.MaxValue);
149      Results.AddOrUpdateResult("Best error", bestError);
150      Results.AddOrUpdateResult("Current error", curError);
151      Results.AddOrUpdateResult("Tree", tree);
152      var qualitiesTable = new DataTable("Qualities");
153      var curQualityRow = new DataRow("Current Quality");
154      var bestQualityRow = new DataRow("Best Quality");
155      qualitiesTable.Rows.Add(bestQualityRow);
156      qualitiesTable.Rows.Add(curQualityRow);
157      Results.AddOrUpdateResult("Qualities", qualitiesTable);
158
159      var curConstraintValue = new DoubleValue(0);
160      Results.AddOrUpdateResult("Current Constraint Value", curConstraintValue);
161      var curConstraintIdx = new IntValue(0);
162      Results.AddOrUpdateResult("Current Constraint Index", curConstraintIdx);
163
164      var curConstraintRow = new DataRow("Constraint Value");
165      var constraintsTable = new DataTable("Constraints");
166
167      constraintsTable.Rows.Add(curConstraintRow);
168      Results.AddOrUpdateResult("Constraints", constraintsTable);
169
170      #endregion
171
172      var state = new ConstrainedNLSInternal(Solver.Value, tree, Iterations, problem.ProblemData, FuncToleranceRel, FuncToleranceAbs, MaxTime);
173      if (CheckGradient) state.CheckGradient = true;
174      int idx = 0;
175      var formatter = new InfixExpressionFormatter();
176      foreach(var constraintTree in state.constraintTrees) {
177        // HACK to remove parameter nodes which occurr multiple times
178        var reparsedTree = parser.Parse(formatter.Format(constraintTree));
179        Results.AddOrUpdateResult($"Constraint {idx++}", reparsedTree);
180      }
181
182      // we use a listener model here to get state from the solver     
183
184      state.FunctionEvaluated += State_FunctionEvaluated;
185      state.ConstraintEvaluated += State_ConstraintEvaluated;
186
187      state.Optimize();
188      bestError.Value = state.BestError;
189      curQualityRow.Values.Add(state.CurError);
190      bestQualityRow.Values.Add(bestError.Value);
191
192      Results.AddOrUpdateResult("Best solution", CreateSolution((ISymbolicExpressionTree)state.BestTree.Clone(), problem.ProblemData));
193      Results.AddOrUpdateResult("Best solution constraint values", new DoubleArray(state.BestConstraintValues));
194
195
196      // local function
197      void State_FunctionEvaluated() {
198        if (cancellationToken.IsCancellationRequested) state.RequestStop();
199        functionEvaluations.Value++;
200        bestError.Value = state.BestError;
201        curError.Value = state.CurError;
202        curQualityRow.Values.Add(state.CurError);
203        bestQualityRow.Values.Add(bestError.Value);
204      }
205
206      // local function
207      void State_ConstraintEvaluated(int constraintIdx, double value) {
208        curConstraintIdx.Value = constraintIdx;
209        curConstraintValue.Value = value;
210        curConstraintRow.Values.Add(value);
211      }
212    }
213
214    private static ISymbolicRegressionSolution CreateSolution(ISymbolicExpressionTree tree, IRegressionProblemData problemData) {
215      var model = new SymbolicRegressionModel(problemData.TargetVariable, tree, new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
216      // model.Scale(problemData);
217      return model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());
218    }
219  }
220}
Note: See TracBrowser for help on using the repository browser.