[17197] | 1 | using System;
|
---|
| 2 | using HeuristicLab.Optimization;
|
---|
| 3 | using HEAL.Attic;
|
---|
| 4 | using HeuristicLab.Common;
|
---|
| 5 | using System.Threading;
|
---|
| 6 | using HeuristicLab.Core;
|
---|
| 7 | using HeuristicLab.Data;
|
---|
| 8 | using HeuristicLab.Parameters;
|
---|
| 9 | using System.Linq;
|
---|
| 10 | using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
|
---|
| 11 | using HeuristicLab.Analysis;
|
---|
| 12 |
|
---|
| 13 | namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression.Extensions {
|
---|
| 14 | [StorableType("676B237C-DD9C-4F24-B64F-D44B0FA1F6A6")]
|
---|
| 15 | [Creatable(CreatableAttribute.Categories.DataAnalysisRegression, Priority = 120)]
|
---|
| 16 | [Item(Name = "ConstrainedNLS", Description = "Non-linear Regression with non-linear constraints")]
|
---|
| 17 | public class ConstrainedNLS : BasicAlgorithm {
|
---|
| 18 | public static readonly string IterationsParameterName = "Iterations";
|
---|
| 19 | public static readonly string SolverParameterName = "Solver";
|
---|
| 20 | public static readonly string ModelStructureParameterName = "Model structure";
|
---|
| 21 |
|
---|
| 22 |
|
---|
| 23 | public IFixedValueParameter<IntValue> IterationsParameter {
|
---|
| 24 | get { return (IFixedValueParameter<IntValue>)Parameters[IterationsParameterName]; }
|
---|
| 25 | }
|
---|
| 26 | public IFixedValueParameter<StringValue> ModelStructureParameter {
|
---|
| 27 | get { return (IFixedValueParameter<StringValue>)Parameters[ModelStructureParameterName]; }
|
---|
| 28 | }
|
---|
| 29 | public IConstrainedValueParameter<StringValue> SolverParameter { get { return (IConstrainedValueParameter<StringValue>)Parameters[SolverParameterName]; } }
|
---|
| 30 |
|
---|
| 31 | public IFixedValueParameter<DoubleValue> FuncToleranceRelParameter {
|
---|
| 32 | get { return (IFixedValueParameter<DoubleValue>)Parameters["FuncToleranceRel"]; }
|
---|
| 33 | }
|
---|
| 34 | public IFixedValueParameter<DoubleValue> FuncToleranceAbsParameter {
|
---|
| 35 | get { return (IFixedValueParameter<DoubleValue>)Parameters["FuncToleranceAbs"]; }
|
---|
| 36 | }
|
---|
| 37 | public IFixedValueParameter<DoubleValue> MaxTimeParameter {
|
---|
| 38 | get { return (IFixedValueParameter<DoubleValue>)Parameters["MaxTime"]; }
|
---|
| 39 | }
|
---|
| 40 | public int Iterations { get { return IterationsParameter.Value.Value; } set { IterationsParameter.Value.Value = value; } }
|
---|
| 41 |
|
---|
| 42 | public StringValue Solver {
|
---|
| 43 | get { return SolverParameter.Value; }
|
---|
| 44 | set { throw new NotImplementedException(); }
|
---|
| 45 | }
|
---|
| 46 |
|
---|
| 47 | public string ModelStructure {
|
---|
| 48 | get { return ModelStructureParameter.Value.Value; }
|
---|
| 49 | set { ModelStructureParameter.Value.Value = value; }
|
---|
| 50 | }
|
---|
| 51 |
|
---|
| 52 | public double FuncToleranceRel { get { return FuncToleranceRelParameter.Value.Value; } set { FuncToleranceRelParameter.Value.Value = value; } }
|
---|
| 53 | public double FuncToleranceAbs { get { return FuncToleranceAbsParameter.Value.Value; } set { FuncToleranceAbsParameter.Value.Value = value; } }
|
---|
| 54 | public double MaxTime { get { return MaxTimeParameter.Value.Value; } set { MaxTimeParameter.Value.Value = value; } }
|
---|
| 55 |
|
---|
| 56 | public ConstrainedNLS() : base() {
|
---|
| 57 | Problem = new RegressionProblem();
|
---|
| 58 |
|
---|
| 59 | Parameters.Add(new FixedValueParameter<StringValue>(ModelStructureParameterName, "The function for which the parameters must be fit (only numeric constants are tuned).", new StringValue("1.0 * x*x + 0.0")));
|
---|
| 60 | Parameters.Add(new FixedValueParameter<IntValue>(IterationsParameterName, "Determines how many iterations should be calculated while optimizing the constant of a symbolic expression tree (0 indicates other or default stopping criterion).", new IntValue(10)));
|
---|
| 61 | var validSolvers = new ItemSet<StringValue>(new[] { "MMA", "COBYLA", "CCSAQ", "ISRES" }.Select(s => new StringValue(s).AsReadOnly()));
|
---|
| 62 | Parameters.Add(new ConstrainedValueParameter<StringValue>(SolverParameterName, "The solver algorithm", validSolvers, validSolvers.First()));
|
---|
| 63 | Parameters.Add(new FixedValueParameter<DoubleValue>("FuncToleranceRel", new DoubleValue(0)));
|
---|
| 64 | Parameters.Add(new FixedValueParameter<DoubleValue>("FuncToleranceAbs", new DoubleValue(0)));
|
---|
| 65 | Parameters.Add(new FixedValueParameter<DoubleValue>("MaxTime", new DoubleValue(10)));
|
---|
| 66 | }
|
---|
| 67 |
|
---|
| 68 | public ConstrainedNLS(ConstrainedNLS original, Cloner cloner) : base(original, cloner) {
|
---|
| 69 | }
|
---|
| 70 |
|
---|
| 71 | [StorableConstructor]
|
---|
| 72 | protected ConstrainedNLS(StorableConstructorFlag _) : base(_) {
|
---|
| 73 | }
|
---|
| 74 |
|
---|
| 75 | public override bool SupportsPause => false;
|
---|
| 76 |
|
---|
| 77 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
| 78 | return new ConstrainedNLS(this, cloner);
|
---|
| 79 | }
|
---|
| 80 |
|
---|
| 81 | protected override void Run(CancellationToken cancellationToken) {
|
---|
| 82 | var parser = new InfixExpressionParser();
|
---|
| 83 | var tree = parser.Parse(ModelStructure);
|
---|
| 84 | var problem = (IRegressionProblem)Problem;
|
---|
| 85 |
|
---|
| 86 |
|
---|
| 87 | #region prepare results
|
---|
| 88 | var functionEvaluations = new IntValue(0);
|
---|
| 89 | Results.AddOrUpdateResult("Evaluations", functionEvaluations);
|
---|
| 90 | var bestError = new DoubleValue(double.MaxValue);
|
---|
| 91 | Results.AddOrUpdateResult("Best error", bestError);
|
---|
| 92 | Results.AddOrUpdateResult("Tree", tree);
|
---|
| 93 | var qualitiesTable = new DataTable("Qualities");
|
---|
| 94 | var curQualityRow = new DataRow("Current Quality");
|
---|
| 95 | var bestQualityRow = new DataRow("Best Quality");
|
---|
| 96 | qualitiesTable.Rows.Add(bestQualityRow);
|
---|
| 97 | qualitiesTable.Rows.Add(curQualityRow);
|
---|
| 98 | Results.AddOrUpdateResult("Qualities", qualitiesTable);
|
---|
| 99 |
|
---|
| 100 | var curConstraintValue = new DoubleValue(0);
|
---|
| 101 | Results.AddOrUpdateResult("Current Constraint Value", curConstraintValue);
|
---|
| 102 | var curConstraintIdx = new IntValue(0);
|
---|
| 103 | Results.AddOrUpdateResult("Current Constraint Index", curConstraintIdx);
|
---|
| 104 |
|
---|
| 105 | var curConstraintRow = new DataRow("Constraint Value");
|
---|
| 106 | var constraintsTable = new DataTable("Constraints");
|
---|
| 107 |
|
---|
| 108 | constraintsTable.Rows.Add(curConstraintRow);
|
---|
| 109 | Results.AddOrUpdateResult("Constraints", constraintsTable);
|
---|
| 110 |
|
---|
| 111 | #endregion
|
---|
| 112 |
|
---|
| 113 | var state = new ConstrainedNLSInternal(Solver.Value, tree, Iterations, problem.ProblemData, FuncToleranceRel, FuncToleranceAbs, MaxTime);
|
---|
| 114 |
|
---|
| 115 | // we use a listener model here to get state from the solver
|
---|
| 116 |
|
---|
| 117 | state.FunctionEvaluated += State_FunctionEvaluated;
|
---|
| 118 | state.ConstraintEvaluated += State_ConstraintEvaluated;
|
---|
| 119 |
|
---|
| 120 | state.Optimize();
|
---|
| 121 | bestError.Value = state.BestError;
|
---|
| 122 | curQualityRow.Values.Add(state.CurError);
|
---|
| 123 | bestQualityRow.Values.Add(bestError.Value);
|
---|
| 124 |
|
---|
| 125 |
|
---|
| 126 | Results.AddOrUpdateResult("Best solution", CreateSolution(state.BestTree, problem.ProblemData));
|
---|
| 127 |
|
---|
| 128 |
|
---|
| 129 | // local function
|
---|
| 130 | void State_FunctionEvaluated() {
|
---|
| 131 | if (cancellationToken.IsCancellationRequested) state.RequestStop();
|
---|
| 132 | functionEvaluations.Value++;
|
---|
| 133 | bestError.Value = state.BestError;
|
---|
| 134 | curQualityRow.Values.Add(state.CurError);
|
---|
| 135 | bestQualityRow.Values.Add(bestError.Value);
|
---|
| 136 | }
|
---|
| 137 |
|
---|
| 138 | // local function
|
---|
| 139 | void State_ConstraintEvaluated(int constraintIdx, double value) {
|
---|
| 140 | curConstraintIdx.Value = constraintIdx;
|
---|
| 141 | curConstraintValue.Value = value;
|
---|
| 142 | curConstraintRow.Values.Add(value);
|
---|
| 143 | }
|
---|
| 144 | }
|
---|
| 145 |
|
---|
| 146 | private static ISymbolicRegressionSolution CreateSolution(ISymbolicExpressionTree tree, IRegressionProblemData problemData) {
|
---|
| 147 | var model = new SymbolicRegressionModel(problemData.TargetVariable, tree, new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
|
---|
| 148 | // model.Scale(problemData);
|
---|
| 149 | return model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());
|
---|
| 150 | }
|
---|
| 151 | }
|
---|
| 152 | }
|
---|