Changeset 17091


Ignore:
Timestamp:
07/05/19 15:46:10 (2 weeks ago)
Author:
abeham
Message:

#2959: merged revisions 16278, 16279, 16283 to stable

Location:
stable
Files:
5 edited

Legend:

Unmodified
Added
Removed
  • stable

  • stable/HeuristicLab.Problems.DataAnalysis.Symbolic

  • stable/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/TreeMatching/SymbolicExpressionTreeBottomUpSimilarityCalculator.cs

    r15584 r17091  
    2222using System;
    2323using System.Collections.Generic;
    24 using System.Diagnostics;
    2524using System.Globalization;
    2625using System.Linq;
     
    3130using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3231
     32using NodeMap = System.Collections.Generic.Dictionary<HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.ISymbolicExpressionTreeNode, HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.ISymbolicExpressionTreeNode>;
     33
    3334namespace HeuristicLab.Problems.DataAnalysis.Symbolic {
    3435  [StorableClass]
     
    4041    protected override bool IsCommutative { get { return true; } }
    4142
     43    public bool MatchConstantValues { get; set; }
     44    public bool MatchVariableWeights { get; set; }
     45
    4246    [StorableConstructor]
    4347    protected SymbolicExpressionTreeBottomUpSimilarityCalculator(bool deserializing)
     
    5357    }
    5458
     59    #region static methods
     60    private static ISymbolicExpressionTreeNode ActualRoot(ISymbolicExpressionTree tree) {
     61      return tree.Root.GetSubtree(0).GetSubtree(0);
     62    }
     63
     64    public static double CalculateSimilarity(ISymbolicExpressionTree t1, ISymbolicExpressionTree t2, bool strict = false) {
     65      return CalculateSimilarity(ActualRoot(t1), ActualRoot(t2), strict);
     66    }
     67
     68    public static double CalculateSimilarity(ISymbolicExpressionTreeNode n1, ISymbolicExpressionTreeNode n2, bool strict = false) {
     69      var calculator = new SymbolicExpressionTreeBottomUpSimilarityCalculator { MatchConstantValues = strict, MatchVariableWeights = strict };
     70      return CalculateSimilarity(n1, n2, strict);
     71    }
     72
     73    public static Dictionary<ISymbolicExpressionTreeNode, ISymbolicExpressionTreeNode> ComputeBottomUpMapping(ISymbolicExpressionTree t1, ISymbolicExpressionTree t2, bool strict = false) {
     74      return ComputeBottomUpMapping(ActualRoot(t1), ActualRoot(t2), strict);
     75    }
     76
     77    public static Dictionary<ISymbolicExpressionTreeNode, ISymbolicExpressionTreeNode> ComputeBottomUpMapping(ISymbolicExpressionTreeNode n1, ISymbolicExpressionTreeNode n2, bool strict = false) {
     78      var calculator = new SymbolicExpressionTreeBottomUpSimilarityCalculator { MatchConstantValues = strict, MatchVariableWeights = strict };
     79      return calculator.ComputeBottomUpMapping(n1, n2);
     80    }
     81    #endregion
     82
    5583    public double CalculateSimilarity(ISymbolicExpressionTree t1, ISymbolicExpressionTree t2) {
    56       if (t1 == t2)
     84      return CalculateSimilarity(t1, t2, out Dictionary<ISymbolicExpressionTreeNode, ISymbolicExpressionTreeNode> map);
     85    }
     86
     87    public double CalculateSimilarity(ISymbolicExpressionTree t1, ISymbolicExpressionTree t2, out NodeMap map) {
     88      if (t1 == t2) {
     89        map = null;
    5790        return 1;
    58 
    59       var map = ComputeBottomUpMapping(t1.Root, t2.Root);
    60       return 2.0 * map.Count / (t1.Length + t2.Length);
     91      }
     92      map = ComputeBottomUpMapping(t1, t2);
     93      return 2.0 * map.Count / (t1.Length + t2.Length - 4); // -4 for skipping root and start symbols in the two trees
    6194    }
    6295
     
    78111    }
    79112
     113    public Dictionary<ISymbolicExpressionTreeNode, ISymbolicExpressionTreeNode> ComputeBottomUpMapping(ISymbolicExpressionTree t1, ISymbolicExpressionTree t2) {
     114      return ComputeBottomUpMapping(t1.Root.GetSubtree(0).GetSubtree(0), t2.Root.GetSubtree(0).GetSubtree(0));
     115    }
     116
    80117    public Dictionary<ISymbolicExpressionTreeNode, ISymbolicExpressionTreeNode> ComputeBottomUpMapping(ISymbolicExpressionTreeNode n1, ISymbolicExpressionTreeNode n2) {
    81       var comparer = new SymbolicExpressionTreeNodeComparer(); // use a node comparer because it's faster than calling node.ToString() (strings are expensive) and comparing strings
    82118      var compactedGraph = Compact(n1, n2);
    83119
    84       var forwardMap = new Dictionary<ISymbolicExpressionTreeNode, ISymbolicExpressionTreeNode>(); // nodes of t1 => nodes of t2
    85       var reverseMap = new Dictionary<ISymbolicExpressionTreeNode, ISymbolicExpressionTreeNode>(); // nodes of t2 => nodes of t1
    86 
    87       // visit nodes in order of decreasing height to ensure correct mapping
    88       var nodes1 = n1.IterateNodesPrefix().OrderByDescending(x => x.GetDepth()).ToList();
    89       var nodes2 = n2.IterateNodesPrefix().ToList();
    90       for (int i = 0; i < nodes1.Count; ++i) {
    91         var v = nodes1[i];
    92         if (forwardMap.ContainsKey(v))
     120      IEnumerable<ISymbolicExpressionTreeNode> Subtrees(ISymbolicExpressionTreeNode node, bool commutative) {
     121        var subtrees = node.IterateNodesPrefix();
     122        return commutative ? subtrees.OrderBy(x => compactedGraph[x].Hash) : subtrees;
     123      }
     124
     125      var nodes1 = n1.IterateNodesPostfix().OrderByDescending(x => x.GetLength()); // by descending length so that largest subtrees are mapped first
     126      var nodes2 = (List<ISymbolicExpressionTreeNode>)n2.IterateNodesPostfix();
     127
     128      var forward = new NodeMap();
     129      var reverse = new NodeMap();
     130
     131      foreach (ISymbolicExpressionTreeNode v in nodes1) {
     132        if (forward.ContainsKey(v))
    93133          continue;
     134
    94135        var kv = compactedGraph[v];
    95         ISymbolicExpressionTreeNode w = null;
    96         for (int j = 0; j < nodes2.Count; ++j) {
    97           var t = nodes2[j];
    98           if (reverseMap.ContainsKey(t) || compactedGraph[t] != kv)
     136        var commutative = v.SubtreeCount > 1 && commutativeSymbols.Contains(kv.Label);
     137
     138        foreach (ISymbolicExpressionTreeNode w in nodes2) {
     139          if (w.GetLength() != kv.Length || w.GetDepth() != kv.Depth || reverse.ContainsKey(w) || compactedGraph[w] != kv)
    99140            continue;
    100           w = t;
     141
     142          // map one whole subtree to the other
     143          foreach (var t in Subtrees(v, commutative).Zip(Subtrees(w, commutative), Tuple.Create)) {
     144            forward[t.Item1] = t.Item2;
     145            reverse[t.Item2] = t.Item1;
     146          }
     147
    101148          break;
    102149        }
    103         if (w == null) continue;
    104 
    105         // at this point we know that v and w are isomorphic, however, the mapping cannot be done directly
    106         // (as in the paper) because the trees are unordered (subtree order might differ). the solution is
    107         // to sort subtrees from under commutative labels (this will work because the subtrees are isomorphic!)
    108         // while iterating over the two subtrees
    109         var vv = IterateBreadthOrdered(v, comparer).ToList();
    110         var ww = IterateBreadthOrdered(w, comparer).ToList();
    111         int len = Math.Min(vv.Count, ww.Count);
    112         for (int j = 0; j < len; ++j) {
    113           var s = vv[j];
    114           var t = ww[j];
    115           Debug.Assert(!reverseMap.ContainsKey(t));
    116 
    117           forwardMap[s] = t;
    118           reverseMap[t] = s;
    119         }
    120       }
    121 
    122       return forwardMap;
     150      }
     151
     152      return forward;
    123153    }
    124154
     
    132162      var nodeMap = new Dictionary<ISymbolicExpressionTreeNode, GraphNode>(); // K
    133163      var labelMap = new Dictionary<string, GraphNode>(); // L
    134       var childrenCount = new Dictionary<ISymbolicExpressionTreeNode, int>(); // Children
    135164
    136165      var nodes = n1.IterateNodesPostfix().Concat(n2.IterateNodesPostfix()); // the disjoint union F
    137       var list = new List<GraphNode>();
    138       var queue = new Queue<ISymbolicExpressionTreeNode>();
    139 
    140       foreach (var n in nodes) {
    141         if (n.SubtreeCount == 0) {
    142           var label = GetLabel(n);
     166      var graph = new List<GraphNode>();
     167
     168      IEnumerable<GraphNode> Subtrees(GraphNode g, bool commutative) {
     169        var subtrees = g.SymbolicExpressionTreeNode.Subtrees.Select(x => nodeMap[x]);
     170        return commutative ? subtrees.OrderBy(x => x.Hash) : subtrees;
     171      }
     172
     173      foreach (var node in nodes) {
     174        var label = GetLabel(node);
     175
     176        if (node.SubtreeCount == 0) {
    143177          if (!labelMap.ContainsKey(label)) {
    144             var z = new GraphNode { SymbolicExpressionTreeNode = n, Label = label };
    145             labelMap[z.Label] = z;
    146           }
    147           nodeMap[n] = labelMap[label];
    148           queue.Enqueue(n);
     178            labelMap[label] = new GraphNode(node, label);
     179          }
     180          nodeMap[node] = labelMap[label];
    149181        } else {
    150           childrenCount[n] = n.SubtreeCount;
    151         }
    152       }
    153       while (queue.Any()) {
    154         var n = queue.Dequeue();
    155         if (n.SubtreeCount > 0) {
     182          var v = new GraphNode(node, label);
    156183          bool found = false;
    157           var label = n.Symbol.Name;
    158           var depth = n.GetDepth();
    159 
    160           bool sort = n.SubtreeCount > 1 && commutativeSymbols.Contains(label);
    161           var nSubtrees = n.Subtrees.Select(x => nodeMap[x]).ToList();
    162           if (sort) nSubtrees.Sort((a, b) => string.CompareOrdinal(a.Label, b.Label));
    163 
    164           for (int i = list.Count - 1; i >= 0; --i) {
    165             var w = list[i];
    166             if (!(n.SubtreeCount == w.SubtreeCount && label == w.Label && depth == w.Depth))
     184          var commutative = node.SubtreeCount > 1 && commutativeSymbols.Contains(label);
     185
     186          var vv = Subtrees(v, commutative);
     187
     188          foreach (var w in graph) {
     189            if (v.Depth != w.Depth || v.SubtreeCount != w.SubtreeCount || v.Length != w.Length || v.Label != w.Label) {
    167190              continue;
    168 
    169             // sort V and W when the symbol is commutative because we are dealing with unordered trees
    170             var m = w.SymbolicExpressionTreeNode;
    171             var mSubtrees = m.Subtrees.Select(x => nodeMap[x]).ToList();
    172             if (sort) mSubtrees.Sort((a, b) => string.CompareOrdinal(a.Label, b.Label));
    173 
    174             found = nSubtrees.SequenceEqual(mSubtrees);
     191            }
     192
     193            var ww = Subtrees(w, commutative);
     194            found = vv.SequenceEqual(ww);
     195
    175196            if (found) {
    176               nodeMap[n] = w;
     197              nodeMap[node] = w;
    177198              break;
    178199            }
    179200          }
    180 
    181201          if (!found) {
    182             var w = new GraphNode { SymbolicExpressionTreeNode = n, Label = label, Depth = depth };
    183             list.Add(w);
    184             nodeMap[n] = w;
     202            nodeMap[node] = v;
     203            graph.Add(v);
    185204          }
    186205        }
    187 
    188         if (n == n1 || n == n2)
    189           continue;
    190 
    191         var p = n.Parent;
    192         if (p == null)
    193           continue;
    194 
    195         childrenCount[p]--;
    196 
    197         if (childrenCount[p] == 0)
    198           queue.Enqueue(p);
    199       }
    200 
     206      }
    201207      return nodeMap;
    202208    }
    203209
    204     private IEnumerable<ISymbolicExpressionTreeNode> IterateBreadthOrdered(ISymbolicExpressionTreeNode node, ISymbolicExpressionTreeNodeComparer comparer) {
    205       var list = new List<ISymbolicExpressionTreeNode> { node };
    206       int i = 0;
    207       while (i < list.Count) {
    208         var n = list[i];
    209         if (n.SubtreeCount > 0) {
    210           var subtrees = commutativeSymbols.Contains(node.Symbol.Name) ? n.Subtrees.OrderBy(x => x, comparer) : n.Subtrees;
    211           list.AddRange(subtrees);
    212         }
    213         i++;
    214       }
    215       return list;
    216     }
    217 
    218     private static string GetLabel(ISymbolicExpressionTreeNode node) {
     210    private string GetLabel(ISymbolicExpressionTreeNode node) {
    219211      if (node.SubtreeCount > 0)
    220212        return node.Symbol.Name;
    221213
    222       var constant = node as ConstantTreeNode;
    223       if (constant != null)
    224         return constant.Value.ToString(CultureInfo.InvariantCulture);
    225 
    226       var variable = node as VariableTreeNode;
    227       if (variable != null)
    228         return variable.Weight + variable.VariableName;
     214      if (node is ConstantTreeNode constant)
     215        return MatchConstantValues ? constant.Value.ToString(CultureInfo.InvariantCulture) : constant.Symbol.Name;
     216
     217      if (node is VariableTreeNode variable)
     218        return MatchVariableWeights ? variable.Weight + variable.VariableName : variable.VariableName;
    229219
    230220      return node.ToString();
     
    232222
    233223    private class GraphNode {
    234       public ISymbolicExpressionTreeNode SymbolicExpressionTreeNode;
    235       public string Label;
    236       public int Depth;
     224      private GraphNode() { }
     225
     226      public GraphNode(ISymbolicExpressionTreeNode node, string label) {
     227        SymbolicExpressionTreeNode = node;
     228        Label = label;
     229        Hash = GetHashCode();
     230        Depth = node.GetDepth();
     231        Length = node.GetLength();
     232      }
     233
     234      public int Hash { get; }
     235      public ISymbolicExpressionTreeNode SymbolicExpressionTreeNode { get; }
     236      public string Label { get; }
     237      public int Depth { get; }
    237238      public int SubtreeCount { get { return SymbolicExpressionTreeNode.SubtreeCount; } }
     239      public int Length { get; }
    238240    }
    239241  }
  • stable/HeuristicLab.Tests

  • stable/HeuristicLab.Tests/HeuristicLab.Problems.DataAnalysis.Symbolic-3.4/SymbolicExpressionTreeBottomUpSimilarityCalculatorTest.cs

    r11986 r17091  
    77  [TestClass]
    88  public class BottomUpSimilarityCalculatorTest {
    9     private readonly SymbolicExpressionTreeBottomUpSimilarityCalculator busCalculator;
    10     private readonly SymbolicExpressionImporter importer;
     9    private readonly SymbolicExpressionTreeBottomUpSimilarityCalculator similarityCalculator = new SymbolicExpressionTreeBottomUpSimilarityCalculator() { MatchConstantValues = false, MatchVariableWeights = false };
     10    private readonly SymbolicExpressionImporter importer = new SymbolicExpressionImporter();
    1111
    12     private const int N = 150;
     12    private const int N = 1000;
    1313    private const int Rows = 1;
    1414    private const int Columns = 10;
    1515
    1616    public BottomUpSimilarityCalculatorTest() {
    17       busCalculator = new SymbolicExpressionTreeBottomUpSimilarityCalculator();
    18       importer = new SymbolicExpressionImporter();
     17      var parser = new InfixExpressionParser();
    1918    }
    2019
     
    2322    [TestProperty("Time", "short")]
    2423    public void BottomUpTreeSimilarityCalculatorTestMapping() {
    25       TestMatchedNodes("(+ 1 2)", "(+ 2 1)", 5);
    26       TestMatchedNodes("(- 2 1)", "(- 1 2)", 2);
    27       TestMatchedNodes("(* (variable 1 X1) (variable 1 X2))", "(* (+ (variable 1 X1) 1) (+ (variable 1 X2) 1))", 2);
     24      TestMatchedNodes("(+ 1 1)", "(+ 2 2)", 0, strict: true);
     25      TestMatchedNodes("(+ 1 1)", "(+ 2 2)", 3, strict: false);
     26      TestMatchedNodes("(+ 1 1)", "(+ 1 2)", 1, strict: true);
     27      TestMatchedNodes("(+ 2 1)", "(+ 1 2)", 3, strict: true);
    2828
    29       TestMatchedNodes("(* (variable 1 X1) (variable 1 X2))", "(* (+ (variable 1 X1) 1) (variable 1 X2))", 2);
     29      TestMatchedNodes("(- 1 1)", "(- 2 2)", 0, strict: true);
     30      TestMatchedNodes("(- 1 1)", "(- 2 2)", 3, strict: false);
    3031
    31       TestMatchedNodes("(+ (variable 1 a) (variable 1 b))", "(+ (variable 1 a) (variable 1 a))", 1);
    32       TestMatchedNodes("(+ (+ (variable 1 a) (variable 1 b)) (variable 1 b))", "(+ (* (+ (variable 1 a) (variable 1 b)) (variable 1 b)) (+ (+ (variable 1 a) (variable 1 b)) (variable 1 b)))", 5);
    33 
    34       TestMatchedNodes(
    35         "(* (+ 2.84 (exp (+ (log (/ (variable 2.0539 X5) (variable -9.2452e-1 X6))) (/ (variable 2.0539 X5) (variable -9.2452e-1 X6))))) 2.9081)",
    36         "(* (- (variable 9.581e-1 X6) (+ (- (variable 5.1491e-1 X5) 1.614e+1) (+ (/ (variable 2.0539 X5) (variable -9.2452e-1 X6)) (log (/ (variable 2.0539 X5) (variable -9.2452e-1 X6)))))) 2.9081)",
    37         9);
    38 
    39       TestMatchedNodes("(* (* (* (variable 1.68 x) (* (variable 1.68 x) (variable 2.55 x))) (variable 1.68 x)) (* (* (variable 1.68 x) (* (variable 1.68 x) (* (variable 1.68 x) (variable 2.55 x)))) (variable 2.55 x)))", "(* (variable 2.55 x) (* (variable 1.68 x) (* (variable 1.68 x) (* (variable 1.68 x) (variable 2.55 x)))))", 9);
    40 
    41       TestMatchedNodes("(+ (exp 2.1033) (/ -4.3072 (variable 2.4691 X7)))", "(/ 1 (+ (/ -4.3072 (variable 2.4691 X7)) (exp 2.1033)))", 6);
    42       TestMatchedNodes("(+ (exp 2.1033) (/ -4.3072 (variable 2.4691 X7)))", "(/ 1 (+ (/ (variable 2.4691 X7) -4.3072) (exp 2.1033)))", 4);
    43 
    44       const string expr1 = "(* (- 1.2175e+1 (+ (/ (exp -1.4134e+1) (exp 9.2013)) (exp (log (exp (/ (exp (- (* -4.2461 (variable 2.2634 X5)) (- -9.6267e-1 3.3243))) (- (/ (/ (variable 1.0883 X1) (variable 6.9620e-1 X2)) (log 1.3011e+1)) (variable -4.3098e-1 X7)))))))) (log 1.3011e+1))";
    45       const string expr2 = "(* (- 1.2175e+1 (+ (/ (/ (+ (variable 3.0140 X9) (variable 1.3430 X8)) -1.0864e+1) (exp 9.2013)) (exp (log (exp (/ (exp (- (* -4.2461 (variable 2.2634 X5)) (- -9.6267e-1 3.3243))) (- (/ (/ (variable 1.0883 X1) (variable 6.9620e-1 X2)) (log 1.3011e+1)) (variable -4.3098e-1 X7)))))))) (exp (variable 4.0899e-1 X7)))";
    46 
    47       TestMatchedNodes(expr1, expr2, 23);
    48 
     32      TestMatchedNodes("(- 2 1)", "(- 1 2)", 2, strict: true);
     33      TestMatchedNodes("(- 2 1)", "(- 1 2)", 3, strict: false);
    4934    }
    5035
    51     private void TestMatchedNodes(string expr1, string expr2, int expected) {
     36    private void TestMatchedNodes(string expr1, string expr2, int expected, bool strict) {
    5237      var t1 = importer.Import(expr1);
    5338      var t2 = importer.Import(expr2);
    5439
    55       var mapping = busCalculator.ComputeBottomUpMapping(t1.Root, t2.Root);
    56       var c = mapping.Count;
     40      var map = SymbolicExpressionTreeBottomUpSimilarityCalculator.ComputeBottomUpMapping(t1, t2, strict);
    5741
    58       if (c != expected) {
    59         throw new Exception("Match count " + c + " is different than expected value " + expected);
     42      if (map.Count != expected) {
     43        throw new Exception($"Match count {map.Count} is different than expected value {expected} for expressions:\n{expr1} and {expr2} (strict = {strict})");
    6044      }
    6145    }
     
    7761      for (int i = 0; i < trees.Length - 1; ++i) {
    7862        for (int j = i + 1; j < trees.Length; ++j) {
    79           s += busCalculator.CalculateSimilarity(trees[i], trees[j]);
     63          s += similarityCalculator.CalculateSimilarity(trees[i], trees[j]);
    8064        }
    8165      }
Note: See TracChangeset for help on using the changeset viewer.