Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2994-AutoDiffForIntervals/Tests/AutoDiffTest.cs @ 16674

Last change on this file since 16674 was 16674, checked in by gkronber, 5 years ago

#2994: worked on AutoDiff implementation based on BatchInterpreter

File size: 3.6 KB
Line 
1using System;
2using System.Collections.Generic;
3using HeuristicLab.Problems.DataAnalysis;
4using HeuristicLab.Problems.DataAnalysis.Symbolic;
5using Microsoft.VisualStudio.TestTools.UnitTesting;
6using System.Linq;
7using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
8
9namespace Tests {
10  [TestClass]
11  public class AutoDiffTest {
12    [TestMethod]
13    public void TestMethod1() {
14      {
15        var parser = new InfixExpressionParser();
16        var t = parser.Parse("2.0*x+y");
17
18        var evaluator = new IntervalEvaluator();
19        var intervals = new Dictionary<string, AlgebraicInterval>();
20        intervals.Add("x", new AlgebraicInterval(-1.0, 1.0));
21        intervals.Add("y", new AlgebraicInterval(2.0, 10.0));
22        var resultInterval = evaluator.Evaluate(t, intervals);
23        Assert.AreEqual(0, resultInterval.LowerBound);
24        Assert.AreEqual(12, resultInterval.UpperBound);
25      }
26
27      {
28        var parser = new InfixExpressionParser();
29        var t = parser.Parse("2.0*x+y");
30
31        var evaluator = new VectorEvaluator();
32        var vars = new string[] { "x", "y", "f(x)" };
33        var values = new double[,] {
34          { 1,  1, 0 },
35          { 2,  1, 0 },
36          { 3, -1, 0 },
37          { 4, -1, 0 },
38          { 5, -1, 0 },
39        };
40
41        var ds = new Dataset(vars, values);
42        var problemData = new RegressionProblemData(ds, vars, "f(x)");
43        var train = evaluator.Evaluate(t, ds, problemData.TrainingIndices.ToArray());
44        Assert.AreEqual(2, train.Length);
45        Assert.AreEqual(3, train[0]);
46        Assert.AreEqual(5, train[1]);
47
48        var test = evaluator.Evaluate(t, ds, problemData.TestIndices.ToArray());
49        Assert.AreEqual(3, test.Length);
50        Assert.AreEqual(5, test[0]);
51        Assert.AreEqual(7, test[1]);
52        Assert.AreEqual(9, test[2]);
53      }
54
55      {
56        var parser = new InfixExpressionParser();
57        var t = parser.Parse("2.0*x+y");
58        var p0 = t.IterateNodesPostfix().First(n => n is ConstantTreeNode);
59        var p1 = t.IterateNodesPostfix().First(n => (n is VariableTreeNode var) && var.VariableName == "y");
60        var paramNodes = new ISymbolicExpressionTreeNode[] { p0, p1 };
61
62        var evaluator = new VectorAutoDiffEvaluator();
63        var vars = new string[] { "x", "y", "f(x)" };
64        var values = new double[,] {
65          { 1,  1, 0 },
66          { 2,  1, 0 },
67          { 3, -1, 0 },
68          { 4, -1, 0 },
69          { 5, -1, 0 },
70        };
71
72        var ds = new Dataset(vars, values);
73        var problemData = new RegressionProblemData(ds, vars, "f(x)");
74        evaluator.Evaluate(t, ds, problemData.TrainingIndices.ToArray(), paramNodes, out double[] train, out double[,] trainJac);
75        Assert.AreEqual(2, train.Length);
76        Assert.AreEqual(3, train[0]);
77        Assert.AreEqual(5, train[1]);
78        // check jac
79        Assert.AreEqual(1, trainJac[0, 0]);
80        Assert.AreEqual(1, trainJac[0, 1]);
81        Assert.AreEqual(2, trainJac[1, 0]);
82        Assert.AreEqual(1, trainJac[1, 1]);
83
84        evaluator.Evaluate(t, ds, problemData.TestIndices.ToArray(), paramNodes, out double[] test, out double[,] testJac);
85        Assert.AreEqual(3, test.Length);
86        Assert.AreEqual(5, test[0]);
87        Assert.AreEqual(7, test[1]);
88        Assert.AreEqual(9, test[2]);
89
90        // check jac
91        Assert.AreEqual(3, testJac[0, 0]);
92        Assert.AreEqual(-1, testJac[0, 1]);
93        Assert.AreEqual(4, testJac[1, 0]);
94        Assert.AreEqual(-1, testJac[1, 1]);
95        Assert.AreEqual(5, testJac[2, 0]);
96        Assert.AreEqual(-1, testJac[2, 1]);
97
98      }
99
100    }
101  }
102}
Note: See TracBrowser for help on using the repository browser.