[17318] | 1 | using System;
|
---|
| 2 | using System.Collections.Generic;
|
---|
| 3 | using System.Linq;
|
---|
| 4 | using System.Text;
|
---|
| 5 | using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
|
---|
| 6 | using HeuristicLab.Problems.DataAnalysis.Symbolic;
|
---|
| 7 | using HeuristicLab.Random;
|
---|
| 8 | using Microsoft.VisualStudio.TestTools.UnitTesting;
|
---|
| 9 |
|
---|
| 10 | namespace HeuristicLab.Problems.DataAnalysis.Tests {
|
---|
| 11 | [TestClass]
|
---|
| 12 | public class IntervalEvaluatorAutoDiffTest {
|
---|
| 13 | [TestMethod]
|
---|
| 14 | [TestCategory("Problems.DataAnalysis")]
|
---|
| 15 | [TestProperty("Time", "short")]
|
---|
| 16 | public void IntervalEvalutorAutoDiffAdd() {
|
---|
| 17 | var eval = new IntervalEvaluator();
|
---|
| 18 | var parser = new InfixExpressionParser();
|
---|
| 19 | var t = parser.Parse("x + y");
|
---|
| 20 | var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 21 | var intervals = new Dictionary<string, Interval>() {
|
---|
| 22 | { "x", new Interval(1, 2) },
|
---|
| 23 | { "y", new Interval(0, 1) }
|
---|
| 24 | };
|
---|
| 25 | var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
|
---|
| 26 | Assert.AreEqual(1, r.LowerBound);
|
---|
| 27 | Assert.AreEqual(3, r.UpperBound);
|
---|
| 28 |
|
---|
| 29 | Assert.AreEqual(1.0, lg[0]); // x
|
---|
| 30 | Assert.AreEqual(2.0, ug[0]);
|
---|
| 31 | Assert.AreEqual(0.0, lg[1]); // y
|
---|
| 32 | Assert.AreEqual(1.0, ug[1]);
|
---|
| 33 | }
|
---|
| 34 |
|
---|
| 35 | [TestMethod]
|
---|
| 36 | [TestCategory("Problems.DataAnalysis")]
|
---|
| 37 | [TestProperty("Time", "short")]
|
---|
| 38 | public void IntervalEvalutorAutoDiffMul() {
|
---|
| 39 | var eval = new IntervalEvaluator();
|
---|
| 40 | var parser = new InfixExpressionParser();
|
---|
| 41 | var t = parser.Parse("x * y");
|
---|
| 42 | var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 43 | var intervals = new Dictionary<string, Interval>() {
|
---|
| 44 | { "x", new Interval(1, 2) },
|
---|
| 45 | { "y", new Interval(0, 1) }
|
---|
| 46 | };
|
---|
| 47 | var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
|
---|
| 48 | Assert.AreEqual(0, r.LowerBound);
|
---|
| 49 | Assert.AreEqual(2, r.UpperBound);
|
---|
| 50 |
|
---|
| 51 | Assert.AreEqual(0.0, lg[0]); // x
|
---|
| 52 | Assert.AreEqual(2.0, ug[0]);
|
---|
| 53 | Assert.AreEqual(0.0, lg[1]); // y
|
---|
| 54 | Assert.AreEqual(2.0, ug[1]);
|
---|
| 55 | }
|
---|
| 56 |
|
---|
| 57 | [TestMethod]
|
---|
| 58 | [TestCategory("Problems.DataAnalysis")]
|
---|
| 59 | [TestProperty("Time", "short")]
|
---|
| 60 | public void IntervalEvalutorAutoDiffSqr() {
|
---|
| 61 | var eval = new IntervalEvaluator();
|
---|
| 62 | var parser = new InfixExpressionParser();
|
---|
| 63 | var t = parser.Parse("sqr(x)");
|
---|
| 64 | var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 65 | var intervals = new Dictionary<string, Interval>() {
|
---|
| 66 | { "x", new Interval(1, 2) },
|
---|
| 67 | { "y", new Interval(0, 1) }
|
---|
| 68 | };
|
---|
| 69 | var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
|
---|
| 70 | // TODO
|
---|
| 71 | // Assert.AreEqual(XXX, r.LowerBound);
|
---|
| 72 | // Assert.AreEqual(XXX, r.UpperBound);
|
---|
| 73 | //
|
---|
| 74 | // Assert.AreEqual(XXX, lg[0]); // x
|
---|
| 75 | // Assert.AreEqual(XXX, ug[0]);
|
---|
| 76 | //
|
---|
| 77 | // for { "x", new Interval(1, 2) },
|
---|
| 78 | // { "y", new Interval(0, 1) },
|
---|
| 79 | //
|
---|
| 80 | // 0 <> -2,50012500572888E-05 for y in SQR(LOG('y'))
|
---|
| 81 | // 0 <> 2, 49987500573946E-05 for x in SQR(LOG('x'))
|
---|
| 82 | }
|
---|
| 83 |
|
---|
| 84 | [TestMethod]
|
---|
| 85 | [TestCategory("Problems.DataAnalysis")]
|
---|
| 86 | [TestProperty("Time", "short")]
|
---|
| 87 | public void IntervalEvalutorAutoDiffExp() {
|
---|
| 88 | var eval = new IntervalEvaluator();
|
---|
| 89 | var parser = new InfixExpressionParser();
|
---|
| 90 | var t = parser.Parse("exp(x)");
|
---|
| 91 | var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 92 | var intervals = new Dictionary<string, Interval>() {
|
---|
| 93 | { "x", new Interval(1, 2) },
|
---|
| 94 | };
|
---|
| 95 | var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
|
---|
| 96 | Assert.AreEqual(Math.Exp(1), r.LowerBound);
|
---|
| 97 | Assert.AreEqual(Math.Exp(2), r.UpperBound);
|
---|
| 98 |
|
---|
| 99 | Assert.AreEqual(Math.Exp(1), lg[0]); // x
|
---|
| 100 | Assert.AreEqual(Math.Exp(2) * 2, ug[0]);
|
---|
| 101 | }
|
---|
| 102 |
|
---|
| 103 | [TestMethod]
|
---|
| 104 | [TestCategory("Problems.DataAnalysis")]
|
---|
| 105 | [TestProperty("Time", "short")]
|
---|
| 106 | public void IntervalEvalutorAutoDiffSin() {
|
---|
| 107 | var eval = new IntervalEvaluator();
|
---|
| 108 | var parser = new InfixExpressionParser();
|
---|
| 109 | var t = parser.Parse("sin(x)");
|
---|
| 110 | var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 111 | var intervals = new Dictionary<string, Interval>() {
|
---|
| 112 | { "x", new Interval(1, 2) },
|
---|
| 113 | };
|
---|
| 114 | var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
|
---|
| 115 | Assert.AreEqual(Math.Sin(1), r.LowerBound); // sin(1) < sin(2)
|
---|
| 116 | Assert.AreEqual(1, r.UpperBound); // 1..2 crosses pi / 2 and sin(pi/2)==1
|
---|
| 117 |
|
---|
| 118 | Assert.AreEqual(Math.Cos(1), lg[0]); // x
|
---|
| 119 | Assert.AreEqual(0, ug[0]);
|
---|
| 120 | }
|
---|
| 121 |
|
---|
| 122 | [TestMethod]
|
---|
| 123 | [TestCategory("Problems.DataAnalysis")]
|
---|
| 124 | [TestProperty("Time", "short")]
|
---|
| 125 | public void IntervalEvalutorAutoDiffCos() {
|
---|
| 126 | var eval = new IntervalEvaluator();
|
---|
| 127 | var parser = new InfixExpressionParser();
|
---|
| 128 | var t = parser.Parse("cos(x)");
|
---|
| 129 | var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 130 | var intervals = new Dictionary<string, Interval>() {
|
---|
| 131 | { "x", new Interval(3, 4) },
|
---|
| 132 | };
|
---|
| 133 | var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
|
---|
| 134 | Assert.AreEqual(-1, r.LowerBound); // 3..4 crosses pi and cos(pi) == -1
|
---|
| 135 | Assert.AreEqual(Math.Cos(4), r.UpperBound); // cos(3) < cos(4)
|
---|
| 136 |
|
---|
| 137 | Assert.AreEqual(0, lg[0]); // x
|
---|
| 138 | Assert.AreEqual(-4 * Math.Sin(4), ug[0]);
|
---|
| 139 | }
|
---|
| 140 |
|
---|
| 141 | [TestMethod]
|
---|
| 142 | [TestCategory("Problems.DataAnalysis")]
|
---|
| 143 | [TestProperty("Time", "short")]
|
---|
| 144 | public void IntervalEvalutorAutoDiffSqrt() {
|
---|
| 145 | var eval = new IntervalEvaluator();
|
---|
| 146 | var parser = new InfixExpressionParser();
|
---|
| 147 | var t = parser.Parse("sqrt(x)");
|
---|
| 148 | var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 149 | var intervals = new Dictionary<string, Interval>() {
|
---|
| 150 | { "x", new Interval(4, 9) },
|
---|
| 151 | { "y", new Interval(1, 2) },
|
---|
| 152 | { "z", new Interval(0, 1) },
|
---|
[17319] | 153 | { "eps", new Interval(1e-10, 1) }
|
---|
[17318] | 154 | };
|
---|
| 155 | var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
|
---|
| 156 | Assert.AreEqual(2, r.LowerBound);
|
---|
| 157 | Assert.AreEqual(3, r.UpperBound);
|
---|
| 158 |
|
---|
| 159 | Assert.AreEqual(1.0, lg[0]); // x
|
---|
| 160 | Assert.AreEqual(1.5, ug[0]);
|
---|
| 161 |
|
---|
| 162 | t = parser.Parse("sqrt(y)");
|
---|
| 163 | paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 164 | r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
|
---|
| 165 | Assert.AreEqual(1, r.LowerBound);
|
---|
| 166 | Assert.AreEqual(Math.Sqrt(2), r.UpperBound);
|
---|
| 167 |
|
---|
| 168 | Assert.AreEqual(0.5, lg[0]); // y
|
---|
| 169 | Assert.AreEqual(0.5 * Math.Sqrt(2), ug[0], 1e-5);
|
---|
| 170 |
|
---|
| 171 | t = parser.Parse("sqrt(z)");
|
---|
| 172 | paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 173 | r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
|
---|
| 174 | Assert.AreEqual(0, r.LowerBound);
|
---|
| 175 | Assert.AreEqual(1, r.UpperBound);
|
---|
| 176 |
|
---|
[17319] | 177 | Assert.AreEqual(double.NaN, lg[0]); // z
|
---|
| 178 | Assert.AreEqual(0.5, ug[0], 1e-5);
|
---|
| 179 |
|
---|
| 180 | t = parser.Parse("sqrt(eps)");
|
---|
| 181 | paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 182 | r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
|
---|
| 183 |
|
---|
| 184 | Assert.AreEqual(0.5 * Math.Sqrt(1e-10), lg[0], 1e-6); // z --> lim x -> 0 (sqrt(x)) = 0
|
---|
| 185 | Assert.AreEqual(0.5, ug[0], 1e-5);
|
---|
[17318] | 186 | }
|
---|
| 187 |
|
---|
| 188 | [TestMethod]
|
---|
| 189 | [TestCategory("Problems.DataAnalysis")]
|
---|
| 190 | [TestProperty("Time", "short")]
|
---|
| 191 | public void IntervalEvalutorAutoDiffCqrt() {
|
---|
| 192 | var eval = new IntervalEvaluator();
|
---|
| 193 | var parser = new InfixExpressionParser();
|
---|
| 194 | var t = parser.Parse("cuberoot(x)");
|
---|
| 195 | var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 196 | var intervals = new Dictionary<string, Interval>() {
|
---|
| 197 | { "x", new Interval(8, 27) },
|
---|
| 198 | { "y", new Interval(1, 2) },
|
---|
| 199 | { "z", new Interval(0, 1) },
|
---|
| 200 | };
|
---|
| 201 | var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
|
---|
| 202 | Assert.AreEqual(2, r.LowerBound);
|
---|
| 203 | Assert.AreEqual(3, r.UpperBound);
|
---|
| 204 |
|
---|
| 205 | Assert.AreEqual(2.0 / 3.0, lg[0]); // x
|
---|
| 206 | Assert.AreEqual(1.0, ug[0]);
|
---|
| 207 |
|
---|
| 208 | t = parser.Parse("cuberoot(y)");
|
---|
| 209 | paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 210 | r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
|
---|
| 211 | Assert.AreEqual(Math.Pow(1, 1.0 / 3.0), r.LowerBound);
|
---|
| 212 | Assert.AreEqual(Math.Pow(2, 1.0 / 3.0), r.UpperBound);
|
---|
| 213 |
|
---|
| 214 | Assert.AreEqual(1.0 / 3.0, lg[0]); // y
|
---|
| 215 | Assert.AreEqual(1.0 / 3.0 * Math.Pow(2, 1.0 / 3.0), ug[0], 1e-5);
|
---|
| 216 |
|
---|
[17319] | 217 | t = parser.Parse("cuberoot(z)");
|
---|
[17318] | 218 | paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 219 | r = eval.Evaluate(t, intervals, paramNodes, out lg, out ug);
|
---|
| 220 | Assert.AreEqual(0.0, r.LowerBound);
|
---|
| 221 | Assert.AreEqual(1.0, r.UpperBound);
|
---|
| 222 |
|
---|
[17319] | 223 | Assert.AreEqual(double.NaN, lg[0]); // z
|
---|
[17318] | 224 | Assert.AreEqual(1.0 / 3.0, ug[0], 1e-5);
|
---|
| 225 | }
|
---|
| 226 |
|
---|
| 227 | [TestMethod]
|
---|
| 228 | [TestCategory("Problems.DataAnalysis")]
|
---|
| 229 | [TestProperty("Time", "short")]
|
---|
| 230 | public void IntervalEvalutorAutoDiffLog() {
|
---|
| 231 | var eval = new IntervalEvaluator();
|
---|
| 232 | var parser = new InfixExpressionParser();
|
---|
| 233 | var t = parser.Parse("log(4*x)");
|
---|
| 234 | var paramNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 235 | var intervals = new Dictionary<string, Interval>() {
|
---|
| 236 | { "x", new Interval(1, 2) },
|
---|
| 237 | };
|
---|
| 238 | var r = eval.Evaluate(t, intervals, paramNodes, out double[] lg, out double[] ug);
|
---|
| 239 | Assert.AreEqual(Math.Log(4), r.LowerBound);
|
---|
| 240 | Assert.AreEqual(Math.Log(8), r.UpperBound);
|
---|
| 241 |
|
---|
| 242 | Assert.AreEqual(0.25, lg[0]); // x
|
---|
| 243 | Assert.AreEqual(0.25, ug[0]);
|
---|
| 244 |
|
---|
| 245 | }
|
---|
| 246 |
|
---|
| 247 | [TestMethod]
|
---|
| 248 | [TestCategory("Problems.DataAnalysis")]
|
---|
| 249 | [TestProperty("Time", "short")]
|
---|
| 250 | public void IntervalEvaluatorAutoDiffCompareWithNumericDifferences() {
|
---|
| 251 |
|
---|
| 252 | // create random trees and evaluate on random data
|
---|
| 253 | // calc gradient for all parameters
|
---|
| 254 | // use numeric differences for approximate gradient calculation
|
---|
| 255 | // compare gradients
|
---|
| 256 |
|
---|
| 257 | var grammar = new TypeCoherentExpressionGrammar();
|
---|
| 258 | grammar.ConfigureAsDefaultRegressionGrammar();
|
---|
| 259 | // activate supported symbols
|
---|
| 260 | grammar.Symbols.First(s => s is Square).Enabled = true;
|
---|
| 261 | grammar.Symbols.First(s => s is SquareRoot).Enabled = true;
|
---|
| 262 | grammar.Symbols.First(s => s is Cube).Enabled = true;
|
---|
| 263 | grammar.Symbols.First(s => s is CubeRoot).Enabled = true;
|
---|
| 264 | grammar.Symbols.First(s => s is Sine).Enabled = true;
|
---|
| 265 | grammar.Symbols.First(s => s is Cosine).Enabled = true;
|
---|
| 266 | grammar.Symbols.First(s => s is Exponential).Enabled = true;
|
---|
| 267 | grammar.Symbols.First(s => s is Logarithm).Enabled = true;
|
---|
| 268 | grammar.Symbols.First(s => s is Absolute).Enabled = true;
|
---|
| 269 | grammar.Symbols.First(s => s is AnalyticQuotient).Enabled = false; // not yet supported by old interval calculator
|
---|
| 270 | grammar.Symbols.First(s => s is Constant).Enabled = false;
|
---|
| 271 |
|
---|
| 272 | var varSy = (Variable)grammar.Symbols.First(s => s is Variable);
|
---|
| 273 | varSy.AllVariableNames = new string[] { "x", "y" };
|
---|
| 274 | varSy.VariableNames = varSy.AllVariableNames;
|
---|
| 275 | varSy.WeightMu = 1.0;
|
---|
| 276 | varSy.WeightSigma = 0.0;
|
---|
| 277 | var rand = new FastRandom(1234);
|
---|
| 278 |
|
---|
| 279 | var intervals = new Dictionary<string, Interval>() {
|
---|
| 280 | { "x", new Interval(1, 2) },
|
---|
| 281 | { "y", new Interval(0, 1) },
|
---|
| 282 | };
|
---|
| 283 |
|
---|
| 284 | var eval = new IntervalEvaluator();
|
---|
| 285 |
|
---|
| 286 | var formatter = new InfixExpressionFormatter();
|
---|
| 287 | var sb = new StringBuilder();
|
---|
| 288 | int N = 10000;
|
---|
| 289 | int iter = 0;
|
---|
| 290 | while (iter < N) {
|
---|
| 291 | var t = ProbabilisticTreeCreator.Create(rand, grammar, maxTreeLength: 5, maxTreeDepth: 5);
|
---|
| 292 | var parameterNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 293 |
|
---|
| 294 | eval.Evaluate(t, intervals, parameterNodes, out double[] lowerGradient, out double[] upperGradient);
|
---|
| 295 |
|
---|
| 296 | ApproximateIntervalGradient(t, intervals, parameterNodes, eval, out double[] refLowerGradient, out double[] refUpperGradient);
|
---|
| 297 |
|
---|
| 298 | // compare autodiff and numeric diff
|
---|
| 299 | for(int p=0;p<parameterNodes.Length;p++) {
|
---|
| 300 | // lower
|
---|
| 301 | if(double.IsNaN(lowerGradient[p]) && double.IsNaN(refLowerGradient[p])) {
|
---|
| 302 |
|
---|
| 303 | } else if(lowerGradient[p] == refLowerGradient[p]){
|
---|
| 304 |
|
---|
| 305 | } else if(Math.Abs(lowerGradient[p] - refLowerGradient[p]) <= Math.Abs(lowerGradient[p]) * 1e-4) {
|
---|
| 306 |
|
---|
| 307 | } else {
|
---|
| 308 | sb.AppendLine($"{lowerGradient[p]} <> {refLowerGradient[p]} for {parameterNodes[p]} in {formatter.Format(t)}");
|
---|
| 309 | }
|
---|
| 310 | // upper
|
---|
| 311 | if (double.IsNaN(upperGradient[p]) && double.IsNaN(refUpperGradient[p])) {
|
---|
| 312 |
|
---|
| 313 | } else if (upperGradient[p] == refUpperGradient[p]) {
|
---|
| 314 |
|
---|
| 315 | } else if (Math.Abs(upperGradient[p] - refUpperGradient[p]) <= Math.Abs(upperGradient[p]) * 1e-4) {
|
---|
| 316 |
|
---|
| 317 | } else {
|
---|
| 318 | sb.AppendLine($"{upperGradient[p]} <> {refUpperGradient[p]} for {parameterNodes[p]} in {formatter.Format(t)}");
|
---|
| 319 | }
|
---|
| 320 | }
|
---|
| 321 |
|
---|
| 322 | iter++;
|
---|
| 323 | }
|
---|
| 324 | if (sb.Length > 0) {
|
---|
| 325 | Console.WriteLine(sb.ToString());
|
---|
| 326 | Assert.Fail("There were differences when validating AutoDiff using numeric differences");
|
---|
| 327 | }
|
---|
| 328 | }
|
---|
| 329 |
|
---|
| 330 | #region helper
|
---|
| 331 |
|
---|
| 332 | private double[] CalculateGradient(string expr, IDataset ds) {
|
---|
| 333 | var eval = new VectorAutoDiffEvaluator();
|
---|
| 334 | var parser = new InfixExpressionParser();
|
---|
| 335 |
|
---|
| 336 | var rows = new int[1];
|
---|
| 337 | var fi = new double[1];
|
---|
| 338 |
|
---|
| 339 | var t = parser.Parse(expr);
|
---|
| 340 | var parameterNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
| 341 | var jac = new double[1, parameterNodes.Length];
|
---|
| 342 | eval.Evaluate(t, ds, rows, parameterNodes, fi, jac);
|
---|
| 343 |
|
---|
| 344 | var g = new double[parameterNodes.Length];
|
---|
| 345 | for (int i = 0; i < g.Length; i++) g[i] = jac[0, i];
|
---|
| 346 | return g;
|
---|
| 347 | }
|
---|
| 348 |
|
---|
| 349 |
|
---|
| 350 | private double[,] ApproximateGradient(ISymbolicExpressionTree t, Dataset ds, int[] rows, ISymbolicExpressionTreeNode[] parameterNodes,
|
---|
| 351 | SymbolicDataAnalysisExpressionTreeLinearInterpreter eval) {
|
---|
| 352 | var jac = new double[rows.Length, parameterNodes.Length];
|
---|
| 353 | for (int p = 0; p < parameterNodes.Length; p++) {
|
---|
| 354 |
|
---|
| 355 | var x = GetValue(parameterNodes[p]);
|
---|
| 356 | var x_diff = x * 1e-4; // relative change
|
---|
| 357 |
|
---|
| 358 | // calculate output for increased parameter value
|
---|
| 359 | SetValue(parameterNodes[p], x + x_diff / 2);
|
---|
| 360 | var f = eval.GetSymbolicExpressionTreeValues(t, ds, rows).ToArray();
|
---|
| 361 | for (int i = 0; i < rows.Length; i++) {
|
---|
| 362 | jac[i, p] = f[i];
|
---|
| 363 | }
|
---|
| 364 |
|
---|
| 365 | // calculate output for decreased parameter value
|
---|
| 366 | SetValue(parameterNodes[p], x - x_diff / 2);
|
---|
| 367 | f = eval.GetSymbolicExpressionTreeValues(t, ds, rows).ToArray();
|
---|
| 368 | for (int i = 0; i < rows.Length; i++) {
|
---|
| 369 | jac[i, p] -= f[i]; // calc difference (and scale for x_diff)
|
---|
| 370 | jac[i, p] /= x_diff;
|
---|
| 371 | }
|
---|
| 372 |
|
---|
| 373 | // restore original value
|
---|
| 374 | SetValue(parameterNodes[p], x);
|
---|
| 375 | }
|
---|
| 376 | return jac;
|
---|
| 377 | }
|
---|
| 378 |
|
---|
| 379 | private void ApproximateIntervalGradient(ISymbolicExpressionTree t, Dictionary<string, Interval> intervals, ISymbolicExpressionTreeNode[] parameterNodes, IntervalEvaluator eval, out double[] lowerGradient, out double[] upperGradient) {
|
---|
| 380 | lowerGradient = new double[parameterNodes.Length];
|
---|
| 381 | upperGradient = new double[parameterNodes.Length];
|
---|
| 382 |
|
---|
| 383 | for(int p=0;p<parameterNodes.Length;p++) {
|
---|
| 384 | var x = GetValue(parameterNodes[p]);
|
---|
| 385 | var x_diff = x * 1e-4; // relative change
|
---|
| 386 |
|
---|
| 387 | // calculate output for increased parameter value
|
---|
| 388 | SetValue(parameterNodes[p], x + x_diff / 2);
|
---|
| 389 | var r1 = eval.Evaluate(t, intervals);
|
---|
| 390 | lowerGradient[p] = r1.LowerBound;
|
---|
| 391 | upperGradient[p] = r1.UpperBound;
|
---|
| 392 |
|
---|
| 393 | // calculate output for decreased parameter value
|
---|
| 394 | SetValue(parameterNodes[p], x - x_diff / 2);
|
---|
| 395 | var r2 = eval.Evaluate(t, intervals);
|
---|
| 396 | lowerGradient[p] -= r2.LowerBound;
|
---|
| 397 | upperGradient[p] -= r2.UpperBound;
|
---|
| 398 |
|
---|
| 399 | lowerGradient[p] /= x_diff;
|
---|
| 400 | upperGradient[p] /= x_diff;
|
---|
| 401 |
|
---|
| 402 | // restore original value
|
---|
| 403 | SetValue(parameterNodes[p], x);
|
---|
| 404 | }
|
---|
| 405 | }
|
---|
| 406 |
|
---|
| 407 | private void SetValue(ISymbolicExpressionTreeNode node, double v) {
|
---|
| 408 | var varNode = node as VariableTreeNode;
|
---|
| 409 | var constNode = node as ConstantTreeNode;
|
---|
| 410 | if (varNode != null) varNode.Weight = v;
|
---|
| 411 | else if (constNode != null) constNode.Value = v;
|
---|
| 412 | else throw new InvalidProgramException();
|
---|
| 413 | }
|
---|
| 414 |
|
---|
| 415 | private double GetValue(ISymbolicExpressionTreeNode node) {
|
---|
| 416 | var varNode = node as VariableTreeNode;
|
---|
| 417 | var constNode = node as ConstantTreeNode;
|
---|
| 418 | if (varNode != null) return varNode.Weight;
|
---|
| 419 | else if (constNode != null) return constNode.Value;
|
---|
| 420 | throw new InvalidProgramException();
|
---|
| 421 | }
|
---|
| 422 | #endregion
|
---|
| 423 | }
|
---|
| 424 | }
|
---|