Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2994-AutoDiffForIntervals/HeuristicLab.Tests/HeuristicLab.Problems.DataAnalysis-3.4/AutoDiffInterpreterTest.cs @ 17308

Last change on this file since 17308 was 17308, checked in by gkronber, 5 years ago

#2994 fix compile error

File size: 7.2 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Linq;
4using System.Text;
5using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
6using HeuristicLab.Problems.DataAnalysis.Symbolic;
7using HeuristicLab.Random;
8using Microsoft.VisualStudio.TestTools.UnitTesting;
9
10namespace HeuristicLab.Problems.DataAnalysis.Tests {
11  [TestClass]
12  public class AutoDiffInterpreterTest {
13
14    [TestMethod]
15    [TestCategory("Problems.DataAnalysis")]
16    [TestProperty("Time", "short")]
17    public void TestAutoDiffUsingNumericDifferences() {
18
19      // create random trees and evaluate on random data
20      // calc gradient for all parameters
21      // use numeric differences for approximate gradient calculation
22      // compare gradients
23
24      var grammar = new TypeCoherentExpressionGrammar();
25      grammar.ConfigureAsDefaultRegressionGrammar();
26      // activate supported symbols
27      grammar.Symbols.First(s => s is Square).Enabled = true;
28      grammar.Symbols.First(s => s is SquareRoot).Enabled = true;
29      grammar.Symbols.First(s => s is Cube).Enabled = true;
30      grammar.Symbols.First(s => s is CubeRoot).Enabled = true;
31      grammar.Symbols.First(s => s is Sine).Enabled = true;
32      grammar.Symbols.First(s => s is Cosine).Enabled = true;
33      grammar.Symbols.First(s => s is Exponential).Enabled = true;
34      grammar.Symbols.First(s => s is Logarithm).Enabled = true;
35      grammar.Symbols.First(s => s is Absolute).Enabled = false; // XXX not yet supported by old interval calculator
36      grammar.Symbols.First(s => s is AnalyticQuotient).Enabled = false; // not yet supported by old interval calculator
37
38      var varSy = (Variable)grammar.Symbols.First(s => s is Variable);
39      varSy.AllVariableNames = new string[] { "x", "y" };
40      varSy.VariableNames = varSy.AllVariableNames;
41      varSy.WeightMu = 1.0;
42      varSy.WeightSigma = 1.0;
43      var rand = new FastRandom(1234);
44
45      // random data
46      var values = new double[100, 2];
47      for (int i = 0; i < 100; i++)
48        for (int j = 0; j < 2; j++) {
49          values[i, j] = rand.NextDouble() * 2 - 1;
50        }
51      var ds = new Dataset(varSy.AllVariableNames, values);
52      // buffers
53      var fi = new double[100];
54      var rows = Enumerable.Range(0, 100).ToArray();
55
56      var eval = new VectorAutoDiffEvaluator();
57      var refEval = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
58
59      var formatter = new InfixExpressionFormatter();
60      var sb = new StringBuilder();
61      int N = 10000;
62      int iter = 0;
63      while (iter < N) {
64        var t = ProbabilisticTreeCreator.Create(rand, grammar, maxTreeLength: 5, maxTreeDepth: 5);
65        var parameterNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
66
67        var jac = new double[100, parameterNodes.Length];
68
69        eval.Evaluate(t, ds, rows, parameterNodes, fi, jac);
70
71        var refJac = ApproximateGradient(t, ds, rows, parameterNodes, refEval);
72
73        for (int k = 0; k < rows.Length; k++) {
74          if (double.IsNaN(fi[k]) || double.IsInfinity(fi[k])) continue; // skip outputs where we expect problematic gradients
75
76          // check partial derivatives
77          for (int p = 0; p < parameterNodes.Length; p++) {
78            if (double.IsNaN(jac[k, p]) && double.IsNaN(refJac[k, p])) continue; // both NaN
79            if (jac[k, p] == refJac[k, p]) continue; // equal
80            if (Math.Abs(jac[k, p]) <= 1e-12 && Math.Abs(refJac[k, p]) <= 1e-12) continue; // both very small
81
82            // check relative error using the larger value as reference
83            var refVal = Math.Max(Math.Abs(jac[k, p]), Math.Abs(refJac[k, p]));
84            if (Math.Abs(jac[k, p] - refJac[k, p]) > refVal * 1e-4)
85              sb.AppendLine($"{jac[k, p]} <> {refJac[k, p]} for {parameterNodes[p]} in {formatter.Format(t)} x={ds.GetDoubleValue("x", k)} y={ds.GetDoubleValue("y", k)}");
86          }
87        }
88
89        iter++;
90      }
91      if (sb.Length > 0) {
92        Console.WriteLine(sb.ToString());
93        Assert.Fail("There were differences when validating AutoDiff using numeric differences");
94      }
95    }
96
97
98    [TestMethod]
99    [TestCategory("Problems.DataAnalysis")]
100    [TestProperty("Time", "short")]
101    public void TestVectorAutoDiffInterpreter() {
102      var ds = new Dataset(new string[] { "x", "y" }, new double[,] { { 1, 0 }, { 2, 1 } });
103
104      Assert.AreEqual(0.25, CalculateGradient("sqrt(4)", ds)[0]);
105      Assert.AreEqual((1.0 / 3.0) * (1.0 / 4.0) , CalculateGradient("cuberoot(8)", ds)[0]);
106      Assert.AreEqual((1.0 / 4.0), CalculateGradient("1.0 / 4.0", ds)[0]);
107      Assert.AreEqual(-1.0 / 16.0, CalculateGradient("1.0 / 4.0", ds)[1]);
108      Assert.AreEqual(1.0 / 16.0, CalculateGradient("1.0 / (-4.0)", ds)[1]);
109    }
110
111    #region helper
112
113    private double[] CalculateGradient(string expr, IDataset ds) {
114      var eval = new VectorAutoDiffEvaluator();
115      var parser = new InfixExpressionParser();
116
117      var rows = new int[1];
118      var fi = new double[1];
119
120      var t = parser.Parse(expr);
121      var parameterNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
122      var jac = new double[1, parameterNodes.Length];
123      eval.Evaluate(t, ds, rows, parameterNodes, fi, jac);
124
125      var g = new double[parameterNodes.Length];
126      for (int i = 0; i < g.Length; i++) g[i] = jac[0, i];
127      return g;
128    }
129
130
131    private double[,] ApproximateGradient(ISymbolicExpressionTree t, Dataset ds, int[] rows, ISymbolicExpressionTreeNode[] parameterNodes,
132      SymbolicDataAnalysisExpressionTreeLinearInterpreter eval) {
133      var jac = new double[rows.Length, parameterNodes.Length];
134      for (int p = 0; p < parameterNodes.Length; p++) {
135
136        var x = GetValue(parameterNodes[p]);
137        var x_diff = x * 1e-4; // relative change
138
139        // calculate output for increased parameter value
140        SetValue(parameterNodes[p], x + x_diff / 2);
141        var f = eval.GetSymbolicExpressionTreeValues(t, ds, rows).ToArray();
142        for (int i = 0; i < rows.Length; i++) {
143          jac[i, p] = f[i];
144        }
145
146        // calculate output for decreased parameter value
147        SetValue(parameterNodes[p], x - x_diff / 2);
148        f = eval.GetSymbolicExpressionTreeValues(t, ds, rows).ToArray();
149        for (int i = 0; i < rows.Length; i++) {
150          jac[i, p] -= f[i]; // calc difference (and scale for x_diff)
151          jac[i, p] /= x_diff;
152        }
153
154        // restore original value
155        SetValue(parameterNodes[p], x);
156      }
157      return jac;
158    }
159
160    private void SetValue(ISymbolicExpressionTreeNode node, double v) {
161      var varNode = node as VariableTreeNode;
162      var constNode = node as ConstantTreeNode;
163      if (varNode != null) varNode.Weight = v;
164      else if (constNode != null) constNode.Value = v;
165      else throw new InvalidProgramException();
166    }
167
168    private double GetValue(ISymbolicExpressionTreeNode node) {
169      var varNode = node as VariableTreeNode;
170      var constNode = node as ConstantTreeNode;
171      if (varNode != null) return varNode.Weight;
172      else if (constNode != null) return constNode.Value;
173      throw new InvalidProgramException();
174    }
175    #endregion
176  }
177}
Note: See TracBrowser for help on using the repository browser.