source: branches/3087_Ceres_Integration/HeuristicLab.Tests/HeuristicLab.Problems.DataAnalysis.Symbolic-3.4/SymbolicDataAnalysisExpressionTreeInterpreterTest.cs @ 18012

Last change on this file since 18012 was 18012, checked in by gkronber, 2 months ago

#3087 fixed compile error in unit tests

File size: 37.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Globalization;
25using System.Linq;
26using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
27using HeuristicLab.NativeInterpreter;
28using HeuristicLab.Random;
29using Microsoft.VisualStudio.TestTools.UnitTesting;
30namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Tests {
31
32
33  [TestClass]
34  public class SymbolicDataAnalysisExpressionTreeInterpreterTest {
35    private const int N = 1000;
36    private const int Rows = 1000;
37    private const int Columns = 50;
38
39    private static Dataset ds = new Dataset(new string[] { "Y", "A", "B" }, new double[,] {
40        { 1.0, 1.0, 1.0 },
41        { 2.0, 2.0, 2.0 },
42        { 3.0, 1.0, 2.0 },
43        { 4.0, 1.0, 1.0 },
44        { 5.0, 2.0, 2.0 },
45        { 6.0, 1.0, 2.0 },
46        { 7.0, 1.0, 1.0 },
47        { 8.0, 2.0, 2.0 },
48        { 9.0, 1.0, 2.0 },
49        { 10.0, 1.0, 1.0 },
50        { 11.0, 2.0, 2.0 },
51        { 12.0, 1.0, 2.0 }
52      });
53
54    [TestMethod]
55    [TestCategory("Problems.DataAnalysis.Symbolic")]
56    [TestProperty("Time", "long")]
57    public void StandardInterpreterTestTypeCoherentGrammarPerformance() {
58      TestTypeCoherentGrammarPerformance(new SymbolicDataAnalysisExpressionTreeInterpreter(), 12.5e6);
59    }
60    [TestMethod]
61    [TestCategory("Problems.DataAnalysis.Symbolic")]
62    [TestProperty("Time", "long")]
63    public void StandardInterpreterTestFullGrammarPerformance() {
64      TestFullGrammarPerformance(new SymbolicDataAnalysisExpressionTreeInterpreter(), 12.5e6);
65    }
66    [TestMethod]
67    [TestCategory("Problems.DataAnalysis.Symbolic")]
68    [TestProperty("Time", "long")]
69    public void StandardInterpreterTestArithmeticGrammarPerformance() {
70      TestArithmeticGrammarPerformance(new SymbolicDataAnalysisExpressionTreeInterpreter(), 12.5e6);
71    }
72
73    [TestMethod]
74    [TestCategory("Problems.DataAnalysis.Symbolic")]
75    [TestProperty("Time", "long")]
76    public void CompiledInterpreterTestTypeCoherentGrammarPerformance() {
77      TestTypeCoherentGrammarPerformance(new SymbolicDataAnalysisExpressionCompiledTreeInterpreter(), 12.5e6);
78    }
79    [TestMethod]
80    [TestCategory("Problems.DataAnalysis.Symbolic")]
81    [TestProperty("Time", "long")]
82    public void CompiledInterpreterTestFullGrammarPerformance() {
83      TestFullGrammarPerformance(new SymbolicDataAnalysisExpressionCompiledTreeInterpreter(), 12.5e6);
84    }
85    [TestMethod]
86    [TestCategory("Problems.DataAnalysis.Symbolic")]
87    [TestProperty("Time", "long")]
88    public void CompiledInterpreterTestArithmeticGrammarPerformance() {
89      TestArithmeticGrammarPerformance(new SymbolicDataAnalysisExpressionCompiledTreeInterpreter(), 12.5e6);
90    }
91
92    [TestMethod]
93    [TestCategory("Problems.DataAnalysis.Symbolic")]
94    [TestProperty("Time", "long")]
95    public void ILEmittingInterpreterTestTypeCoherentGrammarPerformance() {
96      TestTypeCoherentGrammarPerformance(new SymbolicDataAnalysisExpressionTreeILEmittingInterpreter(), 7.5e6);
97    }
98    [TestMethod]
99    [TestCategory("Problems.DataAnalysis.Symbolic")]
100    [TestProperty("Time", "long")]
101    public void ILEmittingInterpreterTestArithmeticGrammarPerformance() {
102      TestArithmeticGrammarPerformance(new SymbolicDataAnalysisExpressionTreeILEmittingInterpreter(), 7.5e6);
103    }
104
105    [TestMethod]
106    [TestCategory("Problems.DataAnalysis.Symbolic")]
107    [TestProperty("Time", "long")]
108    public void LinearInterpreterTestTypeCoherentGrammarPerformance() {
109      TestTypeCoherentGrammarPerformance(new SymbolicDataAnalysisExpressionTreeLinearInterpreter(), 12.5e6);
110    }
111    [TestMethod]
112    [TestCategory("Problems.DataAnalysis.Symbolic")]
113    [TestProperty("Time", "long")]
114    public void LinearInterpreterTestFullGrammarPerformance() {
115      TestFullGrammarPerformance(new SymbolicDataAnalysisExpressionTreeLinearInterpreter(), 12.5e6);
116    }
117    [TestMethod]
118    [TestCategory("Problems.DataAnalysis.Symbolic")]
119    [TestProperty("Time", "long")]
120    public void LinearInterpreterTestArithmeticGrammarPerformance() {
121      TestArithmeticGrammarPerformance(new SymbolicDataAnalysisExpressionTreeLinearInterpreter(), 12.5e6);
122    }
123
124    [TestMethod]
125    [TestCategory("Problems.DataAnalysis.Symbolic")]
126    [TestProperty("Time", "long")]
127    public void NativeInterpreterTestTypeCoherentGrammarPerformance() {
128      TestTypeCoherentGrammarPerformance(new NativeInterpreter(), 12.5e6);
129    }
130    [TestMethod]
131    [TestCategory("Problems.DataAnalysis.Symbolic")]
132    [TestProperty("Time", "long")]
133    public void NativeInterpreterTestFullGrammarPerformance() {
134      TestFullGrammarPerformance(new NativeInterpreter(), 12.5e6);
135    }
136    [TestMethod]
137    [TestCategory("Problems.DataAnalysis.Symbolic")]
138    [TestProperty("Time", "long")]
139    public void NativeInterpreterTestArithmeticGrammarPerformance() {
140      TestArithmeticGrammarPerformance(new NativeInterpreter(), 12.5e6);
141    }
142
143    [TestMethod]
144    [TestCategory("Problems.DataAnalysis.Symbolic")]
145    [TestProperty("Time", "long")]
146    public void NativeInterpreterTestCeres() {
147      var parser = new InfixExpressionParser();
148      var random = new FastRandom(1234);
149      const int nRows = 20;
150
151      var x1 = Enumerable.Range(0, nRows).Select(_ => UniformDistributedRandom.NextDouble(random, -1, 1)).ToArray();
152      var x2 = Enumerable.Range(0, nRows).Select(_ => UniformDistributedRandom.NextDouble(random, -1, 1)).ToArray();
153      var x3 = Enumerable.Range(0, nRows).Select(_ => UniformDistributedRandom.NextDouble(random, -1, 1)).ToArray();
154
155      var optimalAlpha = new double[] { -2, -3, -5 };
156      var y = Enumerable.Range(0, nRows).Select(i =>
157          Math.Exp(x1[i] * optimalAlpha[0]) +
158          Math.Exp(x2[i] * optimalAlpha[1]) +
159          Math.Exp(x3[i] * optimalAlpha[2])).ToArray();
160
161      var initialAlpha = Enumerable.Range(0, 3).Select(_ => UniformDistributedRandom.NextDouble(random, -1, 1)).ToArray();
162      var ds = new Dataset(new[] { "x1", "x2", "x3", "y" }, new[] { x1, x2, x3, y });
163
164      var expr = "EXP(x1) + EXP(x2) + EXP(x3)";
165      var tree = parser.Parse(expr);
166      var rows = Enumerable.Range(0, nRows).ToArray();
167      var options = new SolverOptions {
168        Minimizer = CeresTypes.Minimizer.TRUST_REGION,
169        Iterations = 20,
170        TrustRegionStrategy = CeresTypes.TrustRegionStrategy.LEVENBERG_MARQUARDT,
171        LinearSolver = CeresTypes.LinearSolver.DENSE_QR
172      };
173
174      var nodesToOptimize = new HashSet<ISymbolicExpressionTreeNode>(tree.IterateNodesPrefix().Where(x => x is VariableTreeNode));
175      int idx = 0;
176      foreach(var node in nodesToOptimize) {
177        (node as VariableTreeNode).Weight = initialAlpha[idx++];
178        Console.WriteLine((node as VariableTreeNode).Weight);
179
180      }
181
182      var summary = new OptimizationSummary();
183      var parameters = ParameterOptimizer.OptimizeTree(tree, ds, rows, "y", nodesToOptimize, options, ref summary);
184
185      Console.Write("Optimized parameters: ");
186      foreach (var t in parameters) {
187        Console.Write(t.Value + " ");
188      }
189      Console.WriteLine();
190
191      Console.WriteLine("Optimization summary:");
192      Console.WriteLine("Initial cost:         " + summary.InitialCost);
193      Console.WriteLine("Final cost:           " + summary.FinalCost);
194      Console.WriteLine("Successful steps:     " + summary.SuccessfulSteps);
195      Console.WriteLine("Unsuccessful steps:   " + summary.UnsuccessfulSteps);
196      Console.WriteLine("Residual evaluations: " + summary.ResidualEvaluations);
197      Console.WriteLine("Jacobian evaluations: " + summary.JacobianEvaluations);
198    }
199
200    [TestMethod]
201    [TestCategory("Problems.DataAnalysis.Symbolic")]
202    [TestProperty("Time", "long")]
203    public void NativeInterpreterTestCeresVariableProjection() {
204      var parser = new InfixExpressionParser();
205      var random = new FastRandom(1234);
206      const int nRows = 20;
207
208      var x1 = Enumerable.Range(0, nRows).Select(_ => UniformDistributedRandom.NextDouble(random, -1, 1)).ToArray();
209      var x2 = Enumerable.Range(0, nRows).Select(_ => UniformDistributedRandom.NextDouble(random, -1, 1)).ToArray();
210      var x3 = Enumerable.Range(0, nRows).Select(_ => UniformDistributedRandom.NextDouble(random, -1, 1)).ToArray();
211
212      var optimalAlpha = new double[] { -2, -3, -5 };
213      var y = Enumerable.Range(0, nRows).Select(i =>
214        Math.Exp(x1[i] * optimalAlpha[0]) +
215        Math.Exp(x2[i] * optimalAlpha[1]) +
216        Math.Exp(x3[i] * optimalAlpha[2])).ToArray();
217
218      var initialAlpha = Enumerable.Range(0, 3).Select(_ => UniformDistributedRandom.NextDouble(random, -1, 1)).ToArray();
219      var ds = new Dataset(new[] { "x1", "x2", "x3", "y" }, new[] { x1, x2, x3, y });
220
221      var expr = new[] { "EXP(x1)", "EXP(x2)", "EXP(x3)" };
222      var trees = expr.Select(x => parser.Parse(x)).ToArray();
223      var rows = Enumerable.Range(0, nRows).ToArray();
224      var options = new SolverOptions {
225        Minimizer = CeresTypes.Minimizer.TRUST_REGION,
226        Iterations = 100,
227        TrustRegionStrategy = CeresTypes.TrustRegionStrategy.LEVENBERG_MARQUARDT,
228        LinearSolver = CeresTypes.LinearSolver.DENSE_QR
229      };
230
231      var summary = new OptimizationSummary();
232
233      var nodesToOptimize = new HashSet<ISymbolicExpressionTreeNode>(trees.SelectMany(t => t.IterateNodesPrefix().Where(x => x is VariableTreeNode)));
234      int idx = 0;
235      Console.Write("Initial parameters: ");
236      foreach (var node in nodesToOptimize) {
237        (node as VariableTreeNode).Weight = initialAlpha[idx++];
238        Console.Write((node as VariableTreeNode).Weight + " ");
239      }
240      Console.WriteLine();
241
242      var coeff = new double[trees.Length + 1];
243      var parameters = ParameterOptimizer.OptimizeTree(trees, ds, rows, "y", nodesToOptimize, options, coeff, ref summary);
244      Console.Write("Optimized parameters: ");
245      foreach (var t in parameters) {
246        Console.Write(t.Value + " ");
247      }
248      Console.WriteLine();
249
250      Console.Write("Coefficients: ");
251      foreach (var v in coeff) Console.Write(v + " ");
252      Console.WriteLine();
253
254      Console.WriteLine("Optimization summary:");
255      Console.WriteLine("Initial cost:         " + summary.InitialCost);
256      Console.WriteLine("Final cost:           " + summary.FinalCost);
257      Console.WriteLine("Successful steps:     " + summary.SuccessfulSteps);
258      Console.WriteLine("Unsuccessful steps:   " + summary.UnsuccessfulSteps);
259      Console.WriteLine("Residual evaluations: " + summary.ResidualEvaluations);
260      Console.WriteLine("Jacobian evaluations: " + summary.JacobianEvaluations);
261    }
262
263    [TestMethod]
264    [TestCategory("Problems.DataAnalysis.Symbolic")]
265    [TestProperty("Time", "long")]
266    public void BatchInterpreterTestTypeCoherentGrammarPerformance() {
267      TestTypeCoherentGrammarPerformance(new SymbolicDataAnalysisExpressionTreeBatchInterpreter(), 12.5e6);
268    }
269    [TestMethod]
270    [TestCategory("Problems.DataAnalysis.Symbolic")]
271    [TestProperty("Time", "long")]
272    public void BatchInterpreterTestArithmeticGrammarPerformance() {
273      TestArithmeticGrammarPerformance(new SymbolicDataAnalysisExpressionTreeBatchInterpreter(), 12.5e6);
274    }
275
276    private void TestTypeCoherentGrammarPerformance(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, double nodesPerSecThreshold) {
277      var twister = new MersenneTwister(31415);
278      var dataset = Util.CreateRandomDataset(twister, Rows, Columns);
279
280      var grammar = new TypeCoherentExpressionGrammar();
281      grammar.ConfigureAsDefaultRegressionGrammar();
282
283      var randomTrees = Util.CreateRandomTrees(twister, dataset, grammar, N, 1, 100, 0, 0);
284      foreach (ISymbolicExpressionTree tree in randomTrees) {
285        Util.InitTree(tree, twister, new List<string>(dataset.VariableNames));
286      }
287      double nodesPerSec = Util.CalculateEvaluatedNodesPerSec(randomTrees, interpreter, dataset, 3);
288      //mkommend: commented due to performance issues on the builder
289      // Assert.IsTrue(nodesPerSec > nodesPerSecThreshold); // evaluated nodes per seconds must be larger than 15mNodes/sec
290    }
291
292    private void TestFullGrammarPerformance(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, double nodesPerSecThreshold) {
293      var twister = new MersenneTwister(31415);
294      var dataset = Util.CreateRandomDataset(twister, Rows, Columns);
295
296      var grammar = new FullFunctionalExpressionGrammar();
297      var randomTrees = Util.CreateRandomTrees(twister, dataset, grammar, N, 1, 100, 0, 0);
298      foreach (ISymbolicExpressionTree tree in randomTrees) {
299        Util.InitTree(tree, twister, new List<string>(dataset.VariableNames));
300      }
301      double nodesPerSec = Util.CalculateEvaluatedNodesPerSec(randomTrees, interpreter, dataset, 3);
302      //mkommend: commented due to performance issues on the builder
303      //Assert.IsTrue(nodesPerSec > nodesPerSecThreshold); // evaluated nodes per seconds must be larger than 15mNodes/sec
304    }
305
306    private void TestArithmeticGrammarPerformance(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, double nodesPerSecThreshold) {
307      var twister = new MersenneTwister(31415);
308      var dataset = Util.CreateRandomDataset(twister, Rows, Columns);
309
310      var grammar = new ArithmeticExpressionGrammar();
311      var randomTrees = Util.CreateRandomTrees(twister, dataset, grammar, N, 1, 100, 0, 0);
312      foreach (SymbolicExpressionTree tree in randomTrees) {
313        Util.InitTree(tree, twister, new List<string>(dataset.VariableNames));
314      }
315
316      double nodesPerSec = Util.CalculateEvaluatedNodesPerSec(randomTrees, interpreter, dataset, 3);
317      //mkommend: commented due to performance issues on the builder
318      //Assert.IsTrue(nodesPerSec > nodesPerSecThreshold); // evaluated nodes per seconds must be larger than 15mNodes/sec
319    }
320
321
322    /// <summary>
323    ///A test for Evaluate
324    ///</summary>
325    [TestMethod]
326    [TestCategory("Problems.DataAnalysis.Symbolic")]
327    [TestProperty("Time", "short")]
328    public void StandardInterpreterTestEvaluation() {
329      var interpreter = new SymbolicDataAnalysisExpressionTreeInterpreter();
330      EvaluateTerminals(interpreter, ds);
331      EvaluateOperations(interpreter, ds);
332      EvaluateLaggedOperations(interpreter, ds);
333      EvaluateSpecialFunctions(interpreter, ds);
334      EvaluateAdf(interpreter, ds);
335    }
336
337    /// <summary>
338    ///A test for Evaluate
339    ///</summary>
340    [TestMethod]
341    [TestCategory("Problems.DataAnalysis.Symbolic")]
342    [TestProperty("Time", "short")]
343    public void ILEmittingInterpreterTestEvaluation() {
344      var interpreter = new SymbolicDataAnalysisExpressionTreeILEmittingInterpreter();
345      EvaluateTerminals(interpreter, ds);
346      EvaluateOperations(interpreter, ds);
347      EvaluateLaggedOperations(interpreter, ds);
348      EvaluateSpecialFunctions(interpreter, ds);
349    }
350
351    [TestMethod]
352    [TestCategory("Problems.DataAnalysis.Symbolic")]
353    [TestProperty("Time", "short")]
354    public void CompiledInterpreterTestEvaluation() {
355      var interpreter = new SymbolicDataAnalysisExpressionCompiledTreeInterpreter();
356      EvaluateTerminals(interpreter, ds);
357      EvaluateOperations(interpreter, ds);
358      EvaluateSpecialFunctions(interpreter, ds);
359    }
360
361    [TestMethod]
362    [TestCategory("Problems.DataAnalysis.Symbolic")]
363    [TestProperty("Time", "short")]
364    public void LinearInterpreterTestEvaluation() {
365      var interpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
366      //ADFs are not supported by the linear interpreter
367      EvaluateTerminals(interpreter, ds);
368      EvaluateOperations(interpreter, ds);
369      EvaluateLaggedOperations(interpreter, ds);
370      EvaluateSpecialFunctions(interpreter, ds);
371    }
372
373    [TestMethod]
374    [TestCategory("Problems.DataAnalysis.Symbolic")]
375    [TestProperty("Time", "long")]
376    public void TestInterpretersEstimatedValuesConsistency() {
377      var twister = new MersenneTwister();
378      int seed = twister.Next(0, int.MaxValue);
379      twister.Seed((uint)seed);
380      const int numRows = 100;
381      var dataset = Util.CreateRandomDataset(twister, numRows, Columns);
382
383      var grammar = new TypeCoherentExpressionGrammar();
384
385      var interpreters = new ISymbolicDataAnalysisExpressionTreeInterpreter[] {
386        new SymbolicDataAnalysisExpressionTreeLinearInterpreter(),
387        new SymbolicDataAnalysisExpressionTreeInterpreter(),
388      };
389
390      var rows = Enumerable.Range(0, numRows).ToList();
391      var randomTrees = Util.CreateRandomTrees(twister, dataset, grammar, N, 1, 10, 0, 0);
392      foreach (ISymbolicExpressionTree tree in randomTrees) {
393        Util.InitTree(tree, twister, new List<string>(dataset.VariableNames));
394      }
395
396      for (int i = 0; i < randomTrees.Length; ++i) {
397        var tree = randomTrees[i];
398        var valuesMatrix = interpreters.Select(x => x.GetSymbolicExpressionTreeValues(tree, dataset, rows)).ToList();
399        for (int m = 0; m < interpreters.Length - 1; ++m) {
400          var sum = valuesMatrix[m].Sum();
401          for (int n = m + 1; n < interpreters.Length; ++n) {
402            var s = valuesMatrix[n].Sum();
403            if (double.IsNaN(sum) && double.IsNaN(s)) continue;
404
405            string errorMessage = string.Format("Interpreters {0} and {1} do not agree on tree {2} (seed = {3}).", interpreters[m].Name, interpreters[n].Name, i, seed);
406            Assert.AreEqual(sum, s, 1e-12, errorMessage);
407          }
408        }
409      }
410    }
411
412    [TestMethod]
413    [TestCategory("Problems.DataAnalysis.Symbolic")]
414    [TestProperty("Time", "long")]
415    public void TestCompiledInterpreterEstimatedValuesConsistency() {
416      const double delta = 1e-8;
417
418      var twister = new MersenneTwister();
419      int seed = twister.Next(0, int.MaxValue);
420      twister.Seed((uint)seed);
421
422      Console.WriteLine(seed);
423
424      const int numRows = 100;
425      var dataset = Util.CreateRandomDataset(twister, numRows, Columns);
426
427      var grammar = new TypeCoherentExpressionGrammar();
428      grammar.ConfigureAsDefaultRegressionGrammar();
429      grammar.Symbols.First(x => x.Name == "Power Functions").Enabled = true;
430      grammar.Symbols.First(x => x is Cube).Enabled = true;
431      grammar.Symbols.First(x => x is CubeRoot).Enabled = true;
432      grammar.Symbols.First(x => x is Square).Enabled = true;
433      grammar.Symbols.First(x => x is SquareRoot).Enabled = true;
434      grammar.Symbols.First(x => x is Absolute).Enabled = true;
435      grammar.Symbols.First(x => x is Sine).Enabled = true;
436      grammar.Symbols.First(x => x is Cosine).Enabled = true;
437      grammar.Symbols.First(x => x is Tangent).Enabled = true;
438      grammar.Symbols.First(x => x is Root).Enabled = false;
439      grammar.Symbols.First(x => x is Power).Enabled = false;
440
441      var randomTrees = Util.CreateRandomTrees(twister, dataset, grammar, N, 1, 10, 0, 0);
442      foreach (ISymbolicExpressionTree tree in randomTrees) {
443        Util.InitTree(tree, twister, new List<string>(dataset.VariableNames));
444      }
445
446      var interpreters = new ISymbolicDataAnalysisExpressionTreeInterpreter[] {
447        new SymbolicDataAnalysisExpressionCompiledTreeInterpreter(),
448        new SymbolicDataAnalysisExpressionTreeInterpreter(),
449        new SymbolicDataAnalysisExpressionTreeLinearInterpreter(),
450      };
451      var rows = Enumerable.Range(0, numRows).ToList();
452      var formatter = new SymbolicExpressionTreeHierarchicalFormatter();
453
454      for (int i = 0; i < randomTrees.Length; ++i) {
455        var tree = randomTrees[i];
456        var valuesMatrix = interpreters.Select(x => x.GetSymbolicExpressionTreeValues(tree, dataset, rows).ToList()).ToList();
457        for (int m = 0; m < interpreters.Length - 1; ++m) {
458          for (int n = m + 1; n < interpreters.Length; ++n) {
459            for (int row = 0; row < numRows; ++row) {
460              var v1 = valuesMatrix[m][row];
461              var v2 = valuesMatrix[n][row];
462              if (double.IsNaN(v1) && double.IsNaN(v2)) continue;
463              if (v1 != v2 && Math.Abs(1.0 - v1 / v2) >= delta) {
464                Console.WriteLine(formatter.Format(tree));
465                foreach (var node in tree.Root.GetSubtree(0).GetSubtree(0).IterateNodesPrefix().ToList()) {
466                  var rootNode = (SymbolicExpressionTreeTopLevelNode)grammar.ProgramRootSymbol.CreateTreeNode();
467                  if (rootNode.HasLocalParameters) rootNode.ResetLocalParameters(twister);
468                  rootNode.SetGrammar(grammar.CreateExpressionTreeGrammar());
469
470                  var startNode = (SymbolicExpressionTreeTopLevelNode)grammar.StartSymbol.CreateTreeNode();
471                  if (startNode.HasLocalParameters) startNode.ResetLocalParameters(twister);
472                  startNode.SetGrammar(grammar.CreateExpressionTreeGrammar());
473
474                  rootNode.AddSubtree(startNode);
475                  var t = new SymbolicExpressionTree(rootNode);
476                  var start = t.Root.GetSubtree(0);
477                  var p = node.Parent;
478                  start.AddSubtree(node);
479                  Console.WriteLine(node);
480
481                  var y1 = interpreters[m].GetSymbolicExpressionTreeValues(t, dataset, new[] { row }).First();
482                  var y2 = interpreters[n].GetSymbolicExpressionTreeValues(t, dataset, new[] { row }).First();
483
484                  if (double.IsNaN(y1) && double.IsNaN(y2)) continue;
485                  string prefix = Math.Abs(y1 - y2) > delta ? "++" : "==";
486                  Console.WriteLine("\t{0} Row {1}: {2} {3}, Deviation = {4}", prefix, row, y1, y2, Math.Abs(y1 - y2));
487                  node.Parent = p;
488                }
489              }
490              string errorMessage = string.Format("Interpreters {0} and {1} do not agree on tree {2} and row {3} (seed = {4}).", interpreters[m].Name, interpreters[n].Name, i, row, seed);
491              Assert.IsTrue(double.IsNaN(v1) && double.IsNaN(v2) ||
492                            v1 == v2 || // in particular 0 = 0
493                            Math.Abs(1.0 - v1 / v2) < delta, errorMessage);
494            }
495          }
496        }
497      }
498    }
499
500    private void EvaluateTerminals(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IDataset ds) {
501      // constants
502      Evaluate(interpreter, ds, "(+ 1.5 3.5)", 0, 5.0);
503
504      // variables
505      Evaluate(interpreter, ds, "(variable 2.0 a)", 0, 2.0);
506      Evaluate(interpreter, ds, "(variable 2.0 a)", 1, 4.0);
507    }
508
509    private void EvaluateAdf(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IDataset ds) {
510
511      // ADF     
512      Evaluate(interpreter, ds, @"(PROG
513                                    (MAIN
514                                      (CALL ADF0))
515                                    (defun ADF0 1.0))", 1, 1.0);
516      Evaluate(interpreter, ds, @"(PROG
517                                    (MAIN
518                                      (* (CALL ADF0) (CALL ADF0)))
519                                    (defun ADF0 2.0))", 1, 4.0);
520      Evaluate(interpreter, ds, @"(PROG
521                                    (MAIN
522                                      (CALL ADF0 2.0 3.0))
523                                    (defun ADF0
524                                      (+ (ARG 0) (ARG 1))))", 1, 5.0);
525      Evaluate(interpreter, ds, @"(PROG
526                                    (MAIN (CALL ADF1 2.0 3.0))
527                                    (defun ADF0
528                                      (- (ARG 1) (ARG 0)))
529                                    (defun ADF1
530                                      (+ (CALL ADF0 (ARG 1) (ARG 0))
531                                         (CALL ADF0 (ARG 0) (ARG 1)))))", 1, 0.0);
532      Evaluate(interpreter, ds, @"(PROG
533                                    (MAIN (CALL ADF1 (variable 2.0 a) 3.0))
534                                    (defun ADF0
535                                      (- (ARG 1) (ARG 0)))
536                                    (defun ADF1                                                                             
537                                      (CALL ADF0 (ARG 1) (ARG 0))))", 1, 1.0);
538      Evaluate(interpreter, ds,
539               @"(PROG
540                                    (MAIN (CALL ADF1 (variable 2.0 a) 3.0))
541                                    (defun ADF0
542                                      (- (ARG 1) (ARG 0)))
543                                    (defun ADF1                                                                             
544                                      (+ (CALL ADF0 (ARG 1) (ARG 0))
545                                         (CALL ADF0 (ARG 0) (ARG 1)))))", 1, 0.0);
546    }
547
548    private void EvaluateSpecialFunctions(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IDataset ds) {
549      // special functions
550      Action<double> checkAiry = (x) => {
551        double ai, aip, bi, bip;
552        alglib.airy(x, out ai, out aip, out bi, out bip);
553        Evaluate(interpreter, ds, "(airya " + x + ")", 0, ai);
554        Evaluate(interpreter, ds, "(airyb " + x + ")", 0, bi);
555      };
556
557      Action<double> checkBessel = (x) => {
558        Evaluate(interpreter, ds, "(bessel " + x + ")", 0, alglib.besseli0(x));
559      };
560
561      Action<double> checkSinCosIntegrals = (x) => {
562        double si, ci;
563        alglib.sinecosineintegrals(x, out si, out ci);
564        Evaluate(interpreter, ds, "(cosint " + x + ")", 0, ci);
565        Evaluate(interpreter, ds, "(sinint " + x + ")", 0, si);
566      };
567      Action<double> checkHypSinCosIntegrals = (x) => {
568        double shi, chi;
569        alglib.hyperbolicsinecosineintegrals(x, out shi, out chi);
570        Evaluate(interpreter, ds, "(hypcosint " + x + ")", 0, chi);
571        Evaluate(interpreter, ds, "(hypsinint " + x + ")", 0, shi);
572      };
573      Action<double> checkFresnelSinCosIntegrals = (x) => {
574        double c = 0, s = 0;
575        alglib.fresnelintegral(x, ref c, ref s);
576        Evaluate(interpreter, ds, "(fresnelcosint " + x + ")", 0, c);
577        Evaluate(interpreter, ds, "(fresnelsinint " + x + ")", 0, s);
578      };
579      Action<double> checkNormErf = (x) => {
580        Evaluate(interpreter, ds, "(norm " + x + ")", 0, alglib.normaldistribution(x));
581        Evaluate(interpreter, ds, "(erf " + x + ")", 0, alglib.errorfunction(x));
582      };
583
584      Action<double> checkGamma = (x) => {
585        Evaluate(interpreter, ds, "(gamma " + x + ")", 0, alglib.gammafunction(x));
586      };
587      Action<double> checkPsi = (x) => {
588        try {
589          Evaluate(interpreter, ds, "(psi " + x + ")", 0, alglib.psi(x));
590        } catch (alglib.alglibexception) { // ignore cases where alglib throws an exception
591        }
592      };
593      Action<double> checkDawson = (x) => {
594        Evaluate(interpreter, ds, "(dawson " + x + ")", 0, alglib.dawsonintegral(x));
595      };
596      Action<double> checkExpInt = (x) => {
597        Evaluate(interpreter, ds, "(expint " + x + ")", 0, alglib.exponentialintegralei(x));
598      };
599
600      foreach (var e in new[] { -2.0, -1.0, 0.0, 1.0, 2.0 }) {
601        checkAiry(e);
602        checkBessel(e);
603        checkSinCosIntegrals(e);
604        checkGamma(e);
605        checkExpInt(e);
606        checkDawson(e);
607        checkPsi(e);
608        checkNormErf(e);
609        checkFresnelSinCosIntegrals(e);
610        checkHypSinCosIntegrals(e);
611      }
612    }
613
614    private void EvaluateLaggedOperations(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IDataset ds) {
615      // lag
616      Evaluate(interpreter, ds, "(lagVariable 1.0 a -1) ", 1, ds.GetDoubleValue("A", 0));
617      Evaluate(interpreter, ds, "(lagVariable 1.0 a -1) ", 2, ds.GetDoubleValue("A", 1));
618      Evaluate(interpreter, ds, "(lagVariable 1.0 a 0) ", 2, ds.GetDoubleValue("A", 2));
619      Evaluate(interpreter, ds, "(lagVariable 1.0 a 1) ", 0, ds.GetDoubleValue("A", 1));
620
621      // integral
622      Evaluate(interpreter, ds, "(integral -1.0 (variable 1.0 a)) ", 1, ds.GetDoubleValue("A", 0) + ds.GetDoubleValue("A", 1));
623      Evaluate(interpreter, ds, "(integral -1.0 (lagVariable 1.0 a 1)) ", 1, ds.GetDoubleValue("A", 1) + ds.GetDoubleValue("A", 2));
624      Evaluate(interpreter, ds, "(integral -2.0 (variable 1.0 a)) ", 2, ds.GetDoubleValue("A", 0) + ds.GetDoubleValue("A", 1) + ds.GetDoubleValue("A", 2));
625      Evaluate(interpreter, ds, "(integral -1.0 (* (variable 1.0 a) (variable 1.0 b)))", 1, ds.GetDoubleValue("A", 0) * ds.GetDoubleValue("B", 0) + ds.GetDoubleValue("A", 1) * ds.GetDoubleValue("B", 1));
626      Evaluate(interpreter, ds, "(integral -2.0 3.0)", 1, 9.0);
627
628      // derivative
629      // (f_0 + 2 * f_1 - 2 * f_3 - f_4) / 8; // h = 1
630      Evaluate(interpreter, ds, "(diff (variable 1.0 a)) ", 5, (ds.GetDoubleValue("A", 5) + 2 * ds.GetDoubleValue("A", 4) - 2 * ds.GetDoubleValue("A", 2) - ds.GetDoubleValue("A", 1)) / 8.0);
631      Evaluate(interpreter, ds, "(diff (variable 1.0 b)) ", 5, (ds.GetDoubleValue("B", 5) + 2 * ds.GetDoubleValue("B", 4) - 2 * ds.GetDoubleValue("B", 2) - ds.GetDoubleValue("B", 1)) / 8.0);
632      Evaluate(interpreter, ds, "(diff (* (variable 1.0 a) (variable 1.0 b)))", 5, +
633        (ds.GetDoubleValue("A", 5) * ds.GetDoubleValue("B", 5) +
634        2 * ds.GetDoubleValue("A", 4) * ds.GetDoubleValue("B", 4) -
635        2 * ds.GetDoubleValue("A", 2) * ds.GetDoubleValue("B", 2) -
636        ds.GetDoubleValue("A", 1) * ds.GetDoubleValue("B", 1)) / 8.0);
637      Evaluate(interpreter, ds, "(diff -2.0 3.0)", 5, 0.0);
638
639      // timelag
640      Evaluate(interpreter, ds, "(lag -1.0 (lagVariable 1.0 a 2)) ", 1, ds.GetDoubleValue("A", 2));
641      Evaluate(interpreter, ds, "(lag -2.0 (lagVariable 1.0 a 2)) ", 2, ds.GetDoubleValue("A", 2));
642      Evaluate(interpreter, ds, "(lag -1.0 (* (lagVariable 1.0 a 1) (lagVariable 1.0 b 2)))", 1, ds.GetDoubleValue("A", 1) * ds.GetDoubleValue("B", 2));
643      Evaluate(interpreter, ds, "(lag -2.0 3.0)", 1, 3.0);
644    }
645
646    private void EvaluateOperations(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IDataset ds) {
647      // addition
648      Evaluate(interpreter, ds, "(+ (variable 2.0 a ))", 1, 4.0);
649      Evaluate(interpreter, ds, "(+ (variable 2.0 a ) (variable 3.0 b ))", 0, 5.0);
650      Evaluate(interpreter, ds, "(+ (variable 2.0 a ) (variable 3.0 b ))", 1, 10.0);
651      Evaluate(interpreter, ds, "(+ (variable 2.0 a) (variable 3.0 b ))", 2, 8.0);
652      Evaluate(interpreter, ds, "(+ 8.0 2.0 2.0)", 0, 12.0);
653
654      // subtraction
655      Evaluate(interpreter, ds, "(- (variable 2.0 a ))", 1, -4.0);
656      Evaluate(interpreter, ds, "(- (variable 2.0 a ) (variable 3.0 b))", 0, -1.0);
657      Evaluate(interpreter, ds, "(- (variable 2.0 a ) (variable 3.0 b ))", 1, -2.0);
658      Evaluate(interpreter, ds, "(- (variable 2.0 a ) (variable 3.0 b ))", 2, -4.0);
659      Evaluate(interpreter, ds, "(- 8.0 2.0 2.0)", 0, 4.0);
660
661      // multiplication
662      Evaluate(interpreter, ds, "(* (variable 2.0 a ))", 0, 2.0);
663      Evaluate(interpreter, ds, "(* (variable 2.0 a ) (variable 3.0 b ))", 0, 6.0);
664      Evaluate(interpreter, ds, "(* (variable 2.0 a ) (variable 3.0 b ))", 1, 24.0);
665      Evaluate(interpreter, ds, "(* (variable 2.0 a ) (variable 3.0 b ))", 2, 12.0);
666      Evaluate(interpreter, ds, "(* 8.0 2.0 2.0)", 0, 32.0);
667
668      // division
669      Evaluate(interpreter, ds, "(/ (variable 2.0 a ))", 1, 1.0 / 4.0);
670      Evaluate(interpreter, ds, "(/ (variable 2.0 a ) 2.0)", 0, 1.0);
671      Evaluate(interpreter, ds, "(/ (variable 2.0 a ) 2.0)", 1, 2.0);
672      Evaluate(interpreter, ds, "(/ (variable 3.0 b ) 2.0)", 2, 3.0);
673      Evaluate(interpreter, ds, "(/ 8.0 2.0 2.0)", 0, 2.0);
674
675      // gt
676      Evaluate(interpreter, ds, "(> (variable 2.0 a) 2.0)", 0, -1.0);
677      Evaluate(interpreter, ds, "(> 2.0 (variable 2.0 a))", 0, -1.0);
678      Evaluate(interpreter, ds, "(> (variable 2.0 a) 1.9)", 0, 1.0);
679      Evaluate(interpreter, ds, "(> 1.9 (variable 2.0 a))", 0, -1.0);
680      Evaluate(interpreter, ds, "(> (log -1.0) (log -1.0))", 0, -1.0); // (> nan nan) should be false
681
682      // lt
683      Evaluate(interpreter, ds, "(< (variable 2.0 a) 2.0)", 0, -1.0);
684      Evaluate(interpreter, ds, "(< 2.0 (variable 2.0 a))", 0, -1.0);
685      Evaluate(interpreter, ds, "(< (variable 2.0 a) 1.9)", 0, -1.0);
686      Evaluate(interpreter, ds, "(< 1.9 (variable 2.0 a))", 0, 1.0);
687      Evaluate(interpreter, ds, "(< (log -1.0) (log -1.0))", 0, -1.0); // (< nan nan) should be false
688
689      // If
690      Evaluate(interpreter, ds, "(if -10.0 2.0 3.0)", 0, 3.0);
691      Evaluate(interpreter, ds, "(if -1.0 2.0 3.0)", 0, 3.0);
692      Evaluate(interpreter, ds, "(if 0.0 2.0 3.0)", 0, 3.0);
693      Evaluate(interpreter, ds, "(if 1.0 2.0 3.0)", 0, 2.0);
694      Evaluate(interpreter, ds, "(if 10.0 2.0 3.0)", 0, 2.0);
695      Evaluate(interpreter, ds, "(if (log -1.0) 2.0 3.0)", 0, 3.0); // if(nan) should return the else branch
696
697      // NOT
698      Evaluate(interpreter, ds, "(not -1.0)", 0, 1.0);
699      Evaluate(interpreter, ds, "(not -2.0)", 0, 1.0);
700      Evaluate(interpreter, ds, "(not 1.0)", 0, -1.0);
701      Evaluate(interpreter, ds, "(not 2.0)", 0, -1.0);
702      Evaluate(interpreter, ds, "(not 0.0)", 0, 1.0);
703      Evaluate(interpreter, ds, "(not (log -1.0))", 0, 1.0);
704
705      // AND
706      Evaluate(interpreter, ds, "(and -1.0 -2.0)", 0, -1.0);
707      Evaluate(interpreter, ds, "(and -1.0 2.0)", 0, -1.0);
708      Evaluate(interpreter, ds, "(and 1.0 -2.0)", 0, -1.0);
709      Evaluate(interpreter, ds, "(and 1.0 0.0)", 0, -1.0);
710      Evaluate(interpreter, ds, "(and 0.0 0.0)", 0, -1.0);
711      Evaluate(interpreter, ds, "(and 1.0 2.0)", 0, 1.0);
712      Evaluate(interpreter, ds, "(and 1.0 2.0 3.0)", 0, 1.0);
713      Evaluate(interpreter, ds, "(and 1.0 -2.0 3.0)", 0, -1.0);
714      Evaluate(interpreter, ds, "(and (log -1.0))", 0, -1.0); // (and NaN)
715      Evaluate(interpreter, ds, "(and (log -1.0)  1.0)", 0, -1.0); // (and NaN 1.0)
716
717      // OR
718      Evaluate(interpreter, ds, "(or -1.0 -2.0)", 0, -1.0);
719      Evaluate(interpreter, ds, "(or -1.0 2.0)", 0, 1.0);
720      Evaluate(interpreter, ds, "(or 1.0 -2.0)", 0, 1.0);
721      Evaluate(interpreter, ds, "(or 1.0 2.0)", 0, 1.0);
722      Evaluate(interpreter, ds, "(or 0.0 0.0)", 0, -1.0);
723      Evaluate(interpreter, ds, "(or -1.0 -2.0 -3.0)", 0, -1.0);
724      Evaluate(interpreter, ds, "(or -1.0 -2.0 3.0)", 0, 1.0);
725      Evaluate(interpreter, ds, "(or (log -1.0))", 0, -1.0); // (or NaN)
726      Evaluate(interpreter, ds, "(or (log -1.0)  1.0)", 0, -1.0); // (or NaN 1.0)
727
728      // XOR
729      Evaluate(interpreter, ds, "(xor -1.0 -2.0)", 0, -1.0);
730      Evaluate(interpreter, ds, "(xor -1.0 2.0)", 0, 1.0);
731      Evaluate(interpreter, ds, "(xor 1.0 -2.0)", 0, 1.0);
732      Evaluate(interpreter, ds, "(xor 1.0 2.0)", 0, -1.0);
733      Evaluate(interpreter, ds, "(xor 0.0 0.0)", 0, -1.0);
734      Evaluate(interpreter, ds, "(xor -1.0 -2.0 -3.0)", 0, -1.0);
735      Evaluate(interpreter, ds, "(xor -1.0 -2.0 3.0)", 0, 1.0);
736      Evaluate(interpreter, ds, "(xor -1.0 2.0 3.0)", 0, -1.0);
737      Evaluate(interpreter, ds, "(xor 1.0 2.0 3.0)", 0, 1.0);
738      Evaluate(interpreter, ds, "(xor (log -1.0))", 0, -1.0);
739      Evaluate(interpreter, ds, "(xor (log -1.0)  1.0)", 0, 1.0);
740
741      // sin, cos, tan
742      Evaluate(interpreter, ds, "(sin " + Math.PI.ToString(NumberFormatInfo.InvariantInfo) + ")", 0, 0.0);
743      Evaluate(interpreter, ds, "(sin 0.0)", 0, 0.0);
744      Evaluate(interpreter, ds, "(cos " + Math.PI.ToString(NumberFormatInfo.InvariantInfo) + ")", 0, -1.0);
745      Evaluate(interpreter, ds, "(cos 0.0)", 0, 1.0);
746      Evaluate(interpreter, ds, "(tan " + Math.PI.ToString(NumberFormatInfo.InvariantInfo) + ")", 0, Math.Tan(Math.PI));
747      Evaluate(interpreter, ds, "(tan 0.0)", 0, Math.Tan(Math.PI));
748
749      // exp, log
750      Evaluate(interpreter, ds, "(log (exp 7.0))", 0, Math.Log(Math.Exp(7)));
751      Evaluate(interpreter, ds, "(exp (log 7.0))", 0, Math.Exp(Math.Log(7)));
752      Evaluate(interpreter, ds, "(log -3.0)", 0, Math.Log(-3));
753
754      // power
755      Evaluate(interpreter, ds, "(pow 2.0 3.0)", 0, 8.0);
756      Evaluate(interpreter, ds, "(pow 4.0 0.5)", 0, 1.0); // interpreter should round to the nearest integer value value (.5 is rounded to the even number)
757      Evaluate(interpreter, ds, "(pow 4.0 2.5)", 0, 16.0); // interpreter should round to the nearest integer value value (.5 is rounded to the even number)
758      Evaluate(interpreter, ds, "(pow -2.0 3.0)", 0, -8.0);
759      Evaluate(interpreter, ds, "(pow 2.0 -3.0)", 0, 1.0 / 8.0);
760      Evaluate(interpreter, ds, "(pow -2.0 -3.0)", 0, -1.0 / 8.0);
761
762      // root
763      Evaluate(interpreter, ds, "(root 9.0 2.0)", 0, 3.0);
764      Evaluate(interpreter, ds, "(root 27.0 3.0)", 0, 3.0);
765      Evaluate(interpreter, ds, "(root 2.0 -3.0)", 0, Math.Pow(2.0, -1.0 / 3.0));
766
767      // mean
768      Evaluate(interpreter, ds, "(mean -1.0 1.0 -1.0)", 0, -1.0 / 3.0);
769    }
770
771    private void Evaluate(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IDataset ds, string expr, int index, double expected) {
772      var importer = new SymbolicExpressionImporter();
773      ISymbolicExpressionTree tree = importer.Import(expr);
774
775      double actual = interpreter.GetSymbolicExpressionTreeValues(tree, ds, Enumerable.Range(index, 1)).First();
776
777      Assert.IsFalse(double.IsNaN(actual) && !double.IsNaN(expected));
778      Assert.IsFalse(!double.IsNaN(actual) && double.IsNaN(expected));
779      if (!double.IsNaN(actual) && !double.IsNaN(expected))
780        Assert.AreEqual(expected, actual, 1.0E-12, expr);
781    }
782  }
783}
Note: See TracBrowser for help on using the repository browser.