Free cookie consent management tool by TermsFeed Policy Generator

Changeset 16999


Ignore:
Timestamp:
05/31/19 13:55:51 (5 years ago)
Author:
gkronber
Message:

#2925: Added optimization of weights for variables and added an integration method which uses CVODES to integrate over the whole episode (without input variables)

Location:
branches/2925_AutoDiffForDynamicalModels/HeuristicLab.Problems.DynamicalSystemsModelling/3.3
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • branches/2925_AutoDiffForDynamicalModels/HeuristicLab.Problems.DynamicalSystemsModelling/3.3/HeuristicLab.Problems.DynamicalSystemsModelling-3.3.csproj

    r16976 r16999  
    5858    <CodeAnalysisRuleSet>AllRules.ruleset</CodeAnalysisRuleSet>
    5959    <Prefer32Bit>false</Prefer32Bit>
     60    <AllowUnsafeBlocks>true</AllowUnsafeBlocks>
    6061  </PropertyGroup>
    6162  <PropertyGroup Condition=" '$(Configuration)|$(Platform)' == 'Debug|x64' ">
  • branches/2925_AutoDiffForDynamicalModels/HeuristicLab.Problems.DynamicalSystemsModelling/3.3/Problem.cs

    r16976 r16999  
    229229      Parameters.Add(new FixedValueParameter<DoubleValue>("Numeric differences smoothing", "Determines the amount of smoothing for the numeric differences which are calculated for pre-tuning. Values from -8 to 8 are reasonable. Use very low value if the data contains no noise. Default: 2.", new DoubleValue(2.0)));
    230230
    231       var solversStr = new string[] { "HeuristicLab", "CVODES" };
     231      var solversStr = new string[] { "HeuristicLab", "CVODES", "CVODES (full)" };
    232232      var solvers = new ItemSet<StringValue>(
    233233        solversStr.Select(s => new StringValue(s).AsReadOnly())
     
    238238      InitAllParameters();
    239239
    240       // TODO: use training range as default training episode
    241240      // TODO: optimization of starting values for latent variables in CVODES solver
    242241      // TODO: allow to specify the name for the time variable in the dataset and allow variable step-sizes
     
    249248      var targetVars = TargetVariables.CheckedItems.OrderBy(i => i.Index).Select(i => i.Value.Value).ToArray();
    250249      var latentVariables = Enumerable.Range(1, NumberOfLatentVariables).Select(i => "λ" + i).ToArray(); // TODO: must coincide with the variables which are actually defined in the grammar and also for which we actually have trees
     250      if (latentVariables.Any()) throw new NotSupportedException("latent variables are not supported"); // not sure if everything still works in combination with latent variables
    251251      if (OptimizeParametersForEpisodes) {
    252252        throw new NotImplementedException();
     
    272272        double nmse = OptimizeForEpisodes(trees, problemData, targetVars, latentVariables, random, trainingEpisodes, MaximumPretuningParameterOptimizationIterations, NumericIntegrationSteps, OdeSolver, MaximumOdeParameterOptimizationIterations,
    273273          PretuningErrorWeight.Value.Value, OdeErrorWeight.Value.Value, NumericDifferencesSmoothing);
    274         // individual["OptTheta"] = new DoubleArray(optTheta); // write back optimized parameters so that we can use them in the Analysis method
    275274        return nmse;
    276275      }
     
    298297      var targetVariableTrees = trees.Take(targetVars.Length).ToArray();
    299298      var latentVariableTrees = trees.Skip(targetVars.Length).ToArray();
    300       var constantNodes = targetVariableTrees.Select(t => t.IterateNodesPrefix().OfType<ConstantTreeNode>().ToArray()).ToArray();
    301       var initialTheta = constantNodes.Select(nodes => nodes.Select(n => n.Value).ToArray()).ToArray();
     299      // var constantNodes = targetVariableTrees.Select(t => t.IterateNodesPrefix().OfType<ConstantTreeNode>().ToArray()).ToArray();
     300      // var initialTheta = constantNodes.Select(nodes => nodes.Select(n => n.Value).ToArray()).ToArray();
     301      var constantNodes = targetVariableTrees.Select(
     302        t => t.IterateNodesPrefix()
     303        .Where(n => n.SubtreeCount == 0) // select leaves
     304        .ToArray()).ToArray();
     305      var initialTheta = constantNodes.Select(
     306        a => a.Select(
     307          n => {
     308            if (n is VariableTreeNode varTreeNode) {
     309              return varTreeNode.Weight;
     310            } else if (n is ConstantTreeNode constTreeNode) {
     311              return constTreeNode.Value;
     312            } else throw new InvalidProgramException();
     313          }).ToArray()).ToArray();
    302314
    303315      // optimize parameters by fitting f(x,y) to calculated differences dy/dt(t)
     
    313325      pretunedParameters = pretunedParameters
    314326        .Concat(latentVariableTrees
    315         .SelectMany(t => t.IterateNodesPrefix().OfType<ConstantTreeNode>().Select(n => n.Value)))
     327        .SelectMany(t => t.IterateNodesPrefix()
     328        .Where(n => n.SubtreeCount == 0)
     329        .Select(n => {
     330          if (n is VariableTreeNode varTreeNode) {
     331            return varTreeNode.Weight;
     332          } else if (n is ConstantTreeNode constTreeNode) {
     333            return constTreeNode.Value;
     334          } else throw new InvalidProgramException();
     335        })))
    316336        .ToArray();
    317337
     
    332352      var paramIdx = 0;
    333353      for (var treeIdx = 0; treeIdx < constantNodes.Length; treeIdx++) {
    334         for (int i = 0; i < constantNodes[treeIdx].Length; i++)
    335           constantNodes[treeIdx][i].Value = optTheta[paramIdx++];
     354        for (int i = 0; i < constantNodes[treeIdx].Length; i++) {
     355          if (constantNodes[treeIdx][i] is VariableTreeNode varTreeNode) {
     356            varTreeNode.Weight = optTheta[paramIdx++];
     357          } else if (constantNodes[treeIdx][i] is ConstantTreeNode constTreeNode) {
     358            constTreeNode.Value = optTheta[paramIdx++];
     359          }
     360        }
    336361      }
    337362      return nmse;
     
    359384      if (latentVariables.Length > 0) {
    360385        var inputVariables = targetVars.Concat(latentTrees.SelectMany(t => t.IterateNodesPrefix().OfType<VariableTreeNode>().Select(n => n.VariableName))).Except(latentVariables).Distinct();
    361         var myState = new OptimizationData(latentTrees, targetVars, inputVariables.ToArray(), problemData, null, episodes.ToArray(), 10, latentVariables, "HeuristicLab");
     386        var myState = new OptimizationData(latentTrees, targetVars, inputVariables.ToArray(), problemData, null, episodes.ToArray(), 10, latentVariables, "NONE");
    362387
    363388        var fi = new double[myState.rows.Length * targetVars.Length];
     
    527552        foreach (var variable in variables) {
    528553          // in this problem we also allow fixed numeric parameters (represented as variables with the value as name)
    529           if (double.TryParse(variable, NumberStyles.Float, CultureInfo.InvariantCulture, out double value)) {
    530             nodeValueLookup.SetVariableValue(variable, value); // TODO: Perf we don't need to set this for each index
    531           } else {
    532             nodeValueLookup.SetVariableValue(variable, ds.GetDoubleValue(variable, rows[trainIdx])); // TODO: perf
    533           }
     554          // if (double.TryParse(variable, NumberStyles.Float, CultureInfo.InvariantCulture, out double value)) {
     555          //   nodeValueLookup.SetVariableValue(variable, value); // TODO: Perf we don't need to set this for each index
     556          // } else {
     557          nodeValueLookup.SetVariableValue(variable, ds.GetDoubleValue(variable, rows[trainIdx])); // TODO: perf
     558          // }
    534559        }
    535560        // interpret all trees
     
    561586        foreach (var variable in variables) {
    562587          // in this problem we also allow fixed numeric parameters (represented as variables with the value as name)
    563           if (double.TryParse(variable, NumberStyles.Float, CultureInfo.InvariantCulture, out double value)) {
    564             nodeValueLookup.SetVariableValue(variable, value); // TODO: Perf we don't need to set this for each index
    565           } else {
    566             nodeValueLookup.SetVariableValue(variable, ds.GetDoubleValue(variable, rows[trainIdx])); // TODO: perf
    567           }
     588          // if (double.TryParse(variable, NumberStyles.Float, CultureInfo.InvariantCulture, out double value)) {
     589          //   nodeValueLookup.SetVariableValue(variable, value); // TODO: Perf we don't need to set this for each index
     590          // } else {
     591          nodeValueLookup.SetVariableValue(variable, ds.GetDoubleValue(variable, rows[trainIdx])); // TODO: perf
     592          // }
    568593        }
    569594
     
    658683
    659684      var bestIndividualAndQuality = this.GetBestIndividual(individuals, qualities);
    660       var trees = bestIndividualAndQuality.Item1.Values.Select(v => v.Value).OfType<ISymbolicExpressionTree>().ToArray(); // extract all trees from individual
     685      var trees = bestIndividualAndQuality.Item1.Values.Select(v => v.Value)
     686        .OfType<ISymbolicExpressionTree>().ToArray(); // extract all trees from individual
    661687
    662688      results["SNMSE"].Value = new DoubleValue(bestIndividualAndQuality.Item2);
     
    966992        foreach (var varName in inputVariables) {
    967993          // in this problem we also allow fixed numeric parameters (represented as variables with the value as name)
    968           if (double.TryParse(varName, NumberStyles.Float, CultureInfo.InvariantCulture, out double value)) {
    969             nodeValues.SetVariableValue(varName, value, Vector.Zero);
    970           } else {
    971             var y0 = dataset.GetDoubleValue(varName, t0);
    972             nodeValues.SetVariableValue(varName, y0, Vector.Zero);
    973           }
     994          // if (double.TryParse(varName, NumberStyles.Float, CultureInfo.InvariantCulture, out double value)) {
     995          //   nodeValues.SetVariableValue(varName, value, Vector.Zero);
     996          // } else {
     997          var y0 = dataset.GetDoubleValue(varName, t0);
     998          nodeValues.SetVariableValue(varName, y0, Vector.Zero);
     999          //}
    9741000        }
    9751001        foreach (var varName in targetVariables) {
     
    10131039
    10141040        var prevT = t0; // TODO: here we should use a variable for t if it is available. Right now we assume equidistant measurements.
    1015         foreach (var t in rows.Skip(1)) {
    1016           if (odeSolver == "HeuristicLab")
    1017             IntegrateHL(trees, calculatedVariables, nodeValues, numericIntegrationSteps); // integrator updates nodeValues
    1018           else if (odeSolver == "CVODES")
    1019             IntegrateCVODES(trees, calculatedVariables, nodeValues);
    1020           else throw new InvalidOperationException("Unknown ODE solver " + odeSolver);
    1021           prevT = t;
    1022 
    1023           // update output for target variables (TODO: if we want to visualize the latent variables then we need to provide a separate output)
    1024           for (int i = 0; i < targetVariables.Length; i++) {
    1025             var targetVar = targetVariables[i];
    1026             var yt = nodeValues.GetVariableValue(targetVar);
    1027 
    1028             // fill up remaining rows with last valid value if there are invalid values
    1029             if (double.IsNaN(yt.Item1) || double.IsInfinity(yt.Item1)) {
    1030               for (; outputRowIdx < fi.Length; outputRowIdx++) {
    1031                 var prevIdx = outputRowIdx - targetVariables.Length;
    1032                 fi[outputRowIdx] = fi[prevIdx]; // current <- prev
    1033                 if (jac != null) for (int j = 0; j < jac.GetLength(1); j++) jac[outputRowIdx, j] = jac[prevIdx, j];
     1041        if (odeSolver == "CVODES (full)") {
     1042          IntegrateCVODES(trees, calculatedVariables, nodeValues, rows, fi, jac);
     1043        } else {
     1044
     1045          foreach (var t in rows.Skip(1)) {
     1046            if (odeSolver == "HeuristicLab")
     1047              IntegrateHL(trees, calculatedVariables, nodeValues, numericIntegrationSteps); // integrator updates nodeValues
     1048            else if (odeSolver == "CVODES")
     1049              IntegrateCVODES(trees, calculatedVariables, nodeValues);
     1050            else throw new InvalidOperationException("Unknown ODE solver " + odeSolver);
     1051            prevT = t;
     1052
     1053            // update output for target variables (TODO: if we want to visualize the latent variables then we need to provide a separate output)
     1054            for (int i = 0; i < targetVariables.Length; i++) {
     1055              var targetVar = targetVariables[i];
     1056              var yt = nodeValues.GetVariableValue(targetVar);
     1057
     1058              // fill up remaining rows with last valid value if there are invalid values
     1059              if (double.IsNaN(yt.Item1) || double.IsInfinity(yt.Item1)) {
     1060                for (; outputRowIdx < fi.Length; outputRowIdx++) {
     1061                  var prevIdx = outputRowIdx - targetVariables.Length;
     1062                  fi[outputRowIdx] = fi[prevIdx]; // current <- prev
     1063                  if (jac != null) for (int j = 0; j < jac.GetLength(1); j++) jac[outputRowIdx, j] = jac[prevIdx, j];
     1064                }
     1065                return;
     1066              };
     1067
     1068              fi[outputRowIdx] = yt.Item1;
     1069              var g = yt.Item2;
     1070              g.CopyTo(jac, outputRowIdx);
     1071              outputRowIdx++;
     1072            }
     1073            if (latentValues != null) {
     1074              foreach (var latentVariable in latentVariables) {
     1075                var lt = nodeValues.GetVariableValue(latentVariable).Item1;
     1076                latentValues[latentValueRowIdx, latentValueColIdx++] = lt;
    10341077              }
    1035               return;
    1036             };
    1037 
    1038             fi[outputRowIdx] = yt.Item1;
    1039             var g = yt.Item2;
    1040             g.CopyTo(jac, outputRowIdx);
    1041             outputRowIdx++;
    1042           }
    1043           if (latentValues != null) {
    1044             foreach (var latentVariable in latentVariables) {
    1045               var lt = nodeValues.GetVariableValue(latentVariable).Item1;
    1046               latentValues[latentValueRowIdx, latentValueColIdx++] = lt;
     1078              latentValueRowIdx++; latentValueColIdx = 0;
    10471079            }
    1048             latentValueRowIdx++; latentValueColIdx = 0;
    1049           }
    1050 
    1051           // update for next time step (only the inputs)
    1052           foreach (var varName in inputVariables) {
    1053             // in this problem we also allow fixed numeric parameters (represented as variables with the value as name)
    1054             if (double.TryParse(varName, NumberStyles.Float, CultureInfo.InvariantCulture, out double value)) {
    1055               // value is unchanged
    1056             } else {
     1080
     1081            // update for next time step (only the inputs)
     1082            foreach (var varName in inputVariables) {
     1083              // in this problem we also allow fixed numeric parameters (represented as variables with the value as name)
     1084              // if (double.TryParse(varName, NumberStyles.Float, CultureInfo.InvariantCulture, out double value)) {
     1085              //   // value is unchanged
     1086              // } else {
    10571087              nodeValues.SetVariableValue(varName, dataset.GetDoubleValue(varName, t), Vector.Zero);
     1088              // }
    10581089            }
    10591090          }
     
    11831214    }
    11841215
     1216    /// <summary>
     1217    ///  Here we use CVODES to solve the ODE. Forward sensitivities are used to calculate the gradient for parameter optimization
     1218    /// </summary>
     1219    /// <param name="trees">Each equation in the ODE represented as a tree</param>
     1220    /// <param name="calculatedVariables">The names of the calculated variables</param>
     1221    /// <param name="t">The time t up to which we need to integrate.</param>
     1222    private static void IntegrateCVODES(
     1223      ISymbolicExpressionTree[] trees, // f(y,p) in tree representation
     1224      string[] calculatedVariables, // names of elements of y
     1225      NodeValueLookup nodeValues,
     1226      IEnumerable<int> rows,
     1227      double[] fi,
     1228      double[,] jac
     1229      ) {
     1230
     1231      // the RHS of the ODE
     1232      // dy/dt = f(y_t,x_t,p)
     1233      CVODES.CVRhsFunc f = CreateOdeRhs(trees, calculatedVariables, nodeValues);
     1234
     1235      var calcSens = jac != null;
     1236      // the Jacobian ∂f/∂y
     1237      CVODES.CVDlsJacFunc jacF = CreateJac(trees, calculatedVariables, nodeValues);
     1238
     1239      // the RHS for the forward sensitivities (∂f/∂y)s_i(t) + ∂f/∂p_i
     1240      CVODES.CVSensRhsFn sensF = CreateSensitivityRhs(trees, calculatedVariables, nodeValues);
     1241
     1242      // setup solver
     1243      int numberOfEquations = trees.Length;
     1244      IntPtr y = IntPtr.Zero;
     1245      IntPtr cvode_mem = IntPtr.Zero;
     1246      IntPtr A = IntPtr.Zero;
     1247      IntPtr yS0 = IntPtr.Zero;
     1248      IntPtr linearSolver = IntPtr.Zero;
     1249      var ns = nodeValues.ParameterCount; // number of parameters
     1250
     1251      try {
     1252        y = CVODES.N_VNew_Serial(numberOfEquations);
     1253        // init y to current values of variables
     1254        // y must be initialized before calling CVodeInit
     1255        for (int i = 0; i < calculatedVariables.Length; i++) {
     1256          CVODES.NV_Set_Ith_S(y, i, nodeValues.GetVariableValue(calculatedVariables[i]).Item1);
     1257        }
     1258
     1259        cvode_mem = CVODES.CVodeCreate(CVODES.MultistepMethod.CV_ADAMS, CVODES.NonlinearSolverIteration.CV_FUNCTIONAL);
     1260
     1261        var flag = CVODES.CVodeInit(cvode_mem, f, rows.First(), y);
     1262        Assert(CVODES.CV_SUCCESS == flag);
     1263
     1264        flag = CVODES.CVodeSetErrHandlerFn(cvode_mem, errorFunction, IntPtr.Zero);
     1265        Assert(CVODES.CV_SUCCESS == flag);
     1266        double relTol = 1.0e-2;
     1267        double absTol = 1.0;
     1268        flag = CVODES.CVodeSStolerances(cvode_mem, relTol, absTol);  // TODO: probably need to adjust absTol per variable
     1269        Assert(CVODES.CV_SUCCESS == flag);
     1270
     1271        A = CVODES.SUNDenseMatrix(numberOfEquations, numberOfEquations);
     1272        Assert(A != IntPtr.Zero);
     1273
     1274        linearSolver = CVODES.SUNDenseLinearSolver(y, A);
     1275        Assert(linearSolver != IntPtr.Zero);
     1276
     1277        flag = CVODES.CVDlsSetLinearSolver(cvode_mem, linearSolver, A);
     1278        Assert(CVODES.CV_SUCCESS == flag);
     1279
     1280        flag = CVODES.CVDlsSetJacFn(cvode_mem, jacF);
     1281        Assert(CVODES.CV_SUCCESS == flag);
     1282
     1283        if (calcSens) {
     1284
     1285          yS0 = CVODES.N_VCloneVectorArray_Serial(ns, y); // clone the output vector for each parameter
     1286          unsafe {
     1287            // set to initial sensitivities supplied by caller
     1288            for (int pIdx = 0; pIdx < ns; pIdx++) {
     1289              var yS0_i = *((IntPtr*)yS0.ToPointer() + pIdx);
     1290              for (var varIdx = 0; varIdx < calculatedVariables.Length; varIdx++) {
     1291                CVODES.NV_Set_Ith_S(yS0_i, varIdx, nodeValues.GetVariableValue(calculatedVariables[varIdx]).Item2[pIdx]); // TODO: perf
     1292              }
     1293            }
     1294          }
     1295
     1296          flag = CVODES.CVodeSensInit(cvode_mem, ns, CVODES.CV_SIMULTANEOUS, sensF, yS0);
     1297          Assert(CVODES.CV_SUCCESS == flag);
     1298
     1299          flag = CVODES.CVodeSensEEtolerances(cvode_mem);
     1300          Assert(CVODES.CV_SUCCESS == flag);
     1301        }
     1302        // integrate
     1303        int outputIdx = calculatedVariables.Length; // values at t0 do not need to be set.
     1304        foreach (var tout in rows.Skip(1)) {
     1305          double tret = 0;
     1306          flag = CVODES.CVode(cvode_mem, tout, y, ref tret, CVODES.CV_NORMAL);
     1307          if (flag == CVODES.CV_SUCCESS) {
     1308            // Assert(1.0 == tout);
     1309            if (calcSens) {
     1310              // get sensitivities
     1311              flag = CVODES.CVodeGetSens(cvode_mem, ref tret, yS0);
     1312              Assert(CVODES.CV_SUCCESS == flag);
     1313            }
     1314            // update variableValues based on integration results
     1315            for (int varIdx = 0; varIdx < calculatedVariables.Length; varIdx++) {
     1316              var yi = CVODES.NV_Get_Ith_S(y, varIdx);
     1317              fi[outputIdx] = yi;
     1318              if (calcSens) {
     1319                // var gArr = new double[ns];
     1320                for (var pIdx = 0; pIdx < ns; pIdx++) {
     1321                  unsafe {
     1322                    var yS0_pi = *((IntPtr*)yS0.ToPointer() + pIdx);
     1323                    jac[outputIdx, pIdx] = CVODES.NV_Get_Ith_S(yS0_pi, varIdx);
     1324                  }
     1325                }
     1326              }
     1327              outputIdx++;
     1328            }
     1329
     1330          } else {
     1331            // fill up remaining values
     1332            while (outputIdx < fi.Length) {
     1333              fi[outputIdx] = fi[outputIdx - calculatedVariables.Length];
     1334              if (calcSens) {
     1335                for (var pIdx = 0; pIdx < ns; pIdx++) {
     1336                  jac[outputIdx, pIdx] = jac[outputIdx - calculatedVariables.Length, pIdx];
     1337                }
     1338              }
     1339              outputIdx++;
     1340            }
     1341            return;
     1342          }
     1343        }
     1344
     1345        // cleanup all allocated objects
     1346      } finally {
     1347        if (y != IntPtr.Zero) CVODES.N_VDestroy_Serial(y);
     1348        if (cvode_mem != IntPtr.Zero) CVODES.CVodeFree(ref cvode_mem);
     1349        if (linearSolver != IntPtr.Zero) CVODES.SUNLinSolFree(linearSolver);
     1350        if (A != IntPtr.Zero) CVODES.SUNMatDestroy(A);
     1351        if (yS0 != IntPtr.Zero) CVODES.N_VDestroyVectorArray_Serial(yS0, ns);
     1352      }
     1353    }
     1354
    11851355    private static void errorFunction(int errorCode, IntPtr module, IntPtr function, IntPtr msg, IntPtr ehdata) {
    11861356      var moduleStr = Marshal.PtrToStringAnsi(module);
     
    11881358      var msgStr = Marshal.PtrToStringAnsi(msg);
    11891359      string type = errorCode == 0 ? "Warning" : "Error";
    1190       throw new InvalidProgramException($"{type}: {msgStr} Module: {moduleStr} Function: {functionStr}");
     1360      // throw new InvalidProgramException($"{type}: {msgStr} Module: {moduleStr} Function: {functionStr}");
    11911361    }
    11921362
     
    12571427          InterpretRec(tree.Root.GetSubtree(0).GetSubtree(0), nodeValues, out double z, out Vector dz);
    12581428          for (int j = 0; j < calculatedVariables.Length; j++) {
    1259             CVODES.SUNDenseMatrix_Set(Jac, i, j, dz[j]);
     1429            CVODES.SUNDenseMatrix_Set(Jac, i, j, dz[j]);  //TODO: must set as in SensitivityRhs!
    12601430          }
    12611431        }
     
    13751545    // TODO: use an existing interpreter implementation instead
    13761546    private static double InterpretRec(ISymbolicExpressionTreeNode node, NodeValueLookup nodeValues) {
    1377       if (node is ConstantTreeNode) {
    1378         return ((ConstantTreeNode)node).Value;
    1379       } else if (node is VariableTreeNode) {
    1380         return nodeValues.NodeValue(node);
     1547      if (node is ConstantTreeNode constTreeNode) {
     1548        return nodeValues.ConstantNodeValue(constTreeNode);
     1549      } else if (node is VariableTreeNode varTreeNode) {
     1550        return nodeValues.VariableNodeValue(varTreeNode);
    13811551      } else if (node.Symbol is Addition) {
    13821552        var f = InterpretRec(node.GetSubtree(0), nodeValues);
     
    14761646      double f, g;
    14771647      Vector df, dg;
    1478       if (node.Symbol is Constant || node.Symbol is Variable) {
    1479         z = nodeValues.NodeValue(node);
    1480         dz = Vector.CreateNew(nodeValues.NodeGradient(node)); // original gradient vectors are never changed by evaluation
     1648      if (node is ConstantTreeNode constTreeNode) {
     1649        var val = nodeValues.ConstantNodeValueAndGradient(constTreeNode);
     1650        z = val.Item1;
     1651        dz = val.Item2;
     1652      } else if (node is VariableTreeNode varTreeNode) {
     1653        var val = nodeValues.VariableNodeValueAndGradient(varTreeNode);
     1654        z = val.Item1;
     1655        dz = val.Item2;
    14811656      } else if (node.Symbol is Addition) {
    14821657        InterpretRec(node.GetSubtree(0), nodeValues, out f, out df);
     
    17751950    }
    17761951
    1777     private static bool IsConstantNode(ISymbolicExpressionTreeNode n) {
    1778       return n is ConstantTreeNode;
    1779     }
    1780     private static double GetConstantValue(ISymbolicExpressionTreeNode n) {
    1781       return ((ConstantTreeNode)n).Value;
    1782     }
    1783     private static bool IsLatentVariableNode(ISymbolicExpressionTreeNode n) {
    1784       return n.Symbol.Name[0] == 'λ';
    1785     }
    1786     private static bool IsVariableNode(ISymbolicExpressionTreeNode n) {
    1787       return (n.SubtreeCount == 0) && !IsConstantNode(n) && !IsLatentVariableNode(n);
    1788     }
    1789     private static string GetVariableName(ISymbolicExpressionTreeNode n) {
    1790       return ((VariableTreeNode)n).VariableName;
    1791     }
     1952    // private static bool IsLatentVariableNode(ISymbolicExpressionTreeNode n) {
     1953    //   return n.Symbol.Name[0] == 'λ';
     1954    // }
    17921955
    17931956    private void UpdateTargetVariables() {
     
    18932056      // configure initialization of variables
    18942057      var varSy = (Variable)grammar.GetSymbol("Variable");
    1895       // fix variable weights to 1.0
    1896       varSy.WeightMu = 1.0;
    1897       varSy.WeightSigma = 0.0;
     2058      // init variables to a small value and allow manipulation
     2059      varSy.WeightMu = 0.0;
     2060      varSy.WeightSigma = 1e-1;
    18982061      varSy.WeightManipulatorMu = 0.0;
    1899       varSy.WeightManipulatorSigma = 0.0;
    1900       varSy.MultiplicativeWeightManipulatorSigma = 0.0;
     2062      varSy.WeightManipulatorSigma = 1.0;
     2063      varSy.MultiplicativeWeightManipulatorSigma = 1.0;
    19012064
    19022065      foreach (var f in FunctionSet) {
     
    19752138    public class NodeValueLookup {
    19762139      private readonly Dictionary<ISymbolicExpressionTreeNode, Tuple<double, Vector>> node2val = new Dictionary<ISymbolicExpressionTreeNode, Tuple<double, Vector>>();
    1977       private readonly Dictionary<string, List<ISymbolicExpressionTreeNode>> name2nodes = new Dictionary<string, List<ISymbolicExpressionTreeNode>>();
    1978       private readonly ConstantTreeNode[] constantNodes;
     2140      private readonly ISymbolicExpressionTreeNode[] leafNodes;
    19792141      private readonly Vector[] constantGradientVectors;
    1980 
    1981 
    1982       public double NodeValue(ISymbolicExpressionTreeNode node) => node2val[node].Item1;
    1983       public Vector NodeGradient(ISymbolicExpressionTreeNode node) => node2val[node].Item2;
     2142      private readonly Dictionary<string, Tuple<double, Vector>> variableValues = new Dictionary<string, Tuple<double, Vector>>();
     2143
     2144      // accessors for current values of constant and variable nodes. For variable nodes we also need to account for the variable weight
     2145      public double ConstantNodeValue(ConstantTreeNode node) => node2val[node].Item1;
     2146      public Tuple<double, Vector> ConstantNodeValueAndGradient(ConstantTreeNode node) { var v = node2val[node]; return Tuple.Create(v.Item1, Vector.CreateNew(v.Item2)); }
     2147      public double VariableNodeValue(VariableTreeNode node) => variableValues[node.VariableName].Item1 * node2val[node].Item1;
     2148      public Tuple<double, Vector> VariableNodeValueAndGradient(VariableTreeNode node) {
     2149        // (f*g)' = (f'*g)+(g'*f)       
     2150        var g = node2val[node];
     2151        var f = variableValues[node.VariableName];
     2152
     2153        return Tuple.Create(
     2154          g.Item1 * f.Item1,
     2155          Vector.CreateNew(f.Item2).Scale(g.Item1).Add(Vector.CreateNew(g.Item2).Scale(f.Item1)));
     2156      }
    19842157
    19852158      public NodeValueLookup(ISymbolicExpressionTree[] trees, bool variableGradient = false) {
    1986         this.constantNodes = trees.SelectMany(t => t.IterateNodesPrefix().OfType<ConstantTreeNode>()).ToArray();
     2159        this.leafNodes = trees.SelectMany(t => t.IterateNodesPrefix().Where(n => n.SubtreeCount==0)).ToArray();
    19872160        if (!variableGradient) {
    1988           constantGradientVectors = new Vector[constantNodes.Length];
    1989           for (int paramIdx = 0; paramIdx < constantNodes.Length; paramIdx++) {
    1990             constantGradientVectors[paramIdx] = Vector.CreateIndicator(length: constantNodes.Length, idx: paramIdx);
    1991 
    1992             var node = constantNodes[paramIdx];
    1993             node2val[node] = Tuple.Create(node.Value, constantGradientVectors[paramIdx]);
    1994           }
    1995 
    1996           foreach (var tree in trees) {
    1997             foreach (var node in tree.IterateNodesPrefix().Where(IsVariableNode)) {
    1998               var varName = GetVariableName(node);
    1999               if (!name2nodes.TryGetValue(varName, out List<ISymbolicExpressionTreeNode> nodes)) {
    2000                 nodes = new List<ISymbolicExpressionTreeNode>();
    2001                 name2nodes.Add(varName, nodes);
    2002               }
    2003               nodes.Add(node);
    2004               SetVariableValue(varName, 0.0);  // this value is updated in the prediction loop
    2005             }
    2006           }
    2007         }
    2008         else {
     2161          constantGradientVectors = new Vector[leafNodes.Length];
     2162          for (int paramIdx = 0; paramIdx < leafNodes.Length; paramIdx++) {
     2163            constantGradientVectors[paramIdx] = Vector.CreateIndicator(length: leafNodes.Length, idx: paramIdx);
     2164
     2165            var node = leafNodes[paramIdx];
     2166            if (node is ConstantTreeNode constTreeNode) {
     2167              node2val[node] = Tuple.Create(constTreeNode.Value, constantGradientVectors[paramIdx]);
     2168            } else if (node is VariableTreeNode varTreeNode) {
     2169              node2val[node] = Tuple.Create(varTreeNode.Weight, constantGradientVectors[paramIdx]);
     2170            } else throw new InvalidProgramException();
     2171          }
     2172        } else {
    20092173          // variable gradient means we want to calculate the gradient over the target variables instead of parameters
    2010           for (int paramIdx = 0; paramIdx < constantNodes.Length; paramIdx++) {
    2011             var node = constantNodes[paramIdx];
    2012             node2val[node] = Tuple.Create(node.Value, Vector.Zero);
    2013           }
    2014 
    2015           foreach (var tree in trees) {
    2016             foreach (var node in tree.IterateNodesPrefix().Where(IsVariableNode)) {
    2017               var varName = GetVariableName(node);
    2018               if (!name2nodes.TryGetValue(varName, out List<ISymbolicExpressionTreeNode> nodes)) {
    2019                 nodes = new List<ISymbolicExpressionTreeNode>();
    2020                 name2nodes.Add(varName, nodes);
    2021               }
    2022               nodes.Add(node);
    2023               SetVariableValue(varName, 0.0);  // this value is updated in the prediction loop
    2024             }
    2025           }
    2026         }
    2027       }
    2028 
    2029       public int ParameterCount => constantNodes.Length;
     2174          for (int paramIdx = 0; paramIdx < leafNodes.Length; paramIdx++) {
     2175            var node = leafNodes[paramIdx];
     2176            if (node is ConstantTreeNode constTreeNode) {
     2177              node2val[node] = Tuple.Create(constTreeNode.Value, Vector.Zero);
     2178            } else if (node is VariableTreeNode varTreeNode) {
     2179              node2val[node] = Tuple.Create(varTreeNode.Weight, Vector.Zero);
     2180            } else throw new InvalidProgramException();
     2181          }
     2182        }
     2183      }
     2184
     2185      public int ParameterCount => leafNodes.Length;
    20302186
    20312187      public void SetVariableValue(string variableName, double val) {
    20322188        SetVariableValue(variableName, val, Vector.Zero);
    20332189      }
     2190      /// <summary>
     2191      /// returns the current value for variable variableName
     2192      /// </summary>
     2193      /// <param name="variableName"></param>
     2194      /// <returns></returns>
    20342195      public Tuple<double, Vector> GetVariableValue(string variableName) {
    2035         return node2val[name2nodes[variableName].First()];
    2036       }
     2196        return variableValues[variableName];
     2197      }
     2198
     2199      /// <summary>
     2200      /// sets the current value for variable variableName
     2201      /// </summary>
     2202      /// <param name="variableName"></param>
     2203      /// <param name="val"></param>
     2204      /// <param name="dVal"></param>
    20372205      public void SetVariableValue(string variableName, double val, Vector dVal) {
    2038         if (name2nodes.TryGetValue(variableName, out List<ISymbolicExpressionTreeNode> nodes)) {
    2039           nodes.ForEach(n => node2val[n] = Tuple.Create(val, dVal));
    2040         } else {
    2041           var fakeNode = new VariableTreeNode(new Variable());
    2042           fakeNode.Weight = 1.0;
    2043           fakeNode.VariableName = variableName;
    2044           var newNodeList = new List<ISymbolicExpressionTreeNode>();
    2045           newNodeList.Add(fakeNode);
    2046           name2nodes.Add(variableName, newNodeList);
    2047           node2val[fakeNode] = Tuple.Create(val, dVal);
    2048         }
     2206        variableValues[variableName] = Tuple.Create(val, dVal);
     2207        // if (name2nodes.TryGetValue(variableName, out List<ISymbolicExpressionTreeNode> nodes)) {
     2208        //   nodes.ForEach(n => node2val[n] = Tuple.Create(val, dVal));
     2209        // } else {
     2210        //   var fakeNode = new VariableTreeNode(new Variable());
     2211        //   fakeNode.Weight = 1.0;
     2212        //   fakeNode.VariableName = variableName;
     2213        //   var newNodeList = new List<ISymbolicExpressionTreeNode>();
     2214        //   newNodeList.Add(fakeNode);
     2215        //   name2nodes.Add(variableName, newNodeList);
     2216        //   node2val[fakeNode] = Tuple.Create(val, dVal);
     2217        // }
    20492218      }
    20502219
    20512220      internal void UpdateParamValues(double[] x) {
    20522221        for (int i = 0; i < x.Length; i++) {
    2053           constantNodes[i].Value = x[i];
    2054           node2val[constantNodes[i]] = Tuple.Create(x[i], constantGradientVectors[i]);
     2222          node2val[leafNodes[i]] = Tuple.Create(x[i], constantGradientVectors[i]);
    20552223        }
    20562224      }
  • branches/2925_AutoDiffForDynamicalModels/HeuristicLab.Problems.DynamicalSystemsModelling/3.3/ProblemInstanceProvider.cs

    r16976 r16999  
    293293        TargetVariables = new[] { "x1", "x2", "v1", "v2" },
    294294        InputVariables = new string[] { },
    295         TrainingEpisodes = new IntRange[] { new IntRange(0, 820) },
     295        TrainingEpisodes = new IntRange[] { new IntRange(0, 200) },
    296296        TestEpisodes = new IntRange[] { },
    297297        FileName = "double_linear_h_1_equidistant.txt",
     
    306306        TargetVariables = new[] { "x1", "x2", "v1", "v2" },
    307307        InputVariables = new string[] { },
    308         TrainingEpisodes = new IntRange[] { new IntRange(0, 500) },
     308        TrainingEpisodes = new IntRange[] { new IntRange(0, 150) },
    309309        TestEpisodes = new IntRange[] { },
    310310        FileName = "real_double_linear_h_1_equidistant.txt",
     
    332332        TargetVariables = new[] { "theta1", "theta2", "omega1", "omega2" },
    333333        InputVariables = new string[] { },
    334         TrainingEpisodes = new IntRange[] { new IntRange(0, 886) },
     334        TrainingEpisodes = new IntRange[] { new IntRange(0, 200) },
    335335        TestEpisodes = new IntRange[] {new IntRange(886, 1731) },
    336336        FileName = "real_double_pend_h_1_equidistant.txt",
Note: See TracChangeset for help on using the changeset viewer.