[16674] | 1 | using System;
|
---|
| 2 | using System.Collections.Generic;
|
---|
| 3 | using HeuristicLab.Problems.DataAnalysis;
|
---|
| 4 | using HeuristicLab.Problems.DataAnalysis.Symbolic;
|
---|
| 5 | using Microsoft.VisualStudio.TestTools.UnitTesting;
|
---|
| 6 | using System.Linq;
|
---|
| 7 | using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
|
---|
| 8 |
|
---|
| 9 | namespace Tests {
|
---|
| 10 | [TestClass]
|
---|
[16696] | 11 | public class AutoDiffTestClass {
|
---|
[16674] | 12 | [TestMethod]
|
---|
[16696] | 13 | public void AutoDiffTest() {
|
---|
[16674] | 14 | {
|
---|
[16682] | 15 | // eval
|
---|
[16674] | 16 | var parser = new InfixExpressionParser();
|
---|
| 17 | var t = parser.Parse("2.0*x+y");
|
---|
| 18 |
|
---|
[16682] | 19 | // interval eval
|
---|
[16674] | 20 | var evaluator = new IntervalEvaluator();
|
---|
[16682] | 21 | var intervals = new Dictionary<string, Interval>();
|
---|
| 22 | intervals.Add("x", new Interval(-1.0, 1.0));
|
---|
| 23 | intervals.Add("y", new Interval(2.0, 10.0));
|
---|
[16674] | 24 | var resultInterval = evaluator.Evaluate(t, intervals);
|
---|
| 25 | Assert.AreEqual(0, resultInterval.LowerBound);
|
---|
| 26 | Assert.AreEqual(12, resultInterval.UpperBound);
|
---|
| 27 | }
|
---|
| 28 |
|
---|
| 29 | {
|
---|
[16682] | 30 | // vector eval
|
---|
[16674] | 31 | var parser = new InfixExpressionParser();
|
---|
| 32 | var t = parser.Parse("2.0*x+y");
|
---|
| 33 |
|
---|
| 34 | var evaluator = new VectorEvaluator();
|
---|
| 35 | var vars = new string[] { "x", "y", "f(x)" };
|
---|
| 36 | var values = new double[,] {
|
---|
| 37 | { 1, 1, 0 },
|
---|
| 38 | { 2, 1, 0 },
|
---|
| 39 | { 3, -1, 0 },
|
---|
| 40 | { 4, -1, 0 },
|
---|
| 41 | { 5, -1, 0 },
|
---|
| 42 | };
|
---|
| 43 |
|
---|
| 44 | var ds = new Dataset(vars, values);
|
---|
| 45 | var problemData = new RegressionProblemData(ds, vars, "f(x)");
|
---|
| 46 | var train = evaluator.Evaluate(t, ds, problemData.TrainingIndices.ToArray());
|
---|
| 47 | Assert.AreEqual(2, train.Length);
|
---|
| 48 | Assert.AreEqual(3, train[0]);
|
---|
| 49 | Assert.AreEqual(5, train[1]);
|
---|
| 50 |
|
---|
| 51 | var test = evaluator.Evaluate(t, ds, problemData.TestIndices.ToArray());
|
---|
| 52 | Assert.AreEqual(3, test.Length);
|
---|
| 53 | Assert.AreEqual(5, test[0]);
|
---|
| 54 | Assert.AreEqual(7, test[1]);
|
---|
| 55 | Assert.AreEqual(9, test[2]);
|
---|
| 56 | }
|
---|
| 57 |
|
---|
| 58 | {
|
---|
[16682] | 59 | // vector eval and auto-diff
|
---|
[16674] | 60 | var parser = new InfixExpressionParser();
|
---|
| 61 | var t = parser.Parse("2.0*x+y");
|
---|
| 62 | var p0 = t.IterateNodesPostfix().First(n => n is ConstantTreeNode);
|
---|
| 63 | var p1 = t.IterateNodesPostfix().First(n => (n is VariableTreeNode var) && var.VariableName == "y");
|
---|
| 64 | var paramNodes = new ISymbolicExpressionTreeNode[] { p0, p1 };
|
---|
| 65 |
|
---|
| 66 | var evaluator = new VectorAutoDiffEvaluator();
|
---|
| 67 | var vars = new string[] { "x", "y", "f(x)" };
|
---|
| 68 | var values = new double[,] {
|
---|
| 69 | { 1, 1, 0 },
|
---|
| 70 | { 2, 1, 0 },
|
---|
| 71 | { 3, -1, 0 },
|
---|
| 72 | { 4, -1, 0 },
|
---|
| 73 | { 5, -1, 0 },
|
---|
| 74 | };
|
---|
| 75 |
|
---|
| 76 | var ds = new Dataset(vars, values);
|
---|
| 77 | var problemData = new RegressionProblemData(ds, vars, "f(x)");
|
---|
[16727] | 78 | var train = new double[problemData.TrainingIndices.Count()];
|
---|
| 79 | var trainJac = new double[train.Length, 2];
|
---|
| 80 | evaluator.Evaluate(t, ds, problemData.TrainingIndices.ToArray(), paramNodes, train, trainJac);
|
---|
[16674] | 81 | Assert.AreEqual(2, train.Length);
|
---|
| 82 | Assert.AreEqual(3, train[0]);
|
---|
| 83 | Assert.AreEqual(5, train[1]);
|
---|
| 84 | // check jac
|
---|
| 85 | Assert.AreEqual(1, trainJac[0, 0]);
|
---|
| 86 | Assert.AreEqual(1, trainJac[0, 1]);
|
---|
| 87 | Assert.AreEqual(2, trainJac[1, 0]);
|
---|
| 88 | Assert.AreEqual(1, trainJac[1, 1]);
|
---|
| 89 |
|
---|
[16727] | 90 | var test = new double[problemData.TestIndices.Count()];
|
---|
| 91 | var testJac = new double[test.Length, 2];
|
---|
| 92 | evaluator.Evaluate(t, ds, problemData.TestIndices.ToArray(), paramNodes, test, testJac);
|
---|
[16674] | 93 | Assert.AreEqual(3, test.Length);
|
---|
| 94 | Assert.AreEqual(5, test[0]);
|
---|
| 95 | Assert.AreEqual(7, test[1]);
|
---|
| 96 | Assert.AreEqual(9, test[2]);
|
---|
| 97 |
|
---|
| 98 | // check jac
|
---|
| 99 | Assert.AreEqual(3, testJac[0, 0]);
|
---|
| 100 | Assert.AreEqual(-1, testJac[0, 1]);
|
---|
| 101 | Assert.AreEqual(4, testJac[1, 0]);
|
---|
| 102 | Assert.AreEqual(-1, testJac[1, 1]);
|
---|
| 103 | Assert.AreEqual(5, testJac[2, 0]);
|
---|
| 104 | Assert.AreEqual(-1, testJac[2, 1]);
|
---|
| 105 |
|
---|
| 106 | }
|
---|
| 107 |
|
---|
[16682] | 108 | {
|
---|
[16744] | 109 | // Interval tests
|
---|
| 110 | var intervals = new Dictionary<string, Interval>();
|
---|
| 111 | intervals.Add("x", new Interval(-2.0, 3.0));
|
---|
| 112 | intervals.Add("p", new Interval(1.0, 2.0));
|
---|
| 113 | intervals.Add("n", new Interval(-2.0, -1.0));
|
---|
| 114 |
|
---|
| 115 | AssertInterval("10*x", intervals, -20, 30);
|
---|
| 116 | AssertInterval("sqr(p)", intervals, 1, 4);
|
---|
| 117 | AssertInterval("sqr(n)", intervals, 1, 4);
|
---|
| 118 | AssertInterval("sqr(x)", intervals, 0, 9);
|
---|
| 119 |
|
---|
| 120 | AssertInterval("cube(p)", intervals, 1, 8);
|
---|
| 121 | AssertInterval("cube(n)", intervals, -8, -1);
|
---|
| 122 | AssertInterval("cube(x)", intervals, -8, 27);
|
---|
| 123 | }
|
---|
| 124 |
|
---|
| 125 | {
|
---|
[16682] | 126 | // interval eval and auto-diff
|
---|
| 127 | var parser = new InfixExpressionParser();
|
---|
| 128 | var t = parser.Parse("2.0*x+y");
|
---|
| 129 | var p0 = t.IterateNodesPostfix().First(n => n is ConstantTreeNode);
|
---|
| 130 | var p1 = t.IterateNodesPostfix().First(n => (n is VariableTreeNode var) && var.VariableName == "y");
|
---|
| 131 | var paramNodes = new ISymbolicExpressionTreeNode[] { p0, p1 };
|
---|
| 132 |
|
---|
| 133 | var evaluator = new IntervalEvaluator();
|
---|
| 134 | var intervals = new Dictionary<string, Interval>();
|
---|
| 135 | intervals.Add("x", new Interval(-1.0, 1.0));
|
---|
| 136 | intervals.Add("y", new Interval(2.0, 10.0));
|
---|
| 137 | var resultInterval = evaluator.Evaluate(t, intervals, paramNodes, out double[] lowerGradient, out double[] upperGradient);
|
---|
| 138 | Assert.AreEqual(0, resultInterval.LowerBound);
|
---|
| 139 | Assert.AreEqual(12, resultInterval.UpperBound);
|
---|
| 140 |
|
---|
| 141 | Assert.AreEqual(-1, lowerGradient[0]);
|
---|
| 142 | Assert.AreEqual(2, lowerGradient[1]);
|
---|
| 143 | Assert.AreEqual(1, upperGradient[0]);
|
---|
| 144 | Assert.AreEqual(10, upperGradient[1]);
|
---|
| 145 | }
|
---|
| 146 |
|
---|
[16727] | 147 | {
|
---|
| 148 | // as discussed with Fabrício
|
---|
| 149 | var intervals = new Dictionary<string, Interval>();
|
---|
| 150 | intervals.Add("x1", new Interval(60.0, 65.0));
|
---|
| 151 | intervals.Add("x2", new Interval(30.0, 40.0));
|
---|
| 152 | intervals.Add("x3", new Interval(5.0, 10.0));
|
---|
| 153 | intervals.Add("x4", new Interval(0.5, 0.8));
|
---|
| 154 | intervals.Add("x5", new Interval(0.2, 0.5));
|
---|
| 155 |
|
---|
| 156 | var parser = new InfixExpressionParser();
|
---|
| 157 |
|
---|
| 158 | var t1 = parser.Parse("x5/x4");
|
---|
| 159 | var t2 = parser.Parse("log(x5/x4)");
|
---|
| 160 | var t3 = parser.Parse("x3 * log(x5/x4)");
|
---|
| 161 | var t4 = parser.Parse("x1*x2*x5");
|
---|
| 162 | var t5 = parser.Parse("x4/x5");
|
---|
| 163 | var t6 = parser.Parse("sqr(x4/x5)");
|
---|
| 164 | var t7 = parser.Parse("(1 - sqr(x4/x5)) ");
|
---|
| 165 | var t8 = parser.Parse("x1*x2*x5 *(1 - sqr(x4/x5))");
|
---|
| 166 | var t9 = parser.Parse("x1*x2*x5 *(1 - sqr(x4/x5)) + x3 * log(x5/x4)");
|
---|
| 167 |
|
---|
| 168 | var evaluator = new IntervalEvaluator();
|
---|
| 169 | var result = evaluator.Evaluate(t1, intervals);
|
---|
| 170 | Assert.AreEqual(0.25, result.LowerBound);
|
---|
| 171 | Assert.AreEqual(1, result.UpperBound);
|
---|
| 172 |
|
---|
| 173 | result = evaluator.Evaluate(t2, intervals);
|
---|
| 174 | Assert.AreEqual(-1.386294361, result.LowerBound, 1e-6);
|
---|
| 175 | Assert.AreEqual(0, result.UpperBound);
|
---|
| 176 |
|
---|
| 177 | result = evaluator.Evaluate(t3, intervals);
|
---|
| 178 | Assert.AreEqual(-13.86294361, result.LowerBound, 1e-6);
|
---|
| 179 | Assert.AreEqual(0, result.UpperBound);
|
---|
| 180 |
|
---|
| 181 | result = evaluator.Evaluate(t4, intervals);
|
---|
| 182 | Assert.AreEqual(360, result.LowerBound);
|
---|
| 183 | Assert.AreEqual(1300, result.UpperBound);
|
---|
| 184 |
|
---|
| 185 | result = evaluator.Evaluate(t5, intervals);
|
---|
| 186 | Assert.AreEqual(1, result.LowerBound, 1e-6);
|
---|
| 187 | Assert.AreEqual(4, result.UpperBound);
|
---|
| 188 |
|
---|
| 189 | result = evaluator.Evaluate(t6, intervals);
|
---|
| 190 | Assert.AreEqual(1, result.LowerBound);
|
---|
| 191 | Assert.AreEqual(16, result.UpperBound);
|
---|
| 192 |
|
---|
| 193 | result = evaluator.Evaluate(t7, intervals);
|
---|
| 194 | Assert.AreEqual(-15, result.LowerBound);
|
---|
| 195 | Assert.AreEqual(0, result.UpperBound);
|
---|
| 196 |
|
---|
| 197 | result = evaluator.Evaluate(t8, intervals);
|
---|
| 198 | Assert.AreEqual(-19500, result.LowerBound);
|
---|
| 199 | Assert.AreEqual(0, result.UpperBound);
|
---|
| 200 |
|
---|
| 201 | result = evaluator.Evaluate(t9, intervals);
|
---|
| 202 | Assert.AreEqual(-19513.86294, result.LowerBound, 1e-3);
|
---|
| 203 | Assert.AreEqual(0, result.UpperBound);
|
---|
| 204 |
|
---|
| 205 |
|
---|
| 206 | }
|
---|
| 207 |
|
---|
[16738] | 208 | {
|
---|
| 209 |
|
---|
| 210 | // derivatives and intervals for flow psi problem
|
---|
| 211 | var intervals = new Dictionary<string, Interval>();
|
---|
| 212 | intervals.Add("x1", new Interval(60.0, 65.0));
|
---|
| 213 | intervals.Add("x2", new Interval(30.0, 40.0));
|
---|
| 214 | intervals.Add("x3", new Interval(5.0, 10.0));
|
---|
| 215 | intervals.Add("x4", new Interval(0.5, 0.8));
|
---|
| 216 | intervals.Add("x5", new Interval(0.2, 0.5));
|
---|
| 217 |
|
---|
| 218 | var parser = new InfixExpressionParser();
|
---|
| 219 | var formatter = new InfixExpressionFormatter();
|
---|
| 220 |
|
---|
| 221 | var expr = parser.Parse("x1*x2*x5*(1 - sqr(x4/x5)) + x3 * log(x5/x4)");
|
---|
| 222 |
|
---|
| 223 | var dfdx1 = DerivativeCalculator.Derive(expr, "x1");
|
---|
| 224 | Assert.AreEqual("('x2' * 'x5' * ((SQR(('x4' / 'x5')) * (-1)) + 1))", formatter.Format(dfdx1));
|
---|
| 225 | // x2 x5 (1 - sqr(x4/x5))
|
---|
| 226 |
|
---|
| 227 | var dfdx2 = DerivativeCalculator.Derive(expr, "x2");
|
---|
| 228 | Assert.AreEqual("('x1' * 'x5' * ((SQR(('x4' / 'x5')) * (-1)) + 1))", formatter.Format(dfdx2));
|
---|
| 229 | // x1 x5 (1 - sqr(x4/x5))
|
---|
| 230 |
|
---|
| 231 | var dfdx3 = DerivativeCalculator.Derive(expr, "x3");
|
---|
| 232 | Assert.AreEqual("LOG(('x5' / 'x4'))", formatter.Format(dfdx3));
|
---|
| 233 | // log(x5/x4)
|
---|
| 234 |
|
---|
| 235 | var dfdx4 = DerivativeCalculator.Derive(expr, "x4");
|
---|
| 236 | Assert.AreEqual("((('x1' * 'x2' * 'x5' * 'x4' * 2) / ('x5' * (-1*'x5'))) + (('x4' * 'x5' * 'x3') / ('x5' * SQR('x4') * (-1))))", formatter.Format(dfdx4));
|
---|
| 237 | // -2*x1*x2*x5*x4/x5*1/x5 + x3*1/(x5/x4)*x5/sqr(x4)
|
---|
| 238 |
|
---|
| 239 | var dfdx5 = DerivativeCalculator.Derive(expr, "x5");
|
---|
| 240 | Assert.AreEqual("((('x4' * 'x3') / ('x5' * 'x4')) + ('x1' * 'x2' * ((SQR(('x4' / 'x5')) * (-1)) + 1)) + (('x1' * 'x2' * 'x5' * ('x4' * 'x4') * 2) / ('x5' * SQR('x5') * 1)))", formatter.Format(dfdx5));
|
---|
| 241 | }
|
---|
| 242 |
|
---|
| 243 |
|
---|
[16674] | 244 | }
|
---|
[16744] | 245 |
|
---|
| 246 | private void AssertInterval(string expression, Dictionary<string, Interval> intervals, double expectedLow, double expectedHigh) {
|
---|
| 247 | var parser = new InfixExpressionParser();
|
---|
| 248 | var t = parser.Parse(expression);
|
---|
| 249 | var evaluator = new IntervalEvaluator();
|
---|
| 250 | var result = evaluator.Evaluate(t, intervals);
|
---|
| 251 | Assert.AreEqual(expectedLow, result.LowerBound);
|
---|
| 252 | Assert.AreEqual(expectedHigh, result.UpperBound);
|
---|
| 253 | }
|
---|
[16674] | 254 | }
|
---|
| 255 | }
|
---|