source: branches/2994-AutoDiffForIntervals/HeuristicLab.Tests/HeuristicLab.Problems.DataAnalysis.Symbolic-3.4/IntervalEvaluatorAutoDiffTest.cs @ 17319

Last change on this file since 17319 was 17319, checked in by gkronber, 12 months ago

#2994: recheck cbrt() and sqrt() autoDiff tests

File size: 16.8 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
6using HeuristicLab.Problems.DataAnalysis.Symbolic;
7using HeuristicLab.Random;
8using Microsoft.VisualStudio.TestTools.UnitTesting;
9
10namespace HeuristicLab.Problems.DataAnalysis.Tests {
11  [TestClass]
12  public class IntervalEvaluatorAutoDiffTest {
13    [TestMethod]
14    [TestCategory("Problems.DataAnalysis")]
15    [TestProperty("Time", "short")]
16    public void IntervalEvalutorAutoDiffAdd() {
17      var eval = new IntervalEvaluator();
18      var parser = new InfixExpressionParser();
19      var t = parser.Parse("x + y");
20      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
21      var intervals = new Dictionary<string, Interval>() {
22        { "x", new Interval(1, 2) },
23        { "y", new Interval(0, 1) }
24      };
25      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
26      Assert.AreEqual(1, r.LowerBound);
27      Assert.AreEqual(3, r.UpperBound);
28
29      Assert.AreEqual(1.0, lg[0]); // x
30      Assert.AreEqual(2.0, ug[0]);
31      Assert.AreEqual(0.0, lg[1]); // y
32      Assert.AreEqual(1.0, ug[1]);
33    }
34
35    [TestMethod]
36    [TestCategory("Problems.DataAnalysis")]
37    [TestProperty("Time", "short")]
38    public void IntervalEvalutorAutoDiffMul() {
39      var eval = new IntervalEvaluator();
40      var parser = new InfixExpressionParser();
41      var t = parser.Parse("x * y");
42      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
43      var intervals = new Dictionary<string, Interval>() {
44        { "x", new Interval(1, 2) },
45        { "y", new Interval(0, 1) }
46      };
47      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
48      Assert.AreEqual(0, r.LowerBound);
49      Assert.AreEqual(2, r.UpperBound);
50
51      Assert.AreEqual(0.0, lg[0]); // x
52      Assert.AreEqual(2.0, ug[0]);
53      Assert.AreEqual(0.0, lg[1]); // y
54      Assert.AreEqual(2.0, ug[1]);
55    }
56
57    [TestMethod]
58    [TestCategory("Problems.DataAnalysis")]
59    [TestProperty("Time", "short")]
60    public void IntervalEvalutorAutoDiffSqr() {
61      var eval = new IntervalEvaluator();
62      var parser = new InfixExpressionParser();
63      var t = parser.Parse("sqr(x)");
64      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
65      var intervals = new Dictionary<string, Interval>() {
66        { "x", new Interval(1, 2) },
67        { "y", new Interval(0, 1) }
68      };
69      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
70      // TODO
71      // Assert.AreEqual(XXX, r.LowerBound);
72      // Assert.AreEqual(XXX, r.UpperBound);
73      //
74      // Assert.AreEqual(XXX, lg[0]); // x
75      // Assert.AreEqual(XXX, ug[0]);
76      //
77      // for  { "x", new Interval(1, 2) },
78      //   { "y", new Interval(0, 1) },
79      //
80      // 0 <> -2,50012500572888E-05 for y in SQR(LOG('y'))
81      // 0 <> 2, 49987500573946E-05 for x in SQR(LOG('x'))
82    }
83
84    [TestMethod]
85    [TestCategory("Problems.DataAnalysis")]
86    [TestProperty("Time", "short")]
87    public void IntervalEvalutorAutoDiffExp() {
88      var eval = new IntervalEvaluator();
89      var parser = new InfixExpressionParser();
90      var t = parser.Parse("exp(x)");
91      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
92      var intervals = new Dictionary<string, Interval>() {
93        { "x", new Interval(1, 2) },
94      };
95      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
96      Assert.AreEqual(Math.Exp(1), r.LowerBound);
97      Assert.AreEqual(Math.Exp(2), r.UpperBound);
98
99      Assert.AreEqual(Math.Exp(1), lg[0]); // x
100      Assert.AreEqual(Math.Exp(2) * 2, ug[0]);
101    }
102
103    [TestMethod]
104    [TestCategory("Problems.DataAnalysis")]
105    [TestProperty("Time", "short")]
106    public void IntervalEvalutorAutoDiffSin() {
107      var eval = new IntervalEvaluator();
108      var parser = new InfixExpressionParser();
109      var t = parser.Parse("sin(x)");
110      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
111      var intervals = new Dictionary<string, Interval>() {
112        { "x", new Interval(1, 2) },
113      };
114      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
115      Assert.AreEqual(Math.Sin(1), r.LowerBound); // sin(1) < sin(2)
116      Assert.AreEqual(1, r.UpperBound); //  1..2 crosses pi / 2 and sin(pi/2)==1
117
118      Assert.AreEqual(Math.Cos(1), lg[0]); // x
119      Assert.AreEqual(0, ug[0]);
120    }
121
122    [TestMethod]
123    [TestCategory("Problems.DataAnalysis")]
124    [TestProperty("Time", "short")]
125    public void IntervalEvalutorAutoDiffCos() {
126      var eval = new IntervalEvaluator();
127      var parser = new InfixExpressionParser();
128      var t = parser.Parse("cos(x)");
129      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
130      var intervals = new Dictionary<string, Interval>() {
131        { "x", new Interval(3, 4) },
132      };
133      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
134      Assert.AreEqual(-1, r.LowerBound); //  3..4 crosses pi and cos(pi) == -1
135      Assert.AreEqual(Math.Cos(4), r.UpperBound); // cos(3) < cos(4)
136
137      Assert.AreEqual(0, lg[0]); // x
138      Assert.AreEqual(-4 * Math.Sin(4), ug[0]);
139    }
140
141    [TestMethod]
142    [TestCategory("Problems.DataAnalysis")]
143    [TestProperty("Time", "short")]
144    public void IntervalEvalutorAutoDiffSqrt() {
145      var eval = new IntervalEvaluator();
146      var parser = new InfixExpressionParser();
147      var t = parser.Parse("sqrt(x)");
148      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
149      var intervals = new Dictionary<string, Interval>() {
150        { "x", new Interval(4, 9) },
151        { "y", new Interval(1, 2) },
152        { "z", new Interval(0, 1) },
153        { "eps", new Interval(1e-10, 1) }           
154      };
155      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
156      Assert.AreEqual(2, r.LowerBound);
157      Assert.AreEqual(3, r.UpperBound);
158
159      Assert.AreEqual(1.0, lg[0]); // x
160      Assert.AreEqual(1.5, ug[0]);
161
162      t = parser.Parse("sqrt(y)");
163      paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
164      r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
165      Assert.AreEqual(1, r.LowerBound);
166      Assert.AreEqual(Math.Sqrt(2), r.UpperBound);
167
168      Assert.AreEqual(0.5, lg[0]); // y
169      Assert.AreEqual(0.5 * Math.Sqrt(2), ug[0], 1e-5);
170
171      t = parser.Parse("sqrt(z)");
172      paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
173      r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
174      Assert.AreEqual(0, r.LowerBound);
175      Assert.AreEqual(1, r.UpperBound);
176
177      Assert.AreEqual(double.NaN, lg[0]); // z
178      Assert.AreEqual(0.5, ug[0], 1e-5);
179
180      t = parser.Parse("sqrt(eps)");
181      paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
182      r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
183
184      Assert.AreEqual(0.5 * Math.Sqrt(1e-10), lg[0], 1e-6); // z          --> lim x -> 0 (sqrt(x)) = 0
185      Assert.AreEqual(0.5, ug[0], 1e-5);
186    }
187
188    [TestMethod]
189    [TestCategory("Problems.DataAnalysis")]
190    [TestProperty("Time", "short")]
191    public void IntervalEvalutorAutoDiffCqrt() {
192      var eval = new IntervalEvaluator();
193      var parser = new InfixExpressionParser();
194      var t = parser.Parse("cuberoot(x)");
195      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
196      var intervals = new Dictionary<string, Interval>() {
197        { "x", new Interval(8, 27) },
198        { "y", new Interval(1, 2) },
199        { "z", new Interval(0, 1) },
200      };
201      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
202      Assert.AreEqual(2, r.LowerBound);
203      Assert.AreEqual(3, r.UpperBound);
204
205      Assert.AreEqual(2.0 / 3.0, lg[0]); // x
206      Assert.AreEqual(1.0, ug[0]);
207
208      t = parser.Parse("cuberoot(y)");
209      paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
210      r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
211      Assert.AreEqual(Math.Pow(1, 1.0 / 3.0), r.LowerBound);
212      Assert.AreEqual(Math.Pow(2, 1.0 / 3.0), r.UpperBound);
213
214      Assert.AreEqual(1.0 / 3.0, lg[0]); // y
215      Assert.AreEqual(1.0 / 3.0 * Math.Pow(2, 1.0 / 3.0), ug[0], 1e-5);
216
217      t = parser.Parse("cuberoot(z)");
218      paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
219      r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
220      Assert.AreEqual(0.0, r.LowerBound);
221      Assert.AreEqual(1.0, r.UpperBound);
222
223      Assert.AreEqual(double.NaN, lg[0]); // z
224      Assert.AreEqual(1.0 / 3.0, ug[0], 1e-5);
225    }
226
227    [TestMethod]
228    [TestCategory("Problems.DataAnalysis")]
229    [TestProperty("Time", "short")]
230    public void IntervalEvalutorAutoDiffLog() {
231      var eval = new IntervalEvaluator();
232      var parser = new InfixExpressionParser();
233      var t = parser.Parse("log(4*x)");
234      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
235      var intervals = new Dictionary<string, Interval>() {
236        { "x", new Interval(1, 2) },
237      };
238      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
239      Assert.AreEqual(Math.Log(4), r.LowerBound);
240      Assert.AreEqual(Math.Log(8), r.UpperBound);
241
242      Assert.AreEqual(0.25, lg[0]); // x
243      Assert.AreEqual(0.25, ug[0]);
244
245    }
246
247    [TestMethod]
248    [TestCategory("Problems.DataAnalysis")]
249    [TestProperty("Time", "short")]
250    public void IntervalEvaluatorAutoDiffCompareWithNumericDifferences() {
251
252      // create random trees and evaluate on random data
253      // calc gradient for all parameters
254      // use numeric differences for approximate gradient calculation
255      // compare gradients
256
257      var grammar = new TypeCoherentExpressionGrammar();
258      grammar.ConfigureAsDefaultRegressionGrammar();
259      // activate supported symbols
260      grammar.Symbols.First(s => s is Square).Enabled = true;
261      grammar.Symbols.First(s => s is SquareRoot).Enabled = true;
262      grammar.Symbols.First(s => s is Cube).Enabled = true;
263      grammar.Symbols.First(s => s is CubeRoot).Enabled = true;
264      grammar.Symbols.First(s => s is Sine).Enabled = true;
265      grammar.Symbols.First(s => s is Cosine).Enabled = true;
266      grammar.Symbols.First(s => s is Exponential).Enabled = true;
267      grammar.Symbols.First(s => s is Logarithm).Enabled = true;
268      grammar.Symbols.First(s => s is Absolute).Enabled = true;
269      grammar.Symbols.First(s => s is AnalyticQuotient).Enabled = false; // not yet supported by old interval calculator
270      grammar.Symbols.First(s => s is Constant).Enabled = false;
271
272      var varSy = (Variable)grammar.Symbols.First(s => s is Variable);
273      varSy.AllVariableNames = new string[] { "x", "y" };
274      varSy.VariableNames = varSy.AllVariableNames;
275      varSy.WeightMu = 1.0;
276      varSy.WeightSigma = 0.0;
277      var rand = new FastRandom(1234);
278
279      var intervals = new Dictionary<string, Interval>() {
280        { "x", new Interval(1, 2) },
281        { "y", new Interval(0, 1) },
282      };
283
284      var eval = new IntervalEvaluator();
285
286      var formatter = new InfixExpressionFormatter();
287      var sb = new StringBuilder();
288      int N = 10000;
289      int iter = 0;
290      while (iter < N) {
291        var t = ProbabilisticTreeCreator.Create(rand, grammar, maxTreeLength: 5, maxTreeDepth: 5);
292        var parameterNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
293
294        eval.Evaluate(t, intervals, parameterNodes, out double[] lowerGradient, out double[] upperGradient);
295
296        ApproximateIntervalGradient(t, intervals, parameterNodes, eval, out double[] refLowerGradient, out double[] refUpperGradient);
297
298        // compare autodiff and numeric diff
299        for(int p=0;p<parameterNodes.Length;p++) {
300          // lower
301          if(double.IsNaN(lowerGradient[p]) && double.IsNaN(refLowerGradient[p])) {
302
303          } else if(lowerGradient[p] == refLowerGradient[p]){
304
305          } else if(Math.Abs(lowerGradient[p] - refLowerGradient[p]) <= Math.Abs(lowerGradient[p]) * 1e-4) {
306
307          } else {
308            sb.AppendLine($"{lowerGradient[p]} <> {refLowerGradient[p]} for {parameterNodes[p]} in {formatter.Format(t)}");
309          }
310          // upper
311          if (double.IsNaN(upperGradient[p]) && double.IsNaN(refUpperGradient[p])) {
312
313          } else if (upperGradient[p] == refUpperGradient[p]) {
314
315          } else if (Math.Abs(upperGradient[p] - refUpperGradient[p]) <= Math.Abs(upperGradient[p]) * 1e-4) {
316
317          } else {
318            sb.AppendLine($"{upperGradient[p]} <> {refUpperGradient[p]} for {parameterNodes[p]} in {formatter.Format(t)}");
319          }
320        }
321
322        iter++;
323      }
324      if (sb.Length > 0) {
325        Console.WriteLine(sb.ToString());
326        Assert.Fail("There were differences when validating AutoDiff using numeric differences");
327      }
328    }
329
330    #region helper
331
332    private double[] CalculateGradient(string expr, IDataset ds) {
333      var eval = new VectorAutoDiffEvaluator();
334      var parser = new InfixExpressionParser();
335
336      var rows = new int[1];
337      var fi = new double[1];
338
339      var t = parser.Parse(expr);
340      var parameterNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
341      var jac = new double[1, parameterNodes.Length];
342      eval.Evaluate(t, ds, rows, parameterNodes, fi, jac);
343
344      var g = new double[parameterNodes.Length];
345      for (int i = 0; i < g.Length; i++) g[i] = jac[0, i];
346      return g;
347    }
348
349
350    private double[,] ApproximateGradient(ISymbolicExpressionTree t, Dataset ds, int[] rows, ISymbolicExpressionTreeNode[] parameterNodes,
351      SymbolicDataAnalysisExpressionTreeLinearInterpreter eval) {
352      var jac = new double[rows.Length, parameterNodes.Length];
353      for (int p = 0; p < parameterNodes.Length; p++) {
354
355        var x = GetValue(parameterNodes[p]);
356        var x_diff = x * 1e-4; // relative change
357
358        // calculate output for increased parameter value
359        SetValue(parameterNodes[p], x + x_diff / 2);
360        var f = eval.GetSymbolicExpressionTreeValues(t, ds, rows).ToArray();
361        for (int i = 0; i < rows.Length; i++) {
362          jac[i, p] = f[i];
363        }
364
365        // calculate output for decreased parameter value
366        SetValue(parameterNodes[p], x - x_diff / 2);
367        f = eval.GetSymbolicExpressionTreeValues(t, ds, rows).ToArray();
368        for (int i = 0; i < rows.Length; i++) {
369          jac[i, p] -= f[i]; // calc difference (and scale for x_diff)
370          jac[i, p] /= x_diff;
371        }
372
373        // restore original value
374        SetValue(parameterNodes[p], x);
375      }
376      return jac;
377    }
378
379    private void ApproximateIntervalGradient(ISymbolicExpressionTree t, Dictionary<string, Interval> intervals, ISymbolicExpressionTreeNode[] parameterNodes, IntervalEvaluator eval, out double[] lowerGradient, out double[] upperGradient) {
380      lowerGradient = new double[parameterNodes.Length];
381      upperGradient = new double[parameterNodes.Length];
382
383      for(int p=0;p<parameterNodes.Length;p++) {
384        var x = GetValue(parameterNodes[p]);
385        var x_diff = x * 1e-4; // relative change
386
387        // calculate output for increased parameter value
388        SetValue(parameterNodes[p], x + x_diff / 2);
389        var r1 = eval.Evaluate(t, intervals);
390        lowerGradient[p] = r1.LowerBound;
391        upperGradient[p] = r1.UpperBound;
392
393        // calculate output for decreased parameter value
394        SetValue(parameterNodes[p], x - x_diff / 2);
395        var r2 = eval.Evaluate(t, intervals);
396        lowerGradient[p] -= r2.LowerBound;
397        upperGradient[p] -= r2.UpperBound;
398
399        lowerGradient[p] /= x_diff;
400        upperGradient[p] /= x_diff;
401
402        // restore original value
403        SetValue(parameterNodes[p], x);
404      }
405    }
406
407    private void SetValue(ISymbolicExpressionTreeNode node, double v) {
408      var varNode = node as VariableTreeNode;
409      var constNode = node as ConstantTreeNode;
410      if (varNode != null) varNode.Weight = v;
411      else if (constNode != null) constNode.Value = v;
412      else throw new InvalidProgramException();
413    }
414
415    private double GetValue(ISymbolicExpressionTreeNode node) {
416      var varNode = node as VariableTreeNode;
417      var constNode = node as ConstantTreeNode;
418      if (varNode != null) return varNode.Weight;
419      else if (constNode != null) return constNode.Value;
420      throw new InvalidProgramException();
421    }
422    #endregion
423  }
424}
Note: See TracBrowser for help on using the repository browser.