source: branches/2886_SymRegGrammarEnumeration/Test/TreeHashingTest.cs @ 16193

Last change on this file since 16193 was 16193, checked in by bburlacu, 23 months ago

#2886: Implement new hasher (faster & less collision prone) and update unit tests

File size: 15.5 KB
Line 
1using System;
2using System.Collections.Generic;
3using System.Diagnostics;
4using System.Linq;
5using System.Security.Cryptography;
6using HeuristicLab.Algorithms.DataAnalysis.SymRegGrammarEnumeration;
7using Microsoft.VisualStudio.TestTools.UnitTesting;
8using HierarchicalFormatter = HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.SymbolicExpressionTreeHierarchicalFormatter;
9
10namespace Test {
11  [TestClass]
12  public class TreeHashingTest {
13
14    private Grammar grammar;
15    private TerminalSymbol varA;
16    private TerminalSymbol varB;
17    private TerminalSymbol varC;
18    private TerminalSymbol c;
19
20    Func<Grammar, SymbolList, int> ComputeHash = (grammar, sentence) => grammar.ComputeHash(null, sentence);
21
22    [TestInitialize]
23    public void InitTest() {
24      grammar = new Grammar(new[] { "a", "b", "c" });
25
26      varA = grammar.VarTerminals.First(s => s.StringRepresentation == "a");
27      varB = grammar.VarTerminals.First(s => s.StringRepresentation == "b");
28      varC = grammar.VarTerminals.First(s => s.StringRepresentation == "c");
29      c = grammar.Const;
30    }
31
32    [TestMethod]
33    [TestCategory("TreeHashing")]
34    public void SimpleEqualityAddition() {
35      SymbolList s1 = new SymbolList(new[] { varA, varB, grammar.Addition, varC, grammar.Addition });
36      SymbolList s2 = new SymbolList(new[] { varA, varB, grammar.Addition, varC, grammar.Addition });
37
38      int hash1 = ComputeHash(grammar, s1);
39      int hash2 = ComputeHash(grammar, s2);
40
41      Assert.AreEqual(hash1, hash2);
42    }
43
44    [TestMethod]
45    [TestCategory("TreeHashing")]
46    public void SimpleInequalityAddition() {
47      SymbolList s1 = new SymbolList(new[] { varA, varB, grammar.Addition, varC, grammar.Addition });
48      SymbolList s2 = new SymbolList(new[] { varB, varB, grammar.Addition, varB, grammar.Addition });
49
50      int hash1 = ComputeHash(grammar, s1);
51      int hash2 = ComputeHash(grammar, s2);
52
53      Assert.AreNotEqual(hash1, hash2);
54    }
55
56    [TestMethod]
57    [TestCategory("TreeHashing")]
58    public void CommutativityAddition() {
59      SymbolList s1 = new SymbolList(new[] { varA, varB, grammar.Addition });
60      SymbolList s2 = new SymbolList(new[] { varB, varA, grammar.Addition });
61
62      int hash1 = ComputeHash(grammar, s1);
63      int hash2 = ComputeHash(grammar, s2);
64
65      Assert.AreEqual(hash1, hash2);
66    }
67
68    [TestMethod]
69    [TestCategory("TreeHashing")]
70    public void AssociativityAddition() {
71      SymbolList s1 = new SymbolList(new[] { varA, varB, grammar.Addition, varA, grammar.Addition });
72      SymbolList s2 = new SymbolList(new[] { varA, varB, varA, grammar.Addition, grammar.Addition });
73
74      int hash1 = ComputeHash(grammar, s1);
75      int hash2 = ComputeHash(grammar, s2);
76
77      Assert.AreEqual(hash1, hash2);
78    }
79
80    [TestMethod]
81    [TestCategory("TreeHashing")]
82    public void RepeatedAddition() {
83      SymbolList s1 = new SymbolList(new[] { varA, varA, grammar.Addition, varA, grammar.Addition });
84      SymbolList s2 = new SymbolList(new[] { varA });
85
86      int hash1 = ComputeHash(grammar, s1);
87      int hash2 = ComputeHash(grammar, s2);
88
89      Assert.AreEqual(hash1, hash2);
90    }
91
92    [TestMethod]
93    [TestCategory("TreeHashing")]
94    public void ComplexInequality() {
95      SymbolList s1 = new SymbolList(new[] { varA, varA, varA, grammar.Multiplication, grammar.Multiplication });
96      SymbolList s2 = new SymbolList(new[] { varA, varA, varA, grammar.Multiplication, grammar.Addition });
97
98      int hash1 = ComputeHash(grammar, s1);
99      int hash2 = ComputeHash(grammar, s2);
100
101      Assert.AreNotEqual(hash1, hash2);
102    }
103
104    [TestMethod]
105    [TestCategory("TreeHashing")]
106    public void NonterminalHashing() {
107      SymbolList s1 = new SymbolList(new Symbol[] { varA, varA, grammar.Expr, grammar.Addition, grammar.Addition });
108      SymbolList s2 = new SymbolList(new Symbol[] { varA, grammar.Expr, grammar.Addition });
109
110      int hash1 = ComputeHash(grammar, s1);
111      int hash2 = ComputeHash(grammar, s2);
112
113      Assert.AreEqual(hash1, hash2);
114    }
115
116    [TestMethod]
117    [TestCategory("TreeHashing")]
118    public void InverseFactorCancelationSimple() {
119      // 1/a * b * a * a
120      SymbolList s1 = new SymbolList(new Symbol[] { varA, grammar.Inv, varB, grammar.Multiplication, varA, grammar.Multiplication, varA, grammar.Multiplication });
121      // a * b
122      SymbolList s2 = new SymbolList(new Symbol[] { varA, varB, grammar.Multiplication });
123
124      int hash1 = ComputeHash(grammar, s1);
125      int hash2 = ComputeHash(grammar, s2);
126
127      Assert.AreEqual(hash1, hash2);
128    }
129
130    [TestMethod]
131    [TestCategory("TreeHashing")]
132    public void InverseFactorCancelationComplex() {
133      SymbolList s1 = new SymbolList(new Symbol[] { varA, grammar.Sin, varA, varA, grammar.Multiplication, varA, grammar.Addition, grammar.Sin, grammar.Addition });
134      SymbolList s2 = new SymbolList(new Symbol[] { varA, varA, varA, grammar.Multiplication, grammar.Addition, grammar.Sin, varA, grammar.Inv, varA, grammar.Sin, varA, grammar.Multiplication, grammar.Multiplication, grammar.Addition });
135
136      int hash1 = ComputeHash(grammar, s1);
137      int hash2 = ComputeHash(grammar, s2);
138
139      Console.WriteLine(s1);
140      Console.WriteLine(PrintTree(s1));
141      Console.WriteLine(grammar.Simplify(null, s1));
142      Console.WriteLine(hash1);
143      Console.WriteLine();
144      Console.WriteLine(s2);
145      Console.WriteLine(PrintTree(s2));
146      Console.WriteLine(grammar.Simplify(null, s2));
147      Console.WriteLine(hash2);
148
149
150      Assert.AreEqual(hash1, hash2);
151    }
152
153    // Constants
154    [TestMethod]
155    [TestCategory("TreeHashing")]
156    public void SimpleConst() {
157      SymbolList s1 = new SymbolList(new Symbol[] { c, varA, grammar.Multiplication, c, grammar.Addition });
158      SymbolList s2 = new SymbolList(new Symbol[] { c, varA, grammar.Multiplication, c, varA, grammar.Multiplication, grammar.Addition, c, grammar.Addition });
159
160      int hash1 = ComputeHash(grammar, s1);
161      int hash2 = ComputeHash(grammar, s2);
162
163      Assert.AreEqual(hash1, hash2);
164    }
165
166    [TestMethod]
167    [TestCategory("TreeHashing")]
168    public void EnumerateGrammarTest() {
169      //const int nvars = 1;
170      //var variables = Enumerable.Range(1, nvars).Select(x => $"x{x}").ToArray();
171      var variables = new[] { "a", "b" };
172      var grammar = new Grammar(variables, Enum.GetValues(typeof(GrammarRule)).Cast<GrammarRule>());
173
174      Func<SymbolList, int> hash = s => grammar.Hasher.CalcHashCode(s);
175      int length = 100;
176      int depth = 20;
177
178      //List<SymbolList> sentences = EnumerateGrammarBreadth(grammar, length: length, hashPhrases: false).ToList();
179      //Console.WriteLine("Breadth: {0} generated, {1} distinct sentences", sentences.Count, sentences.GroupBy(hash).Count());
180
181      //var sentences = EnumerateGrammarBreadth(grammar, length: length, hashPhrases: true).ToList();
182      //Console.WriteLine("Breadth (hashed): {0} generated, {1} distinct sentences", sentences.Count, sentences.GroupBy(hash).Count());
183
184      var sentences = EnumerateGrammarDepth(grammar, length: length, depth: depth, hashPhrases: false).ToList();
185      Console.WriteLine("Depth: {0} generated, {1} distinct sentences", sentences.Count, sentences.GroupBy(hash).Count());
186
187      sentences = EnumerateGrammarDepth(grammar, length: length, depth: depth, hashPhrases: true).ToList();
188      Console.WriteLine("Depth (hashed): {0} generated, {1} distinct sentences", sentences.Count, sentences.GroupBy(hash).Count());
189    }
190
191    [TestMethod]
192    [TestCategory("TreeHashing")]
193    public void HashExtensionsTest() {
194      var variables = new[] { "x", "y", "z" };
195      var rules = Enum.GetValues(typeof(GrammarRule)).Cast<GrammarRule>();
196      var grammar = new Grammar(variables, rules);
197      var add = grammar.Addition;
198      var mul = grammar.Multiplication;
199      var exp = grammar.Exp;
200      var log = grammar.Log;
201      var inv = grammar.Inv;
202      var sin = grammar.Sin;
203
204      var c = grammar.Const;
205      var x = grammar.VarTerminals.Single(v => v.StringRepresentation == "x");
206      var y = grammar.VarTerminals.Single(v => v.StringRepresentation == "y");
207      var z = grammar.VarTerminals.Single(v => v.StringRepresentation == "z");
208
209
210      var ha = SHA512.Create();
211
212      var sentences = new[] {
213        new SymbolList(c, c, x, mul, c, add, inv, mul, c, add),
214        new SymbolList(c, c, x, mul, c, add, log, mul, c, add),
215        new SymbolList(x, x, add),
216        new SymbolList(x, x, add, x, add),
217        new SymbolList(x, x, add, y, add),
218        new SymbolList(x, x, add, x, add, x, add),
219        new SymbolList(x, x, add, y, add, x, add),
220        new SymbolList(x, y, mul, x, y, mul, add),
221        new SymbolList(x, y, mul, y, x, mul, add),
222        new SymbolList(x, y, mul, z, y, mul, add),
223        new SymbolList(x, x, add, y, mul),
224        new SymbolList(x, x, add, y, mul, y, mul),
225        new SymbolList(x, x, inv, mul),
226        new SymbolList(x, inv, x, inv, mul, x, mul, x, mul),
227        new SymbolList(c, x, x, x, x, x, x, mul, mul, mul, mul, mul, mul, c, x, mul, c, add, add),
228        new SymbolList(c, x, x, x, x, x, x, mul, mul, mul, mul, mul, mul, c, x, mul, c, add, add),
229        new SymbolList(c, x, mul, c, x, x, x, x, x, x, mul, mul, mul, mul, mul, mul, c, add, add),
230        new SymbolList(c, x, x, x, x, mul, mul, mul, mul, c, x, mul, c, x, mul, c, add, add, add),
231        new SymbolList(c, x, mul, c, x, x, x, x, mul, mul, mul, mul, c, x, mul, c, add, add, add)
232    };
233
234      foreach (var sentence in sentences) {
235        var simplified = grammar.Simplify(ha, sentence);
236        Console.WriteLine($"{sentence} -> {simplified} {grammar.Hasher.CalcHashCode(sentence)} {grammar.ComputeHash(ha, sentence)}");
237        Console.WriteLine();
238      }
239    }
240
241    string PrintTree(SymbolList s) {
242      var t = grammar.ParseSymbolicExpressionTree(s);
243      return HierarchicalFormatter.Format(t.Root.GetSubtree(0).GetSubtree(0));
244    }
245
246    [TestMethod]
247    [TestCategory("TreeHashing")]
248    public void HashCollisionsTest() {
249      var variables = new[] { "x", "y", "z", "w" };
250      //var rules = Enum.GetValues(typeof(GrammarRule)).Cast<GrammarRule>().Except(new[] { GrammarRule.InverseTerm, GrammarRule.Exponentiation, GrammarRule.Logarithm, GrammarRule.Sine });
251      var rules = Enum.GetValues(typeof(GrammarRule)).Cast<GrammarRule>();
252      var grammar = new Grammar(variables, rules);
253
254      int maxLength = 20, maxDepth = int.MaxValue;
255      var sentences = EnumerateGrammarDepth(grammar, length: maxLength, depth: maxDepth, hashPhrases: false).ToList();
256      var count = sentences.Count;
257
258      var sw = new Stopwatch();
259      sw.Start();
260      var hashes = sentences.Select(grammar.Hasher.CalcHashCode).ToList();
261      sw.Stop();
262      Console.WriteLine($"Old hash: {sentences.Count} ({hashes.Distinct().Count()}) hashed in {sw.ElapsedMilliseconds / 1000.0} seconds ({1000d * sentences.Count / sw.ElapsedMilliseconds} hashes/s)");
263
264      var ha = SHA512.Create();
265      sw.Restart();
266      var hashes_new = sentences.Select(x => grammar.ComputeHash(ha, x)).ToList();
267      sw.Stop();
268      Console.WriteLine($"New hash: {sentences.Count} ({hashes_new.Distinct().Count()}) hashed in {sw.ElapsedMilliseconds / 1000.0} seconds ({1000.0 * sentences.Count / sw.ElapsedMilliseconds} hashes/s)");
269
270      var distinct = Enumerable.Range(0, count).GroupBy(x => hashes_new[x]).Select(g => g.OrderBy(x => x).First()).ToList();
271      var collisions = distinct.ToLookup(x => hashes[x], x => Tuple.Create(hashes_new[x], sentences[x]));
272
273      foreach (var pair in collisions) {
274        if (pair.Count() > 1) {
275          Console.WriteLine(pair.Key);
276          foreach (var t in pair) {
277            Console.WriteLine($"\t{t}");
278            Console.WriteLine($"\t{grammar.ToInfixString(t.Item2)}");
279            var simplified = grammar.Simplify(ha, t.Item2);
280            Console.WriteLine($"\t{simplified}");
281            Console.Write($"\t");
282            PrintTree(t.Item2);
283            Console.WriteLine();
284          }
285        }
286      }
287    }
288
289    [TestMethod]
290    [TestCategory("TreeHashing")]
291    public void HashPerformance() {
292      var nvars = 4;
293      var variables = Enumerable.Range(0, nvars).Select(x => $"X{x}").ToArray();
294      //var rules = Enum.GetValues(typeof(GrammarRule)).Cast<GrammarRule>().Except(new[] { GrammarRule.InverseTerm, GrammarRule.Exponentiation, GrammarRule.Logarithm, GrammarRule.Sine });
295      var rules = Enum.GetValues(typeof(GrammarRule)).Cast<GrammarRule>();
296      var grammar = new Grammar(variables, rules);
297
298      int maxLength = 20, maxDepth = int.MaxValue;
299      var sentences = EnumerateGrammarDepth(grammar, length: maxLength, depth: maxDepth, hashPhrases: false).ToList();
300      var count = sentences.Count;
301
302      var ha = SHA512.Create();
303
304      var sw = new Stopwatch();
305      sw.Start();
306      var hashes = sentences.Select(x => grammar.ComputeHash(ha, x)).ToList();
307      sw.Stop();
308
309      Console.WriteLine($"New: {sentences.Count} ({hashes.Distinct().Count()}) hashed in {sw.ElapsedMilliseconds / 1000.0} seconds ({1000d * sentences.Count / sw.ElapsedMilliseconds} hashes/s)");
310
311      sw.Restart();
312      var hashes_old = sentences.Select(x => grammar.Hasher.CalcHashCode(x)).ToList();
313      sw.Stop();
314
315      Console.WriteLine($"Old: {sentences.Count} ({hashes_old.Distinct().Count()}) hashed in {sw.ElapsedMilliseconds / 1000.0} seconds ({1000d * sentences.Count / sw.ElapsedMilliseconds} hashes/s)");
316    }
317
318    #region enumerate the grammar
319    private static IEnumerable<SymbolList> EnumerateGrammarBreadth(Grammar grammar, int length, bool hashPhrases = true) {
320      var phrases = new Queue<SymbolList>();
321      phrases.Enqueue(new SymbolList(grammar.StartSymbol));
322      var sentences = new List<SymbolList>();
323      var archive = new HashSet<int>();
324
325      while (phrases.Any()) {
326        var phrase = phrases.Dequeue();
327
328        if (phrase.Count > length)
329          continue;
330
331        if (phrase.IsSentence()) {
332          sentences.Add(phrase);
333          continue;
334        }
335
336        if (hashPhrases && !archive.Add(grammar.Hasher.CalcHashCode(phrase))) {
337          continue;
338        }
339
340        var idx = phrase.NextNonterminalIndex();
341        var productions = grammar.Productions[phrase[idx]];
342        var derived = productions.Select(p => phrase.DerivePhrase(idx, p)).Where(p => p.Count <= length);
343        foreach (var d in derived)
344          phrases.Enqueue(d);
345      }
346      return sentences;
347    }
348
349    private static IEnumerable<SymbolList> EnumerateGrammarDepth(Grammar grammar, int length, int depth, bool hashPhrases = true) {
350      return Expand(new SymbolList(grammar.StartSymbol), grammar, length, 0, depth, hashPhrases ? new HashSet<int>() : null);
351    }
352
353    private static IEnumerable<SymbolList> Expand(SymbolList phrase, Grammar grammar, int maxLength, int depth, int maxDepth, HashSet<int> visited) {
354      if (maxDepth < depth || maxLength < phrase.Count) {
355        yield break;
356      }
357
358      if (phrase.IsSentence()) {
359        yield return phrase;
360        yield break;
361      }
362
363      if (visited != null && !visited.Add(HashExtensions.ComputeHash(grammar, null, phrase))) {
364        yield break;
365      }
366
367      var i = phrase.NextNonterminalIndex();
368      var productions = grammar.Productions[phrase[i]];
369
370      foreach (var s in productions.SelectMany(p => Expand(phrase.DerivePhrase(i, p), grammar, maxLength, depth + 1, maxDepth, visited)))
371        yield return s;
372    }
373    #endregion
374  }
375}
Note: See TracBrowser for help on using the repository browser.