1 | using System;
|
---|
2 | using System.Collections.Generic;
|
---|
3 | using System.Linq;
|
---|
4 | using System.Text;
|
---|
5 | using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
|
---|
6 | using HeuristicLab.Problems.DataAnalysis.Symbolic;
|
---|
7 | using HeuristicLab.Random;
|
---|
8 | using Microsoft.VisualStudio.TestTools.UnitTesting;
|
---|
9 |
|
---|
10 | namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Tests {
|
---|
11 | [TestClass]
|
---|
12 | public class VectorAutoDiffEvaluatorTest {
|
---|
13 |
|
---|
14 | [TestMethod]
|
---|
15 | [TestCategory("Problems.DataAnalysis")]
|
---|
16 | [TestProperty("Time", "short")]
|
---|
17 | public void VectorAutoDiffEvaluatorCompareWithNumericDifferencesTest() {
|
---|
18 |
|
---|
19 | // create random trees and evaluate on random data
|
---|
20 | // calc gradient for all parameters
|
---|
21 | // use numeric differences for approximate gradient calculation
|
---|
22 | // compare gradients
|
---|
23 |
|
---|
24 | var grammar = new TypeCoherentExpressionGrammar();
|
---|
25 | grammar.ConfigureAsDefaultRegressionGrammar();
|
---|
26 | // activate supported symbols
|
---|
27 | grammar.Symbols.First(s => s is Square).Enabled = true;
|
---|
28 | grammar.Symbols.First(s => s is SquareRoot).Enabled = true;
|
---|
29 | grammar.Symbols.First(s => s is Cube).Enabled = true;
|
---|
30 | grammar.Symbols.First(s => s is CubeRoot).Enabled = true;
|
---|
31 | grammar.Symbols.First(s => s is Sine).Enabled = true;
|
---|
32 | grammar.Symbols.First(s => s is Cosine).Enabled = true;
|
---|
33 | grammar.Symbols.First(s => s is Exponential).Enabled = true;
|
---|
34 | grammar.Symbols.First(s => s is Logarithm).Enabled = true;
|
---|
35 | grammar.Symbols.First(s => s is Absolute).Enabled = true;
|
---|
36 | grammar.Symbols.First(s => s is AnalyticQuotient).Enabled = false; // not yet supported by old interval calculator
|
---|
37 |
|
---|
38 | var varSy = (Variable)grammar.Symbols.First(s => s is Variable);
|
---|
39 | varSy.AllVariableNames = new string[] { "x", "y" };
|
---|
40 | varSy.VariableNames = varSy.AllVariableNames;
|
---|
41 | varSy.WeightMu = 1.0;
|
---|
42 | varSy.WeightSigma = 1.0;
|
---|
43 | var rand = new FastRandom(1234);
|
---|
44 |
|
---|
45 | // random data
|
---|
46 | var values = new double[100, 2];
|
---|
47 | for (int i = 0; i < 100; i++)
|
---|
48 | for (int j = 0; j < 2; j++) {
|
---|
49 | values[i, j] = rand.NextDouble() * 2 - 1;
|
---|
50 | }
|
---|
51 | var ds = new Dataset(varSy.AllVariableNames, values);
|
---|
52 | // buffers
|
---|
53 | var fi = new double[100];
|
---|
54 | var rows = Enumerable.Range(0, 100).ToArray();
|
---|
55 |
|
---|
56 | var eval = new VectorAutoDiffEvaluator();
|
---|
57 | var refEval = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
|
---|
58 |
|
---|
59 | var formatter = new InfixExpressionFormatter();
|
---|
60 | var sb = new StringBuilder();
|
---|
61 | int N = 10000;
|
---|
62 | int iter = 0;
|
---|
63 | while (iter < N) {
|
---|
64 | var t = ProbabilisticTreeCreator.Create(rand, grammar, maxTreeLength: 5, maxTreeDepth: 5);
|
---|
65 | var parameterNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
66 |
|
---|
67 | var jac = new double[100, parameterNodes.Length];
|
---|
68 |
|
---|
69 | eval.Evaluate(t, ds, rows, parameterNodes, fi, jac);
|
---|
70 |
|
---|
71 | var refJac = ApproximateGradient(t, ds, rows, parameterNodes, refEval);
|
---|
72 |
|
---|
73 | for (int k = 0; k < rows.Length; k++) {
|
---|
74 | if (double.IsNaN(fi[k]) || double.IsInfinity(fi[k])) continue; // skip outputs where we expect problematic gradients
|
---|
75 |
|
---|
76 | // check partial derivatives
|
---|
77 | for (int p = 0; p < parameterNodes.Length; p++) {
|
---|
78 | if (double.IsNaN(jac[k, p]) && double.IsNaN(refJac[k, p])) continue; // both NaN
|
---|
79 | if (jac[k, p] == refJac[k, p]) continue; // equal
|
---|
80 | if (Math.Abs(jac[k, p]) <= 1e-12 && Math.Abs(refJac[k, p]) <= 1e-12) continue; // both very small
|
---|
81 |
|
---|
82 | // check relative error using the larger value as reference
|
---|
83 | var refVal = Math.Max(Math.Abs(jac[k, p]), Math.Abs(refJac[k, p]));
|
---|
84 | if (Math.Abs(jac[k, p] - refJac[k, p]) > refVal * 1e-4)
|
---|
85 | sb.AppendLine($"{jac[k, p]} <> {refJac[k, p]} for {parameterNodes[p]} in {formatter.Format(t)} x={ds.GetDoubleValue("x", k)} y={ds.GetDoubleValue("y", k)}");
|
---|
86 | }
|
---|
87 | }
|
---|
88 |
|
---|
89 | iter++;
|
---|
90 | }
|
---|
91 | if (sb.Length > 0) {
|
---|
92 | Console.WriteLine(sb.ToString());
|
---|
93 | Assert.Fail("There were differences when validating AutoDiff using numeric differences");
|
---|
94 | }
|
---|
95 | }
|
---|
96 |
|
---|
97 |
|
---|
98 | [TestMethod]
|
---|
99 | [TestCategory("Problems.DataAnalysis")]
|
---|
100 | [TestProperty("Time", "short")]
|
---|
101 | public void VectorAutoDiffEvaluatorExamplesTest() {
|
---|
102 | var ds = new Dataset(new string[] { "x", "y" }, new double[,] { { 1, 0 }, { 2, 1 } });
|
---|
103 |
|
---|
104 | Assert.AreEqual(0.25, CalculateGradient("sqrt(4)", ds)[0]);
|
---|
105 | Assert.AreEqual((1.0 / 3.0) * (1.0 / 4.0) , CalculateGradient("cuberoot(8)", ds)[0]);
|
---|
106 | Assert.AreEqual((1.0 / 4.0), CalculateGradient("1.0 / 4.0", ds)[0]);
|
---|
107 | Assert.AreEqual(-1.0 / 16.0, CalculateGradient("1.0 / 4.0", ds)[1]);
|
---|
108 | Assert.AreEqual(1.0 / 16.0, CalculateGradient("1.0 / (-4.0)", ds)[1]);
|
---|
109 | }
|
---|
110 |
|
---|
111 |
|
---|
112 | [TestMethod]
|
---|
113 | [TestCategory("Problems.DataAnalysis.Symbolic")]
|
---|
114 | [TestProperty("Time", "long")]
|
---|
115 | public void VectorEvaluatorsEstimatedValuesConsistencyTest() {
|
---|
116 | var twister = new MersenneTwister();
|
---|
117 | twister.Seed(31415);
|
---|
118 | const int numRows = 100;
|
---|
119 | const int Columns = 50;
|
---|
120 | const int N = 10000;
|
---|
121 | var dataset = Util.CreateRandomDataset(twister, numRows, Columns);
|
---|
122 |
|
---|
123 | var grammar = new TypeCoherentExpressionGrammar();
|
---|
124 | grammar.ConfigureAsDefaultRegressionGrammar();
|
---|
125 | grammar.Symbols.First(s => s is Square).Enabled = true;
|
---|
126 | grammar.Symbols.First(s => s is SquareRoot).Enabled = true;
|
---|
127 | grammar.Symbols.First(s => s is Cube).Enabled = true;
|
---|
128 | grammar.Symbols.First(s => s is CubeRoot).Enabled = true;
|
---|
129 | grammar.Symbols.First(s => s is Exponential).Enabled = true;
|
---|
130 | grammar.Symbols.First(s => s is Logarithm).Enabled = true;
|
---|
131 | grammar.Symbols.First(s => s is Sine).Enabled = true;
|
---|
132 | grammar.Symbols.First(s => s is Cosine).Enabled = true;
|
---|
133 | grammar.Symbols.First(s => s is Absolute).Enabled = true;
|
---|
134 | grammar.Symbols.First(s => s is AnalyticQuotient).Enabled = true;
|
---|
135 |
|
---|
136 | var refInterpreter = new SymbolicDataAnalysisExpressionTreeLinearInterpreter();
|
---|
137 | var newEvaluator = new VectorEvaluator();
|
---|
138 | var newAutoDiffEvaluator = new VectorAutoDiffEvaluator();
|
---|
139 |
|
---|
140 | var rows = Enumerable.Range(0, numRows).ToList();
|
---|
141 | var randomTrees = Util.CreateRandomTrees(twister, dataset, grammar, N, 1, 10, 0, 0);
|
---|
142 | foreach (ISymbolicExpressionTree tree in randomTrees) {
|
---|
143 | Util.InitTree(tree, twister, new List<string>(dataset.VariableNames));
|
---|
144 | }
|
---|
145 |
|
---|
146 | for (int i = 0; i < randomTrees.Length; ++i) {
|
---|
147 | var tree = randomTrees[i];
|
---|
148 | var refValues = refInterpreter.GetSymbolicExpressionTreeValues(tree, dataset, rows).ToArray();
|
---|
149 | var newValues = newEvaluator.Evaluate(tree, dataset, rows.ToArray()).ToArray();
|
---|
150 | var newAutoDiffValues = new double[numRows];
|
---|
151 | newAutoDiffEvaluator.Evaluate(tree, dataset, rows.ToArray(), new ISymbolicExpressionTreeNode[0], newAutoDiffValues, null);
|
---|
152 |
|
---|
153 | for (int j = 0; j < rows.Count; j++) {
|
---|
154 | if (double.IsNaN(refValues[j]) && double.IsNaN(newValues[j]) && double.IsNaN(newAutoDiffValues[j])) continue;
|
---|
155 | string errorMessage = string.Format("Interpreters do not agree on tree {0} {1}.", i, (new InfixExpressionFormatter()).Format(tree));
|
---|
156 |
|
---|
157 | var relDelta = Math.Abs(refValues[j]) * 1e-5;
|
---|
158 | Assert.AreEqual(refValues[j], newValues[j], relDelta, errorMessage);
|
---|
159 | Assert.AreEqual(newValues[j], newAutoDiffValues[j], relDelta, errorMessage);
|
---|
160 | }
|
---|
161 | }
|
---|
162 | }
|
---|
163 |
|
---|
164 | #region helper
|
---|
165 |
|
---|
166 | private double[] CalculateGradient(string expr, IDataset ds) {
|
---|
167 | var eval = new VectorAutoDiffEvaluator();
|
---|
168 | var parser = new InfixExpressionParser();
|
---|
169 |
|
---|
170 | var rows = new int[1];
|
---|
171 | var fi = new double[1];
|
---|
172 |
|
---|
173 | var t = parser.Parse(expr);
|
---|
174 | var parameterNodes = t.IterateNodesPostfix().Where(n => n.SubtreeCount == 0).ToArray();
|
---|
175 | var jac = new double[1, parameterNodes.Length];
|
---|
176 | eval.Evaluate(t, ds, rows, parameterNodes, fi, jac);
|
---|
177 |
|
---|
178 | var g = new double[parameterNodes.Length];
|
---|
179 | for (int i = 0; i < g.Length; i++) g[i] = jac[0, i];
|
---|
180 | return g;
|
---|
181 | }
|
---|
182 |
|
---|
183 |
|
---|
184 | private double[,] ApproximateGradient(ISymbolicExpressionTree t, Dataset ds, int[] rows, ISymbolicExpressionTreeNode[] parameterNodes,
|
---|
185 | SymbolicDataAnalysisExpressionTreeLinearInterpreter eval) {
|
---|
186 | var jac = new double[rows.Length, parameterNodes.Length];
|
---|
187 | for (int p = 0; p < parameterNodes.Length; p++) {
|
---|
188 |
|
---|
189 | var x = GetValue(parameterNodes[p]);
|
---|
190 | var x_diff = x * 1e-4; // relative change
|
---|
191 |
|
---|
192 | // calculate output for increased parameter value
|
---|
193 | SetValue(parameterNodes[p], x + x_diff / 2);
|
---|
194 | var f = eval.GetSymbolicExpressionTreeValues(t, ds, rows).ToArray();
|
---|
195 | for (int i = 0; i < rows.Length; i++) {
|
---|
196 | jac[i, p] = f[i];
|
---|
197 | }
|
---|
198 |
|
---|
199 | // calculate output for decreased parameter value
|
---|
200 | SetValue(parameterNodes[p], x - x_diff / 2);
|
---|
201 | f = eval.GetSymbolicExpressionTreeValues(t, ds, rows).ToArray();
|
---|
202 | for (int i = 0; i < rows.Length; i++) {
|
---|
203 | jac[i, p] -= f[i]; // calc difference (and scale for x_diff)
|
---|
204 | jac[i, p] /= x_diff;
|
---|
205 | }
|
---|
206 |
|
---|
207 | // restore original value
|
---|
208 | SetValue(parameterNodes[p], x);
|
---|
209 | }
|
---|
210 | return jac;
|
---|
211 | }
|
---|
212 |
|
---|
213 | private void ApproximateIntervalGradient(ISymbolicExpressionTree t, Dictionary<string, Interval> intervals, ISymbolicExpressionTreeNode[] parameterNodes, IntervalEvaluator eval, out double[] lowerGradient, out double[] upperGradient) {
|
---|
214 | lowerGradient = new double[parameterNodes.Length];
|
---|
215 | upperGradient = new double[parameterNodes.Length];
|
---|
216 |
|
---|
217 | for(int p=0;p<parameterNodes.Length;p++) {
|
---|
218 | var x = GetValue(parameterNodes[p]);
|
---|
219 | var x_diff = x * 1e-4; // relative change
|
---|
220 |
|
---|
221 | // calculate output for increased parameter value
|
---|
222 | SetValue(parameterNodes[p], x + x_diff / 2);
|
---|
223 | var r1 = eval.Evaluate(t, intervals);
|
---|
224 | lowerGradient[p] = r1.LowerBound;
|
---|
225 | upperGradient[p] = r1.UpperBound;
|
---|
226 |
|
---|
227 | // calculate output for decreased parameter value
|
---|
228 | SetValue(parameterNodes[p], x - x_diff / 2);
|
---|
229 | var r2 = eval.Evaluate(t, intervals);
|
---|
230 | lowerGradient[p] -= r2.LowerBound;
|
---|
231 | upperGradient[p] -= r2.UpperBound;
|
---|
232 |
|
---|
233 | lowerGradient[p] /= x_diff;
|
---|
234 | upperGradient[p] /= x_diff;
|
---|
235 |
|
---|
236 | // restore original value
|
---|
237 | SetValue(parameterNodes[p], x);
|
---|
238 | }
|
---|
239 | }
|
---|
240 |
|
---|
241 | private void SetValue(ISymbolicExpressionTreeNode node, double v) {
|
---|
242 | var varNode = node as VariableTreeNode;
|
---|
243 | var constNode = node as ConstantTreeNode;
|
---|
244 | if (varNode != null) varNode.Weight = v;
|
---|
245 | else if (constNode != null) constNode.Value = v;
|
---|
246 | else throw new InvalidProgramException();
|
---|
247 | }
|
---|
248 |
|
---|
249 | private double GetValue(ISymbolicExpressionTreeNode node) {
|
---|
250 | var varNode = node as VariableTreeNode;
|
---|
251 | var constNode = node as ConstantTreeNode;
|
---|
252 | if (varNode != null) return varNode.Weight;
|
---|
253 | else if (constNode != null) return constNode.Value;
|
---|
254 | throw new InvalidProgramException();
|
---|
255 | }
|
---|
256 | #endregion
|
---|
257 | }
|
---|
258 | }
|
---|