1 | using System.Collections.Generic;
2 | using System.Linq;
3 | using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
4 | using static HeuristicLab.Problems.DataAnalysis.Symbolic.SymbolicExpressionHashExtensions;
5 |
6 | namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
7 | public static class SymbolicExpressionTreeHash {
8 | private static readonly Addition add = new Addition();
9 | private static readonly Subtraction sub = new Subtraction();
10 | private static readonly Multiplication mul = new Multiplication();
11 | private static readonly Division div = new Division();
12 | private static readonly Logarithm log = new Logarithm();
13 | private static readonly Exponential exp = new Exponential();
14 | private static readonly Sine sin = new Sine();
15 | private static readonly Cosine cos = new Cosine();
16 | private static readonly Constant constant = new Constant();
17 |
18 | private static readonly ISymbolicExpressionTreeNodeComparer comparer = new SymbolicExpressionTreeNodeComparer();
19 |
20 | public static int ComputeHash(this ISymbolicExpressionTree tree) {
21 | return ComputeHash(tree.Root.GetSubtree(0).GetSubtree(0));
22 | }
23 |
24 | public static Dictionary<ISymbolicExpressionTreeNode, int> ComputeNodeHashes(this ISymbolicExpressionTree tree) {
25 | var root = tree.Root.GetSubtree(0).GetSubtree(0);
26 | var nodes = root.MakeNodes();
27 | nodes.UpdateNodeSizes();
28 |
29 | for (int i = 0; i < nodes.Length; ++i) {
30 | if (nodes[i].IsChild)
31 | continue;
32 | nodes[i].CalculatedHashValue = nodes.ComputeHash(i);
33 | }
34 | return nodes.ToDictionary(x => x.Data, x => x.CalculatedHashValue);
35 | }
36 |
37 | public static int ComputeHash(this ISymbolicExpressionTreeNode treeNode) {
38 | var hashNodes = treeNode.MakeNodes();
39 | var simplified = hashNodes.Simplify();
40 | return ComputeHash(simplified);
41 | }
42 |
43 | public static int ComputeHash(this HashNode<ISymbolicExpressionTreeNode>[] nodes) {
44 | int hash = 1315423911;
45 | foreach (var node in nodes)
46 | hash ^= (hash << 5) + node.CalculatedHashValue + (hash >> 2);
47 | return hash;
48 | }
49 |
50 | public static HashNode<ISymbolicExpressionTreeNode> ToHashNode(this ISymbolicExpressionTreeNode node) {
51 | var symbol = node.Symbol;
52 | var name = symbol.Name;
53 | if (symbol is Variable) {
54 | var variableTreeNode = (VariableTreeNode)node;
55 | name = variableTreeNode.VariableName;
56 | }
57 | var hash = name.GetHashCode();
58 | var hashNode = new HashNode<ISymbolicExpressionTreeNode>(comparer) {
59 | Data = node,
60 | Arity = node.SubtreeCount,
61 | Size = node.SubtreeCount,
62 | IsCommutative = node.Symbol is Addition || node.Symbol is Multiplication,
63 | Enabled = true,
64 | HashValue = hash,
65 | CalculatedHashValue = hash
66 | };
67 | if (symbol is Addition) {
68 | hashNode.Simplify = SimplifyAddition;
69 | } else if (symbol is Multiplication) {
70 | hashNode.Simplify = SimplifyMultiplication;
71 | } else if (symbol is Division) {
72 | hashNode.Simplify = SimplifyDivision;
73 | } else if (symbol is Logarithm || symbol is Exponential || symbol is Sine || symbol is Cosine) {
74 | hashNode.Simplify = SimplifyUnaryNode;
75 | } else if (symbol is Subtraction) {
76 | hashNode.Simplify = SimplifyBinaryNode;
77 | }
78 | return hashNode;
79 | }
80 |
81 | public static HashNode<ISymbolicExpressionTreeNode>[] MakeNodes(this ISymbolicExpressionTreeNode node) {
82 | return node.IterateNodesPostfix().Select(ToHashNode).ToArray();
83 | }
84 |
85 | #region parse a nodes array back into a tree
86 | public static ISymbolicExpressionTree ToTree(this HashNode<ISymbolicExpressionTreeNode>[] nodes) {
87 | var root = new ProgramRootSymbol().CreateTreeNode();
88 | var start = new StartSymbol().CreateTreeNode();
89 | root.AddSubtree(start);
90 | start.AddSubtree(nodes.ToSubtree());
91 | return new SymbolicExpressionTree(root);
92 | }
93 |
94 | public static ISymbolicExpressionTreeNode ToSubtree(this HashNode<ISymbolicExpressionTreeNode>[] nodes) {
95 | var treeNodes = nodes.Select(x => x.Data.Symbol.CreateTreeNode()).ToArray();
96 |
97 | for (int i = nodes.Length - 1; i >= 0; --i) {
98 | var node = nodes[i];
99 |
100 | if (node.IsChild) {
101 | if (node.Data is VariableTreeNode variable) {
102 | var variableTreeNode = (VariableTreeNode)treeNodes[i];
103 | variableTreeNode.VariableName = variable.VariableName;
104 | variableTreeNode.Weight = 1;
105 | } else if (node.Data is ConstantTreeNode @const) {
106 | var constantTreeNode = (ConstantTreeNode)treeNodes[i];
107 | constantTreeNode.Value = @const.Value;
108 | }
109 | continue;
110 | }
111 |
112 | var treeNode = treeNodes[i];
113 |
114 | foreach (var j in nodes.IterateChildren(i)) {
115 | treeNode.AddSubtree(treeNodes[j]);
116 | }
117 | }
118 |
119 | return treeNodes.Last();
120 | }
121 |
122 | private static T CreateTreeNode<T>(this ISymbol symbol) where T : class, ISymbolicExpressionTreeNode {
123 | return (T)symbol.CreateTreeNode();
124 | }
125 | #endregion
126 |
127 | #region tree simplification
128 | // these simplification methods rely on the assumption that child nodes of the current node have already been simplified
129 | // (in other words simplification should be applied in a bottom-up fashion)
130 | public static ISymbolicExpressionTree Simplify(ISymbolicExpressionTree tree) {
131 | var root = tree.Root.GetSubtree(0).GetSubtree(0);
132 | var nodes = root.MakeNodes();
133 | var simplified = nodes.Simplify();
134 | return simplified.ToTree();
135 | }
136 |
137 | public static void SimplifyAddition(HashNode<ISymbolicExpressionTreeNode>[] nodes, int i) {
138 | // simplify additions of terms by eliminating terms with the same symbol and hash
139 | var children = nodes.IterateChildren(i);
140 |
141 | var curr = children[0];
142 | var node = nodes[i];
143 |
144 | foreach (var j in children.Skip(1)) {
145 | if (nodes[j] == nodes[curr]) {
146 | for (int k = j - nodes[j].Size; k <= j; ++k) {
147 | nodes[k].Enabled = false;
148 | }
149 | node.Arity--;
150 | } else {
151 | curr = j;
152 | }
153 | }
154 | if (node.Arity == 1) { // if the arity is 1 we don't need the addition node at all
155 | node.Enabled = false;
156 | }
157 | }
158 |
159 | // simplify multiplications by reducing constants and div terms
160 | public static void SimplifyMultiplication(HashNode<ISymbolicExpressionTreeNode>[] nodes, int i) {
161 | var node = nodes[i];
162 | var children = nodes.IterateChildren(i);
163 |
164 | for (int j = 0; j < children.Length; ++j) {
165 | var c = children[j];
166 | var child = nodes[c];
167 |
168 | if (!child.Enabled)
169 | continue;
170 |
171 | var symbol = child.Data.Symbol;
172 | if (symbol is Constant) {
173 | for (int k = j + 1; k < children.Length; ++k) {
174 | var d = children[k];
175 | if (nodes[d].Data.Symbol is Constant) {
176 | ((ConstantTreeNode)child.Data).Value *= ((ConstantTreeNode)nodes[d].Data).Value;
177 | nodes[d].Enabled = false;
178 | node.Arity--;
179 | } else {
180 | break;
181 | }
182 | }
183 | } else if (symbol is Division) {
184 | var div = nodes[c];
185 | var denominator =
186 | div.Arity == 1 ?
187 | nodes[c - 1] : // 1 / x is expressed as div(x) (with a single child)
188 | nodes[c - nodes[c - 1].Size - 2]; // assume division always has arity 1 or 2
189 |
190 | foreach (var d in children) {
191 | if (nodes[d].Enabled && nodes[d] == denominator) {
192 | nodes[c].Enabled = nodes[d].Enabled = denominator.Enabled = false;
193 | node.Arity -= 2; // matching child + division node
194 | break;
195 | }
196 | }
197 | }
198 |
199 | if (node.Arity == 0) { // if everything is simplified this node becomes constant
200 | var constantTreeNode = constant.CreateTreeNode<ConstantTreeNode>();
201 | constantTreeNode.Value = 1;
202 | nodes[i] = constantTreeNode.ToHashNode();
203 | } else if (node.Arity == 1) { // when i have only 1 arg left i can skip this node
204 | node.Enabled = false;
205 | }
206 | }
207 | }
208 |
209 | public static void SimplifyDivision(HashNode<ISymbolicExpressionTreeNode>[] nodes, int i) {
210 | var node = nodes[i];
211 | var children = nodes.IterateChildren(i);
212 |
213 | if (children.All(x => nodes[x].Data.Symbol is Constant)) {
214 | var v = ((ConstantTreeNode)nodes[children.First()].Data).Value;
215 | if (node.Arity == 1) {
216 | v = 1 / v;
217 | } else if (node.Arity > 1) {
218 | foreach (var j in children.Skip(1)) {
219 | v /= ((ConstantTreeNode)nodes[j].Data).Value;
220 | }
221 | }
222 | var constantTreeNode = constant.CreateTreeNode<ConstantTreeNode>();
223 | constantTreeNode.Value = v;
224 | nodes[i] = constantTreeNode.ToHashNode();
225 | return;
226 | }
227 |
228 | var nominator = nodes[children[0]];
229 | foreach (var j in children.Skip(1)) {
230 | var denominator = nodes[j];
231 | if (nominator == denominator) {
232 | // disable all the children of the division node (nominator and children + denominator and children)
233 | nominator.Enabled = denominator.Enabled = false;
234 | node.Arity -= 2; // nominator + denominator
235 | }
236 | if (node.Arity == 0) {
237 | var constantTreeNode = constant.CreateTreeNode<ConstantTreeNode>();
238 | constantTreeNode.Value = 1; // x / x = 1
239 | nodes[i] = constantTreeNode.ToHashNode();
240 | }
241 | }
242 | }
243 |
244 | public static void SimplifyUnaryNode(HashNode<ISymbolicExpressionTreeNode>[] nodes, int i) {
245 | // check if the child of the unary node is a constant, then the whole node can be simplified
246 | var parent = nodes[i];
247 | var child = nodes[i - 1];
248 |
249 | var parentSymbol = parent.Data.Symbol;
250 | var childSymbol = child.Data.Symbol;
251 |
252 | if (childSymbol is Constant) {
253 | nodes[i].Enabled = false;
254 | } else if ((parentSymbol is Exponential && childSymbol is Logarithm) || (parentSymbol is Logarithm && childSymbol is Exponential)) {
255 | child.Enabled = parent.Enabled = false;
256 | }
257 | }
258 |
259 | public static void SimplifyBinaryNode(HashNode<ISymbolicExpressionTreeNode>[] nodes, int i) {
260 | var children = nodes.IterateChildren(i);
261 | if (children.All(x => nodes[x].Data.Symbol is Constant)) {
262 | foreach (var j in children) {
263 | nodes[j].Enabled = false;
264 | }
265 | nodes[i] = constant.CreateTreeNode().ToHashNode();
266 | }
267 | }
268 | #endregion
269 | }
270 | }