Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
02/11/19 14:15:47 (5 years ago)
Author:
gkronber
Message:

#2925: made some adaptations while debugging parameter identification for dynamical models

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/2925_AutoDiffForDynamicalModels/HeuristicLab.Problems.DynamicalSystemsModelling/3.3/Problem.cs

    r16399 r16597  
    274274        alglib.minlbfgscreate(Math.Min(theta.Length, 5), theta, out state);
    275275        alglib.minlbfgssetcond(state, 0.0, 0.0, 0.0, maxParameterOptIterations);
    276         //alglib.minlbfgssetgradientcheck(state, 1e-6);
     276        // alglib.minlbfgssetgradientcheck(state, 1e-4);
    277277        alglib.minlbfgsoptimize(state, EvaluateObjectiveAndGradient, null,
    278278          new object[] { trees, targetVars, problemData, targetValues, episodes.ToArray(), numericIntegrationSteps, latentVariables, odeSolver }); //TODO: create a type
     
    306306                          * NFEV countains number of function calculations
    307307         */
    308         if (report.terminationtype < 0) { nmse = 10E6; return; }
     308        if (report.terminationtype < 0) { nmse = 10.0; return; }
    309309      }
    310310
     
    314314      EvaluateObjectiveAndGradient(optTheta, ref nmse, grad,
    315315        new object[] { trees, targetVars, problemData, targetValues, episodes.ToArray(), numericIntegrationSteps, latentVariables, odeSolver });
    316       if (double.IsNaN(nmse) || double.IsInfinity(nmse)) { nmse = 10E6; return; } // return a large value (TODO: be consistent by using NMSE)
     316      if (double.IsNaN(nmse) || double.IsInfinity(nmse)) { nmse = 10.0; return; } // return a large value (TODO: be consistent by using NMSE)
    317317    }
    318318
     
    342342
    343343      if (predicted.Length != targetValues.GetLength(0)) {
    344         f = double.MaxValue;
     344        f = 10.0; // TODO
    345345        Array.Clear(grad, 0, grad.Length);
    346346        return;
     
    349349      // for normalized MSE = 1/variance(t) * MSE(t, pred)
    350350      // TODO: Perf. (by standardization of target variables before evaluation of all trees)     
    351       var invVar = Enumerable.Range(0, targetVariables.Length)
    352         .Select(c => Enumerable.Range(0, targetValues.GetLength(0)).Select(row => targetValues[row, c])) // column vectors
    353         .Select(vec => vec.Variance())
    354         .Select(v => 1.0 / v)
    355         .ToArray();
     351      // var invVar = Enumerable.Range(0, targetVariables.Length)
     352      //   .Select(c => Enumerable.Range(0, targetValues.GetLength(0)).Select(row => targetValues[row, c])) // column vectors
     353      //   .Select(vec => vec.StandardDeviation()) // TODO: variance of stddev
     354      //   .Select(v => 1.0 / v)
     355      //   .ToArray();
     356
     357      double[] invVar = Enumerable.Repeat(1.0, targetVariables.Length).ToArray();
     358
    356359
    357360      // objective function is NMSE
     
    370373          var res = (y - y_pred_f);
    371374          var ressq = res * res;
    372           f += ressq * invN * invVar[c];
    373           g += -2.0 * res * y_pred[c].Item2 * invN * invVar[c];
     375          f += ressq * invN * invVar[c] /* * Math.Exp(-0.2 * r) */ ;
     376          g += -2.0 * res * y_pred[c].Item2 * invN * invVar[c] /* * Math.Exp(-0.2 * r) */;
    374377        }
    375378        r++;
     
    396399      if (!results.ContainsKey("Solution")) {
    397400        results.Add(new Result("Solution", typeof(Solution)));
     401      }
     402      if (!results.ContainsKey("Squared error and gradient")) {
     403        results.Add(new Result("Squared error and gradient", typeof(DataTable)));
    398404      }
    399405
     
    477483            trainingDataTable.Rows.Add(actualValuesRow);
    478484            trainingDataTable.Rows.Add(predictedValuesRow);
     485
     486            for (int paramIdx = 0; paramIdx < optTheta.Length; paramIdx++) {
     487              var paramSensitivityRow = new DataRow($"∂{targetVar}/∂θ{paramIdx}", $"Sensitivities of parameter {paramIdx}", trainingPrediction.Select(arr => arr[colIdx].Item2[paramIdx]).ToArray());
     488              paramSensitivityRow.VisualProperties.SecondYAxis = true;
     489              trainingDataTable.Rows.Add(paramSensitivityRow);
     490            }
    479491            trainingList.Add(trainingDataTable);
    480492          } else {
     
    488500          }
    489501        }
     502
     503        var errorTable = new DataTable("Squared error and gradient");
     504        var seRow = new DataRow("Squared error");
     505        var gradientRows = Enumerable.Range(0, optTheta.Length).Select(i => new DataRow($"∂SE/∂θ{i}")).ToArray();
     506        errorTable.Rows.Add(seRow);
     507        foreach (var gRow in gradientRows) {
     508          gRow.VisualProperties.SecondYAxis = true;
     509          errorTable.Rows.Add(gRow);
     510        }
     511        var targetValues = targetVars.Select(v => problemData.Dataset.GetDoubleValues(v, trainingRows).ToArray()).ToArray();
     512        int r = 0;
     513        double invN = 1.0 / trainingRows.Count();
     514        foreach (var y_pred in trainingPrediction) {
     515          // calculate objective function gradient
     516          double f_i = 0.0;
     517          Vector g_i = Vector.CreateNew(new double[optTheta.Length]);
     518          for (int colIdx = 0; colIdx < targetVars.Length; colIdx++) {
     519            var y_pred_f = y_pred[colIdx].Item1;
     520            var y = targetValues[colIdx][r];
     521
     522            var res = (y - y_pred_f);
     523            var ressq = res * res;
     524            f_i += ressq * invN /* * Math.Exp(-0.2 * r) */;
     525            g_i = g_i - 2.0 * res * y_pred[colIdx].Item2 * invN /* * Math.Exp(-0.2 * r)*/;
     526          }
     527          seRow.Values.Add(f_i);
     528          for (int j = 0; j < g_i.Length; j++) gradientRows[j].Values.Add(g_i[j]);
     529          r++;
     530        }
     531        results["Squared error and gradient"].Value = errorTable;
     532
    490533        // TODO: DRY for training and test
    491534        var testList = new ItemList<DataTable>();
     
    659702            IntegrateHL(trees, calculatedVariables, variableValues, parameterValues, numericIntegrationSteps);
    660703          else if (odeSolver == "CVODES")
    661             IntegrateCVODES(trees, calculatedVariables, variableValues, parameterValues, t - prevT);
     704            throw new NotImplementedException();
     705          // IntegrateCVODES(trees, calculatedVariables, variableValues, parameterValues, t - prevT);
    662706          else throw new InvalidOperationException("Unknown ODE solver " + odeSolver);
    663707          prevT = t;
     
    687731    #region CVODES
    688732
    689 
     733    /*
    690734    /// <summary>
    691735    ///  Here we use CVODES to solve the ODE. Forward sensitivities are used to calculate the gradient for parameter optimization
     
    9811025        };
    9821026    }
     1027    */
    9831028    #endregion
    9841029
     
    10121057      }
    10131058
     1059      double[] deltaF = new double[calculatedVariables.Length];
     1060      Vector[] deltaG = new Vector[calculatedVariables.Length];
    10141061
    10151062      double h = 1.0 / numericIntegrationSteps;
    10161063      for (int step = 0; step < numericIntegrationSteps; step++) {
    1017         var deltaValues = new Dictionary<string, Tuple<double, Vector>>();
     1064        //var deltaValues = new Dictionary<string, Tuple<double, Vector>>();
    10181065        for (int i = 0; i < trees.Length; i++) {
    10191066          var tree = trees[i];
     
    10211068
    10221069          // Root.GetSubtree(0).GetSubtree(0) skips programRoot and startSymbol
    1023           var res = InterpretRec(tree.Root.GetSubtree(0).GetSubtree(0), nodeValues);
    1024           deltaValues.Add(targetVarName, res);
     1070          double f; Vector g;
     1071          InterpretRec(tree.Root.GetSubtree(0).GetSubtree(0), nodeValues, out f, out g);
     1072          deltaF[i] = f;
     1073          deltaG[i] = g;
    10251074        }
    10261075
    10271076        // update variableValues for next step, trapezoid integration
    1028         foreach (var kvp in deltaValues) {
    1029           var oldVal = variableValues[kvp.Key];
     1077        for (int i = 0; i < trees.Length; i++) {
     1078          var varName = calculatedVariables[i];
     1079          var oldVal = variableValues[varName];
    10301080          var newVal = Tuple.Create(
    1031             oldVal.Item1 + h * kvp.Value.Item1,
    1032             oldVal.Item2 + h * kvp.Value.Item2
     1081            oldVal.Item1 + h * deltaF[i],
     1082            oldVal.Item2 + deltaG[i].Scale(h)
    10331083          );
    1034           variableValues[kvp.Key] = newVal;
    1035         }
    1036 
    1037 
     1084          variableValues[varName] = newVal;
     1085        }
     1086
     1087        // TODO perf
    10381088        foreach (var node in nodeValues.Keys.ToArray()) {
    10391089          if (node.SubtreeCount == 0 && !IsConstantNode(node)) {
     
    10461096    }
    10471097
    1048     private static Tuple<double, Vector> InterpretRec(
     1098    private static void InterpretRec(
    10491099      ISymbolicExpressionTreeNode node,
    1050       Dictionary<ISymbolicExpressionTreeNode, Tuple<double, Vector>> nodeValues      // contains value and gradient vector for a node (variables and constants only)
    1051         ) {
    1052 
     1100      Dictionary<ISymbolicExpressionTreeNode, Tuple<double, Vector>> nodeValues,      // contains value and gradient vector for a node (variables and constants only)
     1101      out double f,
     1102      out Vector g
     1103      ) {
     1104      double fl, fr;
     1105      Vector gl, gr;
    10531106      switch (node.Symbol.Name) {
    10541107        case "+": {
    1055             var l = InterpretRec(node.GetSubtree(0), nodeValues);
    1056             var r = InterpretRec(node.GetSubtree(1), nodeValues);
    1057 
    1058             return Tuple.Create(l.Item1 + r.Item1, l.Item2 + r.Item2);
     1108            InterpretRec(node.GetSubtree(0), nodeValues, out fl, out gl);
     1109            InterpretRec(node.GetSubtree(1), nodeValues, out fr, out gr);
     1110            f = fl + fr;
     1111            g = Vector.AddTo(gl, gr);
     1112            break;
    10591113          }
    10601114        case "*": {
    1061             var l = InterpretRec(node.GetSubtree(0), nodeValues);
    1062             var r = InterpretRec(node.GetSubtree(1), nodeValues);
    1063 
    1064             return Tuple.Create(l.Item1 * r.Item1, l.Item2 * r.Item1 + l.Item1 * r.Item2);
     1115            InterpretRec(node.GetSubtree(0), nodeValues, out fl, out gl);
     1116            InterpretRec(node.GetSubtree(1), nodeValues, out fr, out gr);
     1117            f = fl * fr;
     1118            g = Vector.AddTo(gl.Scale(fr), gr.Scale(fl)); // f'*g + f*g'
     1119            break;
    10651120          }
    10661121
    10671122        case "-": {
    1068             var l = InterpretRec(node.GetSubtree(0), nodeValues);
    1069             var r = InterpretRec(node.GetSubtree(1), nodeValues);
    1070 
    1071             return Tuple.Create(l.Item1 - r.Item1, l.Item2 - r.Item2);
     1123            InterpretRec(node.GetSubtree(0), nodeValues, out fl, out gl);
     1124            InterpretRec(node.GetSubtree(1), nodeValues, out fr, out gr);
     1125            f = fl - fr;
     1126            g = Vector.Subtract(gl, gr);
     1127            break;
    10721128          }
    10731129        case "%": {
    1074             var l = InterpretRec(node.GetSubtree(0), nodeValues);
    1075             var r = InterpretRec(node.GetSubtree(1), nodeValues);
     1130            InterpretRec(node.GetSubtree(0), nodeValues, out fl, out gl);
     1131            InterpretRec(node.GetSubtree(1), nodeValues, out fr, out gr);
    10761132
    10771133            // protected division
    1078             if (r.Item1.IsAlmost(0.0)) {
    1079               return Tuple.Create(0.0, Vector.Zero);
     1134            if (fr.IsAlmost(0.0)) {
     1135              f = 0;
     1136              g = Vector.Zero;
    10801137            } else {
    1081               return Tuple.Create(
    1082                 l.Item1 / r.Item1,
    1083                 l.Item1 * -1.0 / (r.Item1 * r.Item1) * r.Item2 + 1.0 / r.Item1 * l.Item2 // (f/g)' = f * (1/g)' + 1/g * f' = f * -1/g² * g' + 1/g * f'
    1084                 );
     1138              f = fl / fr;
     1139              g = Vector.AddTo(gr.Scale(fl * -1.0 / (fr * fr)), gl.Scale(1.0 / fr)); // (f/g)' = f * (1/g)' + 1/g * f' = f * -1/g² * g' + 1/g * f'
    10851140            }
     1141            break;
    10861142          }
    10871143        case "sin": {
    1088             var x = InterpretRec(node.GetSubtree(0), nodeValues);
    1089             return Tuple.Create(
    1090               Math.Sin(x.Item1),
    1091               Vector.Cos(x.Item2) * x.Item2
    1092             );
     1144            InterpretRec(node.GetSubtree(0), nodeValues, out fl, out gl);
     1145            f = Math.Sin(fl);
     1146            g = gl.Scale(Math.Cos(fl));
     1147            break;
    10931148          }
    10941149        case "cos": {
    1095             var x = InterpretRec(node.GetSubtree(0), nodeValues);
    1096             return Tuple.Create(
    1097               Math.Cos(x.Item1),
    1098               -Vector.Sin(x.Item2) * x.Item2
    1099             );
     1150            InterpretRec(node.GetSubtree(0), nodeValues, out fl, out gl);
     1151            f = Math.Cos(fl);
     1152            g = gl.Scale(-Math.Sin(fl));
     1153            break;
    11001154          }
    11011155        case "sqr": {
    1102             var x = InterpretRec(node.GetSubtree(0), nodeValues);
    1103             return Tuple.Create(
    1104               x.Item1 * x.Item1,
    1105               2.0 * x.Item1 * x.Item2
    1106             );
     1156            InterpretRec(node.GetSubtree(0), nodeValues, out fl, out gl);
     1157            f = fl * fl;
     1158            g = gl.Scale(2.0 * fl);
     1159            break;
    11071160          }
    11081161        default: {
    1109             return nodeValues[node];  // value and gradient for constants and variables must be set by the caller
     1162            var t = nodeValues[node];
     1163            f = t.Item1;
     1164            g = Vector.CreateNew(t.Item2);
     1165            break;
    11101166          }
    11111167      }
     
    12231279      var newVariablesList = new CheckedItemList<StringValue>(ProblemData.Dataset.VariableNames.Select(str => new StringValue(str).AsReadOnly()).ToArray()).AsReadOnly();
    12241280      var matchingItems = newVariablesList.Where(item => currentlySelectedVariables.Contains(item.Value)).ToArray();
    1225       foreach (var matchingItem in matchingItems) {
    1226         newVariablesList.SetItemCheckedState(matchingItem, true);
     1281      foreach (var item in newVariablesList) {
     1282        if (currentlySelectedVariables.Contains(item.Value)) {
     1283          newVariablesList.SetItemCheckedState(item, true);
     1284        } else {
     1285          newVariablesList.SetItemCheckedState(item, false);
     1286        }
    12271287      }
    12281288      TargetVariablesParameter.Value = newVariablesList;
     
    12441304      // whenever ProblemData is changed we create a new grammar with the necessary symbols
    12451305      var g = new SimpleSymbolicExpressionGrammar();
    1246       g.AddSymbols(FunctionSet.CheckedItems.OrderBy(i => i.Index).Select(i => i.Value.Value).ToArray(), 2, 2);
    1247 
    1248       // TODO
    1249       //g.AddSymbols(new[] {
    1250       //  "exp",
    1251       //  "log", // log( <expr> ) // TODO: init a theta to ensure the value is always positive
    1252       //  "exp_minus" // exp((-1) * <expr>
    1253       //}, 1, 1);
     1306      var unaryFunc = new string[] { "sin", "cos", "sqr" };
     1307      var binaryFunc = new string[] { "+", "-", "*", "%" };
     1308      foreach (var func in unaryFunc) {
     1309        if (FunctionSet.CheckedItems.Any(ci => ci.Value.Value == func)) g.AddSymbol(func, 1, 1);
     1310      }
     1311      foreach (var func in binaryFunc) {
     1312        if (FunctionSet.CheckedItems.Any(ci => ci.Value.Value == func)) g.AddSymbol(func, 2, 2);
     1313      }
    12541314
    12551315      foreach (var variableName in ProblemData.AllowedInputVariables.Union(TargetVariables.CheckedItems.Select(i => i.Value.Value)))
Note: See TracChangeset for help on using the changeset viewer.