Changeset 16602 for branches/2925_AutoDiffForDynamicalModels/HeuristicLab.Problems.DynamicalSystemsModelling
- Timestamp:
- 02/13/19 13:43:03 (6 years ago)
- Location:
- branches/2925_AutoDiffForDynamicalModels/HeuristicLab.Problems.DynamicalSystemsModelling/3.3
- Files:
-
- 3 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/2925_AutoDiffForDynamicalModels/HeuristicLab.Problems.DynamicalSystemsModelling/3.3/OdeParameterIdentification.cs
r16597 r16602 177 177 public void CreateSolution(Problem problem, string[] modelStructure, int maxIterations, IRandom rand) { 178 178 var parser = new InfixExpressionParser(); 179 var trees = modelStructure.Select(expr => Convert(parser.Parse(expr))).ToArray();179 var trees = modelStructure.Select(expr => parser.Parse(expr)).ToArray(); 180 180 var names = problem.Encoding.Encodings.Select(enc => enc.Name).ToArray(); 181 181 if (trees.Length != names.Length) throw new ArgumentException("The number of expressions must match the number of target variables exactly"); … … 190 190 } 191 191 192 private ISymbolicExpressionTree Convert(ISymbolicExpressionTree tree) {193 return new SymbolicExpressionTree(Convert(tree.Root));194 }192 // private ISymbolicExpressionTree Convert(ISymbolicExpressionTree tree) { 193 // return new SymbolicExpressionTree(Convert(tree.Root)); 194 // } 195 195 196 196 197 197 // for translation from symbolic expressions to simple symbols 198 private static Dictionary<Type, string> sym2str = new Dictionary<Type, string>() {199 {typeof(Addition), "+" },200 {typeof(Subtraction), "-" },201 {typeof(Multiplication), "*" },202 {typeof(Sine), "sin" },203 {typeof(Cosine), "cos" },204 {typeof(Square), "sqr" },205 };206 207 private ISymbolicExpressionTreeNode Convert(ISymbolicExpressionTreeNode node) {208 if (sym2str.ContainsKey(node.Symbol.GetType())) {209 var children = node.Subtrees.Select(st => Convert(st)).ToArray();210 return Make(sym2str[node.Symbol.GetType()], children);211 } else if (node.Symbol is ProgramRootSymbol) {212 var child = Convert(node.GetSubtree(0));213 node.RemoveSubtree(0);214 node.AddSubtree(child);215 return node;216 } else if (node.Symbol is StartSymbol) {217 var child = Convert(node.GetSubtree(0));218 node.RemoveSubtree(0);219 node.AddSubtree(child);220 return node;221 } else if (node.Symbol is Division) {222 var children = node.Subtrees.Select(st => Convert(st)).ToArray();223 if (children.Length == 1) {224 return Make("%", new[] { new SimpleSymbol("θ", 0).CreateTreeNode(), children[0] });225 } else if (children.Length != 2) throw new ArgumentException("Division is not supported for multiple arguments");226 else return Make("%", children);227 } else if (node.Symbol is Constant) {228 return new SimpleSymbol("θ", 0).CreateTreeNode();229 } else if (node.Symbol is DataAnalysis.Symbolic.Variable) {230 var varNode = node as VariableTreeNode;231 if (!varNode.Weight.IsAlmost(1.0)) throw new ArgumentException("Variable weights are not supported");232 return new SimpleSymbol(varNode.VariableName, 0).CreateTreeNode();233 } else throw new ArgumentException("Unsupported symbol: " + node.Symbol.Name);234 }235 236 private ISymbolicExpressionTreeNode Make(string op, ISymbolicExpressionTreeNode[] children) {237 if (children.Length == 1) {238 var s = new SimpleSymbol(op, 1).CreateTreeNode();239 s.AddSubtree(children.First());240 return s;241 } else {242 var s = new SimpleSymbol(op, 2).CreateTreeNode();243 var c0 = children[0];244 var c1 = children[1];245 s.AddSubtree(c0);246 s.AddSubtree(c1);247 for (int i = 2; i < children.Length; i++) {248 var sn = new SimpleSymbol(op, 2).CreateTreeNode();249 sn.AddSubtree(s);250 sn.AddSubtree(children[i]);251 s = sn;252 }253 return s;254 }255 }198 // private static Dictionary<Type, string> sym2str = new Dictionary<Type, string>() { 199 // {typeof(Addition), "+" }, 200 // {typeof(Subtraction), "-" }, 201 // {typeof(Multiplication), "*" }, 202 // {typeof(Sine), "sin" }, 203 // {typeof(Cosine), "cos" }, 204 // {typeof(Square), "sqr" }, 205 // }; 206 207 // private ISymbolicExpressionTreeNode Convert(ISymbolicExpressionTreeNode node) { 208 // if (sym2str.ContainsKey(node.Symbol.GetType())) { 209 // var children = node.Subtrees.Select(st => Convert(st)).ToArray(); 210 // return Make(sym2str[node.Symbol.GetType()], children); 211 // } else if (node.Symbol is ProgramRootSymbol) { 212 // var child = Convert(node.GetSubtree(0)); 213 // node.RemoveSubtree(0); 214 // node.AddSubtree(child); 215 // return node; 216 // } else if (node.Symbol is StartSymbol) { 217 // var child = Convert(node.GetSubtree(0)); 218 // node.RemoveSubtree(0); 219 // node.AddSubtree(child); 220 // return node; 221 // } else if (node.Symbol is Division) { 222 // var children = node.Subtrees.Select(st => Convert(st)).ToArray(); 223 // if (children.Length == 1) { 224 // return Make("%", new[] { new SimpleSymbol("θ", 0).CreateTreeNode(), children[0] }); 225 // } else if (children.Length != 2) throw new ArgumentException("Division is not supported for multiple arguments"); 226 // else return Make("%", children); 227 // } else if (node.Symbol is Constant) { 228 // return new SimpleSymbol("θ", 0).CreateTreeNode(); 229 // } else if (node.Symbol is DataAnalysis.Symbolic.Variable) { 230 // var varNode = node as VariableTreeNode; 231 // if (!varNode.Weight.IsAlmost(1.0)) throw new ArgumentException("Variable weights are not supported"); 232 // return new SimpleSymbol(varNode.VariableName, 0).CreateTreeNode(); 233 // } else throw new ArgumentException("Unsupported symbol: " + node.Symbol.Name); 234 // } 235 236 // private ISymbolicExpressionTreeNode Make(string op, ISymbolicExpressionTreeNode[] children) { 237 // if (children.Length == 1) { 238 // var s = new SimpleSymbol(op, 1).CreateTreeNode(); 239 // s.AddSubtree(children.First()); 240 // return s; 241 // } else { 242 // var s = new SimpleSymbol(op, 2).CreateTreeNode(); 243 // var c0 = children[0]; 244 // var c1 = children[1]; 245 // s.AddSubtree(c0); 246 // s.AddSubtree(c1); 247 // for (int i = 2; i < children.Length; i++) { 248 // var sn = new SimpleSymbol(op, 2).CreateTreeNode(); 249 // sn.AddSubtree(s); 250 // sn.AddSubtree(children[i]); 251 // s = sn; 252 // } 253 // return s; 254 // } 255 // } 256 256 #endregion 257 257 } -
branches/2925_AutoDiffForDynamicalModels/HeuristicLab.Problems.DynamicalSystemsModelling/3.3/Problem.cs
r16601 r16602 71 71 get { return (IFixedValueParameter<IntValue>)Parameters[MaximumLengthParameterName]; } 72 72 } 73 73 74 public IFixedValueParameter<IntValue> MaximumParameterOptimizationIterationsParameter { 74 75 get { return (IFixedValueParameter<IntValue>)Parameters[MaximumParameterOptimizationIterationsParameterName]; } … … 196 197 public override double Evaluate(Individual individual, IRandom random) { 197 198 var trees = individual.Values.Select(v => v.Value).OfType<ISymbolicExpressionTree>().ToArray(); // extract all trees from individual 198 // write back optimized parameters to tree nodes instead of the separate OptTheta variable199 // retreive optimized parameters from nodes?200 199 201 200 var problemData = ProblemData; … … 208 207 int totalSize = 0; 209 208 foreach (var episode in TrainingEpisodes) { 210 double[] optTheta; 211 double nmse; 212 OptimizeForEpisodes(trees, problemData, targetVars, latentVariables, random, new[] { episode }, MaximumParameterOptimizationIterations, NumericIntegrationSteps, OdeSolver, out optTheta, out nmse); 213 individual["OptTheta_" + eIdx] = new DoubleArray(optTheta); // write back optimized parameters so that we can use them in the Analysis method 209 // double[] optTheta; 210 double nmse = OptimizeForEpisodes(trees, problemData, targetVars, latentVariables, random, new[] { episode }, MaximumParameterOptimizationIterations, NumericIntegrationSteps, OdeSolver); 211 // individual["OptTheta_" + eIdx] = new DoubleArray(optTheta); // write back optimized parameters so that we can use them in the Analysis method 214 212 eIdx++; 215 213 totalNMSE += nmse * episode.Size; … … 218 216 return totalNMSE / totalSize; 219 217 } else { 220 double[] optTheta; 221 double nmse; 222 OptimizeForEpisodes(trees, problemData, targetVars, latentVariables, random, TrainingEpisodes, MaximumParameterOptimizationIterations, NumericIntegrationSteps, OdeSolver, out optTheta, out nmse); 223 individual["OptTheta"] = new DoubleArray(optTheta); // write back optimized parameters so that we can use them in the Analysis method 218 // double[] optTheta; 219 double nmse = OptimizeForEpisodes(trees, problemData, targetVars, latentVariables, random, TrainingEpisodes, MaximumParameterOptimizationIterations, NumericIntegrationSteps, OdeSolver); 220 // individual["OptTheta"] = new DoubleArray(optTheta); // write back optimized parameters so that we can use them in the Analysis method 224 221 return nmse; 225 222 } 226 223 } 227 224 228 public static voidOptimizeForEpisodes(225 public static double OptimizeForEpisodes( 229 226 ISymbolicExpressionTree[] trees, 230 227 IRegressionProblemData problemData, … … 235 232 int maxParameterOptIterations, 236 233 int numericIntegrationSteps, 237 string odeSolver ,238 out double[] optTheta, 239 out double nmse) {240 234 string odeSolver) { 235 236 var constantNodes = trees.Select(t => t.IterateNodesPrefix().OfType<ConstantTreeNode>().ToArray()).ToArray(); 237 var initialTheta = constantNodes.Select(nodes => nodes.Select(n => n.Value).ToArray()).ToArray(); 241 238 242 239 // optimize parameters by fitting f(x,y) to calculated differences dy/dt(t) 243 nmse = PreTuneParameters(trees, problemData, targetVars, latentVariables, random, episodes, maxParameterOptIterations, out optTheta); 240 double nmse = PreTuneParameters(trees, problemData, targetVars, latentVariables, random, episodes, maxParameterOptIterations, 241 initialTheta, out double[] pretunedParameters); 244 242 245 243 // optimize parameters using integration of f(x,y) to calculate y(t) 246 nmse = OptimizeParameters(trees, problemData, targetVars, latentVariables, episodes, maxParameterOptIterations, optTheta, numericIntegrationSteps, odeSolver, out optTheta); 247 248 if (double.IsNaN(nmse) || double.IsInfinity(nmse)) nmse = 100 * trees.Length * episodes.Sum(ep => ep.Size); 244 nmse = OptimizeParameters(trees, problemData, targetVars, latentVariables, episodes, maxParameterOptIterations, pretunedParameters, numericIntegrationSteps, odeSolver, 245 out double[] optTheta); 246 247 if (double.IsNaN(nmse) || 248 double.IsInfinity(nmse) || 249 nmse > 100 * trees.Length * episodes.Sum(ep => ep.Size)) 250 return 100 * trees.Length * episodes.Sum(ep => ep.Size); 251 252 // update tree nodes with optimized values 253 var paramIdx = 0; 254 for (var treeIdx = 0; treeIdx < constantNodes.Length; treeIdx++) { 255 for (int i = 0; i < constantNodes[treeIdx].Length; i++) 256 constantNodes[treeIdx][i].Value = optTheta[paramIdx++]; 257 } 258 return nmse; 249 259 } 250 260 … … 257 267 IEnumerable<IntRange> episodes, 258 268 int maxParameterOptIterations, 269 double[][] initialTheta, 259 270 out double[] optTheta) { 260 271 var thetas = new List<double>(); 261 272 double nmse = 0.0; 273 var maxTreeNmse = 100 * episodes.Sum(ep => ep.Size); 274 262 275 // NOTE: the order of values in parameter matches prefix order of constant nodes in trees 263 276 for (int treeIdx = 0; treeIdx < trees.Length; treeIdx++) { … … 272 285 var paramCount = myState.nodeValueLookup.ParameterCount; 273 286 274 // init params randomly275 // theta contains parameter values for trees and then the initial values for latent variables (a separate vector for each episode)276 // inital values for latent variables are also optimized277 var theta = new double[paramCount + latentVariables.Length * episodes.Count()];278 for (int i = 0; i < theta.Length; i++)279 theta[i] = random.NextDouble() * 2.0e-1 - 1.0e-1;280 281 287 optTheta = new double[0]; 282 if (theta.Length > 0) { 283 alglib.minlmstate state; 284 alglib.minlmreport report; 285 alglib.minlmcreatevj(targetValuesDiff.Length, theta, out state); 286 alglib.minlmsetcond(state, 0.0, 0.0, 0.0, maxParameterOptIterations); 287 // alglib.minlmsetgradientcheck(state, 1.0e-3); 288 alglib.minlmoptimize(state, EvaluateObjectiveVector, EvaluateObjectiveVectorAndJacobian, null, myState); 289 290 alglib.minlmresults(state, out optTheta, out report); 291 292 293 if (report.terminationtype < 0) { throw new InvalidOperationException("there was a problem in the optimizer"); } 294 288 if (initialTheta[treeIdx].Length > 0) { 289 try { 290 alglib.minlmstate state; 291 alglib.minlmreport report; 292 var p = new double[initialTheta[treeIdx].Length]; 293 var lowerBounds = Enumerable.Repeat(-100.0, p.Length).ToArray(); 294 var upperBounds = Enumerable.Repeat(100.0, p.Length).ToArray(); 295 Array.Copy(initialTheta[treeIdx], p, p.Length); 296 alglib.minlmcreatevj(targetValuesDiff.Length, p, out state); 297 alglib.minlmsetcond(state, 0.0, 0.0, 0.0, maxParameterOptIterations); 298 alglib.minlmsetbc(state, lowerBounds, upperBounds); 299 // alglib.minlmsetgradientcheck(state, 1.0e-3); 300 alglib.minlmoptimize(state, EvaluateObjectiveVector, EvaluateObjectiveVectorAndJacobian, null, myState); 301 302 alglib.minlmresults(state, out optTheta, out report); 303 if (report.terminationtype < 0) { optTheta = initialTheta[treeIdx]; } 304 } catch (alglib.alglibexception) { 305 optTheta = initialTheta[treeIdx]; 306 } 307 } 308 var tree_nmse = EvaluateMSE(optTheta, myState); 309 if (double.IsNaN(tree_nmse) || double.IsInfinity(tree_nmse) || tree_nmse > maxTreeNmse) { 310 nmse += maxTreeNmse; 311 thetas.AddRange(initialTheta[treeIdx]); 312 } else { 313 nmse += tree_nmse; 295 314 thetas.AddRange(optTheta); 296 315 } 297 nmse += EvaluateMSE(optTheta, myState);298 316 } // foreach tree 299 317 optTheta = thetas.ToArray(); 300 318 301 var maxNmse = 100 * trees.Length * episodes.Sum(ep => ep.Size);302 if (double.IsNaN(nmse) || double.IsInfinity(nmse) || nmse > maxNmse) nmse = maxNmse;303 319 return nmse; 304 320 } … … 321 337 322 338 if (initialTheta.Length > 0) { 323 339 var lowerBounds = Enumerable.Repeat(-100.0, initialTheta.Length).ToArray(); 340 var upperBounds = Enumerable.Repeat(100.0, initialTheta.Length).ToArray(); 324 341 try { 325 342 alglib.minlmstate state; 326 343 alglib.minlmreport report; 327 344 alglib.minlmcreatevj(rowsForDataExtraction.Length * trees.Length, initialTheta, out state); 345 alglib.minlmsetbc(state, lowerBounds, upperBounds); 328 346 alglib.minlmsetcond(state, 0.0, 0.0, 0.0, maxParameterOptIterations); 329 347 // alglib.minlmsetgradientcheck(state, 1.0e-3); … … 490 508 491 509 if (OptimizeParametersForEpisodes) { 510 throw new NotSupportedException(); 492 511 var eIdx = 0; 493 512 var trainingPredictions = new List<Tuple<double, Vector>[][]>(); … … 527 546 results["Models"].Value = models; 528 547 } else { 529 var optTheta = ((DoubleArray)bestIndividualAndQuality.Item1["OptTheta"]).ToArray(); // see evaluate548 var optTheta = Problem.ExtractParametersFromTrees(trees); 530 549 var optimizationData = new OptimizationData(trees, targetVars, problemData, null, TrainingEpisodes.ToArray(), NumericIntegrationSteps, latentVariables, OdeSolver); 531 550 var trainingPrediction = Integrate(optimizationData, optTheta).ToArray(); … … 629 648 for (int idx = 0; idx < trees.Length; idx++) { 630 649 var tree = trees[idx]; 631 optimizedTrees.Add(new SymbolicExpressionTree(FixParameters(tree.Root, optTheta.ToArray(), ref nextParIdx))); 650 // optimizedTrees.Add(new SymbolicExpressionTree(FixParameters(tree.Root, optTheta.ToArray(), ref nextParIdx))); 651 optimizedTrees.Add(tree); 632 652 } 633 653 var ds = problemData.Dataset; … … 660 680 var tree = trees[idx]; 661 681 662 // when we reference HeuristicLab.Problems.DataAnalysis.Symbolic we can translate symbols663 var shownTree = new SymbolicExpressionTree(TranslateTreeNode(tree.Root, optTheta.ToArray(),664 ref nextParIdx));665 666 667 682 var origTreeVar = new HeuristicLab.Core.Variable(varName + "(original)"); 668 683 origTreeVar.Value = (ISymbolicExpressionTree)tree.Clone(); 669 684 models.Add(origTreeVar); 670 685 var simplifiedTreeVar = new HeuristicLab.Core.Variable(varName + "(simplified)"); 671 simplifiedTreeVar.Value = TreeSimplifier.Simplify( shownTree);686 simplifiedTreeVar.Value = TreeSimplifier.Simplify(tree); 672 687 models.Add(simplifiedTreeVar); 673 688 … … 678 693 } 679 694 } 695 696 697 public static double[] ExtractParametersFromTrees(ISymbolicExpressionTree[] trees) { 698 return trees 699 .SelectMany(t => t.IterateNodesPrefix().OfType<ConstantTreeNode>().Select(n => n.Value)) 700 .ToArray(); 701 } 702 680 703 681 704 … … 1104 1127 } 1105 1128 1129 // TODO: use an existing interpreter implementation instead 1106 1130 private static double InterpretRec(ISymbolicExpressionTreeNode node, NodeValueLookup nodeValues) { 1107 switch (node.Symbol.Name) { 1108 case "+": { 1109 var f = InterpretRec(node.GetSubtree(0), nodeValues); 1110 var g = InterpretRec(node.GetSubtree(1), nodeValues); 1111 return f + g; 1112 } 1113 case "*": { 1114 var f = InterpretRec(node.GetSubtree(0), nodeValues); 1115 var g = InterpretRec(node.GetSubtree(1), nodeValues); 1116 return f * g; 1117 } 1118 1119 case "-": { 1120 if (node.SubtreeCount == 1) { 1121 var f = InterpretRec(node.GetSubtree(0), nodeValues); 1122 return -f; 1123 } else { 1124 var f = InterpretRec(node.GetSubtree(0), nodeValues); 1125 var g = InterpretRec(node.GetSubtree(1), nodeValues); 1126 1127 return f - g; 1128 } 1129 } 1130 case "%": { 1131 var f = InterpretRec(node.GetSubtree(0), nodeValues); 1132 var g = InterpretRec(node.GetSubtree(1), nodeValues); 1133 1134 // protected division 1135 if (g.IsAlmost(0.0)) { 1136 return 0; 1137 } else { 1138 return f / g; 1139 } 1140 } 1141 case "sin": { 1142 var f = InterpretRec(node.GetSubtree(0), nodeValues); 1143 return Math.Sin(f); 1144 } 1145 case "cos": { 1146 var f = InterpretRec(node.GetSubtree(0), nodeValues); 1147 return Math.Cos(f); 1148 } 1149 case "sqr": { 1150 var f = InterpretRec(node.GetSubtree(0), nodeValues); 1151 return f * f; 1152 } 1153 default: { 1154 return nodeValues.NodeValue(node); 1155 } 1156 } 1131 if (node.Symbol is Constant || node.Symbol is Variable) { 1132 return nodeValues.NodeValue(node); 1133 } else if (node.Symbol is Addition) { 1134 var f = InterpretRec(node.GetSubtree(0), nodeValues); 1135 var g = InterpretRec(node.GetSubtree(1), nodeValues); 1136 return f + g; 1137 } else if (node.Symbol is Multiplication) { 1138 var f = InterpretRec(node.GetSubtree(0), nodeValues); 1139 var g = InterpretRec(node.GetSubtree(1), nodeValues); 1140 return f * g; 1141 } else if (node.Symbol is Subtraction) { 1142 if (node.SubtreeCount == 1) { 1143 var f = InterpretRec(node.GetSubtree(0), nodeValues); 1144 return -f; 1145 } else { 1146 var f = InterpretRec(node.GetSubtree(0), nodeValues); 1147 var g = InterpretRec(node.GetSubtree(1), nodeValues); 1148 1149 return f - g; 1150 } 1151 } else if (node.Symbol is Division) { 1152 var f = InterpretRec(node.GetSubtree(0), nodeValues); 1153 var g = InterpretRec(node.GetSubtree(1), nodeValues); 1154 1155 // protected division 1156 if (g.IsAlmost(0.0)) { 1157 return 0; 1158 } else { 1159 return f / g; 1160 } 1161 } else if (node.Symbol is Sine) { 1162 var f = InterpretRec(node.GetSubtree(0), nodeValues); 1163 return Math.Sin(f); 1164 } else if (node.Symbol is Cosine) { 1165 var f = InterpretRec(node.GetSubtree(0), nodeValues); 1166 return Math.Cos(f); 1167 } else if (node.Symbol is Square) { 1168 var f = InterpretRec(node.GetSubtree(0), nodeValues); 1169 return f * f; 1170 } else throw new NotSupportedException("unsupported symbol"); 1157 1171 } 1158 1172 … … 1165 1179 double f, g; 1166 1180 Vector df, dg; 1167 switch (node.Symbol.Name) { 1168 case "+": { 1169 InterpretRec(node.GetSubtree(0), nodeValues, out f, out df); 1170 InterpretRec(node.GetSubtree(1), nodeValues, out g, out dg); 1171 z = f + g; 1172 dz = df + dg; // Vector.AddTo(gl, gr); 1173 break; 1174 } 1175 case "*": { 1176 InterpretRec(node.GetSubtree(0), nodeValues, out f, out df); 1177 InterpretRec(node.GetSubtree(1), nodeValues, out g, out dg); 1178 z = f * g; 1179 dz = df * g + f * dg; // Vector.AddTo(gl.Scale(fr), gr.Scale(fl)); // f'*g + f*g' 1180 break; 1181 } 1182 1183 case "-": { 1184 if (node.SubtreeCount == 1) { 1185 InterpretRec(node.GetSubtree(0), nodeValues, out f, out df); 1186 z = -f; 1187 dz = -df; 1188 } else { 1189 InterpretRec(node.GetSubtree(0), nodeValues, out f, out df); 1190 InterpretRec(node.GetSubtree(1), nodeValues, out g, out dg); 1191 1192 z = f - g; 1193 dz = df - dg; 1194 } 1195 break; 1196 } 1197 case "%": { 1198 InterpretRec(node.GetSubtree(0), nodeValues, out f, out df); 1199 InterpretRec(node.GetSubtree(1), nodeValues, out g, out dg); 1200 1201 // protected division 1202 if (g.IsAlmost(0.0)) { 1203 z = 0; 1204 dz = Vector.Zero; 1205 } else { 1206 z = f / g; 1207 dz = -f / (g * g) * dg + df / g; // Vector.AddTo(dg.Scale(f * -1.0 / (g * g)), df.Scale(1.0 / g)); // (f/g)' = f * (1/g)' + 1/g * f' = f * -1/g² * g' + 1/g * f' 1208 } 1209 break; 1210 } 1211 case "sin": { 1212 InterpretRec(node.GetSubtree(0), nodeValues, out f, out df); 1213 z = Math.Sin(f); 1214 dz = Math.Cos(f) * df; 1215 break; 1216 } 1217 case "cos": { 1218 InterpretRec(node.GetSubtree(0), nodeValues, out f, out df); 1219 z = Math.Cos(f); 1220 dz = -Math.Sin(f) * df; 1221 break; 1222 } 1223 case "sqr": { 1224 InterpretRec(node.GetSubtree(0), nodeValues, out f, out df); 1225 z = f * f; 1226 dz = 2.0 * f * df; 1227 break; 1228 } 1229 default: { 1230 z = nodeValues.NodeValue(node); 1231 dz = Vector.CreateNew(nodeValues.NodeGradient(node)); 1232 break; 1233 } 1234 } 1235 } 1181 if (node.Symbol is Constant || node.Symbol is Variable) { 1182 z = nodeValues.NodeValue(node); 1183 dz = Vector.CreateNew(nodeValues.NodeGradient(node)); 1184 } else if (node.Symbol is Addition) { 1185 InterpretRec(node.GetSubtree(0), nodeValues, out f, out df); 1186 InterpretRec(node.GetSubtree(1), nodeValues, out g, out dg); 1187 z = f + g; 1188 dz = df + dg; // Vector.AddTo(gl, gr); 1189 1190 } else if (node.Symbol is Multiplication) { 1191 InterpretRec(node.GetSubtree(0), nodeValues, out f, out df); 1192 InterpretRec(node.GetSubtree(1), nodeValues, out g, out dg); 1193 z = f * g; 1194 dz = df * g + f * dg; // Vector.AddTo(gl.Scale(fr), gr.Scale(fl)); // f'*g + f*g' 1195 1196 } else if (node.Symbol is Subtraction) { 1197 if (node.SubtreeCount == 1) { 1198 InterpretRec(node.GetSubtree(0), nodeValues, out f, out df); 1199 z = -f; 1200 dz = -df; 1201 } else { 1202 InterpretRec(node.GetSubtree(0), nodeValues, out f, out df); 1203 InterpretRec(node.GetSubtree(1), nodeValues, out g, out dg); 1204 1205 z = f - g; 1206 dz = df - dg; 1207 } 1208 1209 } else if (node.Symbol is Division) { 1210 InterpretRec(node.GetSubtree(0), nodeValues, out f, out df); 1211 InterpretRec(node.GetSubtree(1), nodeValues, out g, out dg); 1212 1213 // protected division 1214 if (g.IsAlmost(0.0)) { 1215 z = 0; 1216 dz = Vector.Zero; 1217 } else { 1218 z = f / g; 1219 dz = -f / (g * g) * dg + df / g; // Vector.AddTo(dg.Scale(f * -1.0 / (g * g)), df.Scale(1.0 / g)); // (f/g)' = f * (1/g)' + 1/g * f' = f * -1/g² * g' + 1/g * f' 1220 } 1221 1222 } else if (node.Symbol is Sine) { 1223 InterpretRec(node.GetSubtree(0), nodeValues, out f, out df); 1224 z = Math.Sin(f); 1225 dz = Math.Cos(f) * df; 1226 1227 } else if (node.Symbol is Cosine) { 1228 InterpretRec(node.GetSubtree(0), nodeValues, out f, out df); 1229 z = Math.Cos(f); 1230 dz = -Math.Sin(f) * df; 1231 } else if (node.Symbol is Square) { 1232 InterpretRec(node.GetSubtree(0), nodeValues, out f, out df); 1233 z = f * f; 1234 dz = 2.0 * f * df; 1235 } else { 1236 throw new NotSupportedException("unsupported symbol"); 1237 } 1238 } 1239 1236 1240 #endregion 1237 1241 … … 1312 1316 1313 1317 private void InitAllParameters() { 1314 UpdateTargetVariables(); // implicitly updates the grammar and the encoding 1318 UpdateTargetVariables(); // implicitly updates the grammar and the encoding 1315 1319 } 1316 1320 1317 1321 private ReadOnlyCheckedItemList<StringValue> CreateFunctionSet() { 1318 1322 var l = new CheckedItemList<StringValue>(); 1319 l.Add(new StringValue(" +").AsReadOnly());1320 l.Add(new StringValue(" *").AsReadOnly());1321 l.Add(new StringValue(" %").AsReadOnly());1322 l.Add(new StringValue(" -").AsReadOnly());1323 l.Add(new StringValue(" sin").AsReadOnly());1324 l.Add(new StringValue(" cos").AsReadOnly());1325 l.Add(new StringValue(" sqr").AsReadOnly());1323 l.Add(new StringValue("Addition").AsReadOnly()); 1324 l.Add(new StringValue("Multiplication").AsReadOnly()); 1325 l.Add(new StringValue("Division").AsReadOnly()); 1326 l.Add(new StringValue("Subtraction").AsReadOnly()); 1327 l.Add(new StringValue("Sine").AsReadOnly()); 1328 l.Add(new StringValue("Cosine").AsReadOnly()); 1329 l.Add(new StringValue("Square").AsReadOnly()); 1326 1330 return l.AsReadOnly(); 1327 1331 } 1328 1332 1329 1333 private static bool IsConstantNode(ISymbolicExpressionTreeNode n) { 1330 return n.Symbol.Name[0] == 'θ'; 1334 // return n.Symbol.Name[0] == 'θ'; 1335 return n is ConstantTreeNode; 1331 1336 } 1332 1337 private static double GetConstantValue(ISymbolicExpressionTreeNode n) { 1333 return 0.0; // TODO: needs to be updated when we write back values to the tree1338 return ((ConstantTreeNode)n).Value; 1334 1339 } 1335 1340 private static bool IsLatentVariableNode(ISymbolicExpressionTreeNode n) { … … 1340 1345 } 1341 1346 private static string GetVariableName(ISymbolicExpressionTreeNode n) { 1342 return n.Symbol.Name; 1343 } 1344 1347 return ((VariableTreeNode)n).VariableName; 1348 } 1345 1349 1346 1350 private void UpdateTargetVariables() { … … 1375 1379 1376 1380 private ISymbolicExpressionGrammar CreateGrammar() { 1377 // whenever ProblemData is changed we create a new grammar with the necessary symbols 1378 var g = new SimpleSymbolicExpressionGrammar(); 1379 var unaryFunc = new string[] { "sin", "cos", "sqr" }; 1380 var binaryFunc = new string[] { "+", "-", "*", "%" }; 1381 foreach (var func in unaryFunc) { 1382 if (FunctionSet.CheckedItems.Any(ci => ci.Value.Value == func)) g.AddSymbol(func, 1, 1); 1383 } 1384 foreach (var func in binaryFunc) { 1385 if (FunctionSet.CheckedItems.Any(ci => ci.Value.Value == func)) g.AddSymbol(func, 2, 2); 1386 } 1387 1388 foreach (var variableName in ProblemData.AllowedInputVariables.Union(TargetVariables.CheckedItems.Select(i => i.Value.Value))) 1389 g.AddTerminalSymbol(variableName); 1390 1391 // generate symbols for numeric parameters for which the value is optimized using AutoDiff 1392 // we generate multiple symbols to balance the probability for selecting a numeric parameter in the generation of random trees 1393 var numericConstantsFactor = 2.0; 1394 for (int i = 0; i < numericConstantsFactor * (ProblemData.AllowedInputVariables.Count() + TargetVariables.CheckedItems.Count()); i++) { 1395 g.AddTerminalSymbol("θ" + i); // numeric parameter for which the value is optimized using AutoDiff 1396 } 1397 1398 // generate symbols for latent variables 1399 for (int i = 1; i <= NumberOfLatentVariables; i++) { 1400 g.AddTerminalSymbol("λ" + i); // numeric parameter for which the value is optimized using AutoDiff 1401 } 1402 1403 return g; 1404 } 1405 1406 1407 1408 1409 1410 private ISymbolicExpressionTreeNode FixParameters(ISymbolicExpressionTreeNode n, double[] parameterValues, ref int nextParIdx) { 1411 ISymbolicExpressionTreeNode translatedNode = null; 1412 if (n.Symbol is StartSymbol) { 1413 translatedNode = new StartSymbol().CreateTreeNode(); 1414 } else if (n.Symbol is ProgramRootSymbol) { 1415 translatedNode = new ProgramRootSymbol().CreateTreeNode(); 1416 } else if (n.Symbol.Name == "+") { 1417 translatedNode = new SimpleSymbol("+", 2).CreateTreeNode(); 1418 } else if (n.Symbol.Name == "-") { 1419 translatedNode = new SimpleSymbol("-", 2).CreateTreeNode(); 1420 } else if (n.Symbol.Name == "*") { 1421 translatedNode = new SimpleSymbol("*", 2).CreateTreeNode(); 1422 } else if (n.Symbol.Name == "%") { 1423 translatedNode = new SimpleSymbol("%", 2).CreateTreeNode(); 1424 } else if (n.Symbol.Name == "sin") { 1425 translatedNode = new SimpleSymbol("sin", 1).CreateTreeNode(); 1426 } else if (n.Symbol.Name == "cos") { 1427 translatedNode = new SimpleSymbol("cos", 1).CreateTreeNode(); 1428 } else if (n.Symbol.Name == "sqr") { 1429 translatedNode = new SimpleSymbol("sqr", 1).CreateTreeNode(); 1430 } else if (IsConstantNode(n)) { 1431 translatedNode = new SimpleSymbol("c_" + nextParIdx, 0).CreateTreeNode(); 1432 nextParIdx++; 1433 } else { 1434 translatedNode = new SimpleSymbol(n.Symbol.Name, n.SubtreeCount).CreateTreeNode(); 1435 } 1436 foreach (var child in n.Subtrees) { 1437 translatedNode.AddSubtree(FixParameters(child, parameterValues, ref nextParIdx)); 1438 } 1439 return translatedNode; 1440 } 1441 1442 1443 private ISymbolicExpressionTreeNode TranslateTreeNode(ISymbolicExpressionTreeNode n, double[] parameterValues, ref int nextParIdx) { 1444 ISymbolicExpressionTreeNode translatedNode = null; 1445 if (n.Symbol is StartSymbol) { 1446 translatedNode = new StartSymbol().CreateTreeNode(); 1447 } else if (n.Symbol is ProgramRootSymbol) { 1448 translatedNode = new ProgramRootSymbol().CreateTreeNode(); 1449 } else if (n.Symbol.Name == "+") { 1450 translatedNode = new Addition().CreateTreeNode(); 1451 } else if (n.Symbol.Name == "-") { 1452 translatedNode = new Subtraction().CreateTreeNode(); 1453 } else if (n.Symbol.Name == "*") { 1454 translatedNode = new Multiplication().CreateTreeNode(); 1455 } else if (n.Symbol.Name == "%") { 1456 translatedNode = new Division().CreateTreeNode(); 1457 } else if (n.Symbol.Name == "sin") { 1458 translatedNode = new Sine().CreateTreeNode(); 1459 } else if (n.Symbol.Name == "cos") { 1460 translatedNode = new Cosine().CreateTreeNode(); 1461 } else if (n.Symbol.Name == "sqr") { 1462 translatedNode = new Square().CreateTreeNode(); 1463 } else if (IsConstantNode(n)) { 1464 var constNode = (ConstantTreeNode)new Constant().CreateTreeNode(); 1465 constNode.Value = parameterValues[nextParIdx]; 1466 nextParIdx++; 1467 translatedNode = constNode; 1468 } else { 1469 // assume a variable name 1470 var varName = n.Symbol.Name; 1471 var varNode = (VariableTreeNode)new Variable().CreateTreeNode(); 1472 varNode.Weight = 1.0; 1473 varNode.VariableName = varName; 1474 translatedNode = varNode; 1475 } 1476 foreach (var child in n.Subtrees) { 1477 translatedNode.AddSubtree(TranslateTreeNode(child, parameterValues, ref nextParIdx)); 1478 } 1479 return translatedNode; 1381 var grammar = new TypeCoherentExpressionGrammar(); 1382 grammar.StartGrammarManipulation(); 1383 1384 var problemData = ProblemData; 1385 var ds = problemData.Dataset; 1386 grammar.MaximumFunctionArguments = 0; 1387 grammar.MaximumFunctionDefinitions = 0; 1388 var allowedVariables = problemData.AllowedInputVariables.Concat(TargetVariables.CheckedItems.Select(chk => chk.Value.Value)); 1389 foreach (var varSymbol in grammar.Symbols.OfType<HeuristicLab.Problems.DataAnalysis.Symbolic.VariableBase>()) { 1390 if (!varSymbol.Fixed) { 1391 varSymbol.AllVariableNames = problemData.InputVariables.Select(x => x.Value).Where(x => ds.VariableHasType<double>(x)); 1392 varSymbol.VariableNames = allowedVariables.Where(x => ds.VariableHasType<double>(x)); 1393 } 1394 } 1395 foreach (var factorSymbol in grammar.Symbols.OfType<BinaryFactorVariable>()) { 1396 if (!factorSymbol.Fixed) { 1397 factorSymbol.AllVariableNames = problemData.InputVariables.Select(x => x.Value).Where(x => ds.VariableHasType<string>(x)); 1398 factorSymbol.VariableNames = problemData.AllowedInputVariables.Where(x => ds.VariableHasType<string>(x)); 1399 factorSymbol.VariableValues = factorSymbol.VariableNames 1400 .ToDictionary(varName => varName, varName => ds.GetStringValues(varName).Distinct().ToList()); 1401 } 1402 } 1403 foreach (var factorSymbol in grammar.Symbols.OfType<FactorVariable>()) { 1404 if (!factorSymbol.Fixed) { 1405 factorSymbol.AllVariableNames = problemData.InputVariables.Select(x => x.Value).Where(x => ds.VariableHasType<string>(x)); 1406 factorSymbol.VariableNames = problemData.AllowedInputVariables.Where(x => ds.VariableHasType<string>(x)); 1407 factorSymbol.VariableValues = factorSymbol.VariableNames 1408 .ToDictionary(varName => varName, 1409 varName => ds.GetStringValues(varName).Distinct() 1410 .Select((n, i) => Tuple.Create(n, i)) 1411 .ToDictionary(tup => tup.Item1, tup => tup.Item2)); 1412 } 1413 } 1414 1415 grammar.ConfigureAsDefaultRegressionGrammar(); 1416 grammar.GetSymbol("Logarithm").Enabled = false; // not supported yet 1417 grammar.GetSymbol("Exponential").Enabled = false; // not supported yet 1418 1419 // configure initialization of constants 1420 var constSy = (Constant)grammar.GetSymbol("Constant"); 1421 // max and min are only relevant for initialization 1422 constSy.MaxValue = +1.0e-1; // small initial values for constant opt 1423 constSy.MinValue = -1.0e-1; 1424 constSy.MultiplicativeManipulatorSigma = 1.0; // allow large jumps for manipulation 1425 constSy.ManipulatorMu = 0.0; 1426 constSy.ManipulatorSigma = 1.0; // allow large jumps 1427 1428 // configure initialization of variables 1429 var varSy = (Variable)grammar.GetSymbol("Variable"); 1430 // fix variable weights to 1.0 1431 varSy.WeightMu = 1.0; 1432 varSy.WeightSigma = 0.0; 1433 varSy.WeightManipulatorMu = 0.0; 1434 varSy.WeightManipulatorSigma = 0.0; 1435 varSy.MultiplicativeWeightManipulatorSigma = 0.0; 1436 1437 foreach (var f in FunctionSet) { 1438 grammar.GetSymbol(f.Value).Enabled = FunctionSet.ItemChecked(f); 1439 } 1440 1441 grammar.FinishedGrammarManipulation(); 1442 return grammar; 1443 // // whenever ProblemData is changed we create a new grammar with the necessary symbols 1444 // var g = new SimpleSymbolicExpressionGrammar(); 1445 // var unaryFunc = new string[] { "sin", "cos", "sqr" }; 1446 // var binaryFunc = new string[] { "+", "-", "*", "%" }; 1447 // foreach (var func in unaryFunc) { 1448 // if (FunctionSet.CheckedItems.Any(ci => ci.Value.Value == func)) g.AddSymbol(func, 1, 1); 1449 // } 1450 // foreach (var func in binaryFunc) { 1451 // if (FunctionSet.CheckedItems.Any(ci => ci.Value.Value == func)) g.AddSymbol(func, 2, 2); 1452 // } 1453 // 1454 // foreach (var variableName in ProblemData.AllowedInputVariables.Union(TargetVariables.CheckedItems.Select(i => i.Value.Value))) 1455 // g.AddTerminalSymbol(variableName); 1456 // 1457 // // generate symbols for numeric parameters for which the value is optimized using AutoDiff 1458 // // we generate multiple symbols to balance the probability for selecting a numeric parameter in the generation of random trees 1459 // var numericConstantsFactor = 2.0; 1460 // for (int i = 0; i < numericConstantsFactor * (ProblemData.AllowedInputVariables.Count() + TargetVariables.CheckedItems.Count()); i++) { 1461 // g.AddTerminalSymbol("θ" + i); // numeric parameter for which the value is optimized using AutoDiff 1462 // } 1463 // 1464 // // generate symbols for latent variables 1465 // for (int i = 1; i <= NumberOfLatentVariables; i++) { 1466 // g.AddTerminalSymbol("λ" + i); // numeric parameter for which the value is optimized using AutoDiff 1467 // } 1468 // 1469 // return g; 1480 1470 } 1481 1471 #endregion … … 1552 1542 } 1553 1543 nodes.Add(node); 1554 SetVariableValue(varName, 0.0); 1544 SetVariableValue(varName, 0.0); // this value is updated in the prediction loop 1555 1545 } 1556 1546 } … … 1574 1564 nodes.ForEach(n => node2val[n] = Tuple.Create(val, dVal)); 1575 1565 } else { 1576 var fakeNode = new SimpleSymbol(variableName, 0).CreateTreeNode();1566 var fakeNode = new VariableTreeNode(new Variable()); 1577 1567 var newNodeList = new List<ISymbolicExpressionTreeNode>(); 1578 1568 newNodeList.Add(fakeNode); -
branches/2925_AutoDiffForDynamicalModels/HeuristicLab.Problems.DynamicalSystemsModelling/3.3/Solution.cs
r16600 r16602 8 8 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 9 9 using HeuristicLab.Problems.DataAnalysis; 10 using HeuristicLab.Problems.DataAnalysis.Symbolic; 10 11 using HeuristicLab.Random; 11 12 … … 88 89 var forecastEpisode = new IntRange(episode.Start, episode.End + forecastHorizon); 89 90 90 double[] optL0;91 91 var random = new FastRandom(12345); 92 Problem.OptimizeForEpisodes(trees, problemData, targetVars, latentVariables, random, new[] { forecastEpisode }, 100, numericIntegrationSteps, odeSolver, out optL0, out snmse);92 snmse = Problem.OptimizeForEpisodes(trees, problemData, targetVars, latentVariables, random, new[] { forecastEpisode }, 100, numericIntegrationSteps, odeSolver); 93 93 var optimizationData = new Problem.OptimizationData(trees, targetVars, problemData, null, new[] { forecastEpisode }, numericIntegrationSteps, latentVariables, odeSolver); 94 var predictions = Problem.Integrate(optimizationData, optL0).ToArray(); 94 95 96 var theta = Problem.ExtractParametersFromTrees(trees); 97 98 var predictions = Problem.Integrate(optimizationData, theta).ToArray(); 95 99 return predictions.Select(p => p.Select(pi => pi.Item1).ToArray()).ToArray(); 96 100 }
Note: See TracChangeset
for help on using the changeset viewer.