Changeset 17626


Ignore:
Timestamp:
06/23/20 10:05:13 (3 weeks ago)
Author:
pfleck
Message:

#3040 Unified simplification rules for vector aggregation functions.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/VectorTreeSimplifier.cs

    r17604 r17626  
    13961396    }
    13971397
    1398     private ISymbolicExpressionTreeNode MakeSumAggregation(ISymbolicExpressionTreeNode node) {
     1398    delegate ISymbolicExpressionTreeNode ScalarSimplifier(ISymbolicExpressionTreeNode scalar);
     1399    delegate ISymbolicExpressionTreeNode VectorSimplifier(ISymbolicExpressionTreeNode vectorPart, ISymbolicExpressionTreeNode scalarPart);
     1400
     1401    private ISymbolicExpressionTreeNode MakeAggregation(ISymbolicExpressionTreeNode node, Symbol aggregationSymbol,
     1402      ScalarSimplifier simplifyScalar,
     1403      VectorSimplifier simplifyAdditiveTerms,
     1404      VectorSimplifier simplifyMultiplicativeFactors) {
     1405
     1406      ISymbolicExpressionTreeNode MakeAggregationNode(ISymbolicExpressionTreeNode remainingNode) {
     1407        var aggregationNode = aggregationSymbol.CreateTreeNode();
     1408        aggregationNode.AddSubtree(GetSimplifiedTree(remainingNode));
     1409        return aggregationNode;
     1410      }
     1411
     1412      ISymbolicExpressionTreeNode SimplifyArithmeticOperation(IEnumerable<ISymbolicExpressionTreeNode> terms,
     1413        Func<ISymbolicExpressionTreeNode, ISymbolicExpressionTreeNode, ISymbolicExpressionTreeNode> aggregation,
     1414        VectorSimplifier simplifyTerms) {
     1415        var scalarTerms = terms.Where(IsScalarNode).ToList();
     1416        var remainingTerms = terms.Except(scalarTerms).ToList();
     1417
     1418        var scalarNode = scalarTerms.Any() ? scalarTerms.Aggregate(aggregation) : null;
     1419        var remainingNode = remainingTerms.Aggregate(aggregation); // at least one term is remaining, otherwise "node" would have been scalar and earlier rule had applied
     1420
     1421        if (scalarTerms.Any()) {
     1422          return simplifyTerms(remainingNode, scalarNode);
     1423        } else {
     1424          return MakeAggregationNode(remainingNode);
     1425        }
     1426      }
     1427
    13991428      if (IsScalarNode(node)) {
    1400         return node;
     1429        return simplifyScalar(GetSimplifiedTree(node));
    14011430      } else if (IsAddition(node) || IsSubtraction(node)) {
    14021431        var terms = node.Subtrees;
    14031432        if (IsSubtraction(node)) terms = InvertNodes(terms, Negate);
    1404 
    1405         var scalarTerms = terms.Where(IsScalarNode).ToList();
    1406         var remainingTerms = terms.Except(scalarTerms).ToList();
    1407 
    1408         if (scalarTerms.Any() && remainingTerms.Any()) {
    1409           var scalarNode = scalarTerms.Aggregate(MakeSum);
    1410           var vectorNode = remainingTerms.Aggregate(MakeSum);
    1411 
    1412           var lengthNode = MakeLengthAggregation((ISymbolicExpressionTreeNode)vectorNode.Clone());
    1413           var scalarMulNode = MakeProduct(scalarNode, lengthNode);
    1414 
    1415           var sumNode = MakeSumAggregation(vectorNode);
    1416 
    1417           return MakeSum(scalarMulNode, sumNode);
    1418         } else if (scalarTerms.Any()) {
    1419           var scalarNode = scalarTerms.Aggregate(MakeSum);
    1420           return scalarNode;
    1421         } else if (remainingTerms.Any()) {
    1422           var vectorNode = remainingTerms.Aggregate(MakeSum);
    1423           var sumNode = sumSymbol.CreateTreeNode();
    1424           sumNode.AddSubtree(vectorNode);
    1425           return sumNode;
    1426         } else
    1427           throw new InvalidOperationException("Addition does not contain any terms to simplify.");
     1433        return SimplifyArithmeticOperation(terms, MakeSum, simplifyAdditiveTerms);
    14281434      } else if (IsMultiplication(node) || IsDivision(node)) {
    14291435        var factors = node.Subtrees;
    14301436        if (IsDivision(node)) factors = InvertNodes(factors, Invert);
    1431 
    1432         var scalarFactors = factors.Where(IsScalarNode).ToList();
    1433         var remainingFactors = factors.Except(scalarFactors).ToList();
    1434 
    1435         if (scalarFactors.Any() && remainingFactors.Any()) {
    1436           var scalarNode = scalarFactors.Aggregate(MakeProduct);
    1437           var vectorNode = remainingFactors.Aggregate(MakeProduct);
    1438 
    1439           var sumNode = MakeSumAggregation(vectorNode);
    1440 
    1441           return MakeProduct(scalarNode, sumNode);
    1442         } else if (scalarFactors.Any()) {
    1443           var scalarNode = scalarFactors.Aggregate(MakeProduct);
    1444           return scalarNode;
    1445         } else if (remainingFactors.Any()) {
    1446           var vectorNode = remainingFactors.Aggregate(MakeProduct);
    1447           var sumNode = sumSymbol.CreateTreeNode();
    1448           sumNode.AddSubtree(vectorNode);
    1449           return sumNode;
    1450         } else
    1451           throw new InvalidOperationException("Multiplication does not contain any terms to simplify.");
    1452       } else if (IsVariableBase(node)) { // weight is like multiplication
    1453         var variableNode = (VariableTreeNodeBase)node;
     1437        return SimplifyArithmeticOperation(factors, MakeProduct, simplifyMultiplicativeFactors);
     1438      } else if (node is VariableTreeNodeBase variableNode && !variableNode.Weight.IsAlmost(1.0)) { // weight is like multiplication
    14541439        var weight = variableNode.Weight;
    14551440        variableNode.Weight = 1.0;
    1456         var sumNode = sumSymbol.CreateTreeNode();
    1457         sumNode.AddSubtree(node);
    1458         return MakeProduct(MakeConstant(weight), sumNode);
    1459       } else {
    1460         var sumNode = sumSymbol.CreateTreeNode();
    1461         sumNode.AddSubtree(node);
    1462         return sumNode;
    1463       }
     1441        var factors = new[] { variableNode, MakeConstant(weight) };
     1442        return SimplifyArithmeticOperation(factors, MakeProduct, simplifyMultiplicativeFactors);
     1443      } else {
     1444        return MakeAggregationNode(node);
     1445      }
     1446    }
     1447
     1448    private ISymbolicExpressionTreeNode MakeSumAggregation(ISymbolicExpressionTreeNode node) {
     1449      return MakeAggregation(node, sumSymbol,
     1450        simplifyScalar: n => n,
     1451        simplifyAdditiveTerms: (vectorNode, scalarNode) => {
     1452          var lengthNode = MakeLengthAggregation(vectorNode);
     1453          var scalarMulNode = MakeProduct(scalarNode, lengthNode);
     1454          var sumNode = MakeSumAggregation(vectorNode);
     1455          return MakeSum(scalarMulNode, sumNode);
     1456        },
     1457        simplifyMultiplicativeFactors: (vectorNode, scalarNode) => {
     1458          var sumNode = MakeSumAggregation(vectorNode);
     1459          return MakeProduct(scalarNode, sumNode);
     1460        });
    14641461    }
    14651462
    14661463    private ISymbolicExpressionTreeNode MakeMeanAggregation(ISymbolicExpressionTreeNode node) {
    1467       if (IsScalarNode(node)) {
    1468         return node;
    1469       } else if (IsAddition(node) || IsSubtraction(node)) {
    1470         var terms = node.Subtrees;
    1471         if (IsSubtraction(node)) terms = InvertNodes(terms, Negate);
    1472 
    1473         var scalarTerms = terms.Where(IsScalarNode).ToList();
    1474         var remainingTerms = terms.Except(scalarTerms).ToList();
    1475 
    1476         if (scalarTerms.Any() && remainingTerms.Any()) {
    1477           var scalarNode = scalarTerms.Aggregate(MakeSum);
    1478           var vectorNode = remainingTerms.Aggregate(MakeSum);
    1479 
     1464      return MakeAggregation(node, meanSymbol,
     1465        simplifyScalar: n => n,
     1466        simplifyAdditiveTerms: (vectorNode, scalarNode) => {
    14801467          var meanNode = MakeMeanAggregation(vectorNode);
    1481 
    14821468          return MakeSum(scalarNode, meanNode);
    1483         } else if (scalarTerms.Any()) {
    1484           var scalarNode = scalarTerms.Aggregate(MakeSum);
    1485           return scalarNode;
    1486         } else if (remainingTerms.Any()) {
    1487           var vectorNode = remainingTerms.Aggregate(MakeSum);
    1488           var meanNode = meanSymbol.CreateTreeNode();
    1489           meanNode.AddSubtree(vectorNode);
    1490           return meanNode;
    1491         } else
    1492           throw new InvalidOperationException("Addition does not contain any terms to simplify.");
    1493       } else if (IsMultiplication(node) || IsDivision(node)) {
    1494         var factors = node.Subtrees;
    1495         if (IsDivision(node)) factors = InvertNodes(factors, Invert);
    1496 
    1497         var scalarFactors = factors.Where(IsScalarNode).ToList();
    1498         var remainingFactors = factors.Except(scalarFactors).ToList();
    1499 
    1500         if (scalarFactors.Any() && remainingFactors.Any()) {
    1501           var scalarNode = scalarFactors.Aggregate(MakeProduct);
    1502           var vectorNode = remainingFactors.Aggregate(MakeProduct);
    1503 
     1469        },
     1470        simplifyMultiplicativeFactors: (vectorNode, scalarNode) => {
    15041471          var meanNode = MakeMeanAggregation(vectorNode);
    1505 
    15061472          return MakeProduct(scalarNode, meanNode);
    1507         } else if (scalarFactors.Any()) {
    1508           var scalarNode = scalarFactors.Aggregate(MakeProduct);
    1509           return scalarNode;
    1510         } else if (remainingFactors.Any()) {
    1511           var vectorNode = remainingFactors.Aggregate(MakeProduct);
    1512           var meanNode = meanSymbol.CreateTreeNode();
    1513           meanNode.AddSubtree(vectorNode);
    1514           return meanNode;
    1515         } else
    1516           throw new InvalidOperationException("Multiplication does not contain any terms to simplify.");
    1517       } else if (IsVariableBase(node)) { // weight is like multiplication
    1518         var variableNode = (VariableTreeNodeBase)node;
    1519         var weight = variableNode.Weight;
    1520         variableNode.Weight = 1.0;
    1521         var meanNode = meanSymbol.CreateTreeNode();
    1522         meanNode.AddSubtree(node);
    1523         return MakeProduct(MakeConstant(weight), meanNode);
    1524       } else {
    1525         var meanNode = meanSymbol.CreateTreeNode();
    1526         meanNode.AddSubtree(node);
    1527         return meanNode;
    1528       }
     1473        });
    15291474    }
    15301475
    15311476    private ISymbolicExpressionTreeNode MakeLengthAggregation(ISymbolicExpressionTreeNode node) {
    1532       if (IsScalarNode(node)) {
    1533         return MakeConstant(1.0);
    1534       } else if (IsAddition(node) || IsSubtraction(node)) {
    1535         var terms = node.Subtrees;
    1536         if (IsSubtraction(node)) terms = InvertNodes(terms, Negate);
    1537 
    1538         var scalarTerms = terms.Where(IsScalarNode).ToList();
    1539         var remainingTerms = terms.Except(scalarTerms).ToList();
    1540 
    1541         if (remainingTerms.Any()) {
    1542           var vectorNode = remainingTerms.Aggregate(MakeSum);
    1543 
    1544           var lengthNode = lengthSymbol.CreateTreeNode();
    1545           lengthNode.AddSubtree(vectorNode);
    1546 
    1547           return lengthNode;
    1548         } else if (scalarTerms.Any()) {
    1549           return MakeConstant(1.0);
    1550         } else
    1551           throw new InvalidOperationException("Addition does not contain any terms to simplify.");
    1552       } else if (IsMultiplication(node) || IsDivision(node)) {
    1553         var factors = node.Subtrees;
    1554         if (IsDivision(node)) factors = InvertNodes(factors, Invert);
    1555 
    1556         var scalarFactors = factors.Where(IsScalarNode).ToList();
    1557         var remainingFactors = factors.Except(scalarFactors).ToList();
    1558 
    1559         if (remainingFactors.Any()) {
    1560           var vectorNode = remainingFactors.Aggregate(MakeProduct);
    1561 
    1562           var lengthNode = lengthSymbol.CreateTreeNode();
    1563           lengthNode.AddSubtree(vectorNode);
    1564 
    1565           return lengthNode;
    1566         } else if (scalarFactors.Any()) {
    1567           return MakeConstant(1.0);
    1568         } else
    1569           throw new InvalidOperationException("Multiplication does not contain any terms to simplify.");
    1570       } else if (IsVariableBase(node)) { // weight is like multiplication
    1571         var variableNode = (VariableTreeNodeBase)node;
    1572         variableNode.Weight = 1.0;
    1573         var lengthNode = lengthSymbol.CreateTreeNode();
    1574         lengthNode.AddSubtree(node);
    1575         return lengthNode;
    1576       } else {
    1577         var lengthNode = lengthSymbol.CreateTreeNode();
    1578         lengthNode.AddSubtree(node);
    1579         return lengthNode;
    1580       }
     1477      return MakeAggregation(node, lengthSymbol,
     1478        simplifyScalar: _ => MakeConstant(1.0),
     1479        simplifyAdditiveTerms: (vectorNode, _) => {
     1480          return MakeLengthAggregation(vectorNode);
     1481        },
     1482        simplifyMultiplicativeFactors: (vectorNode, _) => {
     1483          return MakeLengthAggregation(vectorNode);
     1484        });
    15811485    }
    15821486
    15831487    private ISymbolicExpressionTreeNode MakeStandardDeviationAggregation(ISymbolicExpressionTreeNode node) {
    1584       if (IsScalarNode(node)) {
    1585         return MakeConstant(0.0);
    1586       } else if (IsAddition(node) || IsSubtraction(node)) { // scalars drop out
    1587         var terms = node.Subtrees;
    1588 
    1589         var scalarTerms = terms.Where(IsScalarNode).ToList();
    1590         var remainingTerms = terms.Except(scalarTerms).ToList();
    1591 
    1592         if (remainingTerms.Any()) {
    1593           var vectorNode = remainingTerms.Aggregate(MakeSum);
    1594 
    1595           var stdevNode = standardDeviationSymbol.CreateTreeNode();
    1596           stdevNode.AddSubtree(vectorNode);
    1597 
    1598           return stdevNode;
    1599         } else if (scalarTerms.Any()) {
    1600           return MakeConstant(0.0);
    1601         } else
    1602           throw new InvalidOperationException("Addition does not contain any terms to simplify.");
    1603       } else if (IsMultiplication(node) || IsDivision(node)) {
    1604         var factors = node.Subtrees;
    1605         if (IsDivision(node)) factors = InvertNodes(factors, Invert);
    1606 
    1607         var scalarFactors = factors.Where(IsScalarNode).ToList();
    1608         var remainingFactors = factors.Except(scalarFactors).ToList();
    1609 
    1610         if (scalarFactors.Any() && remainingFactors.Any()) {
    1611           var scalarNode = scalarFactors.Aggregate(MakeProduct);
    1612           var vectorNode = remainingFactors.Aggregate(MakeProduct);
    1613 
     1488      return MakeAggregation(node, standardDeviationSymbol,
     1489        simplifyScalar: _ => MakeConstant(0.0),
     1490        simplifyAdditiveTerms: (vectorNode, _) => {
     1491          return MakeStandardDeviationAggregation(vectorNode);
     1492        },
     1493        simplifyMultiplicativeFactors: (vectorNode, scalarNode) => {
    16141494          var stdevNode = MakeStandardDeviationAggregation(vectorNode);
    1615 
    16161495          return MakeProduct(scalarNode, stdevNode);
    1617         } else if (scalarFactors.Any()) {
    1618           var scalarNode = scalarFactors.Aggregate(MakeProduct);
    1619           return scalarNode;
    1620         } else if (remainingFactors.Any()) {
    1621           var vectorNode = remainingFactors.Aggregate(MakeProduct);
    1622           var stdevNode = standardDeviationSymbol.CreateTreeNode();
    1623           stdevNode.AddSubtree(vectorNode);
    1624           return stdevNode;
    1625         } else
    1626           throw new InvalidOperationException("Multiplication does not contain any terms to simplify.");
    1627       } else if (IsVariableBase(node)) { // weight is like multiplication
    1628         var variableNode = (VariableTreeNodeBase)node;
    1629         var weight = variableNode.Weight;
    1630         variableNode.Weight = 1.0;
    1631         var stdevNode = standardDeviationSymbol.CreateTreeNode();
    1632         stdevNode.AddSubtree(node);
    1633         return MakeProduct(MakeConstant(weight), stdevNode);
    1634       } else {
    1635         var stdevNode = standardDeviationSymbol.CreateTreeNode();
    1636         stdevNode.AddSubtree(node);
    1637         return stdevNode;
    1638       }
     1496        });
    16391497    }
    16401498
    16411499    private ISymbolicExpressionTreeNode MakeVarianceAggregation(ISymbolicExpressionTreeNode node) {
    1642       if (IsScalarNode(node)) {
    1643         return MakeConstant(0.0);
    1644       } else if (IsAddition(node) || IsSubtraction(node)) { // scalars drop out
    1645         var terms = node.Subtrees;
    1646 
    1647         var scalarTerms = terms.Where(IsScalarNode).ToList();
    1648         var remainingTerms = terms.Except(scalarTerms).ToList();
    1649 
    1650         if (remainingTerms.Any()) {
    1651           var vectorNode = remainingTerms.Aggregate(MakeSum);
    1652 
    1653           var varNode = varianceSymbol.CreateTreeNode();
    1654           varNode.AddSubtree(vectorNode);
    1655 
    1656           return varNode;
    1657         } else if (scalarTerms.Any()) {
    1658           return MakeConstant(0.0);
    1659         } else
    1660           throw new InvalidOperationException("Addition does not contain any terms to simplify.");
    1661       } else if (IsMultiplication(node) || IsDivision(node)) {
    1662         var factors = node.Subtrees;
    1663         if (IsDivision(node)) factors = InvertNodes(factors, Invert);
    1664 
    1665         var scalarFactors = factors.Where(IsScalarNode).ToList();
    1666         var remainingFactors = factors.Except(scalarFactors).ToList();
    1667 
    1668         if (scalarFactors.Any() && remainingFactors.Any()) {
    1669           var scalarNode = scalarFactors.Aggregate(MakeProduct);
    1670           var vectorNode = remainingFactors.Aggregate(MakeProduct);
    1671 
     1500      return MakeAggregation(node, varianceSymbol,
     1501        simplifyScalar: _ => MakeConstant(0.0),
     1502        simplifyAdditiveTerms: (vectorNode, _) => {
     1503          return MakeVarianceAggregation(vectorNode);
     1504        },
     1505        simplifyMultiplicativeFactors: (vectorNode, scalarNode) => {
    16721506          var varNode = MakeVarianceAggregation(vectorNode);
    1673 
    16741507          return MakeProduct(MakeSquare(scalarNode), varNode);
    1675         } else if (scalarFactors.Any()) {
    1676           var scalarNode = scalarFactors.Aggregate(MakeProduct);
    1677           return MakeSquare(scalarNode);
    1678         } else if (remainingFactors.Any()) {
    1679           var vectorNode = remainingFactors.Aggregate(MakeProduct);
    1680           var varNode = varianceSymbol.CreateTreeNode();
    1681           varNode.AddSubtree(vectorNode);
    1682           return varNode;
    1683         } else
    1684           throw new InvalidOperationException("Multiplication does not contain any terms to simplify.");
    1685       } else if (IsVariableBase(node)) { // weight is like multiplication
    1686         var variableNode = (VariableTreeNodeBase)node;
    1687         var weight = variableNode.Weight;
    1688         variableNode.Weight = 1.0;
    1689         var varNode = varianceSymbol.CreateTreeNode();
    1690         varNode.AddSubtree(node);
    1691         return MakeProduct(MakeSquare(MakeConstant(weight)), varNode);
    1692       } else {
    1693         var varNode = varianceSymbol.CreateTreeNode();
    1694         varNode.AddSubtree(node);
    1695         return varNode;
    1696       }
     1508        });
    16971509    }
    16981510    #endregion
Note: See TracChangeset for help on using the changeset viewer.