[17318] | 1 | using System;
|
---|
| 2 | using System.Collections.Generic;
|
---|
| 3 | using System.Linq;
|
---|
| 4 | using System.Text;
|
---|
| 5 | using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
|
---|
| 6 | using HeuristicLab.Problems.DataAnalysis.Symbolic;
|
---|
| 7 | using HeuristicLab.Random;
|
---|
| 8 | using Microsoft.VisualStudio.TestTools.UnitTesting;
|
---|
| 9 |
|
---|
| 10 | namespace HeuristicLab.Problems.DataAnalysis.Tests {
|
---|
| 11 | [TestClass]
|
---|
| 12 | public class IntervalEvaluatorAutoDiffTest {
|
---|
| 13 | [TestMethod]
|
---|
| 14 | [TestCategory("Problems.DataAnalysis")]
|
---|
| 15 | [TestProperty("Time", "short")]
|
---|
| 16 | public void IntervalEvalutorAutoDiffAdd() {
|
---|
| 17 | var eval = new IntervalEvaluator();
|
---|
| 18 | var parser = new InfixExpressionParser();
|
---|
| 19 | var t = parser.Parse("x + y");
|
---|
| 20 | var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 21 | var intervals = new Dictionary<string, Interval>() {
|
---|
| 22 | { "x", new Interval(1, 2) },
|
---|
| 23 | { "y", new Interval(0, 1) }
|
---|
| 24 | };
|
---|
| 25 | var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
|
---|
| 26 | Assert.AreEqual(1, r.LowerBound);
|
---|
| 27 | Assert.AreEqual(3, r.UpperBound);
|
---|
| 28 |
|
---|
| 29 | Assert.AreEqual(1.0, lg[0]); // x
|
---|
| 30 | Assert.AreEqual(2.0, ug[0]);
|
---|
| 31 | Assert.AreEqual(0.0, lg[1]); // y
|
---|
| 32 | Assert.AreEqual(1.0, ug[1]);
|
---|
| 33 | }
|
---|
| 34 |
|
---|
| 35 | [TestMethod]
|
---|
| 36 | [TestCategory("Problems.DataAnalysis")]
|
---|
| 37 | [TestProperty("Time", "short")]
|
---|
| 38 | public void IntervalEvalutorAutoDiffMul() {
|
---|
| 39 | var eval = new IntervalEvaluator();
|
---|
| 40 | var parser = new InfixExpressionParser();
|
---|
| 41 | var t = parser.Parse("x * y");
|
---|
| 42 | var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 43 | var intervals = new Dictionary<string, Interval>() {
|
---|
| 44 | { "x", new Interval(1, 2) },
|
---|
| 45 | { "y", new Interval(0, 1) }
|
---|
| 46 | };
|
---|
| 47 | var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
|
---|
| 48 | Assert.AreEqual(0, r.LowerBound);
|
---|
| 49 | Assert.AreEqual(2, r.UpperBound);
|
---|
| 50 |
|
---|
| 51 | Assert.AreEqual(0.0, lg[0]); // x
|
---|
| 52 | Assert.AreEqual(2.0, ug[0]);
|
---|
| 53 | Assert.AreEqual(0.0, lg[1]); // y
|
---|
| 54 | Assert.AreEqual(2.0, ug[1]);
|
---|
| 55 | }
|
---|
| 56 |
|
---|
| 57 | [TestMethod]
|
---|
| 58 | [TestCategory("Problems.DataAnalysis")]
|
---|
| 59 | [TestProperty("Time", "short")]
|
---|
| 60 | public void IntervalEvalutorAutoDiffSqr() {
|
---|
| 61 | var eval = new IntervalEvaluator();
|
---|
| 62 | var parser = new InfixExpressionParser();
|
---|
| 63 | var intervals = new Dictionary<string, Interval>() {
|
---|
| 64 | { "x", new Interval(1, 2) },
|
---|
[17325] | 65 | { "unit", new Interval(0, 1) },
|
---|
| 66 | { "neg", new Interval(-1, 0) },
|
---|
[17318] | 67 | };
|
---|
[17325] | 68 | var t = parser.Parse("sqr(x)");
|
---|
| 69 | var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
[17318] | 70 | var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
|
---|
[17325] | 71 | Assert.AreEqual(1, r.LowerBound);
|
---|
| 72 | Assert.AreEqual(4, r.UpperBound);
|
---|
| 73 |
|
---|
| 74 | Assert.AreEqual(2.0, lg[0]); // x
|
---|
| 75 | Assert.AreEqual(8.0, ug[0]);
|
---|
| 76 |
|
---|
| 77 | t = parser.Parse("sqr(log(unit))");
|
---|
| 78 | paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 79 | r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
|
---|
| 80 | Assert.AreEqual(0.0, r.LowerBound);
|
---|
| 81 | Assert.AreEqual(double.PositiveInfinity, r.UpperBound);
|
---|
| 82 |
|
---|
| 83 | Assert.AreEqual(0.0, lg[0]); // x
|
---|
| 84 | Assert.AreEqual(double.NaN, ug[0]);
|
---|
| 85 |
|
---|
[17318] | 86 | }
|
---|
| 87 |
|
---|
| 88 | [TestMethod]
|
---|
| 89 | [TestCategory("Problems.DataAnalysis")]
|
---|
| 90 | [TestProperty("Time", "short")]
|
---|
| 91 | public void IntervalEvalutorAutoDiffExp() {
|
---|
| 92 | var eval = new IntervalEvaluator();
|
---|
| 93 | var parser = new InfixExpressionParser();
|
---|
| 94 | var t = parser.Parse("exp(x)");
|
---|
| 95 | var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 96 | var intervals = new Dictionary<string, Interval>() {
|
---|
| 97 | { "x", new Interval(1, 2) },
|
---|
| 98 | };
|
---|
| 99 | var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
|
---|
| 100 | Assert.AreEqual(Math.Exp(1), r.LowerBound);
|
---|
| 101 | Assert.AreEqual(Math.Exp(2), r.UpperBound);
|
---|
| 102 |
|
---|
| 103 | Assert.AreEqual(Math.Exp(1), lg[0]); // x
|
---|
| 104 | Assert.AreEqual(Math.Exp(2) * 2, ug[0]);
|
---|
| 105 | }
|
---|
| 106 |
|
---|
| 107 | [TestMethod]
|
---|
| 108 | [TestCategory("Problems.DataAnalysis")]
|
---|
| 109 | [TestProperty("Time", "short")]
|
---|
| 110 | public void IntervalEvalutorAutoDiffSin() {
|
---|
| 111 | var eval = new IntervalEvaluator();
|
---|
| 112 | var parser = new InfixExpressionParser();
|
---|
| 113 | var t = parser.Parse("sin(x)");
|
---|
| 114 | var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 115 | var intervals = new Dictionary<string, Interval>() {
|
---|
| 116 | { "x", new Interval(1, 2) },
|
---|
| 117 | };
|
---|
| 118 | var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
|
---|
| 119 | Assert.AreEqual(Math.Sin(1), r.LowerBound); // sin(1) < sin(2)
|
---|
| 120 | Assert.AreEqual(1, r.UpperBound); // 1..2 crosses pi / 2 and sin(pi/2)==1
|
---|
| 121 |
|
---|
| 122 | Assert.AreEqual(Math.Cos(1), lg[0]); // x
|
---|
| 123 | Assert.AreEqual(0, ug[0]);
|
---|
| 124 | }
|
---|
| 125 |
|
---|
| 126 | [TestMethod]
|
---|
| 127 | [TestCategory("Problems.DataAnalysis")]
|
---|
| 128 | [TestProperty("Time", "short")]
|
---|
| 129 | public void IntervalEvalutorAutoDiffCos() {
|
---|
| 130 | var eval = new IntervalEvaluator();
|
---|
| 131 | var parser = new InfixExpressionParser();
|
---|
| 132 | var intervals = new Dictionary<string, Interval>() {
|
---|
| 133 | { "x", new Interval(3, 4) },
|
---|
[17325] | 134 | { "z", new Interval(1, 2) }
|
---|
[17318] | 135 | };
|
---|
[17325] | 136 | var t = parser.Parse("cos(x)");
|
---|
| 137 | var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
[17318] | 138 | var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
|
---|
| 139 | Assert.AreEqual(-1, r.LowerBound); // 3..4 crosses pi and cos(pi) == -1
|
---|
| 140 | Assert.AreEqual(Math.Cos(4), r.UpperBound); // cos(3) < cos(4)
|
---|
| 141 |
|
---|
| 142 | Assert.AreEqual(0, lg[0]); // x
|
---|
| 143 | Assert.AreEqual(-4 * Math.Sin(4), ug[0]);
|
---|
[17325] | 144 |
|
---|
| 145 | t = parser.Parse("LOG(COS('z'))");
|
---|
| 146 | paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 147 | r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
|
---|
| 148 | Assert.AreEqual(double.NaN, r.LowerBound);
|
---|
| 149 | Assert.AreEqual(Math.Log(Math.Cos(1)), r.UpperBound);
|
---|
| 150 |
|
---|
| 151 | Assert.AreEqual(-2 * Math.Sin(2) / Math.Cos(2), lg[0], 1e-5); // x
|
---|
| 152 | Assert.AreEqual(-1 * Math.Sin(1) / Math.Cos(1), ug[0], 1e-5);
|
---|
| 153 |
|
---|
[17318] | 154 | }
|
---|
| 155 |
|
---|
| 156 | [TestMethod]
|
---|
| 157 | [TestCategory("Problems.DataAnalysis")]
|
---|
| 158 | [TestProperty("Time", "short")]
|
---|
| 159 | public void IntervalEvalutorAutoDiffSqrt() {
|
---|
| 160 | var eval = new IntervalEvaluator();
|
---|
| 161 | var parser = new InfixExpressionParser();
|
---|
| 162 | var t = parser.Parse("sqrt(x)");
|
---|
| 163 | var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 164 | var intervals = new Dictionary<string, Interval>() {
|
---|
| 165 | { "x", new Interval(4, 9) },
|
---|
| 166 | { "y", new Interval(1, 2) },
|
---|
| 167 | { "z", new Interval(0, 1) },
|
---|
[17319] | 168 | { "eps", new Interval(1e-10, 1) }
|
---|
[17318] | 169 | };
|
---|
| 170 | var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
|
---|
| 171 | Assert.AreEqual(2, r.LowerBound);
|
---|
| 172 | Assert.AreEqual(3, r.UpperBound);
|
---|
| 173 |
|
---|
| 174 | Assert.AreEqual(1.0, lg[0]); // x
|
---|
| 175 | Assert.AreEqual(1.5, ug[0]);
|
---|
| 176 |
|
---|
| 177 | t = parser.Parse("sqrt(y)");
|
---|
| 178 | paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 179 | r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
|
---|
| 180 | Assert.AreEqual(1, r.LowerBound);
|
---|
| 181 | Assert.AreEqual(Math.Sqrt(2), r.UpperBound);
|
---|
| 182 |
|
---|
| 183 | Assert.AreEqual(0.5, lg[0]); // y
|
---|
| 184 | Assert.AreEqual(0.5 * Math.Sqrt(2), ug[0], 1e-5);
|
---|
| 185 |
|
---|
| 186 | t = parser.Parse("sqrt(z)");
|
---|
| 187 | paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 188 | r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
|
---|
| 189 | Assert.AreEqual(0, r.LowerBound);
|
---|
| 190 | Assert.AreEqual(1, r.UpperBound);
|
---|
| 191 |
|
---|
[17319] | 192 | Assert.AreEqual(double.NaN, lg[0]); // z
|
---|
| 193 | Assert.AreEqual(0.5, ug[0], 1e-5);
|
---|
| 194 |
|
---|
| 195 | t = parser.Parse("sqrt(eps)");
|
---|
| 196 | paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 197 | r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
|
---|
| 198 |
|
---|
[17325] | 199 | Assert.AreEqual(0.5 * Math.Sqrt(1e-10), lg[0], 1e-5); // --> lim x -> 0 (sqrt(x)) = 0
|
---|
[17319] | 200 | Assert.AreEqual(0.5, ug[0], 1e-5);
|
---|
[17325] | 201 |
|
---|
| 202 | t = parser.Parse("sqrt(y - z)"); // 1..2 - 0..1
|
---|
| 203 | paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 204 | r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
|
---|
| 205 | Assert.AreEqual(0, r.LowerBound);
|
---|
| 206 | Assert.AreEqual(Math.Sqrt(2), r.UpperBound);
|
---|
| 207 |
|
---|
| 208 | Assert.AreEqual(double.PositiveInfinity, lg[0], 1e-5); // y
|
---|
| 209 | Assert.AreEqual(1/ Math.Sqrt(2) , ug[0], 1e-5);
|
---|
| 210 | Assert.AreEqual(double.NegativeInfinity, lg[1], 1e-5); // z
|
---|
| 211 | Assert.AreEqual(0.0 , ug[1], 1e-5);
|
---|
[17318] | 212 | }
|
---|
| 213 |
|
---|
| 214 | [TestMethod]
|
---|
| 215 | [TestCategory("Problems.DataAnalysis")]
|
---|
| 216 | [TestProperty("Time", "short")]
|
---|
| 217 | public void IntervalEvalutorAutoDiffCqrt() {
|
---|
| 218 | var eval = new IntervalEvaluator();
|
---|
| 219 | var parser = new InfixExpressionParser();
|
---|
| 220 | var t = parser.Parse("cuberoot(x)");
|
---|
| 221 | var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 222 | var intervals = new Dictionary<string, Interval>() {
|
---|
| 223 | { "x", new Interval(8, 27) },
|
---|
| 224 | { "y", new Interval(1, 2) },
|
---|
| 225 | { "z", new Interval(0, 1) },
|
---|
| 226 | };
|
---|
| 227 | var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
|
---|
| 228 | Assert.AreEqual(2, r.LowerBound);
|
---|
| 229 | Assert.AreEqual(3, r.UpperBound);
|
---|
| 230 |
|
---|
| 231 | Assert.AreEqual(2.0 / 3.0, lg[0]); // x
|
---|
| 232 | Assert.AreEqual(1.0, ug[0]);
|
---|
| 233 |
|
---|
| 234 | t = parser.Parse("cuberoot(y)");
|
---|
| 235 | paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 236 | r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
|
---|
| 237 | Assert.AreEqual(Math.Pow(1, 1.0 / 3.0), r.LowerBound);
|
---|
| 238 | Assert.AreEqual(Math.Pow(2, 1.0 / 3.0), r.UpperBound);
|
---|
| 239 |
|
---|
| 240 | Assert.AreEqual(1.0 / 3.0, lg[0]); // y
|
---|
| 241 | Assert.AreEqual(1.0 / 3.0 * Math.Pow(2, 1.0 / 3.0), ug[0], 1e-5);
|
---|
| 242 |
|
---|
[17319] | 243 | t = parser.Parse("cuberoot(z)");
|
---|
[17318] | 244 | paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 245 | r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
|
---|
| 246 | Assert.AreEqual(0.0, r.LowerBound);
|
---|
| 247 | Assert.AreEqual(1.0, r.UpperBound);
|
---|
| 248 |
|
---|
[17319] | 249 | Assert.AreEqual(double.NaN, lg[0]); // z
|
---|
[17318] | 250 | Assert.AreEqual(1.0 / 3.0, ug[0], 1e-5);
|
---|
| 251 | }
|
---|
| 252 |
|
---|
| 253 | [TestMethod]
|
---|
| 254 | [TestCategory("Problems.DataAnalysis")]
|
---|
| 255 | [TestProperty("Time", "short")]
|
---|
| 256 | public void IntervalEvalutorAutoDiffLog() {
|
---|
| 257 | var eval = new IntervalEvaluator();
|
---|
| 258 | var parser = new InfixExpressionParser();
|
---|
| 259 | var t = parser.Parse("log(4*x)");
|
---|
| 260 | var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 261 | var intervals = new Dictionary<string, Interval>() {
|
---|
| 262 | { "x", new Interval(1, 2) },
|
---|
| 263 | };
|
---|
| 264 | var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
|
---|
| 265 | Assert.AreEqual(Math.Log(4), r.LowerBound);
|
---|
| 266 | Assert.AreEqual(Math.Log(8), r.UpperBound);
|
---|
| 267 |
|
---|
| 268 | Assert.AreEqual(0.25, lg[0]); // x
|
---|
| 269 | Assert.AreEqual(0.25, ug[0]);
|
---|
| 270 |
|
---|
| 271 | }
|
---|
| 272 |
|
---|
| 273 | [TestMethod]
|
---|
| 274 | [TestCategory("Problems.DataAnalysis")]
|
---|
| 275 | [TestProperty("Time", "short")]
|
---|
| 276 | public void IntervalEvaluatorAutoDiffCompareWithNumericDifferences() {
|
---|
| 277 |
|
---|
| 278 | // create random trees and evaluate on random data
|
---|
| 279 | // calc gradient for all parameters
|
---|
| 280 | // use numeric differences for approximate gradient calculation
|
---|
| 281 | // compare gradients
|
---|
| 282 |
|
---|
| 283 | var grammar = new TypeCoherentExpressionGrammar();
|
---|
| 284 | grammar.ConfigureAsDefaultRegressionGrammar();
|
---|
| 285 | // activate supported symbols
|
---|
| 286 | grammar.Symbols.First(s => s is Square).Enabled = true;
|
---|
| 287 | grammar.Symbols.First(s => s is SquareRoot).Enabled = true;
|
---|
| 288 | grammar.Symbols.First(s => s is Cube).Enabled = true;
|
---|
| 289 | grammar.Symbols.First(s => s is CubeRoot).Enabled = true;
|
---|
| 290 | grammar.Symbols.First(s => s is Sine).Enabled = true;
|
---|
| 291 | grammar.Symbols.First(s => s is Cosine).Enabled = true;
|
---|
| 292 | grammar.Symbols.First(s => s is Exponential).Enabled = true;
|
---|
| 293 | grammar.Symbols.First(s => s is Logarithm).Enabled = true;
|
---|
| 294 | grammar.Symbols.First(s => s is Absolute).Enabled = true;
|
---|
| 295 | grammar.Symbols.First(s => s is AnalyticQuotient).Enabled = false; // not yet supported by old interval calculator
|
---|
| 296 | grammar.Symbols.First(s => s is Constant).Enabled = false;
|
---|
| 297 |
|
---|
| 298 | var varSy = (Variable)grammar.Symbols.First(s => s is Variable);
|
---|
| 299 | varSy.AllVariableNames = new string[] { "x", "y" };
|
---|
| 300 | varSy.VariableNames = varSy.AllVariableNames;
|
---|
| 301 | varSy.WeightMu = 1.0;
|
---|
| 302 | varSy.WeightSigma = 0.0;
|
---|
| 303 | var rand = new FastRandom(1234);
|
---|
| 304 |
|
---|
| 305 | var intervals = new Dictionary<string, Interval>() {
|
---|
| 306 | { "x", new Interval(1, 2) },
|
---|
| 307 | { "y", new Interval(0, 1) },
|
---|
| 308 | };
|
---|
| 309 |
|
---|
| 310 | var eval = new IntervalEvaluator();
|
---|
| 311 |
|
---|
| 312 | var formatter = new InfixExpressionFormatter();
|
---|
| 313 | var sb = new StringBuilder();
|
---|
| 314 | int N = 10000;
|
---|
| 315 | int iter = 0;
|
---|
| 316 | while (iter < N) {
|
---|
| 317 | var t = ProbabilisticTreeCreator.Create(rand, grammar, maxTreeLength: 5, maxTreeDepth: 5);
|
---|
| 318 | var parameterNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 319 |
|
---|
| 320 | eval.Evaluate(t, intervals, parameterNodes, out double[] lowerGradient, out double[] upperGradient);
|
---|
| 321 |
|
---|
| 322 | ApproximateIntervalGradient(t, intervals, parameterNodes, eval, out double[] refLowerGradient, out double[] refUpperGradient);
|
---|
| 323 |
|
---|
| 324 | // compare autodiff and numeric diff
|
---|
| 325 | for(int p=0;p<parameterNodes.Length;p++) {
|
---|
| 326 | // lower
|
---|
| 327 | if(double.IsNaN(lowerGradient[p]) && double.IsNaN(refLowerGradient[p])) {
|
---|
| 328 |
|
---|
| 329 | } else if(lowerGradient[p] == refLowerGradient[p]){
|
---|
| 330 |
|
---|
| 331 | } else if(Math.Abs(lowerGradient[p] - refLowerGradient[p]) <= Math.Abs(lowerGradient[p]) * 1e-4) {
|
---|
| 332 |
|
---|
| 333 | } else {
|
---|
| 334 | sb.AppendLine($"{lowerGradient[p]} <> {refLowerGradient[p]} for {parameterNodes[p]} in {formatter.Format(t)}");
|
---|
| 335 | }
|
---|
| 336 | // upper
|
---|
| 337 | if (double.IsNaN(upperGradient[p]) && double.IsNaN(refUpperGradient[p])) {
|
---|
| 338 |
|
---|
| 339 | } else if (upperGradient[p] == refUpperGradient[p]) {
|
---|
| 340 |
|
---|
| 341 | } else if (Math.Abs(upperGradient[p] - refUpperGradient[p]) <= Math.Abs(upperGradient[p]) * 1e-4) {
|
---|
| 342 |
|
---|
| 343 | } else {
|
---|
| 344 | sb.AppendLine($"{upperGradient[p]} <> {refUpperGradient[p]} for {parameterNodes[p]} in {formatter.Format(t)}");
|
---|
| 345 | }
|
---|
| 346 | }
|
---|
| 347 |
|
---|
| 348 | iter++;
|
---|
| 349 | }
|
---|
| 350 | if (sb.Length > 0) {
|
---|
| 351 | Console.WriteLine(sb.ToString());
|
---|
| 352 | Assert.Fail("There were differences when validating AutoDiff using numeric differences");
|
---|
| 353 | }
|
---|
| 354 | }
|
---|
| 355 |
|
---|
| 356 | #region helper
|
---|
| 357 |
|
---|
| 358 | private double[] CalculateGradient(string expr, IDataset ds) {
|
---|
| 359 | var eval = new VectorAutoDiffEvaluator();
|
---|
| 360 | var parser = new InfixExpressionParser();
|
---|
| 361 |
|
---|
| 362 | var rows = new int[1];
|
---|
| 363 | var fi = new double[1];
|
---|
| 364 |
|
---|
| 365 | var t = parser.Parse(expr);
|
---|
| 366 | var parameterNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 367 | var jac = new double[1, parameterNodes.Length];
|
---|
| 368 | eval.Evaluate(t, ds, rows, parameterNodes, fi, jac);
|
---|
| 369 |
|
---|
| 370 | var g = new double[parameterNodes.Length];
|
---|
| 371 | for (int i = 0; i < g.Length; i++) g[i] = jac[0, i];
|
---|
| 372 | return g;
|
---|
| 373 | }
|
---|
| 374 |
|
---|
| 375 |
|
---|
| 376 | private double[,] ApproximateGradient(ISymbolicExpressionTree t, Dataset ds, int[] rows, ISymbolicExpressionTreeNode[] parameterNodes,
|
---|
| 377 | SymbolicDataAnalysisExpressionTreeLinearInterpreter eval) {
|
---|
| 378 | var jac = new double[rows.Length, parameterNodes.Length];
|
---|
| 379 | for (int p = 0; p < parameterNodes.Length; p++) {
|
---|
| 380 |
|
---|
| 381 | var x = GetValue(parameterNodes[p]);
|
---|
| 382 | var x_diff = x * 1e-4; // relative change
|
---|
| 383 |
|
---|
| 384 | // calculate output for increased parameter value
|
---|
| 385 | SetValue(parameterNodes[p], x + x_diff / 2);
|
---|
| 386 | var f = eval.GetSymbolicExpressionTreeValues(t, ds, rows).ToArray();
|
---|
| 387 | for (int i = 0; i < rows.Length; i++) {
|
---|
| 388 | jac[i, p] = f[i];
|
---|
| 389 | }
|
---|
| 390 |
|
---|
| 391 | // calculate output for decreased parameter value
|
---|
| 392 | SetValue(parameterNodes[p], x - x_diff / 2);
|
---|
| 393 | f = eval.GetSymbolicExpressionTreeValues(t, ds, rows).ToArray();
|
---|
| 394 | for (int i = 0; i < rows.Length; i++) {
|
---|
| 395 | jac[i, p] -= f[i]; // calc difference (and scale for x_diff)
|
---|
| 396 | jac[i, p] /= x_diff;
|
---|
| 397 | }
|
---|
| 398 |
|
---|
| 399 | // restore original value
|
---|
| 400 | SetValue(parameterNodes[p], x);
|
---|
| 401 | }
|
---|
| 402 | return jac;
|
---|
| 403 | }
|
---|
| 404 |
|
---|
| 405 | private void ApproximateIntervalGradient(ISymbolicExpressionTree t, Dictionary<string, Interval> intervals, ISymbolicExpressionTreeNode[] parameterNodes, IntervalEvaluator eval, out double[] lowerGradient, out double[] upperGradient) {
|
---|
| 406 | lowerGradient = new double[parameterNodes.Length];
|
---|
| 407 | upperGradient = new double[parameterNodes.Length];
|
---|
| 408 |
|
---|
| 409 | for(int p=0;p<parameterNodes.Length;p++) {
|
---|
| 410 | var x = GetValue(parameterNodes[p]);
|
---|
| 411 | var x_diff = x * 1e-4; // relative change
|
---|
| 412 |
|
---|
| 413 | // calculate output for increased parameter value
|
---|
| 414 | SetValue(parameterNodes[p], x + x_diff / 2);
|
---|
| 415 | var r1 = eval.Evaluate(t, intervals);
|
---|
| 416 | lowerGradient[p] = r1.LowerBound;
|
---|
| 417 | upperGradient[p] = r1.UpperBound;
|
---|
| 418 |
|
---|
| 419 | // calculate output for decreased parameter value
|
---|
| 420 | SetValue(parameterNodes[p], x - x_diff / 2);
|
---|
| 421 | var r2 = eval.Evaluate(t, intervals);
|
---|
| 422 | lowerGradient[p] -= r2.LowerBound;
|
---|
| 423 | upperGradient[p] -= r2.UpperBound;
|
---|
| 424 |
|
---|
| 425 | lowerGradient[p] /= x_diff;
|
---|
| 426 | upperGradient[p] /= x_diff;
|
---|
| 427 |
|
---|
| 428 | // restore original value
|
---|
| 429 | SetValue(parameterNodes[p], x);
|
---|
| 430 | }
|
---|
| 431 | }
|
---|
| 432 |
|
---|
| 433 | private void SetValue(ISymbolicExpressionTreeNode node, double v) {
|
---|
| 434 | var varNode = node as VariableTreeNode;
|
---|
| 435 | var constNode = node as ConstantTreeNode;
|
---|
| 436 | if (varNode != null) varNode.Weight = v;
|
---|
| 437 | else if (constNode != null) constNode.Value = v;
|
---|
| 438 | else throw new InvalidProgramException();
|
---|
| 439 | }
|
---|
| 440 |
|
---|
| 441 | private double GetValue(ISymbolicExpressionTreeNode node) {
|
---|
| 442 | var varNode = node as VariableTreeNode;
|
---|
| 443 | var constNode = node as ConstantTreeNode;
|
---|
| 444 | if (varNode != null) return varNode.Weight;
|
---|
| 445 | else if (constNode != null) return constNode.Value;
|
---|
| 446 | throw new InvalidProgramException();
|
---|
| 447 | }
|
---|
| 448 | #endregion
|
---|
| 449 | }
|
---|
| 450 | }
|
---|