Changeset 16602


Ignore:
Timestamp:
02/13/19 13:43:03 (2 months ago)
Author:
gkronber
Message:

#2925: write back optimized constants to trees

Location:
branches/2925_AutoDiffForDynamicalModels
Files:
4 edited

Legend:

Unmodified
Added
Removed
  • branches/2925_AutoDiffForDynamicalModels/AutoDiffForDynamicalModelsTest/TestOdeIdentification.cs

    r16601 r16602  
    1414      var dynProb = new Problem();
    1515      var parser = new HeuristicLab.Problems.Instances.DataAnalysis.TableFileParser();
    16       var fileName = @"C:\reps\HEAL\EuroCAST - Kronberger\DataGeneration\test.csv";
     16      // var fileName = @"C:\reps\HEAL\EuroCAST - Kronberger\DataGeneration\test.csv";
     17      var fileName = @"D:\heal\documents\trunk\Publications\2019\EuroCAST\Kronberger\DataGeneration\test.csv";
    1718      parser.Parse(fileName, true);
    1819      var prov = new HeuristicLab.Problems.Instances.DataAnalysis.RegressionCSVInstanceProvider();
  • branches/2925_AutoDiffForDynamicalModels/HeuristicLab.Problems.DynamicalSystemsModelling/3.3/OdeParameterIdentification.cs

    r16597 r16602  
    177177    public void CreateSolution(Problem problem, string[] modelStructure, int maxIterations, IRandom rand) {
    178178      var parser = new InfixExpressionParser();
    179       var trees = modelStructure.Select(expr => Convert(parser.Parse(expr))).ToArray();
     179      var trees = modelStructure.Select(expr => parser.Parse(expr)).ToArray();
    180180      var names = problem.Encoding.Encodings.Select(enc => enc.Name).ToArray();
    181181      if (trees.Length != names.Length) throw new ArgumentException("The number of expressions must match the number of target variables exactly");
     
    190190    }
    191191
    192     private ISymbolicExpressionTree Convert(ISymbolicExpressionTree tree) {
    193       return new SymbolicExpressionTree(Convert(tree.Root));
    194     }
     192    // private ISymbolicExpressionTree Convert(ISymbolicExpressionTree tree) {
     193    //   return new SymbolicExpressionTree(Convert(tree.Root));
     194    // }
    195195
    196196
    197197    // for translation from symbolic expressions to simple symbols
    198     private static Dictionary<Type, string> sym2str = new Dictionary<Type, string>() {
    199       {typeof(Addition), "+" },
    200       {typeof(Subtraction), "-" },
    201       {typeof(Multiplication), "*" },
    202       {typeof(Sine), "sin" },
    203       {typeof(Cosine), "cos" },
    204       {typeof(Square), "sqr" },
    205     };
    206 
    207     private ISymbolicExpressionTreeNode Convert(ISymbolicExpressionTreeNode node) {
    208       if (sym2str.ContainsKey(node.Symbol.GetType())) {
    209         var children = node.Subtrees.Select(st => Convert(st)).ToArray();
    210         return Make(sym2str[node.Symbol.GetType()], children);
    211       } else if (node.Symbol is ProgramRootSymbol) {
    212         var child = Convert(node.GetSubtree(0));
    213         node.RemoveSubtree(0);
    214         node.AddSubtree(child);
    215         return node;
    216       } else if (node.Symbol is StartSymbol) {
    217         var child = Convert(node.GetSubtree(0));
    218         node.RemoveSubtree(0);
    219         node.AddSubtree(child);
    220         return node;
    221       } else if (node.Symbol is Division) {
    222         var children = node.Subtrees.Select(st => Convert(st)).ToArray();
    223         if (children.Length == 1) {
    224           return Make("%", new[] { new SimpleSymbol("θ", 0).CreateTreeNode(), children[0] });
    225         } else if (children.Length != 2) throw new ArgumentException("Division is not supported for multiple arguments");
    226         else return Make("%", children);
    227       } else if (node.Symbol is Constant) {
    228         return new SimpleSymbol("θ", 0).CreateTreeNode();
    229       } else if (node.Symbol is DataAnalysis.Symbolic.Variable) {
    230         var varNode = node as VariableTreeNode;
    231         if (!varNode.Weight.IsAlmost(1.0)) throw new ArgumentException("Variable weights are not supported");
    232         return new SimpleSymbol(varNode.VariableName, 0).CreateTreeNode();
    233       } else throw new ArgumentException("Unsupported symbol: " + node.Symbol.Name);
    234     }
    235 
    236     private ISymbolicExpressionTreeNode Make(string op, ISymbolicExpressionTreeNode[] children) {
    237       if (children.Length == 1) {
    238         var s = new SimpleSymbol(op, 1).CreateTreeNode();
    239         s.AddSubtree(children.First());
    240         return s;
    241       } else {
    242         var s = new SimpleSymbol(op, 2).CreateTreeNode();
    243         var c0 = children[0];
    244         var c1 = children[1];
    245         s.AddSubtree(c0);
    246         s.AddSubtree(c1);
    247         for (int i = 2; i < children.Length; i++) {
    248           var sn = new SimpleSymbol(op, 2).CreateTreeNode();
    249           sn.AddSubtree(s);
    250           sn.AddSubtree(children[i]);
    251           s = sn;
    252         }
    253         return s;
    254       }
    255     }
     198    // private static Dictionary<Type, string> sym2str = new Dictionary<Type, string>() {
     199    //   {typeof(Addition), "+" },
     200    //   {typeof(Subtraction), "-" },
     201    //   {typeof(Multiplication), "*" },
     202    //   {typeof(Sine), "sin" },
     203    //   {typeof(Cosine), "cos" },
     204    //   {typeof(Square), "sqr" },
     205    // };
     206
     207    // private ISymbolicExpressionTreeNode Convert(ISymbolicExpressionTreeNode node) {
     208    //   if (sym2str.ContainsKey(node.Symbol.GetType())) {
     209    //     var children = node.Subtrees.Select(st => Convert(st)).ToArray();
     210    //     return Make(sym2str[node.Symbol.GetType()], children);
     211    //   } else if (node.Symbol is ProgramRootSymbol) {
     212    //     var child = Convert(node.GetSubtree(0));
     213    //     node.RemoveSubtree(0);
     214    //     node.AddSubtree(child);
     215    //     return node;
     216    //   } else if (node.Symbol is StartSymbol) {
     217    //     var child = Convert(node.GetSubtree(0));
     218    //     node.RemoveSubtree(0);
     219    //     node.AddSubtree(child);
     220    //     return node;
     221    //   } else if (node.Symbol is Division) {
     222    //     var children = node.Subtrees.Select(st => Convert(st)).ToArray();
     223    //     if (children.Length == 1) {
     224    //       return Make("%", new[] { new SimpleSymbol("θ", 0).CreateTreeNode(), children[0] });
     225    //     } else if (children.Length != 2) throw new ArgumentException("Division is not supported for multiple arguments");
     226    //     else return Make("%", children);
     227    //   } else if (node.Symbol is Constant) {
     228    //     return new SimpleSymbol("θ", 0).CreateTreeNode();
     229    //   } else if (node.Symbol is DataAnalysis.Symbolic.Variable) {
     230    //     var varNode = node as VariableTreeNode;
     231    //     if (!varNode.Weight.IsAlmost(1.0)) throw new ArgumentException("Variable weights are not supported");
     232    //     return new SimpleSymbol(varNode.VariableName, 0).CreateTreeNode();
     233    //   } else throw new ArgumentException("Unsupported symbol: " + node.Symbol.Name);
     234    // }
     235
     236    // private ISymbolicExpressionTreeNode Make(string op, ISymbolicExpressionTreeNode[] children) {
     237    //   if (children.Length == 1) {
     238    //     var s = new SimpleSymbol(op, 1).CreateTreeNode();
     239    //     s.AddSubtree(children.First());
     240    //     return s;
     241    //   } else {
     242    //     var s = new SimpleSymbol(op, 2).CreateTreeNode();
     243    //     var c0 = children[0];
     244    //     var c1 = children[1];
     245    //     s.AddSubtree(c0);
     246    //     s.AddSubtree(c1);
     247    //     for (int i = 2; i < children.Length; i++) {
     248    //       var sn = new SimpleSymbol(op, 2).CreateTreeNode();
     249    //       sn.AddSubtree(s);
     250    //       sn.AddSubtree(children[i]);
     251    //       s = sn;
     252    //     }
     253    //     return s;
     254    //   }
     255    // }
    256256    #endregion
    257257  }
  • branches/2925_AutoDiffForDynamicalModels/HeuristicLab.Problems.DynamicalSystemsModelling/3.3/Problem.cs

    r16601 r16602  
    7171      get { return (IFixedValueParameter<IntValue>)Parameters[MaximumLengthParameterName]; }
    7272    }
     73
    7374    public IFixedValueParameter<IntValue> MaximumParameterOptimizationIterationsParameter {
    7475      get { return (IFixedValueParameter<IntValue>)Parameters[MaximumParameterOptimizationIterationsParameterName]; }
     
    196197    public override double Evaluate(Individual individual, IRandom random) {
    197198      var trees = individual.Values.Select(v => v.Value).OfType<ISymbolicExpressionTree>().ToArray(); // extract all trees from individual
    198                                                                                                       // write back optimized parameters to tree nodes instead of the separate OptTheta variable
    199                                                                                                       // retreive optimized parameters from nodes?
    200199
    201200      var problemData = ProblemData;
     
    208207        int totalSize = 0;
    209208        foreach (var episode in TrainingEpisodes) {
    210           double[] optTheta;
    211           double nmse;
    212           OptimizeForEpisodes(trees, problemData, targetVars, latentVariables, random, new[] { episode }, MaximumParameterOptimizationIterations, NumericIntegrationSteps, OdeSolver, out optTheta, out nmse);
    213           individual["OptTheta_" + eIdx] = new DoubleArray(optTheta); // write back optimized parameters so that we can use them in the Analysis method
     209          // double[] optTheta;
     210          double nmse = OptimizeForEpisodes(trees, problemData, targetVars, latentVariables, random, new[] { episode }, MaximumParameterOptimizationIterations, NumericIntegrationSteps, OdeSolver);
     211          // individual["OptTheta_" + eIdx] = new DoubleArray(optTheta); // write back optimized parameters so that we can use them in the Analysis method
    214212          eIdx++;
    215213          totalNMSE += nmse * episode.Size;
     
    218216        return totalNMSE / totalSize;
    219217      } else {
    220         double[] optTheta;
    221         double nmse;
    222         OptimizeForEpisodes(trees, problemData, targetVars, latentVariables, random, TrainingEpisodes, MaximumParameterOptimizationIterations, NumericIntegrationSteps, OdeSolver, out optTheta, out nmse);
    223         individual["OptTheta"] = new DoubleArray(optTheta); // write back optimized parameters so that we can use them in the Analysis method
     218        // double[] optTheta;
     219        double nmse = OptimizeForEpisodes(trees, problemData, targetVars, latentVariables, random, TrainingEpisodes, MaximumParameterOptimizationIterations, NumericIntegrationSteps, OdeSolver);
     220        // individual["OptTheta"] = new DoubleArray(optTheta); // write back optimized parameters so that we can use them in the Analysis method
    224221        return nmse;
    225222      }
    226223    }
    227224
    228     public static void OptimizeForEpisodes(
     225    public static double OptimizeForEpisodes(
    229226      ISymbolicExpressionTree[] trees,
    230227      IRegressionProblemData problemData,
     
    235232      int maxParameterOptIterations,
    236233      int numericIntegrationSteps,
    237       string odeSolver,
    238       out double[] optTheta,
    239       out double nmse) {
    240 
     234      string odeSolver) {
     235
     236      var constantNodes = trees.Select(t => t.IterateNodesPrefix().OfType<ConstantTreeNode>().ToArray()).ToArray();
     237      var initialTheta = constantNodes.Select(nodes => nodes.Select(n => n.Value).ToArray()).ToArray();
    241238
    242239      // optimize parameters by fitting f(x,y) to calculated differences dy/dt(t)
    243       nmse = PreTuneParameters(trees, problemData, targetVars, latentVariables, random, episodes, maxParameterOptIterations, out optTheta);
     240      double nmse = PreTuneParameters(trees, problemData, targetVars, latentVariables, random, episodes, maxParameterOptIterations,
     241        initialTheta, out double[] pretunedParameters);
    244242
    245243      // optimize parameters using integration of f(x,y) to calculate y(t)
    246       nmse = OptimizeParameters(trees, problemData, targetVars, latentVariables, episodes, maxParameterOptIterations, optTheta, numericIntegrationSteps, odeSolver, out optTheta);
    247 
    248       if (double.IsNaN(nmse) || double.IsInfinity(nmse)) nmse = 100 * trees.Length * episodes.Sum(ep => ep.Size);
     244      nmse = OptimizeParameters(trees, problemData, targetVars, latentVariables, episodes, maxParameterOptIterations, pretunedParameters, numericIntegrationSteps, odeSolver,
     245        out double[] optTheta);
     246
     247      if (double.IsNaN(nmse) ||
     248        double.IsInfinity(nmse) ||
     249        nmse > 100 * trees.Length * episodes.Sum(ep => ep.Size))
     250        return 100 * trees.Length * episodes.Sum(ep => ep.Size);
     251
     252      // update tree nodes with optimized values
     253      var paramIdx = 0;
     254      for (var treeIdx = 0; treeIdx < constantNodes.Length; treeIdx++) {
     255        for (int i = 0; i < constantNodes[treeIdx].Length; i++)
     256          constantNodes[treeIdx][i].Value = optTheta[paramIdx++];
     257      }
     258      return nmse;
    249259    }
    250260
     
    257267      IEnumerable<IntRange> episodes,
    258268      int maxParameterOptIterations,
     269      double[][] initialTheta,
    259270      out double[] optTheta) {
    260271      var thetas = new List<double>();
    261272      double nmse = 0.0;
     273      var maxTreeNmse = 100 * episodes.Sum(ep => ep.Size);
     274
    262275      // NOTE: the order of values in parameter matches prefix order of constant nodes in trees
    263276      for (int treeIdx = 0; treeIdx < trees.Length; treeIdx++) {
     
    272285        var paramCount = myState.nodeValueLookup.ParameterCount;
    273286
    274         // init params randomly
    275         // theta contains parameter values for trees and then the initial values for latent variables (a separate vector for each episode)
    276         // inital values for latent variables are also optimized
    277         var theta = new double[paramCount + latentVariables.Length * episodes.Count()];
    278         for (int i = 0; i < theta.Length; i++)
    279           theta[i] = random.NextDouble() * 2.0e-1 - 1.0e-1;
    280 
    281287        optTheta = new double[0];
    282         if (theta.Length > 0) {
    283           alglib.minlmstate state;
    284           alglib.minlmreport report;
    285           alglib.minlmcreatevj(targetValuesDiff.Length, theta, out state);
    286           alglib.minlmsetcond(state, 0.0, 0.0, 0.0, maxParameterOptIterations);
    287           // alglib.minlmsetgradientcheck(state, 1.0e-3);
    288           alglib.minlmoptimize(state, EvaluateObjectiveVector, EvaluateObjectiveVectorAndJacobian, null, myState);
    289 
    290           alglib.minlmresults(state, out optTheta, out report);
    291 
    292 
    293           if (report.terminationtype < 0) { throw new InvalidOperationException("there was a problem in the optimizer"); }
    294 
     288        if (initialTheta[treeIdx].Length > 0) {
     289          try {
     290            alglib.minlmstate state;
     291            alglib.minlmreport report;
     292            var p = new double[initialTheta[treeIdx].Length];
     293            var lowerBounds = Enumerable.Repeat(-100.0, p.Length).ToArray();
     294            var upperBounds = Enumerable.Repeat(100.0, p.Length).ToArray();
     295            Array.Copy(initialTheta[treeIdx], p, p.Length);
     296            alglib.minlmcreatevj(targetValuesDiff.Length, p, out state);
     297            alglib.minlmsetcond(state, 0.0, 0.0, 0.0, maxParameterOptIterations);
     298            alglib.minlmsetbc(state, lowerBounds, upperBounds);
     299            // alglib.minlmsetgradientcheck(state, 1.0e-3);
     300            alglib.minlmoptimize(state, EvaluateObjectiveVector, EvaluateObjectiveVectorAndJacobian, null, myState);
     301
     302            alglib.minlmresults(state, out optTheta, out report);
     303            if (report.terminationtype < 0) { optTheta = initialTheta[treeIdx]; }
     304          } catch (alglib.alglibexception) {
     305            optTheta = initialTheta[treeIdx];
     306          }
     307        }
     308        var tree_nmse = EvaluateMSE(optTheta, myState);
     309        if (double.IsNaN(tree_nmse) || double.IsInfinity(tree_nmse) || tree_nmse > maxTreeNmse) {
     310          nmse += maxTreeNmse;
     311          thetas.AddRange(initialTheta[treeIdx]);
     312        } else {
     313          nmse += tree_nmse;
    295314          thetas.AddRange(optTheta);
    296315        }
    297         nmse += EvaluateMSE(optTheta, myState);
    298316      } // foreach tree
    299317      optTheta = thetas.ToArray();
    300318
    301       var maxNmse = 100 * trees.Length * episodes.Sum(ep => ep.Size);
    302       if (double.IsNaN(nmse) || double.IsInfinity(nmse) || nmse > maxNmse) nmse = maxNmse;
    303319      return nmse;
    304320    }
     
    321337
    322338      if (initialTheta.Length > 0) {
    323 
     339        var lowerBounds = Enumerable.Repeat(-100.0, initialTheta.Length).ToArray();
     340        var upperBounds = Enumerable.Repeat(100.0, initialTheta.Length).ToArray();
    324341        try {
    325342          alglib.minlmstate state;
    326343          alglib.minlmreport report;
    327344          alglib.minlmcreatevj(rowsForDataExtraction.Length * trees.Length, initialTheta, out state);
     345          alglib.minlmsetbc(state, lowerBounds, upperBounds);
    328346          alglib.minlmsetcond(state, 0.0, 0.0, 0.0, maxParameterOptIterations);
    329347          // alglib.minlmsetgradientcheck(state, 1.0e-3);
     
    490508
    491509      if (OptimizeParametersForEpisodes) {
     510        throw new NotSupportedException();
    492511        var eIdx = 0;
    493512        var trainingPredictions = new List<Tuple<double, Vector>[][]>();
     
    527546        results["Models"].Value = models;
    528547      } else {
    529         var optTheta = ((DoubleArray)bestIndividualAndQuality.Item1["OptTheta"]).ToArray(); // see evaluate
     548        var optTheta = Problem.ExtractParametersFromTrees(trees);
    530549        var optimizationData = new OptimizationData(trees, targetVars, problemData, null, TrainingEpisodes.ToArray(), NumericIntegrationSteps, latentVariables, OdeSolver);
    531550        var trainingPrediction = Integrate(optimizationData, optTheta).ToArray();
     
    629648        for (int idx = 0; idx < trees.Length; idx++) {
    630649          var tree = trees[idx];
    631           optimizedTrees.Add(new SymbolicExpressionTree(FixParameters(tree.Root, optTheta.ToArray(), ref nextParIdx)));
     650          // optimizedTrees.Add(new SymbolicExpressionTree(FixParameters(tree.Root, optTheta.ToArray(), ref nextParIdx)));
     651          optimizedTrees.Add(tree);
    632652        }
    633653        var ds = problemData.Dataset;
     
    660680          var tree = trees[idx];
    661681
    662           // when we reference HeuristicLab.Problems.DataAnalysis.Symbolic we can translate symbols
    663           var shownTree = new SymbolicExpressionTree(TranslateTreeNode(tree.Root, optTheta.ToArray(),
    664             ref nextParIdx));
    665 
    666 
    667682          var origTreeVar = new HeuristicLab.Core.Variable(varName + "(original)");
    668683          origTreeVar.Value = (ISymbolicExpressionTree)tree.Clone();
    669684          models.Add(origTreeVar);
    670685          var simplifiedTreeVar = new HeuristicLab.Core.Variable(varName + "(simplified)");
    671           simplifiedTreeVar.Value = TreeSimplifier.Simplify(shownTree);
     686          simplifiedTreeVar.Value = TreeSimplifier.Simplify(tree);
    672687          models.Add(simplifiedTreeVar);
    673688
     
    678693      }
    679694    }
     695
     696
     697    public static double[] ExtractParametersFromTrees(ISymbolicExpressionTree[] trees) {
     698      return trees
     699        .SelectMany(t => t.IterateNodesPrefix().OfType<ConstantTreeNode>().Select(n => n.Value))
     700        .ToArray();
     701    }
     702
    680703
    681704
     
    11041127    }
    11051128
     1129    // TODO: use an existing interpreter implementation instead
    11061130    private static double InterpretRec(ISymbolicExpressionTreeNode node, NodeValueLookup nodeValues) {
    1107       switch (node.Symbol.Name) {
    1108         case "+": {
    1109             var f = InterpretRec(node.GetSubtree(0), nodeValues);
    1110             var g = InterpretRec(node.GetSubtree(1), nodeValues);
    1111             return f + g;
    1112           }
    1113         case "*": {
    1114             var f = InterpretRec(node.GetSubtree(0), nodeValues);
    1115             var g = InterpretRec(node.GetSubtree(1), nodeValues);
    1116             return f * g;
    1117           }
    1118 
    1119         case "-": {
    1120             if (node.SubtreeCount == 1) {
    1121               var f = InterpretRec(node.GetSubtree(0), nodeValues);
    1122               return -f;
    1123             } else {
    1124               var f = InterpretRec(node.GetSubtree(0), nodeValues);
    1125               var g = InterpretRec(node.GetSubtree(1), nodeValues);
    1126 
    1127               return f - g;
    1128             }
    1129           }
    1130         case "%": {
    1131             var f = InterpretRec(node.GetSubtree(0), nodeValues);
    1132             var g = InterpretRec(node.GetSubtree(1), nodeValues);
    1133 
    1134             // protected division
    1135             if (g.IsAlmost(0.0)) {
    1136               return 0;
    1137             } else {
    1138               return f / g;
    1139             }
    1140           }
    1141         case "sin": {
    1142             var f = InterpretRec(node.GetSubtree(0), nodeValues);
    1143             return Math.Sin(f);
    1144           }
    1145         case "cos": {
    1146             var f = InterpretRec(node.GetSubtree(0), nodeValues);
    1147             return Math.Cos(f);
    1148           }
    1149         case "sqr": {
    1150             var f = InterpretRec(node.GetSubtree(0), nodeValues);
    1151             return f * f;
    1152           }
    1153         default: {
    1154             return nodeValues.NodeValue(node);
    1155           }
    1156       }
     1131      if (node.Symbol is Constant || node.Symbol is Variable) {
     1132        return nodeValues.NodeValue(node);
     1133      } else if (node.Symbol is Addition) {
     1134        var f = InterpretRec(node.GetSubtree(0), nodeValues);
     1135        var g = InterpretRec(node.GetSubtree(1), nodeValues);
     1136        return f + g;
     1137      } else if (node.Symbol is Multiplication) {
     1138        var f = InterpretRec(node.GetSubtree(0), nodeValues);
     1139        var g = InterpretRec(node.GetSubtree(1), nodeValues);
     1140        return f * g;
     1141      } else if (node.Symbol is Subtraction) {
     1142        if (node.SubtreeCount == 1) {
     1143          var f = InterpretRec(node.GetSubtree(0), nodeValues);
     1144          return -f;
     1145        } else {
     1146          var f = InterpretRec(node.GetSubtree(0), nodeValues);
     1147          var g = InterpretRec(node.GetSubtree(1), nodeValues);
     1148
     1149          return f - g;
     1150        }
     1151      } else if (node.Symbol is Division) {
     1152        var f = InterpretRec(node.GetSubtree(0), nodeValues);
     1153        var g = InterpretRec(node.GetSubtree(1), nodeValues);
     1154
     1155        // protected division
     1156        if (g.IsAlmost(0.0)) {
     1157          return 0;
     1158        } else {
     1159          return f / g;
     1160        }
     1161      } else if (node.Symbol is Sine) {
     1162        var f = InterpretRec(node.GetSubtree(0), nodeValues);
     1163        return Math.Sin(f);
     1164      } else if (node.Symbol is Cosine) {
     1165        var f = InterpretRec(node.GetSubtree(0), nodeValues);
     1166        return Math.Cos(f);
     1167      } else if (node.Symbol is Square) {
     1168        var f = InterpretRec(node.GetSubtree(0), nodeValues);
     1169        return f * f;
     1170      } else throw new NotSupportedException("unsupported symbol");
    11571171    }
    11581172
     
    11651179      double f, g;
    11661180      Vector df, dg;
    1167       switch (node.Symbol.Name) {
    1168         case "+": {
    1169             InterpretRec(node.GetSubtree(0), nodeValues, out f, out df);
    1170             InterpretRec(node.GetSubtree(1), nodeValues, out g, out dg);
    1171             z = f + g;
    1172             dz = df + dg; // Vector.AddTo(gl, gr);
    1173             break;
    1174           }
    1175         case "*": {
    1176             InterpretRec(node.GetSubtree(0), nodeValues, out f, out df);
    1177             InterpretRec(node.GetSubtree(1), nodeValues, out g, out dg);
    1178             z = f * g;
    1179             dz = df * g + f * dg;  // Vector.AddTo(gl.Scale(fr), gr.Scale(fl)); // f'*g + f*g'
    1180             break;
    1181           }
    1182 
    1183         case "-": {
    1184             if (node.SubtreeCount == 1) {
    1185               InterpretRec(node.GetSubtree(0), nodeValues, out f, out df);
    1186               z = -f;
    1187               dz = -df;
    1188             } else {
    1189               InterpretRec(node.GetSubtree(0), nodeValues, out f, out df);
    1190               InterpretRec(node.GetSubtree(1), nodeValues, out g, out dg);
    1191 
    1192               z = f - g;
    1193               dz = df - dg;
    1194             }
    1195             break;
    1196           }
    1197         case "%": {
    1198             InterpretRec(node.GetSubtree(0), nodeValues, out f, out df);
    1199             InterpretRec(node.GetSubtree(1), nodeValues, out g, out dg);
    1200 
    1201             // protected division
    1202             if (g.IsAlmost(0.0)) {
    1203               z = 0;
    1204               dz = Vector.Zero;
    1205             } else {
    1206               z = f / g;
    1207               dz = -f / (g * g) * dg + df / g; // Vector.AddTo(dg.Scale(f * -1.0 / (g * g)), df.Scale(1.0 / g)); // (f/g)' = f * (1/g)' + 1/g * f' = f * -1/g² * g' + 1/g * f'
    1208             }
    1209             break;
    1210           }
    1211         case "sin": {
    1212             InterpretRec(node.GetSubtree(0), nodeValues, out f, out df);
    1213             z = Math.Sin(f);
    1214             dz = Math.Cos(f) * df;
    1215             break;
    1216           }
    1217         case "cos": {
    1218             InterpretRec(node.GetSubtree(0), nodeValues, out f, out df);
    1219             z = Math.Cos(f);
    1220             dz = -Math.Sin(f) * df;
    1221             break;
    1222           }
    1223         case "sqr": {
    1224             InterpretRec(node.GetSubtree(0), nodeValues, out f, out df);
    1225             z = f * f;
    1226             dz = 2.0 * f * df;
    1227             break;
    1228           }
    1229         default: {
    1230             z = nodeValues.NodeValue(node);
    1231             dz = Vector.CreateNew(nodeValues.NodeGradient(node));
    1232             break;
    1233           }
    1234       }
    1235     }
     1181      if (node.Symbol is Constant || node.Symbol is Variable) {
     1182        z = nodeValues.NodeValue(node);
     1183        dz = Vector.CreateNew(nodeValues.NodeGradient(node));
     1184      } else if (node.Symbol is Addition) {
     1185        InterpretRec(node.GetSubtree(0), nodeValues, out f, out df);
     1186        InterpretRec(node.GetSubtree(1), nodeValues, out g, out dg);
     1187        z = f + g;
     1188        dz = df + dg; // Vector.AddTo(gl, gr);
     1189
     1190      } else if (node.Symbol is Multiplication) {
     1191        InterpretRec(node.GetSubtree(0), nodeValues, out f, out df);
     1192        InterpretRec(node.GetSubtree(1), nodeValues, out g, out dg);
     1193        z = f * g;
     1194        dz = df * g + f * dg;  // Vector.AddTo(gl.Scale(fr), gr.Scale(fl)); // f'*g + f*g'
     1195
     1196      } else if (node.Symbol is Subtraction) {
     1197        if (node.SubtreeCount == 1) {
     1198          InterpretRec(node.GetSubtree(0), nodeValues, out f, out df);
     1199          z = -f;
     1200          dz = -df;
     1201        } else {
     1202          InterpretRec(node.GetSubtree(0), nodeValues, out f, out df);
     1203          InterpretRec(node.GetSubtree(1), nodeValues, out g, out dg);
     1204
     1205          z = f - g;
     1206          dz = df - dg;
     1207        }
     1208
     1209      } else if (node.Symbol is Division) {
     1210        InterpretRec(node.GetSubtree(0), nodeValues, out f, out df);
     1211        InterpretRec(node.GetSubtree(1), nodeValues, out g, out dg);
     1212
     1213        // protected division
     1214        if (g.IsAlmost(0.0)) {
     1215          z = 0;
     1216          dz = Vector.Zero;
     1217        } else {
     1218          z = f / g;
     1219          dz = -f / (g * g) * dg + df / g; // Vector.AddTo(dg.Scale(f * -1.0 / (g * g)), df.Scale(1.0 / g)); // (f/g)' = f * (1/g)' + 1/g * f' = f * -1/g² * g' + 1/g * f'
     1220        }
     1221
     1222      } else if (node.Symbol is Sine) {
     1223        InterpretRec(node.GetSubtree(0), nodeValues, out f, out df);
     1224        z = Math.Sin(f);
     1225        dz = Math.Cos(f) * df;
     1226
     1227      } else if (node.Symbol is Cosine) {
     1228        InterpretRec(node.GetSubtree(0), nodeValues, out f, out df);
     1229        z = Math.Cos(f);
     1230        dz = -Math.Sin(f) * df;
     1231      } else if (node.Symbol is Square) {
     1232        InterpretRec(node.GetSubtree(0), nodeValues, out f, out df);
     1233        z = f * f;
     1234        dz = 2.0 * f * df;
     1235      } else {
     1236        throw new NotSupportedException("unsupported symbol");
     1237      }
     1238    }
     1239
    12361240    #endregion
    12371241
     
    13121316
    13131317    private void InitAllParameters() {
    1314       UpdateTargetVariables(); // implicitly updates the grammar and the encoding     
     1318      UpdateTargetVariables(); // implicitly updates the grammar and the encoding
    13151319    }
    13161320
    13171321    private ReadOnlyCheckedItemList<StringValue> CreateFunctionSet() {
    13181322      var l = new CheckedItemList<StringValue>();
    1319       l.Add(new StringValue("+").AsReadOnly());
    1320       l.Add(new StringValue("*").AsReadOnly());
    1321       l.Add(new StringValue("%").AsReadOnly());
    1322       l.Add(new StringValue("-").AsReadOnly());
    1323       l.Add(new StringValue("sin").AsReadOnly());
    1324       l.Add(new StringValue("cos").AsReadOnly());
    1325       l.Add(new StringValue("sqr").AsReadOnly());
     1323      l.Add(new StringValue("Addition").AsReadOnly());
     1324      l.Add(new StringValue("Multiplication").AsReadOnly());
     1325      l.Add(new StringValue("Division").AsReadOnly());
     1326      l.Add(new StringValue("Subtraction").AsReadOnly());
     1327      l.Add(new StringValue("Sine").AsReadOnly());
     1328      l.Add(new StringValue("Cosine").AsReadOnly());
     1329      l.Add(new StringValue("Square").AsReadOnly());
    13261330      return l.AsReadOnly();
    13271331    }
    13281332
    13291333    private static bool IsConstantNode(ISymbolicExpressionTreeNode n) {
    1330       return n.Symbol.Name[0] == 'θ';
     1334      // return n.Symbol.Name[0] == 'θ';
     1335      return n is ConstantTreeNode;
    13311336    }
    13321337    private static double GetConstantValue(ISymbolicExpressionTreeNode n) {
    1333       return 0.0; // TODO: needs to be updated when we write back values to the tree
     1338      return ((ConstantTreeNode)n).Value;
    13341339    }
    13351340    private static bool IsLatentVariableNode(ISymbolicExpressionTreeNode n) {
     
    13401345    }
    13411346    private static string GetVariableName(ISymbolicExpressionTreeNode n) {
    1342       return n.Symbol.Name;
    1343     }
    1344 
     1347      return ((VariableTreeNode)n).VariableName;
     1348    }
    13451349
    13461350    private void UpdateTargetVariables() {
     
    13751379
    13761380    private ISymbolicExpressionGrammar CreateGrammar() {
    1377       // whenever ProblemData is changed we create a new grammar with the necessary symbols
    1378       var g = new SimpleSymbolicExpressionGrammar();
    1379       var unaryFunc = new string[] { "sin", "cos", "sqr" };
    1380       var binaryFunc = new string[] { "+", "-", "*", "%" };
    1381       foreach (var func in unaryFunc) {
    1382         if (FunctionSet.CheckedItems.Any(ci => ci.Value.Value == func)) g.AddSymbol(func, 1, 1);
    1383       }
    1384       foreach (var func in binaryFunc) {
    1385         if (FunctionSet.CheckedItems.Any(ci => ci.Value.Value == func)) g.AddSymbol(func, 2, 2);
    1386       }
    1387 
    1388       foreach (var variableName in ProblemData.AllowedInputVariables.Union(TargetVariables.CheckedItems.Select(i => i.Value.Value)))
    1389         g.AddTerminalSymbol(variableName);
    1390 
    1391       // generate symbols for numeric parameters for which the value is optimized using AutoDiff
    1392       // we generate multiple symbols to balance the probability for selecting a numeric parameter in the generation of random trees
    1393       var numericConstantsFactor = 2.0;
    1394       for (int i = 0; i < numericConstantsFactor * (ProblemData.AllowedInputVariables.Count() + TargetVariables.CheckedItems.Count()); i++) {
    1395         g.AddTerminalSymbol("θ" + i); // numeric parameter for which the value is optimized using AutoDiff
    1396       }
    1397 
    1398       // generate symbols for latent variables
    1399       for (int i = 1; i <= NumberOfLatentVariables; i++) {
    1400         g.AddTerminalSymbol("λ" + i); // numeric parameter for which the value is optimized using AutoDiff
    1401       }
    1402 
    1403       return g;
    1404     }
    1405 
    1406 
    1407 
    1408 
    1409 
    1410     private ISymbolicExpressionTreeNode FixParameters(ISymbolicExpressionTreeNode n, double[] parameterValues, ref int nextParIdx) {
    1411       ISymbolicExpressionTreeNode translatedNode = null;
    1412       if (n.Symbol is StartSymbol) {
    1413         translatedNode = new StartSymbol().CreateTreeNode();
    1414       } else if (n.Symbol is ProgramRootSymbol) {
    1415         translatedNode = new ProgramRootSymbol().CreateTreeNode();
    1416       } else if (n.Symbol.Name == "+") {
    1417         translatedNode = new SimpleSymbol("+", 2).CreateTreeNode();
    1418       } else if (n.Symbol.Name == "-") {
    1419         translatedNode = new SimpleSymbol("-", 2).CreateTreeNode();
    1420       } else if (n.Symbol.Name == "*") {
    1421         translatedNode = new SimpleSymbol("*", 2).CreateTreeNode();
    1422       } else if (n.Symbol.Name == "%") {
    1423         translatedNode = new SimpleSymbol("%", 2).CreateTreeNode();
    1424       } else if (n.Symbol.Name == "sin") {
    1425         translatedNode = new SimpleSymbol("sin", 1).CreateTreeNode();
    1426       } else if (n.Symbol.Name == "cos") {
    1427         translatedNode = new SimpleSymbol("cos", 1).CreateTreeNode();
    1428       } else if (n.Symbol.Name == "sqr") {
    1429         translatedNode = new SimpleSymbol("sqr", 1).CreateTreeNode();
    1430       } else if (IsConstantNode(n)) {
    1431         translatedNode = new SimpleSymbol("c_" + nextParIdx, 0).CreateTreeNode();
    1432         nextParIdx++;
    1433       } else {
    1434         translatedNode = new SimpleSymbol(n.Symbol.Name, n.SubtreeCount).CreateTreeNode();
    1435       }
    1436       foreach (var child in n.Subtrees) {
    1437         translatedNode.AddSubtree(FixParameters(child, parameterValues, ref nextParIdx));
    1438       }
    1439       return translatedNode;
    1440     }
    1441 
    1442 
    1443     private ISymbolicExpressionTreeNode TranslateTreeNode(ISymbolicExpressionTreeNode n, double[] parameterValues, ref int nextParIdx) {
    1444       ISymbolicExpressionTreeNode translatedNode = null;
    1445       if (n.Symbol is StartSymbol) {
    1446         translatedNode = new StartSymbol().CreateTreeNode();
    1447       } else if (n.Symbol is ProgramRootSymbol) {
    1448         translatedNode = new ProgramRootSymbol().CreateTreeNode();
    1449       } else if (n.Symbol.Name == "+") {
    1450         translatedNode = new Addition().CreateTreeNode();
    1451       } else if (n.Symbol.Name == "-") {
    1452         translatedNode = new Subtraction().CreateTreeNode();
    1453       } else if (n.Symbol.Name == "*") {
    1454         translatedNode = new Multiplication().CreateTreeNode();
    1455       } else if (n.Symbol.Name == "%") {
    1456         translatedNode = new Division().CreateTreeNode();
    1457       } else if (n.Symbol.Name == "sin") {
    1458         translatedNode = new Sine().CreateTreeNode();
    1459       } else if (n.Symbol.Name == "cos") {
    1460         translatedNode = new Cosine().CreateTreeNode();
    1461       } else if (n.Symbol.Name == "sqr") {
    1462         translatedNode = new Square().CreateTreeNode();
    1463       } else if (IsConstantNode(n)) {
    1464         var constNode = (ConstantTreeNode)new Constant().CreateTreeNode();
    1465         constNode.Value = parameterValues[nextParIdx];
    1466         nextParIdx++;
    1467         translatedNode = constNode;
    1468       } else {
    1469         // assume a variable name
    1470         var varName = n.Symbol.Name;
    1471         var varNode = (VariableTreeNode)new Variable().CreateTreeNode();
    1472         varNode.Weight = 1.0;
    1473         varNode.VariableName = varName;
    1474         translatedNode = varNode;
    1475       }
    1476       foreach (var child in n.Subtrees) {
    1477         translatedNode.AddSubtree(TranslateTreeNode(child, parameterValues, ref nextParIdx));
    1478       }
    1479       return translatedNode;
     1381      var grammar = new TypeCoherentExpressionGrammar();
     1382      grammar.StartGrammarManipulation();
     1383
     1384      var problemData = ProblemData;
     1385      var ds = problemData.Dataset;
     1386      grammar.MaximumFunctionArguments = 0;
     1387      grammar.MaximumFunctionDefinitions = 0;
     1388      var allowedVariables = problemData.AllowedInputVariables.Concat(TargetVariables.CheckedItems.Select(chk => chk.Value.Value));
     1389      foreach (var varSymbol in grammar.Symbols.OfType<HeuristicLab.Problems.DataAnalysis.Symbolic.VariableBase>()) {
     1390        if (!varSymbol.Fixed) {
     1391          varSymbol.AllVariableNames = problemData.InputVariables.Select(x => x.Value).Where(x => ds.VariableHasType<double>(x));
     1392          varSymbol.VariableNames = allowedVariables.Where(x => ds.VariableHasType<double>(x));
     1393        }
     1394      }
     1395      foreach (var factorSymbol in grammar.Symbols.OfType<BinaryFactorVariable>()) {
     1396        if (!factorSymbol.Fixed) {
     1397          factorSymbol.AllVariableNames = problemData.InputVariables.Select(x => x.Value).Where(x => ds.VariableHasType<string>(x));
     1398          factorSymbol.VariableNames = problemData.AllowedInputVariables.Where(x => ds.VariableHasType<string>(x));
     1399          factorSymbol.VariableValues = factorSymbol.VariableNames
     1400            .ToDictionary(varName => varName, varName => ds.GetStringValues(varName).Distinct().ToList());
     1401        }
     1402      }
     1403      foreach (var factorSymbol in grammar.Symbols.OfType<FactorVariable>()) {
     1404        if (!factorSymbol.Fixed) {
     1405          factorSymbol.AllVariableNames = problemData.InputVariables.Select(x => x.Value).Where(x => ds.VariableHasType<string>(x));
     1406          factorSymbol.VariableNames = problemData.AllowedInputVariables.Where(x => ds.VariableHasType<string>(x));
     1407          factorSymbol.VariableValues = factorSymbol.VariableNames
     1408            .ToDictionary(varName => varName,
     1409            varName => ds.GetStringValues(varName).Distinct()
     1410            .Select((n, i) => Tuple.Create(n, i))
     1411            .ToDictionary(tup => tup.Item1, tup => tup.Item2));
     1412        }
     1413      }
     1414
     1415      grammar.ConfigureAsDefaultRegressionGrammar();
     1416      grammar.GetSymbol("Logarithm").Enabled = false; // not supported yet
     1417      grammar.GetSymbol("Exponential").Enabled = false; // not supported yet
     1418
     1419      // configure initialization of constants
     1420      var constSy = (Constant)grammar.GetSymbol("Constant");
     1421      // max and min are only relevant for initialization
     1422      constSy.MaxValue = +1.0e-1; // small initial values for constant opt
     1423      constSy.MinValue = -1.0e-1;
     1424      constSy.MultiplicativeManipulatorSigma = 1.0; // allow large jumps for manipulation
     1425      constSy.ManipulatorMu = 0.0;
     1426      constSy.ManipulatorSigma = 1.0; // allow large jumps
     1427
     1428      // configure initialization of variables
     1429      var varSy = (Variable)grammar.GetSymbol("Variable");
     1430      // fix variable weights to 1.0
     1431      varSy.WeightMu = 1.0;
     1432      varSy.WeightSigma = 0.0;
     1433      varSy.WeightManipulatorMu = 0.0;
     1434      varSy.WeightManipulatorSigma = 0.0;
     1435      varSy.MultiplicativeWeightManipulatorSigma = 0.0;
     1436
     1437      foreach (var f in FunctionSet) {
     1438        grammar.GetSymbol(f.Value).Enabled = FunctionSet.ItemChecked(f);
     1439      }
     1440
     1441      grammar.FinishedGrammarManipulation();
     1442      return grammar;
     1443      // // whenever ProblemData is changed we create a new grammar with the necessary symbols
     1444      // var g = new SimpleSymbolicExpressionGrammar();
     1445      // var unaryFunc = new string[] { "sin", "cos", "sqr" };
     1446      // var binaryFunc = new string[] { "+", "-", "*", "%" };
     1447      // foreach (var func in unaryFunc) {
     1448      //   if (FunctionSet.CheckedItems.Any(ci => ci.Value.Value == func)) g.AddSymbol(func, 1, 1);
     1449      // }
     1450      // foreach (var func in binaryFunc) {
     1451      //   if (FunctionSet.CheckedItems.Any(ci => ci.Value.Value == func)) g.AddSymbol(func, 2, 2);
     1452      // }
     1453      //
     1454      // foreach (var variableName in ProblemData.AllowedInputVariables.Union(TargetVariables.CheckedItems.Select(i => i.Value.Value)))
     1455      //   g.AddTerminalSymbol(variableName);
     1456      //
     1457      // // generate symbols for numeric parameters for which the value is optimized using AutoDiff
     1458      // // we generate multiple symbols to balance the probability for selecting a numeric parameter in the generation of random trees
     1459      // var numericConstantsFactor = 2.0;
     1460      // for (int i = 0; i < numericConstantsFactor * (ProblemData.AllowedInputVariables.Count() + TargetVariables.CheckedItems.Count()); i++) {
     1461      //   g.AddTerminalSymbol("θ" + i); // numeric parameter for which the value is optimized using AutoDiff
     1462      // }
     1463      //
     1464      // // generate symbols for latent variables
     1465      // for (int i = 1; i <= NumberOfLatentVariables; i++) {
     1466      //   g.AddTerminalSymbol("λ" + i); // numeric parameter for which the value is optimized using AutoDiff
     1467      // }
     1468      //
     1469      // return g;
    14801470    }
    14811471    #endregion
     
    15521542            }
    15531543            nodes.Add(node);
    1554             SetVariableValue(varName, 0.0);
     1544            SetVariableValue(varName, 0.0);  // this value is updated in the prediction loop
    15551545          }
    15561546        }
     
    15741564          nodes.ForEach(n => node2val[n] = Tuple.Create(val, dVal));
    15751565        } else {
    1576           var fakeNode = new SimpleSymbol(variableName, 0).CreateTreeNode();
     1566          var fakeNode = new VariableTreeNode(new Variable());
    15771567          var newNodeList = new List<ISymbolicExpressionTreeNode>();
    15781568          newNodeList.Add(fakeNode);
  • branches/2925_AutoDiffForDynamicalModels/HeuristicLab.Problems.DynamicalSystemsModelling/3.3/Solution.cs

    r16600 r16602  
    88using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    99using HeuristicLab.Problems.DataAnalysis;
     10using HeuristicLab.Problems.DataAnalysis.Symbolic;
    1011using HeuristicLab.Random;
    1112
     
    8889      var forecastEpisode = new IntRange(episode.Start, episode.End + forecastHorizon);
    8990
    90       double[] optL0;
    9191      var random = new FastRandom(12345);
    92       Problem.OptimizeForEpisodes(trees, problemData, targetVars, latentVariables, random, new[] { forecastEpisode }, 100, numericIntegrationSteps, odeSolver, out optL0, out snmse);
     92      snmse = Problem.OptimizeForEpisodes(trees, problemData, targetVars, latentVariables, random, new[] { forecastEpisode }, 100, numericIntegrationSteps, odeSolver);
    9393      var optimizationData = new Problem.OptimizationData(trees, targetVars, problemData, null, new[] { forecastEpisode }, numericIntegrationSteps, latentVariables, odeSolver);
    94       var predictions = Problem.Integrate(optimizationData, optL0).ToArray();
     94
     95
     96      var theta = Problem.ExtractParametersFromTrees(trees);
     97
     98      var predictions = Problem.Integrate(optimizationData, theta).ToArray();
    9599      return predictions.Select(p => p.Select(pi => pi.Item1).ToArray()).ToArray();
    96100    }
Note: See TracChangeset for help on using the changeset viewer.