Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2994-AutoDiffForIntervals/Tests/AutoDiffTest.cs

Last change on this file was 16744, checked in by gkronber, 5 years ago

#2994: slightly changed calculation of integer powers for intervals and added unit tests.

File size: 10.1 KB
Line 
1using System;
2using System.Collections.Generic;
3using HeuristicLab.Problems.DataAnalysis;
4using HeuristicLab.Problems.DataAnalysis.Symbolic;
5using Microsoft.VisualStudio.TestTools.UnitTesting;
6using System.Linq;
7using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
8
9namespace Tests {
10  [TestClass]
11  public class AutoDiffTestClass {
12    [TestMethod]
13    public void AutoDiffTest() {
14      {
15        // eval
16        var parser = new InfixExpressionParser();
17        var t = parser.Parse("2.0*x+y");
18
19        // interval eval
20        var evaluator = new IntervalEvaluator();
21        var intervals = new Dictionary<string, Interval>();
22        intervals.Add("x", new Interval(-1.0, 1.0));
23        intervals.Add("y", new Interval(2.0, 10.0));
24        var resultInterval = evaluator.Evaluate(t, intervals);
25        Assert.AreEqual(0, resultInterval.LowerBound);
26        Assert.AreEqual(12, resultInterval.UpperBound);
27      }
28
29      {
30        // vector eval
31        var parser = new InfixExpressionParser();
32        var t = parser.Parse("2.0*x+y");
33
34        var evaluator = new VectorEvaluator();
35        var vars = new string[] { "x", "y", "f(x)" };
36        var values = new double[,] {
37          { 1,  1, 0 },
38          { 2,  1, 0 },
39          { 3, -1, 0 },
40          { 4, -1, 0 },
41          { 5, -1, 0 },
42        };
43
44        var ds = new Dataset(vars, values);
45        var problemData = new RegressionProblemData(ds, vars, "f(x)");
46        var train = evaluator.Evaluate(t, ds, problemData.TrainingIndices.ToArray());
47        Assert.AreEqual(2, train.Length);
48        Assert.AreEqual(3, train[0]);
49        Assert.AreEqual(5, train[1]);
50
51        var test = evaluator.Evaluate(t, ds, problemData.TestIndices.ToArray());
52        Assert.AreEqual(3, test.Length);
53        Assert.AreEqual(5, test[0]);
54        Assert.AreEqual(7, test[1]);
55        Assert.AreEqual(9, test[2]);
56      }
57
58      {
59        // vector eval and auto-diff
60        var parser = new InfixExpressionParser();
61        var t = parser.Parse("2.0*x+y");
62        var p0 = t.IterateNodesPostfix().First(n => n is ConstantTreeNode);
63        var p1 = t.IterateNodesPostfix().First(n => (n is VariableTreeNode var) && var.VariableName == "y");
64        var paramNodes = new ISymbolicExpressionTreeNode[] { p0, p1 };
65
66        var evaluator = new VectorAutoDiffEvaluator();
67        var vars = new string[] { "x", "y", "f(x)" };
68        var values = new double[,] {
69          { 1,  1, 0 },
70          { 2,  1, 0 },
71          { 3, -1, 0 },
72          { 4, -1, 0 },
73          { 5, -1, 0 },
74        };
75
76        var ds = new Dataset(vars, values);
77        var problemData = new RegressionProblemData(ds, vars, "f(x)");
78        var train = new double[problemData.TrainingIndices.Count()];
79        var trainJac = new double[train.Length, 2];
80        evaluator.Evaluate(t, ds, problemData.TrainingIndices.ToArray(), paramNodes, train, trainJac);
81        Assert.AreEqual(2, train.Length);
82        Assert.AreEqual(3, train[0]);
83        Assert.AreEqual(5, train[1]);
84        // check jac
85        Assert.AreEqual(1, trainJac[0, 0]);
86        Assert.AreEqual(1, trainJac[0, 1]);
87        Assert.AreEqual(2, trainJac[1, 0]);
88        Assert.AreEqual(1, trainJac[1, 1]);
89
90        var test = new double[problemData.TestIndices.Count()];
91        var testJac = new double[test.Length, 2];
92        evaluator.Evaluate(t, ds, problemData.TestIndices.ToArray(), paramNodes, test, testJac);
93        Assert.AreEqual(3, test.Length);
94        Assert.AreEqual(5, test[0]);
95        Assert.AreEqual(7, test[1]);
96        Assert.AreEqual(9, test[2]);
97
98        // check jac
99        Assert.AreEqual(3, testJac[0, 0]);
100        Assert.AreEqual(-1, testJac[0, 1]);
101        Assert.AreEqual(4, testJac[1, 0]);
102        Assert.AreEqual(-1, testJac[1, 1]);
103        Assert.AreEqual(5, testJac[2, 0]);
104        Assert.AreEqual(-1, testJac[2, 1]);
105
106      }
107
108      {
109        // Interval tests
110        var intervals = new Dictionary<string, Interval>();
111        intervals.Add("x", new Interval(-2.0, 3.0));
112        intervals.Add("p", new Interval(1.0, 2.0));
113        intervals.Add("n", new Interval(-2.0, -1.0));
114
115        AssertInterval("10*x", intervals, -20, 30);
116        AssertInterval("sqr(p)", intervals, 1, 4);
117        AssertInterval("sqr(n)", intervals, 1, 4);
118        AssertInterval("sqr(x)", intervals, 0, 9);
119
120        AssertInterval("cube(p)", intervals, 1, 8);
121        AssertInterval("cube(n)", intervals, -8, -1);
122        AssertInterval("cube(x)", intervals, -8, 27);
123      }
124
125      {
126        // interval eval and auto-diff
127        var parser = new InfixExpressionParser();
128        var t = parser.Parse("2.0*x+y");
129        var p0 = t.IterateNodesPostfix().First(n => n is ConstantTreeNode);
130        var p1 = t.IterateNodesPostfix().First(n => (n is VariableTreeNode var) && var.VariableName == "y");
131        var paramNodes = new ISymbolicExpressionTreeNode[] { p0, p1 };
132
133        var evaluator = new IntervalEvaluator();
134        var intervals = new Dictionary<string, Interval>();
135        intervals.Add("x", new Interval(-1.0, 1.0));
136        intervals.Add("y", new Interval(2.0, 10.0));
137        var resultInterval = evaluator.Evaluate(t, intervals, paramNodes, out double[] lowerGradient, out double[] upperGradient);
138        Assert.AreEqual(0, resultInterval.LowerBound);
139        Assert.AreEqual(12, resultInterval.UpperBound);
140
141        Assert.AreEqual(-1, lowerGradient[0]);
142        Assert.AreEqual(2, lowerGradient[1]);
143        Assert.AreEqual(1, upperGradient[0]);
144        Assert.AreEqual(10, upperGradient[1]);
145      }
146
147      {
148        // as discussed with Fabrício
149        var intervals = new Dictionary<string, Interval>();
150        intervals.Add("x1", new Interval(60.0, 65.0));
151        intervals.Add("x2", new Interval(30.0, 40.0));
152        intervals.Add("x3", new Interval(5.0, 10.0));
153        intervals.Add("x4", new Interval(0.5, 0.8));
154        intervals.Add("x5", new Interval(0.2, 0.5));
155
156        var parser = new InfixExpressionParser();
157
158        var t1 = parser.Parse("x5/x4");
159        var t2 = parser.Parse("log(x5/x4)");
160        var t3 = parser.Parse("x3 * log(x5/x4)");
161        var t4 = parser.Parse("x1*x2*x5");
162        var t5 = parser.Parse("x4/x5");
163        var t6 = parser.Parse("sqr(x4/x5)");
164        var t7 = parser.Parse("(1 - sqr(x4/x5)) ");
165        var t8 = parser.Parse("x1*x2*x5 *(1 - sqr(x4/x5))");
166        var t9 = parser.Parse("x1*x2*x5 *(1 - sqr(x4/x5)) + x3 * log(x5/x4)");
167
168        var evaluator = new IntervalEvaluator();
169        var result = evaluator.Evaluate(t1, intervals);
170        Assert.AreEqual(0.25, result.LowerBound);
171        Assert.AreEqual(1, result.UpperBound);
172
173        result = evaluator.Evaluate(t2, intervals);
174        Assert.AreEqual(-1.386294361, result.LowerBound, 1e-6);
175        Assert.AreEqual(0, result.UpperBound);
176
177        result = evaluator.Evaluate(t3, intervals);
178        Assert.AreEqual(-13.86294361, result.LowerBound, 1e-6);
179        Assert.AreEqual(0, result.UpperBound);
180
181        result = evaluator.Evaluate(t4, intervals);
182        Assert.AreEqual(360, result.LowerBound);
183        Assert.AreEqual(1300, result.UpperBound);
184
185        result = evaluator.Evaluate(t5, intervals);
186        Assert.AreEqual(1, result.LowerBound, 1e-6);
187        Assert.AreEqual(4, result.UpperBound);
188
189        result = evaluator.Evaluate(t6, intervals);
190        Assert.AreEqual(1, result.LowerBound);
191        Assert.AreEqual(16, result.UpperBound);
192
193        result = evaluator.Evaluate(t7, intervals);
194        Assert.AreEqual(-15, result.LowerBound);
195        Assert.AreEqual(0, result.UpperBound);
196
197        result = evaluator.Evaluate(t8, intervals);
198        Assert.AreEqual(-19500, result.LowerBound);
199        Assert.AreEqual(0, result.UpperBound);
200
201        result = evaluator.Evaluate(t9, intervals);
202        Assert.AreEqual(-19513.86294, result.LowerBound, 1e-3);
203        Assert.AreEqual(0, result.UpperBound);
204
205
206      }
207
208      {
209
210        // derivatives and intervals for flow psi problem
211        var intervals = new Dictionary<string, Interval>();
212        intervals.Add("x1", new Interval(60.0, 65.0));
213        intervals.Add("x2", new Interval(30.0, 40.0));
214        intervals.Add("x3", new Interval(5.0, 10.0));
215        intervals.Add("x4", new Interval(0.5, 0.8));
216        intervals.Add("x5", new Interval(0.2, 0.5));
217
218        var parser = new InfixExpressionParser();
219        var formatter = new InfixExpressionFormatter();
220
221        var expr = parser.Parse("x1*x2*x5*(1 - sqr(x4/x5)) + x3 * log(x5/x4)");
222
223        var dfdx1 = DerivativeCalculator.Derive(expr, "x1");
224        Assert.AreEqual("('x2' * 'x5' * ((SQR(('x4' / 'x5')) * (-1)) + 1))", formatter.Format(dfdx1));
225        // x2 x5 (1 - sqr(x4/x5))
226
227        var dfdx2 = DerivativeCalculator.Derive(expr, "x2");
228        Assert.AreEqual("('x1' * 'x5' * ((SQR(('x4' / 'x5')) * (-1)) + 1))", formatter.Format(dfdx2));
229        // x1 x5 (1 - sqr(x4/x5))
230
231        var dfdx3 = DerivativeCalculator.Derive(expr, "x3");
232        Assert.AreEqual("LOG(('x5' / 'x4'))", formatter.Format(dfdx3));
233        // log(x5/x4)
234
235        var dfdx4 = DerivativeCalculator.Derive(expr, "x4");
236        Assert.AreEqual("((('x1' * 'x2' * 'x5' * 'x4' * 2) / ('x5' * (-1*'x5'))) + (('x4' * 'x5' * 'x3') / ('x5' * SQR('x4') * (-1))))", formatter.Format(dfdx4));
237        // -2*x1*x2*x5*x4/x5*1/x5 + x3*1/(x5/x4)*x5/sqr(x4)
238
239        var dfdx5 = DerivativeCalculator.Derive(expr, "x5");
240        Assert.AreEqual("((('x4' * 'x3') / ('x5' * 'x4')) + ('x1' * 'x2' * ((SQR(('x4' / 'x5')) * (-1)) + 1)) + (('x1' * 'x2' * 'x5' * ('x4' * 'x4') * 2) / ('x5' * SQR('x5') * 1)))", formatter.Format(dfdx5));
241      }
242
243
244    }
245
246    private void AssertInterval(string expression, Dictionary<string, Interval> intervals, double expectedLow, double expectedHigh) {
247      var parser = new InfixExpressionParser();
248      var t = parser.Parse(expression);
249      var evaluator = new IntervalEvaluator();
250      var result = evaluator.Evaluate(t, intervals);
251      Assert.AreEqual(expectedLow, result.LowerBound);
252      Assert.AreEqual(expectedHigh, result.UpperBound);
253    }
254  }
255}
Note: See TracBrowser for help on using the repository browser.