source: branches/2994-AutoDiffForIntervals/Tests/AutoDiffTest.cs @ 16682

Last change on this file since 16682 was 16682, checked in by gkronber, 3 months ago

#2994: worked on auto diff for intervals and vectors

File size: 4.8 KB
Line 
1using System;
2using System.Collections.Generic;
3using HeuristicLab.Problems.DataAnalysis;
4using HeuristicLab.Problems.DataAnalysis.Symbolic;
5using Microsoft.VisualStudio.TestTools.UnitTesting;
6using System.Linq;
7using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
8
9namespace Tests {
10  [TestClass]
11  public class AutoDiffTest {
12    [TestMethod]
13    public void Test() {
14      {
15        // eval
16        var parser = new InfixExpressionParser();
17        var t = parser.Parse("2.0*x+y");
18
19        // interval eval
20        var evaluator = new IntervalEvaluator();
21        var intervals = new Dictionary<string, Interval>();
22        intervals.Add("x", new Interval(-1.0, 1.0));
23        intervals.Add("y", new Interval(2.0, 10.0));
24        var resultInterval = evaluator.Evaluate(t, intervals);
25        Assert.AreEqual(0, resultInterval.LowerBound);
26        Assert.AreEqual(12, resultInterval.UpperBound);
27      }
28
29      {
30        // vector eval
31        var parser = new InfixExpressionParser();
32        var t = parser.Parse("2.0*x+y");
33
34        var evaluator = new VectorEvaluator();
35        var vars = new string[] { "x", "y", "f(x)" };
36        var values = new double[,] {
37          { 1,  1, 0 },
38          { 2,  1, 0 },
39          { 3, -1, 0 },
40          { 4, -1, 0 },
41          { 5, -1, 0 },
42        };
43
44        var ds = new Dataset(vars, values);
45        var problemData = new RegressionProblemData(ds, vars, "f(x)");
46        var train = evaluator.Evaluate(t, ds, problemData.TrainingIndices.ToArray());
47        Assert.AreEqual(2, train.Length);
48        Assert.AreEqual(3, train[0]);
49        Assert.AreEqual(5, train[1]);
50
51        var test = evaluator.Evaluate(t, ds, problemData.TestIndices.ToArray());
52        Assert.AreEqual(3, test.Length);
53        Assert.AreEqual(5, test[0]);
54        Assert.AreEqual(7, test[1]);
55        Assert.AreEqual(9, test[2]);
56      }
57
58      {
59        // vector eval and auto-diff
60        var parser = new InfixExpressionParser();
61        var t = parser.Parse("2.0*x+y");
62        var p0 = t.IterateNodesPostfix().First(n => n is ConstantTreeNode);
63        var p1 = t.IterateNodesPostfix().First(n => (n is VariableTreeNode var) && var.VariableName == "y");
64        var paramNodes = new ISymbolicExpressionTreeNode[] { p0, p1 };
65
66        var evaluator = new VectorAutoDiffEvaluator();
67        var vars = new string[] { "x", "y", "f(x)" };
68        var values = new double[,] {
69          { 1,  1, 0 },
70          { 2,  1, 0 },
71          { 3, -1, 0 },
72          { 4, -1, 0 },
73          { 5, -1, 0 },
74        };
75
76        var ds = new Dataset(vars, values);
77        var problemData = new RegressionProblemData(ds, vars, "f(x)");
78        evaluator.Evaluate(t, ds, problemData.TrainingIndices.ToArray(), paramNodes, out double[] train, out double[,] trainJac);
79        Assert.AreEqual(2, train.Length);
80        Assert.AreEqual(3, train[0]);
81        Assert.AreEqual(5, train[1]);
82        // check jac
83        Assert.AreEqual(1, trainJac[0, 0]);
84        Assert.AreEqual(1, trainJac[0, 1]);
85        Assert.AreEqual(2, trainJac[1, 0]);
86        Assert.AreEqual(1, trainJac[1, 1]);
87
88        evaluator.Evaluate(t, ds, problemData.TestIndices.ToArray(), paramNodes, out double[] test, out double[,] testJac);
89        Assert.AreEqual(3, test.Length);
90        Assert.AreEqual(5, test[0]);
91        Assert.AreEqual(7, test[1]);
92        Assert.AreEqual(9, test[2]);
93
94        // check jac
95        Assert.AreEqual(3, testJac[0, 0]);
96        Assert.AreEqual(-1, testJac[0, 1]);
97        Assert.AreEqual(4, testJac[1, 0]);
98        Assert.AreEqual(-1, testJac[1, 1]);
99        Assert.AreEqual(5, testJac[2, 0]);
100        Assert.AreEqual(-1, testJac[2, 1]);
101
102      }
103
104      {
105        // interval eval and auto-diff
106        var parser = new InfixExpressionParser();
107        var t = parser.Parse("2.0*x+y");
108        var p0 = t.IterateNodesPostfix().First(n => n is ConstantTreeNode);
109        var p1 = t.IterateNodesPostfix().First(n => (n is VariableTreeNode var) && var.VariableName == "y");
110        var paramNodes = new ISymbolicExpressionTreeNode[] { p0, p1 };
111
112        var evaluator = new IntervalEvaluator();
113        var intervals = new Dictionary<string, Interval>();
114        intervals.Add("x", new Interval(-1.0, 1.0));
115        intervals.Add("y", new Interval(2.0, 10.0));
116        var resultInterval = evaluator.Evaluate(t, intervals, paramNodes, out double[] lowerGradient, out double[] upperGradient);
117        Assert.AreEqual(0, resultInterval.LowerBound);
118        Assert.AreEqual(12, resultInterval.UpperBound);
119
120        Assert.AreEqual(-1, lowerGradient[0]);
121        Assert.AreEqual(2, lowerGradient[1]);
122        Assert.AreEqual(1, upperGradient[0]);
123        Assert.AreEqual(10, upperGradient[1]);
124      }
125
126    }
127  }
128}
Note: See TracBrowser for help on using the repository browser.