Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
10/09/19 11:13:11 (5 years ago)
Author:
gkronber
Message:

#2994: worked on ConstrainedNLS

Location:
branches/2994-AutoDiffForIntervals/HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Extensions
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • branches/2994-AutoDiffForIntervals/HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Extensions/ConstrainedNLS.cs

    r17311 r17325  
    1010using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    1111using HeuristicLab.Analysis;
     12using System.Collections.Generic;
    1213
    1314namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
     
    157158      Results.AddOrUpdateResult("Qualities", qualitiesTable);
    158159
    159       var curConstraintValue = new DoubleValue(0);
    160       Results.AddOrUpdateResult("Current Constraint Value", curConstraintValue);
    161       var curConstraintIdx = new IntValue(0);
    162       Results.AddOrUpdateResult("Current Constraint Index", curConstraintIdx);
    163 
    164       var curConstraintRow = new DataRow("Constraint Value");
    165       var constraintsTable = new DataTable("Constraints");
    166 
    167       constraintsTable.Rows.Add(curConstraintRow);
     160      var constraintRows = new List<IndexedDataRow<int>>(); // for access via index
     161      var constraintsTable = new IndexedDataTable<int>("Constraints");
    168162      Results.AddOrUpdateResult("Constraints", constraintsTable);
     163      foreach (var constraint in problem.ProblemData.IntervalConstraints.Constraints.Where(c => c.Enabled)) {
     164        if (constraint.Interval.LowerBound > double.NegativeInfinity) {
     165          var constraintRow = new IndexedDataRow<int>("-" + constraint.Expression + " < " + (-constraint.Interval.LowerBound));
     166          constraintRows.Add(constraintRow);
     167          constraintsTable.Rows.Add(constraintRow);
     168        }
     169        if (constraint.Interval.UpperBound < double.PositiveInfinity) {
     170          var constraintRow = new IndexedDataRow<int>(constraint.Expression + " < " + (constraint.Interval.UpperBound));
     171          constraintRows.Add(constraintRow);
     172          constraintsTable.Rows.Add(constraintRow);
     173        }
     174      }
     175
     176      var parametersTable = new IndexedDataTable<int>("Parameters");
    169177
    170178      #endregion
     
    175183      var formatter = new InfixExpressionFormatter();
    176184      var constraintDescriptions = state.ConstraintDescriptions.ToArray();
    177       foreach(var constraintTree in state.constraintTrees) {
     185      foreach (var constraintTree in state.constraintTrees) {
    178186        // HACK to remove parameter nodes which occurr multiple times
    179187        var reparsedTree = parser.Parse(formatter.Format(constraintTree));
     
    186194      state.ConstraintEvaluated += State_ConstraintEvaluated;
    187195
    188       state.Optimize();
     196      state.Optimize(ConstrainedNLSInternal.OptimizationMode.UpdateParametersAndKeepLinearScaling);
    189197      bestError.Value = state.BestError;
    190198      curQualityRow.Values.Add(state.CurError);
     
    205213        curQualityRow.Values.Add(state.CurError);
    206214        bestQualityRow.Values.Add(bestError.Value);
     215
     216        // on the first call create the data rows
     217        if(!parametersTable.Rows.Any()) {
     218          for(int i=0;i<state.BestSolution.Length;i++) {
     219            parametersTable.Rows.Add(new IndexedDataRow<int>("p" + i));
     220          }
     221        }
     222        for (int i = 0; i < state.BestSolution.Length; i++) {
     223          parametersTable.Rows["p" + i].Values.Add(Tuple.Create(functionEvaluations.Value, state.BestSolution[i])); // TODO: remove access via string
     224        }
    207225      }
    208226
    209227      // local function
    210228      void State_ConstraintEvaluated(int constraintIdx, double value) {
    211         curConstraintIdx.Value = constraintIdx;
    212         curConstraintValue.Value = value;
    213         curConstraintRow.Values.Add(value);
     229        constraintRows[constraintIdx].Values.Add(Tuple.Create(functionEvaluations.Value, value));
    214230      }
    215231    }
     
    217233    private static ISymbolicRegressionSolution CreateSolution(ISymbolicExpressionTree tree, IRegressionProblemData problemData) {
    218234      var model = new SymbolicRegressionModel(problemData.TargetVariable, tree, new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
    219       // model.Scale(problemData);
     235      // model.CreateRegressionSolution produces a new ProblemData and recalculates intervals ==> use SymbolicRegressionSolution.ctor instead
     236      var sol = new SymbolicRegressionSolution(model, (IRegressionProblemData)problemData.Clone());
     237      // NOTE: the solution has slightly different derivative values because simplification of derivatives can be done differently when parameter values are fixed.
     238
    220239      // var sol = model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());
    221       // model.CreateRegressionSolution produces a new ProblemData and recalculates intervals
    222 
    223       var sol = new SymbolicRegressionSolution(model, (IRegressionProblemData)problemData.Clone());
    224 
    225       // ==> set variable ranges to same range as in original problemData
    226       // foreach(var interval in problemData.VariableRanges.GetIntervals()) {
    227       //   sol.ProblemData.VariableRanges.SetInterval
    228       // }
     240
    229241      return sol;
    230242    }
  • branches/2994-AutoDiffForIntervals/HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Extensions/ConstrainedNLSInternal.cs

    r17311 r17325  
    66using HeuristicLab.Common;
    77using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
     8using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression.Extensions;
    89
    910namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
     
    4041    private double[] bestConstraintValues;
    4142    public double[] BestConstraintValues => bestConstraintValues;
     43
     44    private bool disposed = false;
    4245
    4346
     
    120123      }
    121124
     125      // all trees are linearly scaled (to improve GP performance)
    122126      #region linear scaling
    123127      var predStDev = pred.StandardDeviationPop();
     
    134138
    135139      // convert constants to variables named theta...
    136       var treeForDerivation = ReplaceConstWithVar(scaledTree, out List<string> thetaNames, out thetaValues); // copies the tree
     140      var treeForDerivation = ReplaceAndExtractParameters(scaledTree, out List<string> thetaNames, out thetaValues); // copies the tree
    137141
    138142      // create trees for relevant derivatives
     
    220224
    221225    ~ConstrainedNLSInternal() {
    222       Dispose();
    223     }
    224 
    225 
    226     internal void Optimize() {
     226      Dispose(false);
     227    }
     228
     229
     230    public enum OptimizationMode { ReadOnly, UpdateParameters, UpdateParametersAndKeepLinearScaling };
     231
     232    internal void Optimize(OptimizationMode mode) {
    227233      if (invalidProblem) return;
    228234      var x = thetaValues.ToArray();  /* initial guess */
     
    233239        // throw new InvalidOperationException($"NLOpt failed {res} {NLOpt.nlopt_get_errmsg(nlopt)}");
    234240        return;
    235       } else if (minf <= bestError) {
     241      } else /*if ( minf <= bestError ) */{
    236242        bestSolution = x;
    237243        bestError = minf;
     
    245251
    246252        // update parameters in tree
    247         var pIdx = 0;
    248         // here we lose the two last parameters (for linear scaling)
    249         foreach (var node in scaledTree.IterateNodesPostfix()) {
    250           if (node is ConstantTreeNode constTreeNode) {
    251             constTreeNode.Value = x[pIdx++];
    252           } else if (node is VariableTreeNode varTreeNode) {
    253             varTreeNode.Weight = x[pIdx++];
    254           }
    255         }
    256         if (pIdx != x.Length) throw new InvalidProgramException();
    257       }
    258       bestTree = scaledTree;
    259     }
     253        UpdateParametersInTree(scaledTree, x);
     254
     255        if (mode == OptimizationMode.UpdateParameters) {
     256          // update original expression (when called from evaluator we want to write back optimized parameters)
     257          expr.Root.GetSubtree(0).RemoveSubtree(0); // delete old tree
     258          expr.Root.GetSubtree(0).InsertSubtree(0,
     259            scaledTree.Root.GetSubtree(0).GetSubtree(0).GetSubtree(0).GetSubtree(0) // insert the optimized sub-tree (without scaling nodes)
     260            );
     261        } else if (mode == OptimizationMode.UpdateParametersAndKeepLinearScaling) {
     262          expr.Root.GetSubtree(0).RemoveSubtree(0); // delete old tree
     263          expr.Root.GetSubtree(0).InsertSubtree(0, scaledTree.Root.GetSubtree(0).GetSubtree(0)); // insert the optimized sub-tree (including scaling nodes)
     264        }
     265      }
     266      bestTree = expr;
     267    }
     268
    260269
    261270    double CalculateObjective(uint dim, double[] curX, double[] grad, IntPtr data) {
     
    314323      }
    315324
    316       // UpdateBestSolution(sse / target.Length, curX);
     325      UpdateBestSolution(sse / target.Length, curX);
    317326      RaiseFunctionEvaluated();
    318327
     
    427436      UpdateConstraintViolations(constraintData.Idx, interval.UpperBound);
    428437      if (double.IsNaN(interval.UpperBound)) {
    429         if(grad!=null)Array.Clear(grad, 0, grad.Length);
     438        if (grad != null) Array.Clear(grad, 0, grad.Length);
    430439        return double.MaxValue;
    431440      } else return interval.UpperBound;
     
    463472    }
    464473
    465     private static void UpdateConstants(ISymbolicExpressionTreeNode[] nodes, double[] constants) {
    466       if (nodes.Length != constants.Length) throw new InvalidOperationException();
    467       for (int i = 0; i < nodes.Length; i++) {
    468         if (nodes[i] is VariableTreeNode varNode) varNode.Weight = constants[i];
    469         else if (nodes[i] is ConstantTreeNode constNode) constNode.Value = constants[i];
    470       }
    471     }
    472474
    473475    private NLOpt.nlopt_algorithm GetSolver(string solver) {
     
    514516    }
    515517
     518    // determines the nodes over which we can calculate the partial derivative
     519    // this is different from the vector of all parameters because not every tree contains all parameters
    516520    private static ISymbolicExpressionTreeNode[] GetParameterNodes(ISymbolicExpressionTree tree, List<ConstantTreeNode>[] allNodes) {
    517521      // TODO better solution necessary
     
    553557    }
    554558
    555     private static ISymbolicExpressionTree ReplaceConstWithVar(ISymbolicExpressionTree tree, out List<string> thetaNames, out List<double> thetaValues) {
     559
     560
     561
     562    private void UpdateParametersInTree(ISymbolicExpressionTree scaledTree, double[] x) {
     563      var pIdx = 0;
     564      // here we lose the two last parameters (for linear scaling)
     565      foreach (var node in scaledTree.IterateNodesPostfix()) {
     566        if (node is ConstantTreeNode constTreeNode) {
     567          constTreeNode.Value = x[pIdx++];
     568        } else if (node is VariableTreeNode varTreeNode) {
     569          if (varTreeNode.Weight != 1.0) // see ReplaceAndExtractParameters
     570            varTreeNode.Weight = x[pIdx++];
     571        }
     572      }
     573      if (pIdx != x.Length) throw new InvalidProgramException();
     574    }
     575
     576    private static ISymbolicExpressionTree ReplaceAndExtractParameters(ISymbolicExpressionTree tree, out List<string> thetaNames, out List<double> thetaValues) {
    556577      thetaNames = new List<string>();
    557578      thetaValues = new List<double>();
     
    578599        }
    579600        if (node is VariableTreeNode varTreeNode) {
     601          if (varTreeNode.Weight == 1) continue; // NOTE: here we assume that we do not tune variable weights when they are originally exactly 1 because we assume that the tree has been parsed and the tree explicitly has the structure w * var
     602
    580603          var thetaVar = (VariableTreeNode)new Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();
    581604          thetaVar.Weight = 1;
     
    626649
    627650    public void Dispose() {
     651      Dispose(true);
     652      GC.SuppressFinalize(this);
     653    }
     654
     655    protected virtual void Dispose(bool disposing) {
     656      if (disposed)
     657        return;
     658
     659      if (disposing) {
     660        // Free any other managed objects here.
     661      }
     662
     663      // Free any unmanaged objects here.
    628664      if (nlopt != IntPtr.Zero) {
    629665        NLOpt.nlopt_destroy(nlopt);
     
    637673          }
    638674      }
     675
     676      disposed = true;
    639677    }
    640678    #endregion
  • branches/2994-AutoDiffForIntervals/HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Extensions/NLOptEvaluator.cs

    r17215 r17325  
    199199    }
    200200
    201     private static void GetParameterNodes(ISymbolicExpressionTree tree, out List<ISymbolicExpressionTreeNode> thetaNodes, out List<double> thetaValues) {
    202       thetaNodes = new List<ISymbolicExpressionTreeNode>();
    203       thetaValues = new List<double>();
    204 
    205       var nodes = tree.IterateNodesPrefix().ToArray();
    206       for (int i = 0; i < nodes.Length; ++i) {
    207         var node = nodes[i];
    208         if (node is VariableTreeNode variableTreeNode) {
    209           thetaValues.Add(variableTreeNode.Weight);
    210           thetaNodes.Add(node);
    211         } else if (node is ConstantTreeNode constantTreeNode) {
    212           thetaNodes.Add(node);
    213           thetaValues.Add(constantTreeNode.Value);
    214         }
    215       }
    216     }
    217 
    218201
    219202    public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
     
    230213
    231214      using (var state = new ConstrainedNLSInternal(solver, tree, maxIterations, problemData, 0, 0, 0)) {
    232         state.Optimize();
     215        state.Optimize(ConstrainedNLSInternal.OptimizationMode.UpdateParameters);
    233216        return state.BestError;
    234217      }
Note: See TracChangeset for help on using the changeset viewer.