Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2994-AutoDiffForIntervals/HeuristicLab.Tests/HeuristicLab.Problems.DataAnalysis.Symbolic-3.4/IntervalEvaluatorAutoDiffTest.cs @ 17325

Last change on this file since 17325 was 17325, checked in by gkronber, 5 years ago

#2994: worked on ConstrainedNLS

File size: 18.0 KB
RevLine 
[17318]1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
6using HeuristicLab.Problems.DataAnalysis.Symbolic;
7using HeuristicLab.Random;
8using Microsoft.VisualStudio.TestTools.UnitTesting;
9
10namespace HeuristicLab.Problems.DataAnalysis.Tests {
11  [TestClass]
12  public class IntervalEvaluatorAutoDiffTest {
13    [TestMethod]
14    [TestCategory("Problems.DataAnalysis")]
15    [TestProperty("Time", "short")]
16    public void IntervalEvalutorAutoDiffAdd() {
17      var eval = new IntervalEvaluator();
18      var parser = new InfixExpressionParser();
19      var t = parser.Parse("x + y");
20      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
21      var intervals = new Dictionary<string, Interval>() {
22        { "x", new Interval(1, 2) },
23        { "y", new Interval(0, 1) }
24      };
25      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
26      Assert.AreEqual(1, r.LowerBound);
27      Assert.AreEqual(3, r.UpperBound);
28
29      Assert.AreEqual(1.0, lg[0]); // x
30      Assert.AreEqual(2.0, ug[0]);
31      Assert.AreEqual(0.0, lg[1]); // y
32      Assert.AreEqual(1.0, ug[1]);
33    }
34
35    [TestMethod]
36    [TestCategory("Problems.DataAnalysis")]
37    [TestProperty("Time", "short")]
38    public void IntervalEvalutorAutoDiffMul() {
39      var eval = new IntervalEvaluator();
40      var parser = new InfixExpressionParser();
41      var t = parser.Parse("x * y");
42      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
43      var intervals = new Dictionary<string, Interval>() {
44        { "x", new Interval(1, 2) },
45        { "y", new Interval(0, 1) }
46      };
47      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
48      Assert.AreEqual(0, r.LowerBound);
49      Assert.AreEqual(2, r.UpperBound);
50
51      Assert.AreEqual(0.0, lg[0]); // x
52      Assert.AreEqual(2.0, ug[0]);
53      Assert.AreEqual(0.0, lg[1]); // y
54      Assert.AreEqual(2.0, ug[1]);
55    }
56
57    [TestMethod]
58    [TestCategory("Problems.DataAnalysis")]
59    [TestProperty("Time", "short")]
60    public void IntervalEvalutorAutoDiffSqr() {
61      var eval = new IntervalEvaluator();
62      var parser = new InfixExpressionParser();
63      var intervals = new Dictionary<string, Interval>() {
64        { "x", new Interval(1, 2) },
[17325]65        { "unit", new Interval(0, 1) },
66        { "neg", new Interval(-1, 0) },
[17318]67      };
[17325]68      var t = parser.Parse("sqr(x)");
69      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
[17318]70      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
[17325]71      Assert.AreEqual(1, r.LowerBound);
72      Assert.AreEqual(4, r.UpperBound);
73
74      Assert.AreEqual(2.0, lg[0]); // x
75      Assert.AreEqual(8.0, ug[0]);
76
77      t = parser.Parse("sqr(log(unit))");
78      paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
79      r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
80      Assert.AreEqual(0.0, r.LowerBound);
81      Assert.AreEqual(double.PositiveInfinity, r.UpperBound);
82
83      Assert.AreEqual(0.0, lg[0]); // x
84      Assert.AreEqual(double.NaN, ug[0]);
85
[17318]86    }
87
88    [TestMethod]
89    [TestCategory("Problems.DataAnalysis")]
90    [TestProperty("Time", "short")]
91    public void IntervalEvalutorAutoDiffExp() {
92      var eval = new IntervalEvaluator();
93      var parser = new InfixExpressionParser();
94      var t = parser.Parse("exp(x)");
95      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
96      var intervals = new Dictionary<string, Interval>() {
97        { "x", new Interval(1, 2) },
98      };
99      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
100      Assert.AreEqual(Math.Exp(1), r.LowerBound);
101      Assert.AreEqual(Math.Exp(2), r.UpperBound);
102
103      Assert.AreEqual(Math.Exp(1), lg[0]); // x
104      Assert.AreEqual(Math.Exp(2) * 2, ug[0]);
105    }
106
107    [TestMethod]
108    [TestCategory("Problems.DataAnalysis")]
109    [TestProperty("Time", "short")]
110    public void IntervalEvalutorAutoDiffSin() {
111      var eval = new IntervalEvaluator();
112      var parser = new InfixExpressionParser();
113      var t = parser.Parse("sin(x)");
114      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
115      var intervals = new Dictionary<string, Interval>() {
116        { "x", new Interval(1, 2) },
117      };
118      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
119      Assert.AreEqual(Math.Sin(1), r.LowerBound); // sin(1) < sin(2)
120      Assert.AreEqual(1, r.UpperBound); //  1..2 crosses pi / 2 and sin(pi/2)==1
121
122      Assert.AreEqual(Math.Cos(1), lg[0]); // x
123      Assert.AreEqual(0, ug[0]);
124    }
125
126    [TestMethod]
127    [TestCategory("Problems.DataAnalysis")]
128    [TestProperty("Time", "short")]
129    public void IntervalEvalutorAutoDiffCos() {
130      var eval = new IntervalEvaluator();
131      var parser = new InfixExpressionParser();
132      var intervals = new Dictionary<string, Interval>() {
133        { "x", new Interval(3, 4) },
[17325]134        { "z", new Interval(1, 2) }
[17318]135      };
[17325]136      var t = parser.Parse("cos(x)");
137      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
[17318]138      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
139      Assert.AreEqual(-1, r.LowerBound); //  3..4 crosses pi and cos(pi) == -1
140      Assert.AreEqual(Math.Cos(4), r.UpperBound); // cos(3) < cos(4)
141
142      Assert.AreEqual(0, lg[0]); // x
143      Assert.AreEqual(-4 * Math.Sin(4), ug[0]);
[17325]144
145      t = parser.Parse("LOG(COS('z'))");
146      paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
147      r = eval.Evaluate(t, intervals, paramNodes, out  lg, out  ug);
148      Assert.AreEqual(double.NaN, r.LowerBound);
149      Assert.AreEqual(Math.Log(Math.Cos(1)), r.UpperBound);
150
151      Assert.AreEqual(-2 * Math.Sin(2) / Math.Cos(2), lg[0], 1e-5); // x
152      Assert.AreEqual(-1 * Math.Sin(1) / Math.Cos(1), ug[0], 1e-5);
153     
[17318]154    }
155
156    [TestMethod]
157    [TestCategory("Problems.DataAnalysis")]
158    [TestProperty("Time", "short")]
159    public void IntervalEvalutorAutoDiffSqrt() {
160      var eval = new IntervalEvaluator();
161      var parser = new InfixExpressionParser();
162      var t = parser.Parse("sqrt(x)");
163      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
164      var intervals = new Dictionary<string, Interval>() {
165        { "x", new Interval(4, 9) },
166        { "y", new Interval(1, 2) },
167        { "z", new Interval(0, 1) },
[17319]168        { "eps", new Interval(1e-10, 1) }           
[17318]169      };
170      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
171      Assert.AreEqual(2, r.LowerBound);
172      Assert.AreEqual(3, r.UpperBound);
173
174      Assert.AreEqual(1.0, lg[0]); // x
175      Assert.AreEqual(1.5, ug[0]);
176
177      t = parser.Parse("sqrt(y)");
178      paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
179      r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
180      Assert.AreEqual(1, r.LowerBound);
181      Assert.AreEqual(Math.Sqrt(2), r.UpperBound);
182
183      Assert.AreEqual(0.5, lg[0]); // y
184      Assert.AreEqual(0.5 * Math.Sqrt(2), ug[0], 1e-5);
185
186      t = parser.Parse("sqrt(z)");
187      paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
188      r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
189      Assert.AreEqual(0, r.LowerBound);
190      Assert.AreEqual(1, r.UpperBound);
191
[17319]192      Assert.AreEqual(double.NaN, lg[0]); // z
193      Assert.AreEqual(0.5, ug[0], 1e-5);
194
195      t = parser.Parse("sqrt(eps)");
196      paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
197      r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
198
[17325]199      Assert.AreEqual(0.5 * Math.Sqrt(1e-10), lg[0], 1e-5); // --> lim x -> 0 (sqrt(x)) = 0
[17319]200      Assert.AreEqual(0.5, ug[0], 1e-5);
[17325]201
202      t = parser.Parse("sqrt(y - z)"); // 1..2 - 0..1
203      paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
204      r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
205      Assert.AreEqual(0, r.LowerBound);
206      Assert.AreEqual(Math.Sqrt(2), r.UpperBound);
207
208      Assert.AreEqual(double.PositiveInfinity, lg[0], 1e-5); // y
209      Assert.AreEqual(1/ Math.Sqrt(2)  , ug[0], 1e-5);
210      Assert.AreEqual(double.NegativeInfinity, lg[1], 1e-5); // z
211      Assert.AreEqual(0.0   , ug[1], 1e-5);
[17318]212    }
213
214    [TestMethod]
215    [TestCategory("Problems.DataAnalysis")]
216    [TestProperty("Time", "short")]
217    public void IntervalEvalutorAutoDiffCqrt() {
218      var eval = new IntervalEvaluator();
219      var parser = new InfixExpressionParser();
220      var t = parser.Parse("cuberoot(x)");
221      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
222      var intervals = new Dictionary<string, Interval>() {
223        { "x", new Interval(8, 27) },
224        { "y", new Interval(1, 2) },
225        { "z", new Interval(0, 1) },
226      };
227      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
228      Assert.AreEqual(2, r.LowerBound);
229      Assert.AreEqual(3, r.UpperBound);
230
231      Assert.AreEqual(2.0 / 3.0, lg[0]); // x
232      Assert.AreEqual(1.0, ug[0]);
233
234      t = parser.Parse("cuberoot(y)");
235      paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
236      r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
237      Assert.AreEqual(Math.Pow(1, 1.0 / 3.0), r.LowerBound);
238      Assert.AreEqual(Math.Pow(2, 1.0 / 3.0), r.UpperBound);
239
240      Assert.AreEqual(1.0 / 3.0, lg[0]); // y
241      Assert.AreEqual(1.0 / 3.0 * Math.Pow(2, 1.0 / 3.0), ug[0], 1e-5);
242
[17319]243      t = parser.Parse("cuberoot(z)");
[17318]244      paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
245      r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
246      Assert.AreEqual(0.0, r.LowerBound);
247      Assert.AreEqual(1.0, r.UpperBound);
248
[17319]249      Assert.AreEqual(double.NaN, lg[0]); // z
[17318]250      Assert.AreEqual(1.0 / 3.0, ug[0], 1e-5);
251    }
252
253    [TestMethod]
254    [TestCategory("Problems.DataAnalysis")]
255    [TestProperty("Time", "short")]
256    public void IntervalEvalutorAutoDiffLog() {
257      var eval = new IntervalEvaluator();
258      var parser = new InfixExpressionParser();
259      var t = parser.Parse("log(4*x)");
260      var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
261      var intervals = new Dictionary<string, Interval>() {
262        { "x", new Interval(1, 2) },
263      };
264      var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
265      Assert.AreEqual(Math.Log(4), r.LowerBound);
266      Assert.AreEqual(Math.Log(8), r.UpperBound);
267
268      Assert.AreEqual(0.25, lg[0]); // x
269      Assert.AreEqual(0.25, ug[0]);
270
271    }
272
273    [TestMethod]
274    [TestCategory("Problems.DataAnalysis")]
275    [TestProperty("Time", "short")]
276    public void IntervalEvaluatorAutoDiffCompareWithNumericDifferences() {
277
278      // create random trees and evaluate on random data
279      // calc gradient for all parameters
280      // use numeric differences for approximate gradient calculation
281      // compare gradients
282
283      var grammar = new TypeCoherentExpressionGrammar();
284      grammar.ConfigureAsDefaultRegressionGrammar();
285      // activate supported symbols
286      grammar.Symbols.First(s => s is Square).Enabled = true;
287      grammar.Symbols.First(s => s is SquareRoot).Enabled = true;
288      grammar.Symbols.First(s => s is Cube).Enabled = true;
289      grammar.Symbols.First(s => s is CubeRoot).Enabled = true;
290      grammar.Symbols.First(s => s is Sine).Enabled = true;
291      grammar.Symbols.First(s => s is Cosine).Enabled = true;
292      grammar.Symbols.First(s => s is Exponential).Enabled = true;
293      grammar.Symbols.First(s => s is Logarithm).Enabled = true;
294      grammar.Symbols.First(s => s is Absolute).Enabled = true;
295      grammar.Symbols.First(s => s is AnalyticQuotient).Enabled = false; // not yet supported by old interval calculator
296      grammar.Symbols.First(s => s is Constant).Enabled = false;
297
298      var varSy = (Variable)grammar.Symbols.First(s => s is Variable);
299      varSy.AllVariableNames = new string[] { "x", "y" };
300      varSy.VariableNames = varSy.AllVariableNames;
301      varSy.WeightMu = 1.0;
302      varSy.WeightSigma = 0.0;
303      var rand = new FastRandom(1234);
304
305      var intervals = new Dictionary<string, Interval>() {
306        { "x", new Interval(1, 2) },
307        { "y", new Interval(0, 1) },
308      };
309
310      var eval = new IntervalEvaluator();
311
312      var formatter = new InfixExpressionFormatter();
313      var sb = new StringBuilder();
314      int N = 10000;
315      int iter = 0;
316      while (iter < N) {
317        var t = ProbabilisticTreeCreator.Create(rand, grammar, maxTreeLength: 5, maxTreeDepth: 5);
318        var parameterNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
319
320        eval.Evaluate(t, intervals, parameterNodes, out double[] lowerGradient, out double[] upperGradient);
321
322        ApproximateIntervalGradient(t, intervals, parameterNodes, eval, out double[] refLowerGradient, out double[] refUpperGradient);
323
324        // compare autodiff and numeric diff
325        for(int p=0;p<parameterNodes.Length;p++) {
326          // lower
327          if(double.IsNaN(lowerGradient[p]) && double.IsNaN(refLowerGradient[p])) {
328
329          } else if(lowerGradient[p] == refLowerGradient[p]){
330
331          } else if(Math.Abs(lowerGradient[p] - refLowerGradient[p]) <= Math.Abs(lowerGradient[p]) * 1e-4) {
332
333          } else {
334            sb.AppendLine($"{lowerGradient[p]} <> {refLowerGradient[p]} for {parameterNodes[p]} in {formatter.Format(t)}");
335          }
336          // upper
337          if (double.IsNaN(upperGradient[p]) && double.IsNaN(refUpperGradient[p])) {
338
339          } else if (upperGradient[p] == refUpperGradient[p]) {
340
341          } else if (Math.Abs(upperGradient[p] - refUpperGradient[p]) <= Math.Abs(upperGradient[p]) * 1e-4) {
342
343          } else {
344            sb.AppendLine($"{upperGradient[p]} <> {refUpperGradient[p]} for {parameterNodes[p]} in {formatter.Format(t)}");
345          }
346        }
347
348        iter++;
349      }
350      if (sb.Length > 0) {
351        Console.WriteLine(sb.ToString());
352        Assert.Fail("There were differences when validating AutoDiff using numeric differences");
353      }
354    }
355
356    #region helper
357
358    private double[] CalculateGradient(string expr, IDataset ds) {
359      var eval = new VectorAutoDiffEvaluator();
360      var parser = new InfixExpressionParser();
361
362      var rows = new int[1];
363      var fi = new double[1];
364
365      var t = parser.Parse(expr);
366      var parameterNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
367      var jac = new double[1, parameterNodes.Length];
368      eval.Evaluate(t, ds, rows, parameterNodes, fi, jac);
369
370      var g = new double[parameterNodes.Length];
371      for (int i = 0; i < g.Length; i++) g[i] = jac[0, i];
372      return g;
373    }
374
375
376    private double[,] ApproximateGradient(ISymbolicExpressionTree t, Dataset ds, int[] rows, ISymbolicExpressionTreeNode[] parameterNodes,
377      SymbolicDataAnalysisExpressionTreeLinearInterpreter eval) {
378      var jac = new double[rows.Length, parameterNodes.Length];
379      for (int p = 0; p < parameterNodes.Length; p++) {
380
381        var x = GetValue(parameterNodes[p]);
382        var x_diff = x * 1e-4; // relative change
383
384        // calculate output for increased parameter value
385        SetValue(parameterNodes[p], x + x_diff / 2);
386        var f = eval.GetSymbolicExpressionTreeValues(t, ds, rows).ToArray();
387        for (int i = 0; i < rows.Length; i++) {
388          jac[i, p] = f[i];
389        }
390
391        // calculate output for decreased parameter value
392        SetValue(parameterNodes[p], x - x_diff / 2);
393        f = eval.GetSymbolicExpressionTreeValues(t, ds, rows).ToArray();
394        for (int i = 0; i < rows.Length; i++) {
395          jac[i, p] -= f[i]; // calc difference (and scale for x_diff)
396          jac[i, p] /= x_diff;
397        }
398
399        // restore original value
400        SetValue(parameterNodes[p], x);
401      }
402      return jac;
403    }
404
405    private void ApproximateIntervalGradient(ISymbolicExpressionTree t, Dictionary<string, Interval> intervals, ISymbolicExpressionTreeNode[] parameterNodes, IntervalEvaluator eval, out double[] lowerGradient, out double[] upperGradient) {
406      lowerGradient = new double[parameterNodes.Length];
407      upperGradient = new double[parameterNodes.Length];
408
409      for(int p=0;p<parameterNodes.Length;p++) {
410        var x = GetValue(parameterNodes[p]);
411        var x_diff = x * 1e-4; // relative change
412
413        // calculate output for increased parameter value
414        SetValue(parameterNodes[p], x + x_diff / 2);
415        var r1 = eval.Evaluate(t, intervals);
416        lowerGradient[p] = r1.LowerBound;
417        upperGradient[p] = r1.UpperBound;
418
419        // calculate output for decreased parameter value
420        SetValue(parameterNodes[p], x - x_diff / 2);
421        var r2 = eval.Evaluate(t, intervals);
422        lowerGradient[p] -= r2.LowerBound;
423        upperGradient[p] -= r2.UpperBound;
424
425        lowerGradient[p] /= x_diff;
426        upperGradient[p] /= x_diff;
427
428        // restore original value
429        SetValue(parameterNodes[p], x);
430      }
431    }
432
433    private void SetValue(ISymbolicExpressionTreeNode node, double v) {
434      var varNode = node as VariableTreeNode;
435      var constNode = node as ConstantTreeNode;
436      if (varNode != null) varNode.Weight = v;
437      else if (constNode != null) constNode.Value = v;
438      else throw new InvalidProgramException();
439    }
440
441    private double GetValue(ISymbolicExpressionTreeNode node) {
442      var varNode = node as VariableTreeNode;
443      var constNode = node as ConstantTreeNode;
444      if (varNode != null) return varNode.Weight;
445      else if (constNode != null) return constNode.Value;
446      throw new InvalidProgramException();
447    }
448    #endregion
449  }
450}
Note: See TracBrowser for help on using the repository browser.