Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
10/04/19 09:32:41 (5 years ago)
Author:
gkronber
Message:

#2994: worked on unit tests for interval autodiff (work-in-progress, doesn't compile)

Location:
branches/2994-AutoDiffForIntervals/HeuristicLab.Tests/HeuristicLab.Problems.DataAnalysis-3.4
Files:
2 edited

Legend:

Unmodified
Added
Removed
  • branches/2994-AutoDiffForIntervals/HeuristicLab.Tests/HeuristicLab.Problems.DataAnalysis-3.4/AutoDiffInterpreterTest.cs

    r17308 r17310  
    109109    }
    110110
     111    [TestMethod]
     112    [TestCategory("Problems.DataAnalysis")]
     113    [TestProperty("Time", "short")]
     114    public void TestIntervalAutoDiffUsingNumericDifferences() {
     115
     116      // create random trees and evaluate on random data
     117      // calc gradient for all parameters
     118      // use numeric differences for approximate gradient calculation
     119      // compare gradients
     120
     121      var grammar = new TypeCoherentExpressionGrammar();
     122      grammar.ConfigureAsDefaultRegressionGrammar();
     123      // activate supported symbols
     124      grammar.Symbols.First(s => s is Square).Enabled = true;
     125      grammar.Symbols.First(s => s is SquareRoot).Enabled = true;
     126      grammar.Symbols.First(s => s is Cube).Enabled = true;
     127      grammar.Symbols.First(s => s is CubeRoot).Enabled = true;
     128      grammar.Symbols.First(s => s is Sine).Enabled = true;
     129      grammar.Symbols.First(s => s is Cosine).Enabled = true;
     130      grammar.Symbols.First(s => s is Exponential).Enabled = true;
     131      grammar.Symbols.First(s => s is Logarithm).Enabled = true;
     132      grammar.Symbols.First(s => s is Absolute).Enabled = false; // XXX not yet supported by old interval calculator
     133      grammar.Symbols.First(s => s is AnalyticQuotient).Enabled = false; // not yet supported by old interval calculator
     134      grammar.Symbols.First(s => s is Constant).Enabled = false;
     135
     136      var varSy = (Variable)grammar.Symbols.First(s => s is Variable);
     137      varSy.AllVariableNames = new string[] { "x", "y" };
     138      varSy.VariableNames = varSy.AllVariableNames;
     139      varSy.WeightMu = 1.0;
     140      varSy.WeightSigma = 0.0;
     141      var rand = new FastRandom(1234);
     142
     143      var intervals = new Dictionary<string, Interval>() {
     144        { "x", new Interval(1, 2) },
     145        { "y", new Interval(0, 1) },
     146      };
     147
     148      var eval = new IntervalEvaluator();
     149
     150      var formatter = new InfixExpressionFormatter();
     151      var sb = new StringBuilder();
     152      int N = 10000;
     153      int iter = 0;
     154      while (iter < N) {
     155        var t = ProbabilisticTreeCreator.Create(rand, grammar, maxTreeLength: 5, maxTreeDepth: 5);
     156        var parameterNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
     157
     158        eval.Evaluate(t, intervals, parameterNodes, out double[] lowerGradient, out double[] upperGradient);
     159
     160        ApproximateIntervalGradient(t, intervals, parameterNodes, eval, out double[] refLowerGradient, out double[] refUpperGradient);
     161
     162        // compare autodiff and numeric diff
     163        for(int p=0;p<parameterNodes.Length;p++) {
     164          // lower
     165          if(double.IsNaN(lowerGradient[p]) && double.IsNaN(refLowerGradient[p])) {
     166
     167          } else if(lowerGradient[p] == refLowerGradient[p]){
     168
     169          } else if(Math.Abs(lowerGradient[p] - refLowerGradient[p]) < Math.Abs(lowerGradient[p]) * 1e-4) {
     170
     171          } else {
     172            sb.AppendLine($"{lowerGradient[p]} <> {refLowerGradient[p]} for {parameterNodes[p]} in {formatter.Format(t)}");
     173          }
     174          // upper
     175          if (double.IsNaN(upperGradient[p]) && double.IsNaN(refUpperGradient[p])) {
     176
     177          } else if (upperGradient[p] == refUpperGradient[p]) {
     178
     179          } else if (Math.Abs(upperGradient[p] - refUpperGradient[p]) < Math.Abs(upperGradient[p]) * 1e-4) {
     180
     181          } else {
     182            sb.AppendLine($"{upperGradient[p]} <> {refUpperGradient[p]} for {parameterNodes[p]} in {formatter.Format(t)}");
     183          }
     184        }
     185
     186        iter++;
     187      }
     188      if (sb.Length > 0) {
     189        Console.WriteLine(sb.ToString());
     190        Assert.Fail("There were differences when validating AutoDiff using numeric differences");
     191      }
     192    }
     193
    111194    #region helper
    112195
     
    158241    }
    159242
     243    private void ApproximateIntervalGradient(ISymbolicExpressionTree t, Dictionary<string, Interval> intervals, ISymbolicExpressionTreeNode[] parameterNodes, IntervalEvaluator eval, out double[] lowerGradient, out double[] upperGradient) {
     244      lowerGradient = new double[parameterNodes.Length];
     245      upperGradient = new double[parameterNodes.Length];
     246
     247      for(int p=0;p<parameterNodes.Length;p++) {
     248        var x = GetValue(parameterNodes[p]);
     249        var x_diff = x * 1e-4; // relative change
     250
     251        // calculate output for increased parameter value
     252        SetValue(parameterNodes[p], x + x_diff / 2);
     253        var r1 = eval.Evaluate(t, intervals);
     254        lowerGradient[p] = r1.LowerBound;
     255        upperGradient[p] = r1.UpperBound;
     256
     257        // calculate output for decreased parameter value
     258        SetValue(parameterNodes[p], x - x_diff / 2);
     259        var r2 = eval.Evaluate(t, intervals);
     260        lowerGradient[p] -= r2.LowerBound;
     261        upperGradient[p] -= r2.UpperBound;
     262
     263        lowerGradient[p] /= x_diff;
     264        upperGradient[p] /= x_diff;
     265
     266        // restore original value
     267        SetValue(parameterNodes[p], x);
     268      }
     269    }
     270
    160271    private void SetValue(ISymbolicExpressionTreeNode node, double v) {
    161272      var varNode = node as VariableTreeNode;
  • branches/2994-AutoDiffForIntervals/HeuristicLab.Tests/HeuristicLab.Problems.DataAnalysis-3.4/AutoDiffIntervalTest.cs

    r17303 r17310  
    11using System;
    22using System.Collections.Generic;
     3using System.Linq;
    34using HeuristicLab.Problems.DataAnalysis.Symbolic;
    45using Microsoft.VisualStudio.TestTools.UnitTesting;
     
    3233        Assert.IsTrue(double.IsNaN(b.LowerBound.Value));
    3334      } else {
    34         Assert.AreEqual(a.LowerBound.Value.Value, b.LowerBound.Value.Value, Math.Abs(a.LowerBound.Value.Value)*1e-4); // relative error < 0.1%
     35        Assert.AreEqual(a.LowerBound.Value.Value, b.LowerBound.Value.Value, Math.Abs(a.LowerBound.Value.Value) * 1e-4); // relative error < 0.1%
    3536      }
    3637
     
    226227      AssertAreEqualInterval(new AlgebraicInterval(-2, -1), new AlgebraicInterval(-8, -1).IntRoot(3));
    227228    }
     229
     230    [TestMethod]
     231    [TestCategory("Problems.DataAnalysis")]
     232    [TestProperty("Time", "short")]
     233    public void TestIntervalAddAutoDiff() {
     234      var eval = new IntervalEvaluator();
     235      var parser = new InfixExpressionParser();
     236      var t = parser.Parse("x + y");
     237      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
     238      var intervals = new Dictionary<string, Interval>() {
     239        { "x", new Interval(1, 2) },
     240        { "y", new Interval(0, 1) }
     241      };
     242      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
     243      Assert.AreEqual(1, r.LowerBound);
     244      Assert.AreEqual(3, r.UpperBound);
     245
     246      Assert.AreEqual(1.0, lg[0]); // x
     247      Assert.AreEqual(2.0, ug[0]);
     248      Assert.AreEqual(0.0, lg[1]); // y
     249      Assert.AreEqual(1.0, ug[1]);
     250    }
     251
     252    [TestMethod]
     253    [TestCategory("Problems.DataAnalysis")]
     254    [TestProperty("Time", "short")]
     255    public void TestIntervalMulAutoDiff() {
     256      var eval = new IntervalEvaluator();
     257      var parser = new InfixExpressionParser();
     258      var t = parser.Parse("x * y");
     259      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
     260      var intervals = new Dictionary<string, Interval>() {
     261        { "x", new Interval(1, 2) },
     262        { "y", new Interval(0, 1) }
     263      };
     264      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
     265      Assert.AreEqual(0, r.LowerBound);
     266      Assert.AreEqual(2, r.UpperBound);
     267
     268      Assert.AreEqual(0.0, lg[0]); // x
     269      Assert.AreEqual(2.0, ug[0]);
     270      Assert.AreEqual(0.0, lg[1]); // y
     271      Assert.AreEqual(2.0, ug[1]);
     272    }
     273
     274    [TestMethod]
     275    [TestCategory("Problems.DataAnalysis")]
     276    [TestProperty("Time", "short")]
     277    public void TestIntervalSqrAutoDiff() {
     278      var eval = new IntervalEvaluator();
     279      var parser = new InfixExpressionParser();
     280      var t = parser.Parse("sqr(x)");
     281      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
     282      var intervals = new Dictionary<string, Interval>() {
     283        { "x", new Interval(1, 2) },
     284        { "y", new Interval(0, 1) }
     285      };
     286      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
     287      Assert.AreEqual(XXX, r.LowerBound);
     288      Assert.AreEqual(XXX, r.UpperBound);
     289
     290      Assert.AreEqual(XXX, lg[0]); // x
     291      Assert.AreEqual(XXX, ug[0]);
     292
     293      for  { "x", new Interval(1, 2) },
     294        { "y", new Interval(0, 1) },
     295
     296      0 <> -2,50012500572888E-05 for y in SQR(LOG('y'))
     297      0 <> 2, 49987500573946E-05 for x in SQR(LOG('x'))
     298    }
     299
     300    [TestMethod]
     301    [TestCategory("Problems.DataAnalysis")]
     302    [TestProperty("Time", "short")]
     303    public void TestIntervalExpAutoDiff() {
     304      var eval = new IntervalEvaluator();
     305      var parser = new InfixExpressionParser();
     306      var t = parser.Parse("exp(x)");
     307      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
     308      var intervals = new Dictionary<string, Interval>() {
     309        { "x", new Interval(1, 2) },
     310      };
     311      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
     312      Assert.AreEqual(Math.Exp(1), r.LowerBound);
     313      Assert.AreEqual(Math.Exp(2), r.UpperBound);
     314
     315      Assert.AreEqual(Math.Exp(1), lg[0]); // x
     316      Assert.AreEqual(Math.Exp(2) * 2, ug[0]);
     317    }
     318
     319    [TestMethod]
     320    [TestCategory("Problems.DataAnalysis")]
     321    [TestProperty("Time", "short")]
     322    public void TestIntervalSinAutoDiff() {
     323      var eval = new IntervalEvaluator();
     324      var parser = new InfixExpressionParser();
     325      var t = parser.Parse("sin(x)");
     326      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
     327      var intervals = new Dictionary<string, Interval>() {
     328        { "x", new Interval(1, 2) },
     329      };
     330      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
     331      Assert.AreEqual(Math.Sin(1), r.LowerBound); // sin(1) < sin(2)
     332      Assert.AreEqual(1, r.UpperBound); //  1..2 crosses pi / 2 and sin(pi/2)==1
     333
     334      Assert.AreEqual(Math.Cos(1), lg[0]); // x
     335      Assert.AreEqual(0, ug[0]);
     336    }
     337
     338    [TestMethod]
     339    [TestCategory("Problems.DataAnalysis")]
     340    [TestProperty("Time", "short")]
     341    public void TestIntervalCosAutoDiff() {
     342      var eval = new IntervalEvaluator();
     343      var parser = new InfixExpressionParser();
     344      var t = parser.Parse("cos(x)");
     345      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
     346      var intervals = new Dictionary<string, Interval>() {
     347        { "x", new Interval(3, 4) },
     348      };
     349      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
     350      Assert.AreEqual(-1, r.LowerBound); //  3..4 crosses pi and cos(pi) == -1
     351      Assert.AreEqual(Math.Cos(4), r.UpperBound); // cos(3) < cos(4)
     352
     353      Assert.AreEqual(0, lg[0]); // x
     354      Assert.AreEqual(-4*Math.Sin(4), ug[0]);
     355    }
     356
     357    [TestMethod]
     358    [TestCategory("Problems.DataAnalysis")]
     359    [TestProperty("Time", "short")]
     360    public void TestIntervalSqrtAutoDiff() {
     361      var eval = new IntervalEvaluator();
     362      var parser = new InfixExpressionParser();
     363      var t = parser.Parse("sqrt(x)");
     364      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
     365      var intervals = new Dictionary<string, Interval>() {
     366        { "x", new Interval(4, 9) },
     367        { "y", new Interval(1, 2) },
     368        { "z", new Interval(0, 1) },
     369      };
     370      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
     371      Assert.AreEqual(2, r.LowerBound);
     372      Assert.AreEqual(3, r.UpperBound);
     373
     374      Assert.AreEqual(1.0, lg[0]); // x
     375      Assert.AreEqual(1.5, ug[0]);
     376
     377      t = parser.Parse("sqrt(y)");
     378      paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
     379      r = eval.Evaluate(t, intervals, paramNodes, out lg, out  ug);
     380      Assert.AreEqual(1, r.LowerBound);
     381      Assert.AreEqual(Math.Sqrt(2), r.UpperBound);
     382
     383      Assert.AreEqual(0.5, lg[0]); // y
     384      Assert.AreEqual(0.5*Math.Sqrt(2), ug[0], 1e-5);
     385
     386      t = parser.Parse("sqrt(z)");
     387      paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
     388      r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
     389      Assert.AreEqual(0, r.LowerBound);
     390      Assert.AreEqual(1, r.UpperBound);
     391
     392      Assert.AreEqual(0, lg[0]); // z
     393      Assert.AreEqual(0.5 * Math.Sqrt(2), ug[0], 1e-5);
     394    }
     395
     396    [TestMethod]
     397    [TestCategory("Problems.DataAnalysis")]
     398    [TestProperty("Time", "short")]
     399    public void TestIntervalCqrtAutoDiff() {
     400      var eval = new IntervalEvaluator();
     401      var parser = new InfixExpressionParser();
     402      var t = parser.Parse("cuberoot(x)");
     403      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
     404      var intervals = new Dictionary<string, Interval>() {
     405        { "x", new Interval(8, 27) },
     406        { "y", new Interval(1, 2) },
     407        { "z", new Interval(0, 1) },
     408      };
     409      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
     410      Assert.AreEqual(2, r.LowerBound);
     411      Assert.AreEqual(3, r.UpperBound);
     412
     413      Assert.AreEqual(0.0, lg[0]); // x
     414      Assert.AreEqual(0.0, ug[0]); XXXX
     415
     416      t = parser.Parse("sqrt(y)");
     417      paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
     418      r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
     419      Assert.AreEqual(0.0, r.LowerBound);
     420      Assert.AreEqual(0.0, r.UpperBound);
     421
     422      Assert.AreEqual(0.0, lg[0]); // y
     423      Assert.AreEqual(0.0, ug[0], 1e-5);
     424
     425      t = parser.Parse("sqrt(z)");
     426      paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
     427      r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
     428      Assert.AreEqual(0.0, r.LowerBound);
     429      Assert.AreEqual(0.0, r.UpperBound);
     430
     431      Assert.AreEqual(0.0, lg[0]); // z
     432      Assert.AreEqual(0.0, ug[0], 1e-5);
     433    }
     434
     435    [TestMethod]
     436    [TestCategory("Problems.DataAnalysis")]
     437    [TestProperty("Time", "short")]
     438    public void TestIntervalLogAutoDiff() {
     439      var eval = new IntervalEvaluator();
     440      var parser = new InfixExpressionParser();
     441      var t = parser.Parse("log(4*x)");
     442      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
     443      var intervals = new Dictionary<string, Interval>() {
     444        { "x", new Interval(1, 2) },
     445      };
     446      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
     447      Assert.AreEqual(Math.Log(4), r.LowerBound);
     448      Assert.AreEqual(Math.Log(8), r.UpperBound);
     449
     450      Assert.AreEqual(0.25, lg[0]); // x
     451      Assert.AreEqual(0.25, ug[0]);
     452
     453    }
    228454  }
    229455}
Note: See TracChangeset for help on using the changeset viewer.