Free cookie consent management tool by TermsFeed Policy Generator

# Changeset 17626

Ignore:
Timestamp:
06/23/20 10:05:13 (4 years ago)
Message:

#3040 Unified simplification rules for vector aggregation functions.

File:
1 edited

Unmodified
Added
Removed
• ## branches/3040_VectorBasedGP/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/Converters/VectorTreeSimplifier.cs

 r17604 } private ISymbolicExpressionTreeNode MakeSumAggregation(ISymbolicExpressionTreeNode node) { delegate ISymbolicExpressionTreeNode ScalarSimplifier(ISymbolicExpressionTreeNode scalar); delegate ISymbolicExpressionTreeNode VectorSimplifier(ISymbolicExpressionTreeNode vectorPart, ISymbolicExpressionTreeNode scalarPart); private ISymbolicExpressionTreeNode MakeAggregation(ISymbolicExpressionTreeNode node, Symbol aggregationSymbol, ScalarSimplifier simplifyScalar, VectorSimplifier simplifyAdditiveTerms, VectorSimplifier simplifyMultiplicativeFactors) { ISymbolicExpressionTreeNode MakeAggregationNode(ISymbolicExpressionTreeNode remainingNode) { var aggregationNode = aggregationSymbol.CreateTreeNode(); aggregationNode.AddSubtree(GetSimplifiedTree(remainingNode)); return aggregationNode; } ISymbolicExpressionTreeNode SimplifyArithmeticOperation(IEnumerable terms, Func aggregation, VectorSimplifier simplifyTerms) { var scalarTerms = terms.Where(IsScalarNode).ToList(); var remainingTerms = terms.Except(scalarTerms).ToList(); var scalarNode = scalarTerms.Any() ? scalarTerms.Aggregate(aggregation) : null; var remainingNode = remainingTerms.Aggregate(aggregation); // at least one term is remaining, otherwise "node" would have been scalar and earlier rule had applied if (scalarTerms.Any()) { return simplifyTerms(remainingNode, scalarNode); } else { return MakeAggregationNode(remainingNode); } } if (IsScalarNode(node)) { return node; return simplifyScalar(GetSimplifiedTree(node)); } else if (IsAddition(node) || IsSubtraction(node)) { var terms = node.Subtrees; if (IsSubtraction(node)) terms = InvertNodes(terms, Negate); var scalarTerms = terms.Where(IsScalarNode).ToList(); var remainingTerms = terms.Except(scalarTerms).ToList(); if (scalarTerms.Any() && remainingTerms.Any()) { var scalarNode = scalarTerms.Aggregate(MakeSum); var vectorNode = remainingTerms.Aggregate(MakeSum); var lengthNode = MakeLengthAggregation((ISymbolicExpressionTreeNode)vectorNode.Clone()); var scalarMulNode = MakeProduct(scalarNode, lengthNode); var sumNode = MakeSumAggregation(vectorNode); return MakeSum(scalarMulNode, sumNode); } else if (scalarTerms.Any()) { var scalarNode = scalarTerms.Aggregate(MakeSum); return scalarNode; } else if (remainingTerms.Any()) { var vectorNode = remainingTerms.Aggregate(MakeSum); var sumNode = sumSymbol.CreateTreeNode(); sumNode.AddSubtree(vectorNode); return sumNode; } else throw new InvalidOperationException("Addition does not contain any terms to simplify."); return SimplifyArithmeticOperation(terms, MakeSum, simplifyAdditiveTerms); } else if (IsMultiplication(node) || IsDivision(node)) { var factors = node.Subtrees; if (IsDivision(node)) factors = InvertNodes(factors, Invert); var scalarFactors = factors.Where(IsScalarNode).ToList(); var remainingFactors = factors.Except(scalarFactors).ToList(); if (scalarFactors.Any() && remainingFactors.Any()) { var scalarNode = scalarFactors.Aggregate(MakeProduct); var vectorNode = remainingFactors.Aggregate(MakeProduct); var sumNode = MakeSumAggregation(vectorNode); return MakeProduct(scalarNode, sumNode); } else if (scalarFactors.Any()) { var scalarNode = scalarFactors.Aggregate(MakeProduct); return scalarNode; } else if (remainingFactors.Any()) { var vectorNode = remainingFactors.Aggregate(MakeProduct); var sumNode = sumSymbol.CreateTreeNode(); sumNode.AddSubtree(vectorNode); return sumNode; } else throw new InvalidOperationException("Multiplication does not contain any terms to simplify."); } else if (IsVariableBase(node)) { // weight is like multiplication var variableNode = (VariableTreeNodeBase)node; return SimplifyArithmeticOperation(factors, MakeProduct, simplifyMultiplicativeFactors); } else if (node is VariableTreeNodeBase variableNode && !variableNode.Weight.IsAlmost(1.0)) { // weight is like multiplication var weight = variableNode.Weight; variableNode.Weight = 1.0; var sumNode = sumSymbol.CreateTreeNode(); sumNode.AddSubtree(node); return MakeProduct(MakeConstant(weight), sumNode); } else { var sumNode = sumSymbol.CreateTreeNode(); sumNode.AddSubtree(node); return sumNode; } var factors = new[] { variableNode, MakeConstant(weight) }; return SimplifyArithmeticOperation(factors, MakeProduct, simplifyMultiplicativeFactors); } else { return MakeAggregationNode(node); } } private ISymbolicExpressionTreeNode MakeSumAggregation(ISymbolicExpressionTreeNode node) { return MakeAggregation(node, sumSymbol, simplifyScalar: n => n, simplifyAdditiveTerms: (vectorNode, scalarNode) => { var lengthNode = MakeLengthAggregation(vectorNode); var scalarMulNode = MakeProduct(scalarNode, lengthNode); var sumNode = MakeSumAggregation(vectorNode); return MakeSum(scalarMulNode, sumNode); }, simplifyMultiplicativeFactors: (vectorNode, scalarNode) => { var sumNode = MakeSumAggregation(vectorNode); return MakeProduct(scalarNode, sumNode); }); } private ISymbolicExpressionTreeNode MakeMeanAggregation(ISymbolicExpressionTreeNode node) { if (IsScalarNode(node)) { return node; } else if (IsAddition(node) || IsSubtraction(node)) { var terms = node.Subtrees; if (IsSubtraction(node)) terms = InvertNodes(terms, Negate); var scalarTerms = terms.Where(IsScalarNode).ToList(); var remainingTerms = terms.Except(scalarTerms).ToList(); if (scalarTerms.Any() && remainingTerms.Any()) { var scalarNode = scalarTerms.Aggregate(MakeSum); var vectorNode = remainingTerms.Aggregate(MakeSum); return MakeAggregation(node, meanSymbol, simplifyScalar: n => n, simplifyAdditiveTerms: (vectorNode, scalarNode) => { var meanNode = MakeMeanAggregation(vectorNode); return MakeSum(scalarNode, meanNode); } else if (scalarTerms.Any()) { var scalarNode = scalarTerms.Aggregate(MakeSum); return scalarNode; } else if (remainingTerms.Any()) { var vectorNode = remainingTerms.Aggregate(MakeSum); var meanNode = meanSymbol.CreateTreeNode(); meanNode.AddSubtree(vectorNode); return meanNode; } else throw new InvalidOperationException("Addition does not contain any terms to simplify."); } else if (IsMultiplication(node) || IsDivision(node)) { var factors = node.Subtrees; if (IsDivision(node)) factors = InvertNodes(factors, Invert); var scalarFactors = factors.Where(IsScalarNode).ToList(); var remainingFactors = factors.Except(scalarFactors).ToList(); if (scalarFactors.Any() && remainingFactors.Any()) { var scalarNode = scalarFactors.Aggregate(MakeProduct); var vectorNode = remainingFactors.Aggregate(MakeProduct); }, simplifyMultiplicativeFactors: (vectorNode, scalarNode) => { var meanNode = MakeMeanAggregation(vectorNode); return MakeProduct(scalarNode, meanNode); } else if (scalarFactors.Any()) { var scalarNode = scalarFactors.Aggregate(MakeProduct); return scalarNode; } else if (remainingFactors.Any()) { var vectorNode = remainingFactors.Aggregate(MakeProduct); var meanNode = meanSymbol.CreateTreeNode(); meanNode.AddSubtree(vectorNode); return meanNode; } else throw new InvalidOperationException("Multiplication does not contain any terms to simplify."); } else if (IsVariableBase(node)) { // weight is like multiplication var variableNode = (VariableTreeNodeBase)node; var weight = variableNode.Weight; variableNode.Weight = 1.0; var meanNode = meanSymbol.CreateTreeNode(); meanNode.AddSubtree(node); return MakeProduct(MakeConstant(weight), meanNode); } else { var meanNode = meanSymbol.CreateTreeNode(); meanNode.AddSubtree(node); return meanNode; } }); } private ISymbolicExpressionTreeNode MakeLengthAggregation(ISymbolicExpressionTreeNode node) { if (IsScalarNode(node)) { return MakeConstant(1.0); } else if (IsAddition(node) || IsSubtraction(node)) { var terms = node.Subtrees; if (IsSubtraction(node)) terms = InvertNodes(terms, Negate); var scalarTerms = terms.Where(IsScalarNode).ToList(); var remainingTerms = terms.Except(scalarTerms).ToList(); if (remainingTerms.Any()) { var vectorNode = remainingTerms.Aggregate(MakeSum); var lengthNode = lengthSymbol.CreateTreeNode(); lengthNode.AddSubtree(vectorNode); return lengthNode; } else if (scalarTerms.Any()) { return MakeConstant(1.0); } else throw new InvalidOperationException("Addition does not contain any terms to simplify."); } else if (IsMultiplication(node) || IsDivision(node)) { var factors = node.Subtrees; if (IsDivision(node)) factors = InvertNodes(factors, Invert); var scalarFactors = factors.Where(IsScalarNode).ToList(); var remainingFactors = factors.Except(scalarFactors).ToList(); if (remainingFactors.Any()) { var vectorNode = remainingFactors.Aggregate(MakeProduct); var lengthNode = lengthSymbol.CreateTreeNode(); lengthNode.AddSubtree(vectorNode); return lengthNode; } else if (scalarFactors.Any()) { return MakeConstant(1.0); } else throw new InvalidOperationException("Multiplication does not contain any terms to simplify."); } else if (IsVariableBase(node)) { // weight is like multiplication var variableNode = (VariableTreeNodeBase)node; variableNode.Weight = 1.0; var lengthNode = lengthSymbol.CreateTreeNode(); lengthNode.AddSubtree(node); return lengthNode; } else { var lengthNode = lengthSymbol.CreateTreeNode(); lengthNode.AddSubtree(node); return lengthNode; } return MakeAggregation(node, lengthSymbol, simplifyScalar: _ => MakeConstant(1.0), simplifyAdditiveTerms: (vectorNode, _) => { return MakeLengthAggregation(vectorNode); }, simplifyMultiplicativeFactors: (vectorNode, _) => { return MakeLengthAggregation(vectorNode); }); } private ISymbolicExpressionTreeNode MakeStandardDeviationAggregation(ISymbolicExpressionTreeNode node) { if (IsScalarNode(node)) { return MakeConstant(0.0); } else if (IsAddition(node) || IsSubtraction(node)) { // scalars drop out var terms = node.Subtrees; var scalarTerms = terms.Where(IsScalarNode).ToList(); var remainingTerms = terms.Except(scalarTerms).ToList(); if (remainingTerms.Any()) { var vectorNode = remainingTerms.Aggregate(MakeSum); var stdevNode = standardDeviationSymbol.CreateTreeNode(); stdevNode.AddSubtree(vectorNode); return stdevNode; } else if (scalarTerms.Any()) { return MakeConstant(0.0); } else throw new InvalidOperationException("Addition does not contain any terms to simplify."); } else if (IsMultiplication(node) || IsDivision(node)) { var factors = node.Subtrees; if (IsDivision(node)) factors = InvertNodes(factors, Invert); var scalarFactors = factors.Where(IsScalarNode).ToList(); var remainingFactors = factors.Except(scalarFactors).ToList(); if (scalarFactors.Any() && remainingFactors.Any()) { var scalarNode = scalarFactors.Aggregate(MakeProduct); var vectorNode = remainingFactors.Aggregate(MakeProduct); return MakeAggregation(node, standardDeviationSymbol, simplifyScalar: _ => MakeConstant(0.0), simplifyAdditiveTerms: (vectorNode, _) => { return MakeStandardDeviationAggregation(vectorNode); }, simplifyMultiplicativeFactors: (vectorNode, scalarNode) => { var stdevNode = MakeStandardDeviationAggregation(vectorNode); return MakeProduct(scalarNode, stdevNode); } else if (scalarFactors.Any()) { var scalarNode = scalarFactors.Aggregate(MakeProduct); return scalarNode; } else if (remainingFactors.Any()) { var vectorNode = remainingFactors.Aggregate(MakeProduct); var stdevNode = standardDeviationSymbol.CreateTreeNode(); stdevNode.AddSubtree(vectorNode); return stdevNode; } else throw new InvalidOperationException("Multiplication does not contain any terms to simplify."); } else if (IsVariableBase(node)) { // weight is like multiplication var variableNode = (VariableTreeNodeBase)node; var weight = variableNode.Weight; variableNode.Weight = 1.0; var stdevNode = standardDeviationSymbol.CreateTreeNode(); stdevNode.AddSubtree(node); return MakeProduct(MakeConstant(weight), stdevNode); } else { var stdevNode = standardDeviationSymbol.CreateTreeNode(); stdevNode.AddSubtree(node); return stdevNode; } }); } private ISymbolicExpressionTreeNode MakeVarianceAggregation(ISymbolicExpressionTreeNode node) { if (IsScalarNode(node)) { return MakeConstant(0.0); } else if (IsAddition(node) || IsSubtraction(node)) { // scalars drop out var terms = node.Subtrees; var scalarTerms = terms.Where(IsScalarNode).ToList(); var remainingTerms = terms.Except(scalarTerms).ToList(); if (remainingTerms.Any()) { var vectorNode = remainingTerms.Aggregate(MakeSum); var varNode = varianceSymbol.CreateTreeNode(); varNode.AddSubtree(vectorNode); return varNode; } else if (scalarTerms.Any()) { return MakeConstant(0.0); } else throw new InvalidOperationException("Addition does not contain any terms to simplify."); } else if (IsMultiplication(node) || IsDivision(node)) { var factors = node.Subtrees; if (IsDivision(node)) factors = InvertNodes(factors, Invert); var scalarFactors = factors.Where(IsScalarNode).ToList(); var remainingFactors = factors.Except(scalarFactors).ToList(); if (scalarFactors.Any() && remainingFactors.Any()) { var scalarNode = scalarFactors.Aggregate(MakeProduct); var vectorNode = remainingFactors.Aggregate(MakeProduct); return MakeAggregation(node, varianceSymbol, simplifyScalar: _ => MakeConstant(0.0), simplifyAdditiveTerms: (vectorNode, _) => { return MakeVarianceAggregation(vectorNode); }, simplifyMultiplicativeFactors: (vectorNode, scalarNode) => { var varNode = MakeVarianceAggregation(vectorNode); return MakeProduct(MakeSquare(scalarNode), varNode); } else if (scalarFactors.Any()) { var scalarNode = scalarFactors.Aggregate(MakeProduct); return MakeSquare(scalarNode); } else if (remainingFactors.Any()) { var vectorNode = remainingFactors.Aggregate(MakeProduct); var varNode = varianceSymbol.CreateTreeNode(); varNode.AddSubtree(vectorNode); return varNode; } else throw new InvalidOperationException("Multiplication does not contain any terms to simplify."); } else if (IsVariableBase(node)) { // weight is like multiplication var variableNode = (VariableTreeNodeBase)node; var weight = variableNode.Weight; variableNode.Weight = 1.0; var varNode = varianceSymbol.CreateTreeNode(); varNode.AddSubtree(node); return MakeProduct(MakeSquare(MakeConstant(weight)), varNode); } else { var varNode = varianceSymbol.CreateTreeNode(); varNode.AddSubtree(node); return varNode; } }); } #endregion
Note: See TracChangeset for help on using the changeset viewer.