Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2994-AutoDiffForIntervals/Tests/AutoDiffTest.cs @ 16738

Last change on this file since 16738 was 16738, checked in by gkronber, 5 years ago

#2994: another unit test, smaller fixes and checks for invalid parameters

File size: 9.0 KB
Line 
1using System;
2using System.Collections.Generic;
3using HeuristicLab.Problems.DataAnalysis;
4using HeuristicLab.Problems.DataAnalysis.Symbolic;
5using Microsoft.VisualStudio.TestTools.UnitTesting;
6using System.Linq;
7using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
8
9namespace Tests {
10  [TestClass]
11  public class AutoDiffTestClass {
12    [TestMethod]
13    public void AutoDiffTest() {
14      {
15        // eval
16        var parser = new InfixExpressionParser();
17        var t = parser.Parse("2.0*x+y");
18
19        // interval eval
20        var evaluator = new IntervalEvaluator();
21        var intervals = new Dictionary<string, Interval>();
22        intervals.Add("x", new Interval(-1.0, 1.0));
23        intervals.Add("y", new Interval(2.0, 10.0));
24        var resultInterval = evaluator.Evaluate(t, intervals);
25        Assert.AreEqual(0, resultInterval.LowerBound);
26        Assert.AreEqual(12, resultInterval.UpperBound);
27      }
28
29      {
30        // vector eval
31        var parser = new InfixExpressionParser();
32        var t = parser.Parse("2.0*x+y");
33
34        var evaluator = new VectorEvaluator();
35        var vars = new string[] { "x", "y", "f(x)" };
36        var values = new double[,] {
37          { 1,  1, 0 },
38          { 2,  1, 0 },
39          { 3, -1, 0 },
40          { 4, -1, 0 },
41          { 5, -1, 0 },
42        };
43
44        var ds = new Dataset(vars, values);
45        var problemData = new RegressionProblemData(ds, vars, "f(x)");
46        var train = evaluator.Evaluate(t, ds, problemData.TrainingIndices.ToArray());
47        Assert.AreEqual(2, train.Length);
48        Assert.AreEqual(3, train[0]);
49        Assert.AreEqual(5, train[1]);
50
51        var test = evaluator.Evaluate(t, ds, problemData.TestIndices.ToArray());
52        Assert.AreEqual(3, test.Length);
53        Assert.AreEqual(5, test[0]);
54        Assert.AreEqual(7, test[1]);
55        Assert.AreEqual(9, test[2]);
56      }
57
58      {
59        // vector eval and auto-diff
60        var parser = new InfixExpressionParser();
61        var t = parser.Parse("2.0*x+y");
62        var p0 = t.IterateNodesPostfix().First(n => n is ConstantTreeNode);
63        var p1 = t.IterateNodesPostfix().First(n => (n is VariableTreeNode var) && var.VariableName == "y");
64        var paramNodes = new ISymbolicExpressionTreeNode[] { p0, p1 };
65
66        var evaluator = new VectorAutoDiffEvaluator();
67        var vars = new string[] { "x", "y", "f(x)" };
68        var values = new double[,] {
69          { 1,  1, 0 },
70          { 2,  1, 0 },
71          { 3, -1, 0 },
72          { 4, -1, 0 },
73          { 5, -1, 0 },
74        };
75
76        var ds = new Dataset(vars, values);
77        var problemData = new RegressionProblemData(ds, vars, "f(x)");
78        var train = new double[problemData.TrainingIndices.Count()];
79        var trainJac = new double[train.Length, 2];
80        evaluator.Evaluate(t, ds, problemData.TrainingIndices.ToArray(), paramNodes, train, trainJac);
81        Assert.AreEqual(2, train.Length);
82        Assert.AreEqual(3, train[0]);
83        Assert.AreEqual(5, train[1]);
84        // check jac
85        Assert.AreEqual(1, trainJac[0, 0]);
86        Assert.AreEqual(1, trainJac[0, 1]);
87        Assert.AreEqual(2, trainJac[1, 0]);
88        Assert.AreEqual(1, trainJac[1, 1]);
89
90        var test = new double[problemData.TestIndices.Count()];
91        var testJac = new double[test.Length, 2];
92        evaluator.Evaluate(t, ds, problemData.TestIndices.ToArray(), paramNodes, test, testJac);
93        Assert.AreEqual(3, test.Length);
94        Assert.AreEqual(5, test[0]);
95        Assert.AreEqual(7, test[1]);
96        Assert.AreEqual(9, test[2]);
97
98        // check jac
99        Assert.AreEqual(3, testJac[0, 0]);
100        Assert.AreEqual(-1, testJac[0, 1]);
101        Assert.AreEqual(4, testJac[1, 0]);
102        Assert.AreEqual(-1, testJac[1, 1]);
103        Assert.AreEqual(5, testJac[2, 0]);
104        Assert.AreEqual(-1, testJac[2, 1]);
105
106      }
107
108      {
109        // interval eval and auto-diff
110        var parser = new InfixExpressionParser();
111        var t = parser.Parse("2.0*x+y");
112        var p0 = t.IterateNodesPostfix().First(n => n is ConstantTreeNode);
113        var p1 = t.IterateNodesPostfix().First(n => (n is VariableTreeNode var) && var.VariableName == "y");
114        var paramNodes = new ISymbolicExpressionTreeNode[] { p0, p1 };
115
116        var evaluator = new IntervalEvaluator();
117        var intervals = new Dictionary<string, Interval>();
118        intervals.Add("x", new Interval(-1.0, 1.0));
119        intervals.Add("y", new Interval(2.0, 10.0));
120        var resultInterval = evaluator.Evaluate(t, intervals, paramNodes, out double[] lowerGradient, out double[] upperGradient);
121        Assert.AreEqual(0, resultInterval.LowerBound);
122        Assert.AreEqual(12, resultInterval.UpperBound);
123
124        Assert.AreEqual(-1, lowerGradient[0]);
125        Assert.AreEqual(2, lowerGradient[1]);
126        Assert.AreEqual(1, upperGradient[0]);
127        Assert.AreEqual(10, upperGradient[1]);
128      }
129
130      {
131        // as discussed with Fabrício
132        var intervals = new Dictionary<string, Interval>();
133        intervals.Add("x1", new Interval(60.0, 65.0));
134        intervals.Add("x2", new Interval(30.0, 40.0));
135        intervals.Add("x3", new Interval(5.0, 10.0));
136        intervals.Add("x4", new Interval(0.5, 0.8));
137        intervals.Add("x5", new Interval(0.2, 0.5));
138
139        var parser = new InfixExpressionParser();
140
141        var t1 = parser.Parse("x5/x4");
142        var t2 = parser.Parse("log(x5/x4)");
143        var t3 = parser.Parse("x3 * log(x5/x4)");
144        var t4 = parser.Parse("x1*x2*x5");
145        var t5 = parser.Parse("x4/x5");
146        var t6 = parser.Parse("sqr(x4/x5)");
147        var t7 = parser.Parse("(1 - sqr(x4/x5)) ");
148        var t8 = parser.Parse("x1*x2*x5 *(1 - sqr(x4/x5))");
149        var t9 = parser.Parse("x1*x2*x5 *(1 - sqr(x4/x5)) + x3 * log(x5/x4)");
150
151        var evaluator = new IntervalEvaluator();
152        var result = evaluator.Evaluate(t1, intervals);
153        Assert.AreEqual(0.25, result.LowerBound);
154        Assert.AreEqual(1, result.UpperBound);
155
156        result = evaluator.Evaluate(t2, intervals);
157        Assert.AreEqual(-1.386294361, result.LowerBound, 1e-6);
158        Assert.AreEqual(0, result.UpperBound);
159
160        result = evaluator.Evaluate(t3, intervals);
161        Assert.AreEqual(-13.86294361, result.LowerBound, 1e-6);
162        Assert.AreEqual(0, result.UpperBound);
163
164        result = evaluator.Evaluate(t4, intervals);
165        Assert.AreEqual(360, result.LowerBound);
166        Assert.AreEqual(1300, result.UpperBound);
167
168        result = evaluator.Evaluate(t5, intervals);
169        Assert.AreEqual(1, result.LowerBound, 1e-6);
170        Assert.AreEqual(4, result.UpperBound);
171
172        result = evaluator.Evaluate(t6, intervals);
173        Assert.AreEqual(1, result.LowerBound);
174        Assert.AreEqual(16, result.UpperBound);
175
176        result = evaluator.Evaluate(t7, intervals);
177        Assert.AreEqual(-15, result.LowerBound);
178        Assert.AreEqual(0, result.UpperBound);
179
180        result = evaluator.Evaluate(t8, intervals);
181        Assert.AreEqual(-19500, result.LowerBound);
182        Assert.AreEqual(0, result.UpperBound);
183
184        result = evaluator.Evaluate(t9, intervals);
185        Assert.AreEqual(-19513.86294, result.LowerBound, 1e-3);
186        Assert.AreEqual(0, result.UpperBound);
187
188
189      }
190
191      {
192
193        // derivatives and intervals for flow psi problem
194        var intervals = new Dictionary<string, Interval>();
195        intervals.Add("x1", new Interval(60.0, 65.0));
196        intervals.Add("x2", new Interval(30.0, 40.0));
197        intervals.Add("x3", new Interval(5.0, 10.0));
198        intervals.Add("x4", new Interval(0.5, 0.8));
199        intervals.Add("x5", new Interval(0.2, 0.5));
200
201        var parser = new InfixExpressionParser();
202        var formatter = new InfixExpressionFormatter();
203
204        var expr = parser.Parse("x1*x2*x5*(1 - sqr(x4/x5)) + x3 * log(x5/x4)");
205
206        var dfdx1 = DerivativeCalculator.Derive(expr, "x1");
207        Assert.AreEqual("('x2' * 'x5' * ((SQR(('x4' / 'x5')) * (-1)) + 1))", formatter.Format(dfdx1));
208        // x2 x5 (1 - sqr(x4/x5))
209
210        var dfdx2 = DerivativeCalculator.Derive(expr, "x2");
211        Assert.AreEqual("('x1' * 'x5' * ((SQR(('x4' / 'x5')) * (-1)) + 1))", formatter.Format(dfdx2));
212        // x1 x5 (1 - sqr(x4/x5))
213
214        var dfdx3 = DerivativeCalculator.Derive(expr, "x3");
215        Assert.AreEqual("LOG(('x5' / 'x4'))", formatter.Format(dfdx3));
216        // log(x5/x4)
217
218        var dfdx4 = DerivativeCalculator.Derive(expr, "x4");
219        Assert.AreEqual("((('x1' * 'x2' * 'x5' * 'x4' * 2) / ('x5' * (-1*'x5'))) + (('x4' * 'x5' * 'x3') / ('x5' * SQR('x4') * (-1))))", formatter.Format(dfdx4));
220        // -2*x1*x2*x5*x4/x5*1/x5 + x3*1/(x5/x4)*x5/sqr(x4)
221
222        var dfdx5 = DerivativeCalculator.Derive(expr, "x5");
223        Assert.AreEqual("((('x4' * 'x3') / ('x5' * 'x4')) + ('x1' * 'x2' * ((SQR(('x4' / 'x5')) * (-1)) + 1)) + (('x1' * 'x2' * 'x5' * ('x4' * 'x4') * 2) / ('x5' * SQR('x5') * 1)))", formatter.Format(dfdx5));
224      }
225
226
227    }
228  }
229}
Note: See TracBrowser for help on using the repository browser.