Changeset 17325


Ignore:
Timestamp:
10/09/19 11:13:11 (6 days ago)
Author:
gkronber
Message:

#2994: worked on ConstrainedNLS

Location:
branches/2994-AutoDiffForIntervals
Files:
4 edited

Legend:

Unmodified
Added
Removed
  • branches/2994-AutoDiffForIntervals/HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Extensions/ConstrainedNLS.cs

    r17311 r17325  
    1010using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    1111using HeuristicLab.Analysis;
     12using System.Collections.Generic;
    1213
    1314namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
     
    157158      Results.AddOrUpdateResult("Qualities", qualitiesTable);
    158159
    159       var curConstraintValue = new DoubleValue(0);
    160       Results.AddOrUpdateResult("Current Constraint Value", curConstraintValue);
    161       var curConstraintIdx = new IntValue(0);
    162       Results.AddOrUpdateResult("Current Constraint Index", curConstraintIdx);
    163 
    164       var curConstraintRow = new DataRow("Constraint Value");
    165       var constraintsTable = new DataTable("Constraints");
    166 
    167       constraintsTable.Rows.Add(curConstraintRow);
     160      var constraintRows = new List<IndexedDataRow<int>>(); // for access via index
     161      var constraintsTable = new IndexedDataTable<int>("Constraints");
    168162      Results.AddOrUpdateResult("Constraints", constraintsTable);
     163      foreach (var constraint in problem.ProblemData.IntervalConstraints.Constraints.Where(c => c.Enabled)) {
     164        if (constraint.Interval.LowerBound > double.NegativeInfinity) {
     165          var constraintRow = new IndexedDataRow<int>("-" + constraint.Expression + " < " + (-constraint.Interval.LowerBound));
     166          constraintRows.Add(constraintRow);
     167          constraintsTable.Rows.Add(constraintRow);
     168        }
     169        if (constraint.Interval.UpperBound < double.PositiveInfinity) {
     170          var constraintRow = new IndexedDataRow<int>(constraint.Expression + " < " + (constraint.Interval.UpperBound));
     171          constraintRows.Add(constraintRow);
     172          constraintsTable.Rows.Add(constraintRow);
     173        }
     174      }
     175
     176      var parametersTable = new IndexedDataTable<int>("Parameters");
    169177
    170178      #endregion
     
    175183      var formatter = new InfixExpressionFormatter();
    176184      var constraintDescriptions = state.ConstraintDescriptions.ToArray();
    177       foreach(var constraintTree in state.constraintTrees) {
     185      foreach (var constraintTree in state.constraintTrees) {
    178186        // HACK to remove parameter nodes which occurr multiple times
    179187        var reparsedTree = parser.Parse(formatter.Format(constraintTree));
     
    186194      state.ConstraintEvaluated += State_ConstraintEvaluated;
    187195
    188       state.Optimize();
     196      state.Optimize(ConstrainedNLSInternal.OptimizationMode.UpdateParametersAndKeepLinearScaling);
    189197      bestError.Value = state.BestError;
    190198      curQualityRow.Values.Add(state.CurError);
     
    205213        curQualityRow.Values.Add(state.CurError);
    206214        bestQualityRow.Values.Add(bestError.Value);
     215
     216        // on the first call create the data rows
     217        if(!parametersTable.Rows.Any()) {
     218          for(int i=0;i<state.BestSolution.Length;i++) {
     219            parametersTable.Rows.Add(new IndexedDataRow<int>("p" + i));
     220          }
     221        }
     222        for (int i = 0; i < state.BestSolution.Length; i++) {
     223          parametersTable.Rows["p" + i].Values.Add(Tuple.Create(functionEvaluations.Value, state.BestSolution[i])); // TODO: remove access via string
     224        }
    207225      }
    208226
    209227      // local function
    210228      void State_ConstraintEvaluated(int constraintIdx, double value) {
    211         curConstraintIdx.Value = constraintIdx;
    212         curConstraintValue.Value = value;
    213         curConstraintRow.Values.Add(value);
     229        constraintRows[constraintIdx].Values.Add(Tuple.Create(functionEvaluations.Value, value));
    214230      }
    215231    }
     
    217233    private static ISymbolicRegressionSolution CreateSolution(ISymbolicExpressionTree tree, IRegressionProblemData problemData) {
    218234      var model = new SymbolicRegressionModel(problemData.TargetVariable, tree, new SymbolicDataAnalysisExpressionTreeLinearInterpreter());
    219       // model.Scale(problemData);
     235      // model.CreateRegressionSolution produces a new ProblemData and recalculates intervals ==> use SymbolicRegressionSolution.ctor instead
     236      var sol = new SymbolicRegressionSolution(model, (IRegressionProblemData)problemData.Clone());
     237      // NOTE: the solution has slightly different derivative values because simplification of derivatives can be done differently when parameter values are fixed.
     238
    220239      // var sol = model.CreateRegressionSolution((IRegressionProblemData)problemData.Clone());
    221       // model.CreateRegressionSolution produces a new ProblemData and recalculates intervals
    222 
    223       var sol = new SymbolicRegressionSolution(model, (IRegressionProblemData)problemData.Clone());
    224 
    225       // ==> set variable ranges to same range as in original problemData
    226       // foreach(var interval in problemData.VariableRanges.GetIntervals()) {
    227       //   sol.ProblemData.VariableRanges.SetInterval
    228       // }
     240
    229241      return sol;
    230242    }
  • branches/2994-AutoDiffForIntervals/HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Extensions/ConstrainedNLSInternal.cs

    r17311 r17325  
    66using HeuristicLab.Common;
    77using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
     8using HeuristicLab.Problems.DataAnalysis.Symbolic.Regression.Extensions;
    89
    910namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression {
     
    4041    private double[] bestConstraintValues;
    4142    public double[] BestConstraintValues => bestConstraintValues;
     43
     44    private bool disposed = false;
    4245
    4346
     
    120123      }
    121124
     125      // all trees are linearly scaled (to improve GP performance)
    122126      #region linear scaling
    123127      var predStDev = pred.StandardDeviationPop();
     
    134138
    135139      // convert constants to variables named theta...
    136       var treeForDerivation = ReplaceConstWithVar(scaledTree, out List<string> thetaNames, out thetaValues); // copies the tree
     140      var treeForDerivation = ReplaceAndExtractParameters(scaledTree, out List<string> thetaNames, out thetaValues); // copies the tree
    137141
    138142      // create trees for relevant derivatives
     
    220224
    221225    ~ConstrainedNLSInternal() {
    222       Dispose();
    223     }
    224 
    225 
    226     internal void Optimize() {
     226      Dispose(false);
     227    }
     228
     229
     230    public enum OptimizationMode { ReadOnly, UpdateParameters, UpdateParametersAndKeepLinearScaling };
     231
     232    internal void Optimize(OptimizationMode mode) {
    227233      if (invalidProblem) return;
    228234      var x = thetaValues.ToArray();  /* initial guess */
     
    233239        // throw new InvalidOperationException($"NLOpt failed {res} {NLOpt.nlopt_get_errmsg(nlopt)}");
    234240        return;
    235       } else if (minf <= bestError) {
     241      } else /*if ( minf <= bestError ) */{
    236242        bestSolution = x;
    237243        bestError = minf;
     
    245251
    246252        // update parameters in tree
    247         var pIdx = 0;
    248         // here we lose the two last parameters (for linear scaling)
    249         foreach (var node in scaledTree.IterateNodesPostfix()) {
    250           if (node is ConstantTreeNode constTreeNode) {
    251             constTreeNode.Value = x[pIdx++];
    252           } else if (node is VariableTreeNode varTreeNode) {
    253             varTreeNode.Weight = x[pIdx++];
    254           }
    255         }
    256         if (pIdx != x.Length) throw new InvalidProgramException();
    257       }
    258       bestTree = scaledTree;
    259     }
     253        UpdateParametersInTree(scaledTree, x);
     254
     255        if (mode == OptimizationMode.UpdateParameters) {
     256          // update original expression (when called from evaluator we want to write back optimized parameters)
     257          expr.Root.GetSubtree(0).RemoveSubtree(0); // delete old tree
     258          expr.Root.GetSubtree(0).InsertSubtree(0,
     259            scaledTree.Root.GetSubtree(0).GetSubtree(0).GetSubtree(0).GetSubtree(0) // insert the optimized sub-tree (without scaling nodes)
     260            );
     261        } else if (mode == OptimizationMode.UpdateParametersAndKeepLinearScaling) {
     262          expr.Root.GetSubtree(0).RemoveSubtree(0); // delete old tree
     263          expr.Root.GetSubtree(0).InsertSubtree(0, scaledTree.Root.GetSubtree(0).GetSubtree(0)); // insert the optimized sub-tree (including scaling nodes)
     264        }
     265      }
     266      bestTree = expr;
     267    }
     268
    260269
    261270    double CalculateObjective(uint dim, double[] curX, double[] grad, IntPtr data) {
     
    314323      }
    315324
    316       // UpdateBestSolution(sse / target.Length, curX);
     325      UpdateBestSolution(sse / target.Length, curX);
    317326      RaiseFunctionEvaluated();
    318327
     
    427436      UpdateConstraintViolations(constraintData.Idx, interval.UpperBound);
    428437      if (double.IsNaN(interval.UpperBound)) {
    429         if(grad!=null)Array.Clear(grad, 0, grad.Length);
     438        if (grad != null) Array.Clear(grad, 0, grad.Length);
    430439        return double.MaxValue;
    431440      } else return interval.UpperBound;
     
    463472    }
    464473
    465     private static void UpdateConstants(ISymbolicExpressionTreeNode[] nodes, double[] constants) {
    466       if (nodes.Length != constants.Length) throw new InvalidOperationException();
    467       for (int i = 0; i < nodes.Length; i++) {
    468         if (nodes[i] is VariableTreeNode varNode) varNode.Weight = constants[i];
    469         else if (nodes[i] is ConstantTreeNode constNode) constNode.Value = constants[i];
    470       }
    471     }
    472474
    473475    private NLOpt.nlopt_algorithm GetSolver(string solver) {
     
    514516    }
    515517
     518    // determines the nodes over which we can calculate the partial derivative
     519    // this is different from the vector of all parameters because not every tree contains all parameters
    516520    private static ISymbolicExpressionTreeNode[] GetParameterNodes(ISymbolicExpressionTree tree, List<ConstantTreeNode>[] allNodes) {
    517521      // TODO better solution necessary
     
    553557    }
    554558
    555     private static ISymbolicExpressionTree ReplaceConstWithVar(ISymbolicExpressionTree tree, out List<string> thetaNames, out List<double> thetaValues) {
     559
     560
     561
     562    private void UpdateParametersInTree(ISymbolicExpressionTree scaledTree, double[] x) {
     563      var pIdx = 0;
     564      // here we lose the two last parameters (for linear scaling)
     565      foreach (var node in scaledTree.IterateNodesPostfix()) {
     566        if (node is ConstantTreeNode constTreeNode) {
     567          constTreeNode.Value = x[pIdx++];
     568        } else if (node is VariableTreeNode varTreeNode) {
     569          if (varTreeNode.Weight != 1.0) // see ReplaceAndExtractParameters
     570            varTreeNode.Weight = x[pIdx++];
     571        }
     572      }
     573      if (pIdx != x.Length) throw new InvalidProgramException();
     574    }
     575
     576    private static ISymbolicExpressionTree ReplaceAndExtractParameters(ISymbolicExpressionTree tree, out List<string> thetaNames, out List<double> thetaValues) {
    556577      thetaNames = new List<string>();
    557578      thetaValues = new List<double>();
     
    578599        }
    579600        if (node is VariableTreeNode varTreeNode) {
     601          if (varTreeNode.Weight == 1) continue; // NOTE: here we assume that we do not tune variable weights when they are originally exactly 1 because we assume that the tree has been parsed and the tree explicitly has the structure w * var
     602
    580603          var thetaVar = (VariableTreeNode)new Problems.DataAnalysis.Symbolic.Variable().CreateTreeNode();
    581604          thetaVar.Weight = 1;
     
    626649
    627650    public void Dispose() {
     651      Dispose(true);
     652      GC.SuppressFinalize(this);
     653    }
     654
     655    protected virtual void Dispose(bool disposing) {
     656      if (disposed)
     657        return;
     658
     659      if (disposing) {
     660        // Free any other managed objects here.
     661      }
     662
     663      // Free any unmanaged objects here.
    628664      if (nlopt != IntPtr.Zero) {
    629665        NLOpt.nlopt_destroy(nlopt);
     
    637673          }
    638674      }
     675
     676      disposed = true;
    639677    }
    640678    #endregion
  • branches/2994-AutoDiffForIntervals/HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Extensions/NLOptEvaluator.cs

    r17215 r17325  
    199199    }
    200200
    201     private static void GetParameterNodes(ISymbolicExpressionTree tree, out List<ISymbolicExpressionTreeNode> thetaNodes, out List<double> thetaValues) {
    202       thetaNodes = new List<ISymbolicExpressionTreeNode>();
    203       thetaValues = new List<double>();
    204 
    205       var nodes = tree.IterateNodesPrefix().ToArray();
    206       for (int i = 0; i < nodes.Length; ++i) {
    207         var node = nodes[i];
    208         if (node is VariableTreeNode variableTreeNode) {
    209           thetaValues.Add(variableTreeNode.Weight);
    210           thetaNodes.Add(node);
    211         } else if (node is ConstantTreeNode constantTreeNode) {
    212           thetaNodes.Add(node);
    213           thetaValues.Add(constantTreeNode.Value);
    214         }
    215       }
    216     }
    217 
    218201
    219202    public static double OptimizeConstants(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
     
    230213
    231214      using (var state = new ConstrainedNLSInternal(solver, tree, maxIterations, problemData, 0, 0, 0)) {
    232         state.Optimize();
     215        state.Optimize(ConstrainedNLSInternal.OptimizationMode.UpdateParameters);
    233216        return state.BestError;
    234217      }
  • branches/2994-AutoDiffForIntervals/HeuristicLab.Tests/HeuristicLab.Problems.DataAnalysis.Symbolic-3.4/IntervalEvaluatorAutoDiffTest.cs

    r17319 r17325  
    6161      var eval = new IntervalEvaluator();
    6262      var parser = new InfixExpressionParser();
     63      var intervals = new Dictionary<string, Interval>() {
     64        { "x", new Interval(1, 2) },
     65        { "unit", new Interval(0, 1) },
     66        { "neg", new Interval(-1, 0) },
     67      };
    6368      var t = parser.Parse("sqr(x)");
    6469      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
    65       var intervals = new Dictionary<string, Interval>() {
    66         { "x", new Interval(1, 2) },
    67         { "y", new Interval(0, 1) }
    68       };
    69       var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
    70       // TODO
    71       // Assert.AreEqual(XXX, r.LowerBound);
    72       // Assert.AreEqual(XXX, r.UpperBound);
    73       //
    74       // Assert.AreEqual(XXX, lg[0]); // x
    75       // Assert.AreEqual(XXX, ug[0]);
    76       //
    77       // for  { "x", new Interval(1, 2) },
    78       //   { "y", new Interval(0, 1) },
    79       //
    80       // 0 <> -2,50012500572888E-05 for y in SQR(LOG('y'))
    81       // 0 <> 2, 49987500573946E-05 for x in SQR(LOG('x'))
     70      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
     71      Assert.AreEqual(1, r.LowerBound);
     72      Assert.AreEqual(4, r.UpperBound);
     73
     74      Assert.AreEqual(2.0, lg[0]); // x
     75      Assert.AreEqual(8.0, ug[0]);
     76
     77      t = parser.Parse("sqr(log(unit))");
     78      paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
     79      r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
     80      Assert.AreEqual(0.0, r.LowerBound);
     81      Assert.AreEqual(double.PositiveInfinity, r.UpperBound);
     82
     83      Assert.AreEqual(0.0, lg[0]); // x
     84      Assert.AreEqual(double.NaN, ug[0]);
     85
    8286    }
    8387
     
    126130      var eval = new IntervalEvaluator();
    127131      var parser = new InfixExpressionParser();
     132      var intervals = new Dictionary<string, Interval>() {
     133        { "x", new Interval(3, 4) },
     134        { "z", new Interval(1, 2) }
     135      };
    128136      var t = parser.Parse("cos(x)");
    129137      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
    130       var intervals = new Dictionary<string, Interval>() {
    131         { "x", new Interval(3, 4) },
    132       };
    133138      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
    134139      Assert.AreEqual(-1, r.LowerBound); //  3..4 crosses pi and cos(pi) == -1
     
    137142      Assert.AreEqual(0, lg[0]); // x
    138143      Assert.AreEqual(-4 * Math.Sin(4), ug[0]);
     144
     145      t = parser.Parse("LOG(COS('z'))");
     146      paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
     147      r = eval.Evaluate(t, intervals, paramNodes, out  lg, out  ug);
     148      Assert.AreEqual(double.NaN, r.LowerBound);
     149      Assert.AreEqual(Math.Log(Math.Cos(1)), r.UpperBound);
     150
     151      Assert.AreEqual(-2 * Math.Sin(2) / Math.Cos(2), lg[0], 1e-5); // x
     152      Assert.AreEqual(-1 * Math.Sin(1) / Math.Cos(1), ug[0], 1e-5);
     153     
    139154    }
    140155
     
    182197      r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
    183198
    184       Assert.AreEqual(0.5 * Math.Sqrt(1e-10), lg[0], 1e-6); // z          --> lim x -> 0 (sqrt(x)) = 0
     199      Assert.AreEqual(0.5 * Math.Sqrt(1e-10), lg[0], 1e-5); // --> lim x -> 0 (sqrt(x)) = 0
    185200      Assert.AreEqual(0.5, ug[0], 1e-5);
     201
     202      t = parser.Parse("sqrt(y - z)"); // 1..2 - 0..1
     203      paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
     204      r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
     205      Assert.AreEqual(0, r.LowerBound);
     206      Assert.AreEqual(Math.Sqrt(2), r.UpperBound);
     207
     208      Assert.AreEqual(double.PositiveInfinity, lg[0], 1e-5); // y
     209      Assert.AreEqual(1/ Math.Sqrt(2)  , ug[0], 1e-5);
     210      Assert.AreEqual(double.NegativeInfinity, lg[1], 1e-5); // z
     211      Assert.AreEqual(0.0   , ug[1], 1e-5);
    186212    }
    187213
Note: See TracChangeset for help on using the changeset viewer.