source: stable/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Hashing/SymbolicExpressionTreeHash.cs @ 17142

Last change on this file since 17142 was 17142, checked in by abeham, 2 months ago

#3015: merged 17132 to stable

File size: 15.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
26using static HeuristicLab.Problems.DataAnalysis.Symbolic.SymbolicExpressionHashExtensions;
27
28namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
29  public static class SymbolicExpressionTreeHash {
30    private static readonly Addition add = new Addition();
31    private static readonly Subtraction sub = new Subtraction();
32    private static readonly Multiplication mul = new Multiplication();
33    private static readonly Division div = new Division();
34    private static readonly Logarithm log = new Logarithm();
35    private static readonly Exponential exp = new Exponential();
36    private static readonly Sine sin = new Sine();
37    private static readonly Cosine cos = new Cosine();
38    private static readonly Constant constant = new Constant();
39
40    private static ISymbolicExpressionTreeNode ActualRoot(this ISymbolicExpressionTree tree) => tree.Root.GetSubtree(0).GetSubtree(0);
41    public static ulong HashFunction(byte[] input) => HashUtil.DJBHash(input);
42
43    #region tree hashing
44    public static ulong[] Hash(this ISymbolicExpressionTree tree, bool simplify = false, bool strict = false) {
45      return tree.ActualRoot().Hash(simplify, strict);
46    }
47
48    public static ulong[] Hash(this ISymbolicExpressionTreeNode node, bool simplify = false, bool strict = false) {
49      var hashNodes = simplify ? node.MakeNodes(strict).Simplify(HashFunction) : node.MakeNodes(strict).Sort(HashFunction);
50      var hashes = new ulong[hashNodes.Length];
51      for (int i = 0; i < hashes.Length; ++i) {
52        hashes[i] = hashNodes[i].CalculatedHashValue;
53      }
54      return hashes;
55    }
56
57    public static ulong ComputeHash(this ISymbolicExpressionTree tree, bool simplify = false, bool strict = false) {
58      return ComputeHash(tree.ActualRoot(), simplify, strict);
59    }
60
61    public static ulong ComputeHash(this ISymbolicExpressionTreeNode treeNode, bool simplify = false, bool strict = false) {
62      return treeNode.Hash(simplify, strict).Last();
63    }
64
65    public static HashNode<ISymbolicExpressionTreeNode> ToHashNode(this ISymbolicExpressionTreeNode node, bool strict = false) {
66      var symbol = node.Symbol;
67      var name = symbol.Name;
68      if (node is ConstantTreeNode constantNode) {
69        name = strict ? constantNode.Value.ToString() : symbol.Name;
70      } else if (node is VariableTreeNode variableNode) {
71        name = strict ? variableNode.Weight.ToString() + variableNode.VariableName : variableNode.VariableName;
72      }
73      var hash = (ulong)name.GetHashCode();
74      var hashNode = new HashNode<ISymbolicExpressionTreeNode> {
75        Data = node,
76        Arity = node.SubtreeCount,
77        Size = node.SubtreeCount,
78        IsCommutative = node.Symbol is Addition || node.Symbol is Multiplication,
79        Enabled = true,
80        HashValue = hash,
81        CalculatedHashValue = hash
82      };
83      if (symbol is Addition) {
84        hashNode.Simplify = SimplifyAddition;
85      } else if (symbol is Multiplication) {
86        hashNode.Simplify = SimplifyMultiplication;
87      } else if (symbol is Division) {
88        hashNode.Simplify = SimplifyDivision;
89      } else if (symbol is Logarithm || symbol is Exponential || symbol is Sine || symbol is Cosine) {
90        hashNode.Simplify = SimplifyUnaryNode;
91      } else if (symbol is Subtraction) {
92        hashNode.Simplify = SimplifyBinaryNode;
93      }
94      return hashNode;
95    }
96
97    public static HashNode<ISymbolicExpressionTreeNode>[] MakeNodes(this ISymbolicExpressionTree tree, bool strict = false) {
98      return MakeNodes(tree.ActualRoot(), strict);
99    }
100
101    public static HashNode<ISymbolicExpressionTreeNode>[] MakeNodes(this ISymbolicExpressionTreeNode node, bool strict = false) {
102      return node.IterateNodesPostfix().Select(x => x.ToHashNode(strict)).ToArray().UpdateNodeSizes();
103    }
104    #endregion
105
106    #region tree similarity
107    public static double ComputeSimilarity(ISymbolicExpressionTree t1, ISymbolicExpressionTree t2, bool simplify = false, bool strict = false) {
108      return ComputeSimilarity(t1.ActualRoot(), t2.ActualRoot(), simplify, strict);
109    }
110
111    public static double ComputeSimilarity(ISymbolicExpressionTreeNode t1, ISymbolicExpressionTreeNode t2, bool simplify = false, bool strict = false) {
112      var lh = t1.Hash(simplify, strict);
113      var rh = t2.Hash(simplify, strict);
114
115      Array.Sort(lh);
116      Array.Sort(rh);
117
118      return ComputeSimilarity(lh, rh);
119    }
120
121    // requires lhs and rhs to be sorted
122    public static int IntersectCount(this ulong[] lh, ulong[] rh) {
123      int count = 0;
124      for (int i = 0, j = 0; i < lh.Length && j < rh.Length;) {
125        var h1 = lh[i];
126        var h2 = rh[j];
127        if (h1 == h2) {
128          ++count;
129          ++i;
130          ++j;
131        } else if (h1 < h2) {
132          ++i;
133        } else if (h1 > h2) {
134          ++j;
135        }
136      }
137      return count;
138    }
139
140    public static IEnumerable<ulong> Intersect(this ulong[] lh, ulong[] rh) {
141      for (int i = 0, j = 0; i < lh.Length && j < rh.Length;) {
142        var h1 = lh[i];
143        var h2 = rh[j];
144        if (h1 == h2) {
145          yield return h1;
146          ++i;
147          ++j;
148        } else if (h1 < h2) {
149          ++i;
150        } else if (h1 > h2) {
151          ++j;
152        }
153      }
154    }
155
156    // this will only work if lh and rh are sorted
157    public static double ComputeSimilarity(ulong[] lh, ulong[] rh) {
158      return 2d * IntersectCount(lh, rh) / (lh.Length + rh.Length);
159    }
160
161    public static double ComputeAverageSimilarity(IList<ISymbolicExpressionTree> trees, bool simplify = false, bool strict = false) {
162      var total = trees.Count * (trees.Count - 1) / 2;
163      double avg = 0;
164      var hashes = new ulong[trees.Count][];
165      // build hash arrays
166      for (int i = 0; i < trees.Count; ++i) {
167        var nodes = trees[i].MakeNodes(strict);
168        hashes[i] = (simplify ? nodes.Simplify(HashFunction) : nodes.Sort(HashFunction)).Select(x => x.CalculatedHashValue).ToArray();
169        Array.Sort(hashes[i]);
170      }
171      // compute similarity matrix
172      for (int i = 0; i < trees.Count - 1; ++i) {
173        for (int j = i + 1; j < trees.Count; ++j) {
174          avg += ComputeSimilarity(hashes[i], hashes[j]);
175        }
176      }
177      return avg / total;
178    }
179
180    public static double[,] ComputeSimilarityMatrix(IList<ulong[]> hashes) {
181      // compute similarity matrix
182      var n = hashes.Count;
183      var sim = new double[n, n];
184      for (int i = 0; i < n - 1; ++i) {
185        for (int j = i + 1; j < n; ++j) {
186          sim[i, j] = sim[j, i] = ComputeSimilarity(hashes[i], hashes[j]);
187        }
188      }
189      return sim;
190    }
191
192    public static double[,] ComputeSimilarityMatrix(IList<ISymbolicExpressionTree> trees, bool simplify = false, bool strict = false) {
193      var hashes = new ulong[trees.Count][];
194      // build hash arrays
195      for (int i = 0; i < trees.Count; ++i) {
196        var nodes = trees[i].MakeNodes(strict);
197        hashes[i] = (simplify ? nodes.Simplify(HashFunction) : nodes.Sort(HashFunction)).Select(x => x.CalculatedHashValue).ToArray();
198        Array.Sort(hashes[i]);
199      }
200      return ComputeSimilarityMatrix(hashes);
201    }
202    #endregion
203
204    #region parse a nodes array back into a tree
205    public static ISymbolicExpressionTree ToTree(this HashNode<ISymbolicExpressionTreeNode>[] nodes) {
206      var root = new ProgramRootSymbol().CreateTreeNode();
207      var start = new StartSymbol().CreateTreeNode();
208      root.AddSubtree(start);
209      start.AddSubtree(nodes.ToSubtree());
210      return new SymbolicExpressionTree(root);
211    }
212
213    public static ISymbolicExpressionTreeNode ToSubtree(this HashNode<ISymbolicExpressionTreeNode>[] nodes) {
214      var treeNodes = nodes.Select(x => x.Data.Symbol.CreateTreeNode()).ToArray();
215
216      for (int i = nodes.Length - 1; i >= 0; --i) {
217        var node = nodes[i];
218
219        if (node.IsLeaf) {
220          if (node.Data is VariableTreeNode variable) {
221            var variableTreeNode = (VariableTreeNode)treeNodes[i];
222            variableTreeNode.VariableName = variable.VariableName;
223            variableTreeNode.Weight = variable.Weight;
224          } else if (node.Data is ConstantTreeNode @const) {
225            var constantTreeNode = (ConstantTreeNode)treeNodes[i];
226            constantTreeNode.Value = @const.Value;
227          }
228          continue;
229        }
230
231        var treeNode = treeNodes[i];
232
233        foreach (var j in nodes.IterateChildren(i)) {
234          treeNode.AddSubtree(treeNodes[j]);
235        }
236      }
237
238      return treeNodes.Last();
239    }
240
241    private static T CreateTreeNode<T>(this ISymbol symbol) where T : class, ISymbolicExpressionTreeNode {
242      return (T)symbol.CreateTreeNode();
243    }
244    #endregion
245
246    #region tree simplification
247    // these simplification methods rely on the assumption that child nodes of the current node have already been simplified
248    // (in other words simplification should be applied in a bottom-up fashion)
249    public static ISymbolicExpressionTree Simplify(ISymbolicExpressionTree tree) {
250      return tree.MakeNodes().Simplify(HashFunction).ToTree();
251    }
252
253    public static void SimplifyAddition(ref HashNode<ISymbolicExpressionTreeNode>[] nodes, int i) {
254      // simplify additions of terms by eliminating terms with the same symbol and hash
255      var children = nodes.IterateChildren(i);
256
257      // we always assume the child nodes are sorted
258      var curr = children[0];
259      var node = nodes[i];
260
261      foreach (var j in children.Skip(1)) {
262        if (nodes[j] == nodes[curr]) {
263          nodes.SetEnabled(j, false);
264          node.Arity--;
265        } else {
266          curr = j;
267        }
268      }
269      if (node.Arity == 1) { // if the arity is 1 we don't need the addition node at all
270        node.Enabled = false;
271      }
272    }
273
274    // simplify multiplications by reducing constants and div terms
275    public static void SimplifyMultiplication(ref HashNode<ISymbolicExpressionTreeNode>[] nodes, int i) {
276      var node = nodes[i];
277      var children = nodes.IterateChildren(i);
278
279      for (int j = 0; j < children.Length; ++j) {
280        var c = children[j];
281        var child = nodes[c];
282
283        if (!child.Enabled)
284          continue;
285
286        var symbol = child.Data.Symbol;
287        if (child.Data is ConstantTreeNode firstConst) {
288          // fold sibling constant nodes into the first constant
289          for (int k = j + 1; k < children.Length; ++k) {
290            var sibling = nodes[children[k]];
291            if (sibling.Data is ConstantTreeNode otherConst) {
292              sibling.Enabled = false;
293              node.Arity--;
294              firstConst.Value *= otherConst.Value;
295            } else {
296              break;
297            }
298          }
299        } else if (child.Data is VariableTreeNode variable) {
300          // fold sibling constant nodes into the variable weight
301          for (int k = j + 1; k < children.Length; ++k) {
302            var sibling = nodes[children[k]];
303            if (sibling.Data is ConstantTreeNode constantNode) {
304              sibling.Enabled = false;
305              node.Arity--;
306              variable.Weight *= constantNode.Value;
307            } else {
308              break;
309            }
310          }
311        } else if (symbol is Division) {
312          // 1/x is expressed as div(x) (with a single child)
313          // we assume division always has arity 1 or 2
314          var d = child.Arity == 1 ? c - 1 : c - nodes[c - 1].Size - 2;
315          var denominator = nodes[d];
316
317          // iterate children of node i to see if any of them matches the denominator of div node c
318          for (int k = 0; k < children.Length; ++k) {
319            var sibling = nodes[children[k]];
320            if (sibling.Enabled && sibling == denominator) {
321              nodes.SetEnabled(children[j], false); // disable the div subtree
322              nodes.SetEnabled(children[k], false); // disable the sibling matching the denominator
323
324              node.Arity -= 2; // matching child + division node
325              break;
326            }
327          }
328        }
329
330        if (node.Arity == 0) { // if everything is simplified this node becomes constant
331          var constantTreeNode = constant.CreateTreeNode<ConstantTreeNode>();
332          constantTreeNode.Value = 1;
333          nodes[i] = constantTreeNode.ToHashNode();
334        } else if (node.Arity == 1) { // when i have only 1 arg left i can skip this node
335          node.Enabled = false;
336        }
337      }
338    }
339
340    public static void SimplifyDivision(ref HashNode<ISymbolicExpressionTreeNode>[] nodes, int i) {
341      var node = nodes[i];
342      var children = nodes.IterateChildren(i);
343
344      var tmp = nodes;
345
346      if (children.All(x => tmp[x].Data.Symbol is Constant)) {
347        var v = ((ConstantTreeNode)nodes[children.First()].Data).Value;
348        if (node.Arity == 1) {
349          v = 1 / v;
350        } else if (node.Arity > 1) {
351          foreach (var j in children.Skip(1)) {
352            v /= ((ConstantTreeNode)nodes[j].Data).Value;
353          }
354        }
355        var constantTreeNode = constant.CreateTreeNode<ConstantTreeNode>();
356        constantTreeNode.Value = v;
357        nodes[i] = constantTreeNode.ToHashNode();
358        return;
359      }
360
361      var nominator = nodes[children[0]];
362      foreach (var j in children.Skip(1)) {
363        var denominator = nodes[j];
364        if (nominator == denominator) {
365          // disable all the children of the division node (nominator and children + denominator and children)
366          nominator.Enabled = denominator.Enabled = false;
367          node.Arity -= 2; // nominator + denominator
368        }
369        if (node.Arity == 0) {
370          var constantTreeNode = constant.CreateTreeNode<ConstantTreeNode>();
371          constantTreeNode.Value = 1; // x / x = 1
372          nodes[i] = constantTreeNode.ToHashNode();
373        }
374      }
375    }
376
377    public static void SimplifyUnaryNode(ref HashNode<ISymbolicExpressionTreeNode>[] nodes, int i) {
378      // check if the child of the unary node is a constant, then the whole node can be simplified
379      var parent = nodes[i];
380      var child = nodes[i - 1];
381
382      var parentSymbol = parent.Data.Symbol;
383      var childSymbol = child.Data.Symbol;
384
385      if (childSymbol is Constant) {
386        nodes[i].Enabled = false;
387      } else if ((parentSymbol is Exponential && childSymbol is Logarithm) || (parentSymbol is Logarithm && childSymbol is Exponential)) {
388        child.Enabled = parent.Enabled = false;
389      }
390    }
391
392    public static void SimplifyBinaryNode(ref HashNode<ISymbolicExpressionTreeNode>[] nodes, int i) {
393      var children = nodes.IterateChildren(i);
394      var tmp = nodes;
395      if (children.All(x => tmp[x].Data.Symbol is Constant)) {
396        foreach (var j in children) {
397          nodes[j].Enabled = false;
398        }
399        nodes[i] = constant.CreateTreeNode().ToHashNode();
400      }
401    }
402    #endregion
403  }
404}
Note: See TracBrowser for help on using the repository browser.