Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
10/09/19 11:13:11 (5 years ago)
Author:
gkronber
Message:

#2994: worked on ConstrainedNLS

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/2994-AutoDiffForIntervals/HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Extensions/ConstrainedNLSInternal.cs

    r17311 r17325  
    66using HeuristicLab.Common;
    77using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
     8using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression.Extensions;
    89
    910namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
     
    4041    private double[] bestConstraintValues;
    4142    public double[] BestConstraintValues => bestConstraintValues;
     43
     44    private bool disposed = false;
    4245
    4346
     
    120123      }
    121124
     125      // all trees are linearly scaled (to improve GP performance)
    122126      #region linear scaling
    123127      var predStDev = pred.StandardDeviationPop();
     
    134138
    135139      // convert constants to variables named theta...
    136       var treeForDerivation = ReplaceConstWithVar(scaledTree, out List<string> thetaNames, out thetaValues); // copies the tree
     140      var treeForDerivation = ReplaceAndExtractParameters(scaledTree, out List<string> thetaNames, out thetaValues); // copies the tree
    137141
    138142      // create trees for relevant derivatives
     
    220224
    221225    ~ConstrainedNLSInternal() {
    222       Dispose();
    223     }
    224 
    225 
    226     internal void Optimize() {
     226      Dispose(false);
     227    }
     228
     229
     230    public enum OptimizationMode { ReadOnly, UpdateParameters, UpdateParametersAndKeepLinearScaling };
     231
     232    internal void Optimize(OptimizationMode mode) {
    227233      if (invalidProblem) return;
    228234      var x = thetaValues.ToArray();  /* initial guess */
     
    233239        // throw new InvalidOperationException($"NLOpt failed {res} {NLOpt.nlopt_get_errmsg(nlopt)}");
    234240        return;
    235       } else if (minf <= bestError) {
     241      } else /*if ( minf <= bestError ) */{
    236242        bestSolution = x;
    237243        bestError = minf;
     
    245251
    246252        // update parameters in tree
    247         var pIdx = 0;
    248         // here we lose the two last parameters (for linear scaling)
    249         foreach (var node in scaledTree.IterateNodesPostfix()) {
    250           if (node is ConstantTreeNode constTreeNode) {
    251             constTreeNode.Value = x[pIdx++];
    252           } else if (node is VariableTreeNode varTreeNode) {
    253             varTreeNode.Weight = x[pIdx++];
    254           }
    255         }
    256         if (pIdx != x.Length) throw new InvalidProgramException();
    257       }
    258       bestTree = scaledTree;
    259     }
     253        UpdateParametersInTree(scaledTree, x);
     254
     255        if (mode == OptimizationMode.UpdateParameters) {
     256          // update original expression (when called from evaluator we want to write back optimized parameters)
     257          expr.Root.GetSubtree(0).RemoveSubtree(0); // delete old tree
     258          expr.Root.GetSubtree(0).InsertSubtree(0,
     259            scaledTree.Root.GetSubtree(0).GetSubtree(0).GetSubtree(0).GetSubtree(0) // insert the optimized sub-tree (without scaling nodes)
     260            );
     261        } else if (mode == OptimizationMode.UpdateParametersAndKeepLinearScaling) {
     262          expr.Root.GetSubtree(0).RemoveSubtree(0); // delete old tree
     263          expr.Root.GetSubtree(0).InsertSubtree(0, scaledTree.Root.GetSubtree(0).GetSubtree(0)); // insert the optimized sub-tree (including scaling nodes)
     264        }
     265      }
     266      bestTree = expr;
     267    }
     268
    260269
    261270    double CalculateObjective(uint dim, double[] curX, double[] grad, IntPtr data) {
     
    314323      }
    315324
    316       // UpdateBestSolution(sse / target.Length, curX);
     325      UpdateBestSolution(sse / target.Length, curX);
    317326      RaiseFunctionEvaluated();
    318327
     
    427436      UpdateConstraintViolations(constraintData.Idx, interval.UpperBound);
    428437      if (double.IsNaN(interval.UpperBound)) {
    429         if(grad!=null)Array.Clear(grad, 0, grad.Length);
     438        if (grad != null) Array.Clear(grad, 0, grad.Length);
    430439        return double.MaxValue;
    431440      } else return interval.UpperBound;
     
    463472    }
    464473
    465     private static void UpdateConstants(ISymbolicExpressionTreeNode[] nodes, double[] constants) {
    466       if (nodes.Length != constants.Length) throw new InvalidOperationException();
    467       for (int i = 0; i < nodes.Length; i++) {
    468         if (nodes[i] is VariableTreeNode varNode) varNode.Weight = constants[i];
    469         else if (nodes[i] is ConstantTreeNode constNode) constNode.Value = constants[i];
    470       }
    471     }
    472474
    473475    private NLOpt.nlopt_algorithm GetSolver(string solver) {
     
    514516    }
    515517
     518    // determines the nodes over which we can calculate the partial derivative
     519    // this is different from the vector of all parameters because not every tree contains all parameters
    516520    private static ISymbolicExpressionTreeNode[] GetParameterNodes(ISymbolicExpressionTree tree, List<ConstantTreeNode>[] allNodes) {
    517521      // TODO better solution necessary
     
    553557    }
    554558
    555     private static ISymbolicExpressionTree ReplaceConstWithVar(ISymbolicExpressionTree tree, out List<string> thetaNames, out List<double> thetaValues) {
     559
     560
     561
     562    private void UpdateParametersInTree(ISymbolicExpressionTree scaledTree, double[] x) {
     563      var pIdx = 0;
     564      // here we lose the two last parameters (for linear scaling)
     565      foreach (var node in scaledTree.IterateNodesPostfix()) {
     566        if (node is ConstantTreeNode constTreeNode) {
     567          constTreeNode.Value = x[pIdx++];
     568        } else if (node is VariableTreeNode varTreeNode) {
     569          if (varTreeNode.Weight != 1.0) // see ReplaceAndExtractParameters
     570            varTreeNode.Weight = x[pIdx++];
     571        }
     572      }
     573      if (pIdx != x.Length) throw new InvalidProgramException();
     574    }
     575
     576    private static ISymbolicExpressionTree ReplaceAndExtractParameters(ISymbolicExpressionTree tree, out List<string> thetaNames, out List<double> thetaValues) {
    556577      thetaNames = new List<string>();
    557578      thetaValues = new List<double>();
     
    578599        }
    579600        if (node is VariableTreeNode varTreeNode) {
     601          if (varTreeNode.Weight == 1) continue; // NOTE: here we assume that we do not tune variable weights when they are originally exactly 1 because we assume that the tree has been parsed and the tree explicitly has the structure w * var
     602
    580603          var thetaVar = (VariableTreeNode)new Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();
    581604          thetaVar.Weight = 1;
     
    626649
    627650    public void Dispose() {
     651      Dispose(true);
     652      GC.SuppressFinalize(this);
     653    }
     654
     655    protected virtual void Dispose(bool disposing) {
     656      if (disposed)
     657        return;
     658
     659      if (disposing) {
     660        // Free any other managed objects here.
     661      }
     662
     663      // Free any unmanaged objects here.
    628664      if (nlopt != IntPtr.Zero) {
    629665        NLOpt.nlopt_destroy(nlopt);
     
    637673          }
    638674      }
     675
     676      disposed = true;
    639677    }
    640678    #endregion
Note: See TracChangeset for help on using the changeset viewer.