Changeset 17310
- Timestamp:
- 10/04/19 09:32:41 (5 years ago)
- Location:
- branches/2994-AutoDiffForIntervals/HeuristicLab.Tests/HeuristicLab.Problems.DataAnalysis-3.4
- Files:
-
- 2 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/2994-AutoDiffForIntervals/HeuristicLab.Tests/HeuristicLab.Problems.DataAnalysis-3.4/AutoDiffInterpreterTest.cs
r17308 r17310 109 109 } 110 110 111 [TestMethod] 112 [TestCategory("Problems.DataAnalysis")] 113 [TestProperty("Time", "short")] 114 public void TestIntervalAutoDiffUsingNumericDifferences() { 115 116 // create random trees and evaluate on random data 117 // calc gradient for all parameters 118 // use numeric differences for approximate gradient calculation 119 // compare gradients 120 121 var grammar = new TypeCoherentExpressionGrammar(); 122 grammar.ConfigureAsDefaultRegressionGrammar(); 123 // activate supported symbols 124 grammar.Symbols.First(s => s is Square).Enabled = true; 125 grammar.Symbols.First(s => s is SquareRoot).Enabled = true; 126 grammar.Symbols.First(s => s is Cube).Enabled = true; 127 grammar.Symbols.First(s => s is CubeRoot).Enabled = true; 128 grammar.Symbols.First(s => s is Sine).Enabled = true; 129 grammar.Symbols.First(s => s is Cosine).Enabled = true; 130 grammar.Symbols.First(s => s is Exponential).Enabled = true; 131 grammar.Symbols.First(s => s is Logarithm).Enabled = true; 132 grammar.Symbols.First(s => s is Absolute).Enabled = false; // XXX not yet supported by old interval calculator 133 grammar.Symbols.First(s => s is AnalyticQuotient).Enabled = false; // not yet supported by old interval calculator 134 grammar.Symbols.First(s => s is Constant).Enabled = false; 135 136 var varSy = (Variable)grammar.Symbols.First(s => s is Variable); 137 varSy.AllVariableNames = new string[] { "x", "y" }; 138 varSy.VariableNames = varSy.AllVariableNames; 139 varSy.WeightMu = 1.0; 140 varSy.WeightSigma = 0.0; 141 var rand = new FastRandom(1234); 142 143 var intervals = new Dictionary<string, Interval>() { 144 { "x", new Interval(1, 2) }, 145 { "y", new Interval(0, 1) }, 146 }; 147 148 var eval = new IntervalEvaluator(); 149 150 var formatter = new InfixExpressionFormatter(); 151 var sb = new StringBuilder(); 152 int N = 10000; 153 int iter = 0; 154 while (iter < N) { 155 var t = ProbabilisticTreeCreator.Create(rand, grammar, maxTreeLength: 5, maxTreeDepth: 5); 156 var parameterNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray(); 157 158 eval.Evaluate(t, intervals, parameterNodes, out double[] lowerGradient, out double[] upperGradient); 159 160 ApproximateIntervalGradient(t, intervals, parameterNodes, eval, out double[] refLowerGradient, out double[] refUpperGradient); 161 162 // compare autodiff and numeric diff 163 for(int p=0;p<parameterNodes.Length;p++) { 164 // lower 165 if(double.IsNaN(lowerGradient[p]) && double.IsNaN(refLowerGradient[p])) { 166 167 } else if(lowerGradient[p] == refLowerGradient[p]){ 168 169 } else if(Math.Abs(lowerGradient[p] - refLowerGradient[p]) < Math.Abs(lowerGradient[p]) * 1e-4) { 170 171 } else { 172 sb.AppendLine($"{lowerGradient[p]} <> {refLowerGradient[p]} for {parameterNodes[p]} in {formatter.Format(t)}"); 173 } 174 // upper 175 if (double.IsNaN(upperGradient[p]) && double.IsNaN(refUpperGradient[p])) { 176 177 } else if (upperGradient[p] == refUpperGradient[p]) { 178 179 } else if (Math.Abs(upperGradient[p] - refUpperGradient[p]) < Math.Abs(upperGradient[p]) * 1e-4) { 180 181 } else { 182 sb.AppendLine($"{upperGradient[p]} <> {refUpperGradient[p]} for {parameterNodes[p]} in {formatter.Format(t)}"); 183 } 184 } 185 186 iter++; 187 } 188 if (sb.Length > 0) { 189 Console.WriteLine(sb.ToString()); 190 Assert.Fail("There were differences when validating AutoDiff using numeric differences"); 191 } 192 } 193 111 194 #region helper 112 195 … … 158 241 } 159 242 243 private void ApproximateIntervalGradient(ISymbolicExpressionTree t, Dictionary<string, Interval> intervals, ISymbolicExpressionTreeNode[] parameterNodes, IntervalEvaluator eval, out double[] lowerGradient, out double[] upperGradient) { 244 lowerGradient = new double[parameterNodes.Length]; 245 upperGradient = new double[parameterNodes.Length]; 246 247 for(int p=0;p<parameterNodes.Length;p++) { 248 var x = GetValue(parameterNodes[p]); 249 var x_diff = x * 1e-4; // relative change 250 251 // calculate output for increased parameter value 252 SetValue(parameterNodes[p], x + x_diff / 2); 253 var r1 = eval.Evaluate(t, intervals); 254 lowerGradient[p] = r1.LowerBound; 255 upperGradient[p] = r1.UpperBound; 256 257 // calculate output for decreased parameter value 258 SetValue(parameterNodes[p], x - x_diff / 2); 259 var r2 = eval.Evaluate(t, intervals); 260 lowerGradient[p] -= r2.LowerBound; 261 upperGradient[p] -= r2.UpperBound; 262 263 lowerGradient[p] /= x_diff; 264 upperGradient[p] /= x_diff; 265 266 // restore original value 267 SetValue(parameterNodes[p], x); 268 } 269 } 270 160 271 private void SetValue(ISymbolicExpressionTreeNode node, double v) { 161 272 var varNode = node as VariableTreeNode; -
branches/2994-AutoDiffForIntervals/HeuristicLab.Tests/HeuristicLab.Problems.DataAnalysis-3.4/AutoDiffIntervalTest.cs
r17303 r17310 1 1 using System; 2 2 using System.Collections.Generic; 3 using System.Linq; 3 4 using HeuristicLab.Problems.DataAnalysis.Symbolic; 4 5 using Microsoft.VisualStudio.TestTools.UnitTesting; … … 32 33 Assert.IsTrue(double.IsNaN(b.LowerBound.Value)); 33 34 } else { 34 Assert.AreEqual(a.LowerBound.Value.Value, b.LowerBound.Value.Value, Math.Abs(a.LowerBound.Value.Value) *1e-4); // relative error < 0.1%35 Assert.AreEqual(a.LowerBound.Value.Value, b.LowerBound.Value.Value, Math.Abs(a.LowerBound.Value.Value) * 1e-4); // relative error < 0.1% 35 36 } 36 37 … … 226 227 AssertAreEqualInterval(new AlgebraicInterval(-2, -1), new AlgebraicInterval(-8, -1).IntRoot(3)); 227 228 } 229 230 [TestMethod] 231 [TestCategory("Problems.DataAnalysis")] 232 [TestProperty("Time", "short")] 233 public void TestIntervalAddAutoDiff() { 234 var eval = new IntervalEvaluator(); 235 var parser = new InfixExpressionParser(); 236 var t = parser.Parse("x + y"); 237 var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray(); 238 var intervals = new Dictionary<string, Interval>() { 239 { "x", new Interval(1, 2) }, 240 { "y", new Interval(0, 1) } 241 }; 242 var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug); 243 Assert.AreEqual(1, r.LowerBound); 244 Assert.AreEqual(3, r.UpperBound); 245 246 Assert.AreEqual(1.0, lg[0]); // x 247 Assert.AreEqual(2.0, ug[0]); 248 Assert.AreEqual(0.0, lg[1]); // y 249 Assert.AreEqual(1.0, ug[1]); 250 } 251 252 [TestMethod] 253 [TestCategory("Problems.DataAnalysis")] 254 [TestProperty("Time", "short")] 255 public void TestIntervalMulAutoDiff() { 256 var eval = new IntervalEvaluator(); 257 var parser = new InfixExpressionParser(); 258 var t = parser.Parse("x * y"); 259 var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray(); 260 var intervals = new Dictionary<string, Interval>() { 261 { "x", new Interval(1, 2) }, 262 { "y", new Interval(0, 1) } 263 }; 264 var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug); 265 Assert.AreEqual(0, r.LowerBound); 266 Assert.AreEqual(2, r.UpperBound); 267 268 Assert.AreEqual(0.0, lg[0]); // x 269 Assert.AreEqual(2.0, ug[0]); 270 Assert.AreEqual(0.0, lg[1]); // y 271 Assert.AreEqual(2.0, ug[1]); 272 } 273 274 [TestMethod] 275 [TestCategory("Problems.DataAnalysis")] 276 [TestProperty("Time", "short")] 277 public void TestIntervalSqrAutoDiff() { 278 var eval = new IntervalEvaluator(); 279 var parser = new InfixExpressionParser(); 280 var t = parser.Parse("sqr(x)"); 281 var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray(); 282 var intervals = new Dictionary<string, Interval>() { 283 { "x", new Interval(1, 2) }, 284 { "y", new Interval(0, 1) } 285 }; 286 var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug); 287 Assert.AreEqual(XXX, r.LowerBound); 288 Assert.AreEqual(XXX, r.UpperBound); 289 290 Assert.AreEqual(XXX, lg[0]); // x 291 Assert.AreEqual(XXX, ug[0]); 292 293 for { "x", new Interval(1, 2) }, 294 { "y", new Interval(0, 1) }, 295 296 0 <> -2,50012500572888E-05 for y in SQR(LOG('y')) 297 0 <> 2, 49987500573946E-05 for x in SQR(LOG('x')) 298 } 299 300 [TestMethod] 301 [TestCategory("Problems.DataAnalysis")] 302 [TestProperty("Time", "short")] 303 public void TestIntervalExpAutoDiff() { 304 var eval = new IntervalEvaluator(); 305 var parser = new InfixExpressionParser(); 306 var t = parser.Parse("exp(x)"); 307 var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray(); 308 var intervals = new Dictionary<string, Interval>() { 309 { "x", new Interval(1, 2) }, 310 }; 311 var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug); 312 Assert.AreEqual(Math.Exp(1), r.LowerBound); 313 Assert.AreEqual(Math.Exp(2), r.UpperBound); 314 315 Assert.AreEqual(Math.Exp(1), lg[0]); // x 316 Assert.AreEqual(Math.Exp(2) * 2, ug[0]); 317 } 318 319 [TestMethod] 320 [TestCategory("Problems.DataAnalysis")] 321 [TestProperty("Time", "short")] 322 public void TestIntervalSinAutoDiff() { 323 var eval = new IntervalEvaluator(); 324 var parser = new InfixExpressionParser(); 325 var t = parser.Parse("sin(x)"); 326 var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray(); 327 var intervals = new Dictionary<string, Interval>() { 328 { "x", new Interval(1, 2) }, 329 }; 330 var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug); 331 Assert.AreEqual(Math.Sin(1), r.LowerBound); // sin(1) < sin(2) 332 Assert.AreEqual(1, r.UpperBound); // 1..2 crosses pi / 2 and sin(pi/2)==1 333 334 Assert.AreEqual(Math.Cos(1), lg[0]); // x 335 Assert.AreEqual(0, ug[0]); 336 } 337 338 [TestMethod] 339 [TestCategory("Problems.DataAnalysis")] 340 [TestProperty("Time", "short")] 341 public void TestIntervalCosAutoDiff() { 342 var eval = new IntervalEvaluator(); 343 var parser = new InfixExpressionParser(); 344 var t = parser.Parse("cos(x)"); 345 var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray(); 346 var intervals = new Dictionary<string, Interval>() { 347 { "x", new Interval(3, 4) }, 348 }; 349 var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug); 350 Assert.AreEqual(-1, r.LowerBound); // 3..4 crosses pi and cos(pi) == -1 351 Assert.AreEqual(Math.Cos(4), r.UpperBound); // cos(3) < cos(4) 352 353 Assert.AreEqual(0, lg[0]); // x 354 Assert.AreEqual(-4*Math.Sin(4), ug[0]); 355 } 356 357 [TestMethod] 358 [TestCategory("Problems.DataAnalysis")] 359 [TestProperty("Time", "short")] 360 public void TestIntervalSqrtAutoDiff() { 361 var eval = new IntervalEvaluator(); 362 var parser = new InfixExpressionParser(); 363 var t = parser.Parse("sqrt(x)"); 364 var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray(); 365 var intervals = new Dictionary<string, Interval>() { 366 { "x", new Interval(4, 9) }, 367 { "y", new Interval(1, 2) }, 368 { "z", new Interval(0, 1) }, 369 }; 370 var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug); 371 Assert.AreEqual(2, r.LowerBound); 372 Assert.AreEqual(3, r.UpperBound); 373 374 Assert.AreEqual(1.0, lg[0]); // x 375 Assert.AreEqual(1.5, ug[0]); 376 377 t = parser.Parse("sqrt(y)"); 378 paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray(); 379 r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug); 380 Assert.AreEqual(1, r.LowerBound); 381 Assert.AreEqual(Math.Sqrt(2), r.UpperBound); 382 383 Assert.AreEqual(0.5, lg[0]); // y 384 Assert.AreEqual(0.5*Math.Sqrt(2), ug[0], 1e-5); 385 386 t = parser.Parse("sqrt(z)"); 387 paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray(); 388 r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug); 389 Assert.AreEqual(0, r.LowerBound); 390 Assert.AreEqual(1, r.UpperBound); 391 392 Assert.AreEqual(0, lg[0]); // z 393 Assert.AreEqual(0.5 * Math.Sqrt(2), ug[0], 1e-5); 394 } 395 396 [TestMethod] 397 [TestCategory("Problems.DataAnalysis")] 398 [TestProperty("Time", "short")] 399 public void TestIntervalCqrtAutoDiff() { 400 var eval = new IntervalEvaluator(); 401 var parser = new InfixExpressionParser(); 402 var t = parser.Parse("cuberoot(x)"); 403 var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray(); 404 var intervals = new Dictionary<string, Interval>() { 405 { "x", new Interval(8, 27) }, 406 { "y", new Interval(1, 2) }, 407 { "z", new Interval(0, 1) }, 408 }; 409 var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug); 410 Assert.AreEqual(2, r.LowerBound); 411 Assert.AreEqual(3, r.UpperBound); 412 413 Assert.AreEqual(0.0, lg[0]); // x 414 Assert.AreEqual(0.0, ug[0]); XXXX 415 416 t = parser.Parse("sqrt(y)"); 417 paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray(); 418 r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug); 419 Assert.AreEqual(0.0, r.LowerBound); 420 Assert.AreEqual(0.0, r.UpperBound); 421 422 Assert.AreEqual(0.0, lg[0]); // y 423 Assert.AreEqual(0.0, ug[0], 1e-5); 424 425 t = parser.Parse("sqrt(z)"); 426 paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray(); 427 r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug); 428 Assert.AreEqual(0.0, r.LowerBound); 429 Assert.AreEqual(0.0, r.UpperBound); 430 431 Assert.AreEqual(0.0, lg[0]); // z 432 Assert.AreEqual(0.0, ug[0], 1e-5); 433 } 434 435 [TestMethod] 436 [TestCategory("Problems.DataAnalysis")] 437 [TestProperty("Time", "short")] 438 public void TestIntervalLogAutoDiff() { 439 var eval = new IntervalEvaluator(); 440 var parser = new InfixExpressionParser(); 441 var t = parser.Parse("log(4*x)"); 442 var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray(); 443 var intervals = new Dictionary<string, Interval>() { 444 { "x", new Interval(1, 2) }, 445 }; 446 var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug); 447 Assert.AreEqual(Math.Log(4), r.LowerBound); 448 Assert.AreEqual(Math.Log(8), r.UpperBound); 449 450 Assert.AreEqual(0.25, lg[0]); // x 451 Assert.AreEqual(0.25, ug[0]); 452 453 } 228 454 } 229 455 }
Note: See TracChangeset
for help on using the changeset viewer.