Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2994-AutoDiffForIntervals/HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Extensions/ConstrainedNLS.cs @ 17204

Last change on this file since 17204 was 17204, checked in by gkronber, 5 years ago

#2994: added parameter for gradient checks and experimented with preconditioning

File size: 8.6 KB
Line 
1using System;
2using HeuristicLab.Optimization;
3using HEAL.Attic;
4using HeuristicLab.Common;
5using System.Threading;
6using HeuristicLab.Core;
7using HeuristicLab.Data;
8using HeuristicLab.Parameters;
9using System.Linq;
10using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
11using HeuristicLab.Analysis;
12
13namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression.Extensions {
14  [StorableType("676B237C-DD9C-4F24-B64F-D44B0FA1F6A6")]
15  [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 120)]
16  [Item(Name = "ConstrainedNLS", Description = "Non-linear Regression with non-linear constraints")]
17  public class ConstrainedNLS : BasicAlgorithm {
18    public static readonly string IterationsParameterName = "Iterations";
19    public static readonly string SolverParameterName = "Solver";
20    public static readonly string ModelStructureParameterName = "Model structure";
21
22
23    public IFixedValueParameter<IntValue> IterationsParameter {
24      get { return (IFixedValueParameter<IntValue>)Parameters[IterationsParameterName]; }
25    }
26    public IFixedValueParameter<StringValue> ModelStructureParameter {
27      get { return (IFixedValueParameter<StringValue>)Parameters[ModelStructureParameterName]; }
28    }
29    public IConstrainedValueParameter<StringValue> SolverParameter { get { return (IConstrainedValueParameter<StringValue>)Parameters[SolverParameterName]; } }
30
31    public IFixedValueParameter<DoubleValue> FuncToleranceRelParameter {
32      get { return (IFixedValueParameter<DoubleValue>)Parameters["FuncToleranceRel"]; }
33    }
34    public IFixedValueParameter<DoubleValue> FuncToleranceAbsParameter {
35      get { return (IFixedValueParameter<DoubleValue>)Parameters["FuncToleranceAbs"]; }
36    }
37    public IFixedValueParameter<DoubleValue> MaxTimeParameter {
38      get { return (IFixedValueParameter<DoubleValue>)Parameters["MaxTime"]; }
39    }
40    public IFixedValueParameter<BoolValue> CheckGradientParameter {
41      get { return (IFixedValueParameter<BoolValue>)Parameters["CheckGradient"]; }
42    }
43    public int Iterations { get { return IterationsParameter.Value.Value; } set { IterationsParameter.Value.Value = value; } }
44
45    public StringValue Solver {
46      get { return SolverParameter.Value; }
47      set { throw new NotImplementedException(); }
48    }
49
50    public string ModelStructure {
51      get { return ModelStructureParameter.Value.Value; }
52      set { ModelStructureParameter.Value.Value = value; }
53    }
54    public bool CheckGradient {
55      get { return CheckGradientParameter.Value.Value; }
56      set { CheckGradientParameter.Value.Value = value; }
57    }
58
59    public double FuncToleranceRel { get { return FuncToleranceRelParameter.Value.Value; } set { FuncToleranceRelParameter.Value.Value = value; } }
60    public double FuncToleranceAbs { get { return FuncToleranceAbsParameter.Value.Value; } set { FuncToleranceAbsParameter.Value.Value = value; } }
61    public double MaxTime { get { return MaxTimeParameter.Value.Value; } set { MaxTimeParameter.Value.Value = value; } }
62
63    public ConstrainedNLS() : base() {
64      Problem = new RegressionProblem();
65
66      Parameters.Add(new FixedValueParameter<StringValue>(ModelStructureParameterName, "The function for which the parameters must be fit (only numeric constants are tuned).", new StringValue("1.0 * x*x + 0.0")));
67      Parameters.Add(new FixedValueParameter<IntValue>(IterationsParameterName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree (0 indicates other or default stopping criterion).", new IntValue(10)));
68      var validSolvers = new ItemSet<StringValue>(new[] { "MMA", "COBYLA", "CCSAQ", "ISRES" }.Select(s => new StringValue(s).AsReadOnly()));
69      Parameters.Add(new ConstrainedValueParameter<StringValue>(SolverParameterName, "The solver algorithm", validSolvers, validSolvers.First()));
70      Parameters.Add(new FixedValueParameter<DoubleValue>("FuncToleranceRel", new DoubleValue(0)));
71      Parameters.Add(new FixedValueParameter<DoubleValue>("FuncToleranceAbs", new DoubleValue(0)));
72      Parameters.Add(new FixedValueParameter<DoubleValue>("MaxTime", new DoubleValue(10)));
73      Parameters.Add(new FixedValueParameter<BoolValue>("CheckGradient", "Flag to indicate whether the gradient should be checked using numeric approximation", new BoolValue(false)));
74
75      CheckGradientParameter.Hidden = true;
76    }
77
78    public ConstrainedNLS(ConstrainedNLS original, Cloner cloner) : base(original, cloner) {
79    }
80
81    [StorableHook(HookType.AfterDeserialization)]
82    public void AfterDeserializationHook() {
83      if (!Parameters.ContainsKey("CheckGradient")) {
84        Parameters.Add(new FixedValueParameter<BoolValue>("CheckGradient", "Flag to indicate whether the gradient should be checked using numeric approximation", new BoolValue(false)));
85
86        CheckGradientParameter.Hidden = true;
87      }
88    }
89
90    [StorableConstructor]
91    protected ConstrainedNLS(StorableConstructorFlag _) : base(_) {
92    }
93
94    public override bool SupportsPause => false;
95
96    public override IDeepCloneable Clone(Cloner cloner) {
97      return new ConstrainedNLS(this, cloner);
98    }
99
100    protected override void Run(CancellationToken cancellationToken) {
101      var parser = new InfixExpressionParser();
102      var tree = parser.Parse(ModelStructure);
103      var problem = (IRegressionProblem)Problem;
104
105
106      #region prepare results
107      var functionEvaluations = new IntValue(0);
108      Results.AddOrUpdateResult("Evaluations", functionEvaluations);
109      var bestError = new DoubleValue(double.MaxValue);
110      var curError = new DoubleValue(double.MaxValue);
111      Results.AddOrUpdateResult("Best error", bestError);
112      Results.AddOrUpdateResult("Current error", curError);
113      Results.AddOrUpdateResult("Tree", tree);
114      var qualitiesTable = new DataTable("Qualities");
115      var curQualityRow = new DataRow("Current Quality");
116      var bestQualityRow = new DataRow("Best Quality");
117      qualitiesTable.Rows.Add(bestQualityRow);
118      qualitiesTable.Rows.Add(curQualityRow);
119      Results.AddOrUpdateResult("Qualities", qualitiesTable);
120
121      var curConstraintValue = new DoubleValue(0);
122      Results.AddOrUpdateResult("Current Constraint Value", curConstraintValue);
123      var curConstraintIdx = new IntValue(0);
124      Results.AddOrUpdateResult("Current Constraint Index", curConstraintIdx);
125
126      var curConstraintRow = new DataRow("Constraint Value");
127      var constraintsTable = new DataTable("Constraints");
128
129      constraintsTable.Rows.Add(curConstraintRow);
130      Results.AddOrUpdateResult("Constraints", constraintsTable);
131
132      #endregion
133
134      var state = new ConstrainedNLSInternal(Solver.Value, tree, Iterations, problem.ProblemData, FuncToleranceRel, FuncToleranceAbs, MaxTime);
135      if (CheckGradient) state.CheckGradient = true;
136      int idx = 0;
137      foreach(var constraintTree in state.constraintTrees) {
138        Results.AddOrUpdateResult($"Constraint {idx++}", constraintTree);
139      }
140
141      // we use a listener model here to get state from the solver     
142
143      state.FunctionEvaluated += State_FunctionEvaluated;
144      state.ConstraintEvaluated += State_ConstraintEvaluated;
145
146      state.Optimize();
147      bestError.Value = state.BestError;
148      curQualityRow.Values.Add(state.CurError);
149      bestQualityRow.Values.Add(bestError.Value);
150
151      Results.AddOrUpdateResult("Best solution", CreateSolution((ISymbolicExpressionTree)state.BestTree.Clone(), problem.ProblemData));
152      Results.AddOrUpdateResult("Best solution constraint values", new DoubleArray(state.BestConstraintValues));
153
154
155      // local function
156      void State_FunctionEvaluated() {
157        if (cancellationToken.IsCancellationRequested) state.RequestStop();
158        functionEvaluations.Value++;
159        bestError.Value = state.BestError;
160        curError.Value = state.CurError;
161        curQualityRow.Values.Add(state.CurError);
162        bestQualityRow.Values.Add(bestError.Value);
163      }
164
165      // local function
166      void State_ConstraintEvaluated(int constraintIdx, double value) {
167        curConstraintIdx.Value = constraintIdx;
168        curConstraintValue.Value = value;
169        curConstraintRow.Values.Add(value);
170      }
171    }
172
173    private static ISymbolicRegressionSolution CreateSolution(ISymbolicExpressionTree tree, IRegressionProblemData problemData) {
174      var model = new SymbolicRegressionModel(problemData.TargetVariable, tree, new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
175      // model.Scale(problemData);
176      return model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());
177    }
178  }
179}
Note: See TracBrowser for help on using the repository browser.