18a1ca2cdSRiver Riddle //===- PredicateTree.cpp - Predicate tree merging -------------------------===// 28a1ca2cdSRiver Riddle // 38a1ca2cdSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 48a1ca2cdSRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 58a1ca2cdSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 68a1ca2cdSRiver Riddle // 78a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 88a1ca2cdSRiver Riddle 98a1ca2cdSRiver Riddle #include "PredicateTree.h" 108a1ca2cdSRiver Riddle #include "mlir/Dialect/PDL/IR/PDL.h" 118a1ca2cdSRiver Riddle #include "mlir/Dialect/PDL/IR/PDLTypes.h" 128a1ca2cdSRiver Riddle #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 13*65fcddffSRiver Riddle #include "mlir/IR/BuiltinOps.h" 148a1ca2cdSRiver Riddle #include "mlir/Interfaces/InferTypeOpInterface.h" 158a1ca2cdSRiver Riddle 168a1ca2cdSRiver Riddle using namespace mlir; 178a1ca2cdSRiver Riddle using namespace mlir::pdl_to_pdl_interp; 188a1ca2cdSRiver Riddle 198a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 208a1ca2cdSRiver Riddle // Predicate List Building 218a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 228a1ca2cdSRiver Riddle 238a1ca2cdSRiver Riddle /// Compares the depths of two positions. 248a1ca2cdSRiver Riddle static bool comparePosDepth(Position *lhs, Position *rhs) { 258a1ca2cdSRiver Riddle return lhs->getIndex().size() < rhs->getIndex().size(); 268a1ca2cdSRiver Riddle } 278a1ca2cdSRiver Riddle 288a1ca2cdSRiver Riddle /// Collect the tree predicates anchored at the given value. 298a1ca2cdSRiver Riddle static void getTreePredicates(std::vector<PositionalPredicate> &predList, 308a1ca2cdSRiver Riddle Value val, PredicateBuilder &builder, 318a1ca2cdSRiver Riddle DenseMap<Value, Position *> &inputs, 328a1ca2cdSRiver Riddle Position *pos) { 338a1ca2cdSRiver Riddle // Make sure this input value is accessible to the rewrite. 348a1ca2cdSRiver Riddle auto it = inputs.try_emplace(val, pos); 358a1ca2cdSRiver Riddle 368a1ca2cdSRiver Riddle // If this is an input value that has been visited in the tree, add a 378a1ca2cdSRiver Riddle // constraint to ensure that both instances refer to the same value. 388a1ca2cdSRiver Riddle if (!it.second && 398a1ca2cdSRiver Riddle isa<pdl::AttributeOp, pdl::InputOp, pdl::TypeOp>(val.getDefiningOp())) { 408a1ca2cdSRiver Riddle auto minMaxPositions = std::minmax(pos, it.first->second, comparePosDepth); 418a1ca2cdSRiver Riddle predList.emplace_back(minMaxPositions.second, 428a1ca2cdSRiver Riddle builder.getEqualTo(minMaxPositions.first)); 438a1ca2cdSRiver Riddle return; 448a1ca2cdSRiver Riddle } 458a1ca2cdSRiver Riddle 468a1ca2cdSRiver Riddle // Check for a per-position predicate to apply. 478a1ca2cdSRiver Riddle switch (pos->getKind()) { 488a1ca2cdSRiver Riddle case Predicates::AttributePos: { 498a1ca2cdSRiver Riddle assert(val.getType().isa<pdl::AttributeType>() && 508a1ca2cdSRiver Riddle "expected attribute type"); 518a1ca2cdSRiver Riddle pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp()); 528a1ca2cdSRiver Riddle predList.emplace_back(pos, builder.getIsNotNull()); 538a1ca2cdSRiver Riddle 548a1ca2cdSRiver Riddle // If the attribute has a type, add a type constraint. 558a1ca2cdSRiver Riddle if (Value type = attr.type()) { 568a1ca2cdSRiver Riddle getTreePredicates(predList, type, builder, inputs, builder.getType(pos)); 578a1ca2cdSRiver Riddle 588a1ca2cdSRiver Riddle // Check for a constant value of the attribute. 598a1ca2cdSRiver Riddle } else if (Optional<Attribute> value = attr.value()) { 608a1ca2cdSRiver Riddle predList.emplace_back(pos, builder.getAttributeConstraint(*value)); 618a1ca2cdSRiver Riddle } 628a1ca2cdSRiver Riddle break; 638a1ca2cdSRiver Riddle } 648a1ca2cdSRiver Riddle case Predicates::OperandPos: { 658a1ca2cdSRiver Riddle assert(val.getType().isa<pdl::ValueType>() && "expected value type"); 668a1ca2cdSRiver Riddle 678a1ca2cdSRiver Riddle // Prevent traversal into a null value. 688a1ca2cdSRiver Riddle predList.emplace_back(pos, builder.getIsNotNull()); 698a1ca2cdSRiver Riddle 708a1ca2cdSRiver Riddle // If this is a typed input, add a type constraint. 718a1ca2cdSRiver Riddle if (auto in = val.getDefiningOp<pdl::InputOp>()) { 728a1ca2cdSRiver Riddle if (Value type = in.type()) { 738a1ca2cdSRiver Riddle getTreePredicates(predList, type, builder, inputs, 748a1ca2cdSRiver Riddle builder.getType(pos)); 758a1ca2cdSRiver Riddle } 768a1ca2cdSRiver Riddle 778a1ca2cdSRiver Riddle // Otherwise, recurse into the parent node. 788a1ca2cdSRiver Riddle } else if (auto parentOp = val.getDefiningOp<pdl::OperationOp>()) { 798a1ca2cdSRiver Riddle getTreePredicates(predList, parentOp.op(), builder, inputs, 808a1ca2cdSRiver Riddle builder.getParent(cast<OperandPosition>(pos))); 818a1ca2cdSRiver Riddle } 828a1ca2cdSRiver Riddle break; 838a1ca2cdSRiver Riddle } 848a1ca2cdSRiver Riddle case Predicates::OperationPos: { 858a1ca2cdSRiver Riddle assert(val.getType().isa<pdl::OperationType>() && "expected operation"); 868a1ca2cdSRiver Riddle pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp()); 878a1ca2cdSRiver Riddle OperationPosition *opPos = cast<OperationPosition>(pos); 888a1ca2cdSRiver Riddle 898a1ca2cdSRiver Riddle // Ensure getDefiningOp returns a non-null operation. 908a1ca2cdSRiver Riddle if (!opPos->isRoot()) 918a1ca2cdSRiver Riddle predList.emplace_back(pos, builder.getIsNotNull()); 928a1ca2cdSRiver Riddle 938a1ca2cdSRiver Riddle // Check that this is the correct root operation. 948a1ca2cdSRiver Riddle if (Optional<StringRef> opName = op.name()) 958a1ca2cdSRiver Riddle predList.emplace_back(pos, builder.getOperationName(*opName)); 968a1ca2cdSRiver Riddle 978a1ca2cdSRiver Riddle // Check that the operation has the proper number of operands and results. 988a1ca2cdSRiver Riddle OperandRange operands = op.operands(); 998a1ca2cdSRiver Riddle ResultRange results = op.results(); 1008a1ca2cdSRiver Riddle predList.emplace_back(pos, builder.getOperandCount(operands.size())); 1018a1ca2cdSRiver Riddle predList.emplace_back(pos, builder.getResultCount(results.size())); 1028a1ca2cdSRiver Riddle 1038a1ca2cdSRiver Riddle // Recurse into any attributes, operands, or results. 1048a1ca2cdSRiver Riddle for (auto it : llvm::zip(op.attributeNames(), op.attributes())) { 1058a1ca2cdSRiver Riddle getTreePredicates( 1068a1ca2cdSRiver Riddle predList, std::get<1>(it), builder, inputs, 1078a1ca2cdSRiver Riddle builder.getAttribute(opPos, 1088a1ca2cdSRiver Riddle std::get<0>(it).cast<StringAttr>().getValue())); 1098a1ca2cdSRiver Riddle } 1108a1ca2cdSRiver Riddle for (auto operandIt : llvm::enumerate(operands)) 1118a1ca2cdSRiver Riddle getTreePredicates(predList, operandIt.value(), builder, inputs, 1128a1ca2cdSRiver Riddle builder.getOperand(opPos, operandIt.index())); 1138a1ca2cdSRiver Riddle 1148a1ca2cdSRiver Riddle // Only recurse into results that are not referenced in the source tree. 1158a1ca2cdSRiver Riddle for (auto resultIt : llvm::enumerate(results)) { 1168a1ca2cdSRiver Riddle getTreePredicates(predList, resultIt.value(), builder, inputs, 1178a1ca2cdSRiver Riddle builder.getResult(opPos, resultIt.index())); 1188a1ca2cdSRiver Riddle } 1198a1ca2cdSRiver Riddle break; 1208a1ca2cdSRiver Riddle } 1218a1ca2cdSRiver Riddle case Predicates::ResultPos: { 1228a1ca2cdSRiver Riddle assert(val.getType().isa<pdl::ValueType>() && "expected value type"); 1238a1ca2cdSRiver Riddle pdl::OperationOp parentOp = cast<pdl::OperationOp>(val.getDefiningOp()); 1248a1ca2cdSRiver Riddle 1258a1ca2cdSRiver Riddle // Prevent traversing a null value. 1268a1ca2cdSRiver Riddle predList.emplace_back(pos, builder.getIsNotNull()); 1278a1ca2cdSRiver Riddle 1288a1ca2cdSRiver Riddle // Traverse the type constraint. 1298a1ca2cdSRiver Riddle unsigned resultNo = cast<ResultPosition>(pos)->getResultNumber(); 1308a1ca2cdSRiver Riddle getTreePredicates(predList, parentOp.types()[resultNo], builder, inputs, 1318a1ca2cdSRiver Riddle builder.getType(pos)); 1328a1ca2cdSRiver Riddle break; 1338a1ca2cdSRiver Riddle } 1348a1ca2cdSRiver Riddle case Predicates::TypePos: { 1358a1ca2cdSRiver Riddle assert(val.getType().isa<pdl::TypeType>() && "expected value type"); 1368a1ca2cdSRiver Riddle pdl::TypeOp typeOp = cast<pdl::TypeOp>(val.getDefiningOp()); 1378a1ca2cdSRiver Riddle 1388a1ca2cdSRiver Riddle // Check for a constraint on a constant type. 1398a1ca2cdSRiver Riddle if (Optional<Type> type = typeOp.type()) 1408a1ca2cdSRiver Riddle predList.emplace_back(pos, builder.getTypeConstraint(*type)); 1418a1ca2cdSRiver Riddle break; 1428a1ca2cdSRiver Riddle } 1438a1ca2cdSRiver Riddle default: 1448a1ca2cdSRiver Riddle llvm_unreachable("unknown position kind"); 1458a1ca2cdSRiver Riddle } 1468a1ca2cdSRiver Riddle } 1478a1ca2cdSRiver Riddle 1488a1ca2cdSRiver Riddle /// Collect all of the predicates related to constraints within the given 1498a1ca2cdSRiver Riddle /// pattern operation. 1508a1ca2cdSRiver Riddle static void collectConstraintPredicates( 1518a1ca2cdSRiver Riddle pdl::PatternOp pattern, std::vector<PositionalPredicate> &predList, 1528a1ca2cdSRiver Riddle PredicateBuilder &builder, DenseMap<Value, Position *> &inputs) { 1538a1ca2cdSRiver Riddle for (auto op : pattern.body().getOps<pdl::ApplyConstraintOp>()) { 1548a1ca2cdSRiver Riddle OperandRange arguments = op.args(); 1558a1ca2cdSRiver Riddle ArrayAttr parameters = op.constParamsAttr(); 1568a1ca2cdSRiver Riddle 1578a1ca2cdSRiver Riddle std::vector<Position *> allPositions; 1588a1ca2cdSRiver Riddle allPositions.reserve(arguments.size()); 1598a1ca2cdSRiver Riddle for (Value arg : arguments) 1608a1ca2cdSRiver Riddle allPositions.push_back(inputs.lookup(arg)); 1618a1ca2cdSRiver Riddle 1628a1ca2cdSRiver Riddle // Push the constraint to the furthest position. 1638a1ca2cdSRiver Riddle Position *pos = *std::max_element(allPositions.begin(), allPositions.end(), 1648a1ca2cdSRiver Riddle comparePosDepth); 1658a1ca2cdSRiver Riddle PredicateBuilder::Predicate pred = 1668a1ca2cdSRiver Riddle builder.getConstraint(op.name(), std::move(allPositions), parameters); 1678a1ca2cdSRiver Riddle predList.emplace_back(pos, pred); 1688a1ca2cdSRiver Riddle } 1698a1ca2cdSRiver Riddle } 1708a1ca2cdSRiver Riddle 1718a1ca2cdSRiver Riddle /// Given a pattern operation, build the set of matcher predicates necessary to 1728a1ca2cdSRiver Riddle /// match this pattern. 1738a1ca2cdSRiver Riddle static void buildPredicateList(pdl::PatternOp pattern, 1748a1ca2cdSRiver Riddle PredicateBuilder &builder, 1758a1ca2cdSRiver Riddle std::vector<PositionalPredicate> &predList, 1768a1ca2cdSRiver Riddle DenseMap<Value, Position *> &valueToPosition) { 1778a1ca2cdSRiver Riddle getTreePredicates(predList, pattern.getRewriter().root(), builder, 1788a1ca2cdSRiver Riddle valueToPosition, builder.getRoot()); 1798a1ca2cdSRiver Riddle collectConstraintPredicates(pattern, predList, builder, valueToPosition); 1808a1ca2cdSRiver Riddle } 1818a1ca2cdSRiver Riddle 1828a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 1838a1ca2cdSRiver Riddle // Pattern Predicate Tree Merging 1848a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 1858a1ca2cdSRiver Riddle 1868a1ca2cdSRiver Riddle namespace { 1878a1ca2cdSRiver Riddle 1888a1ca2cdSRiver Riddle /// This class represents a specific predicate applied to a position, and 1898a1ca2cdSRiver Riddle /// provides hashing and ordering operators. This class allows for computing a 1908a1ca2cdSRiver Riddle /// frequence sum and ordering predicates based on a cost model. 1918a1ca2cdSRiver Riddle struct OrderedPredicate { 1928a1ca2cdSRiver Riddle OrderedPredicate(const std::pair<Position *, Qualifier *> &ip) 1938a1ca2cdSRiver Riddle : position(ip.first), question(ip.second) {} 1948a1ca2cdSRiver Riddle OrderedPredicate(const PositionalPredicate &ip) 1958a1ca2cdSRiver Riddle : position(ip.position), question(ip.question) {} 1968a1ca2cdSRiver Riddle 1978a1ca2cdSRiver Riddle /// The position this predicate is applied to. 1988a1ca2cdSRiver Riddle Position *position; 1998a1ca2cdSRiver Riddle 2008a1ca2cdSRiver Riddle /// The question that is applied by this predicate onto the position. 2018a1ca2cdSRiver Riddle Qualifier *question; 2028a1ca2cdSRiver Riddle 2038a1ca2cdSRiver Riddle /// The first and second order benefit sums. 2048a1ca2cdSRiver Riddle /// The primary sum is the number of occurrences of this predicate among all 2058a1ca2cdSRiver Riddle /// of the patterns. 2068a1ca2cdSRiver Riddle unsigned primary = 0; 2078a1ca2cdSRiver Riddle /// The secondary sum is a squared summation of the primary sum of all of the 2088a1ca2cdSRiver Riddle /// predicates within each pattern that contains this predicate. This allows 2098a1ca2cdSRiver Riddle /// for favoring predicates that are more commonly shared within a pattern, as 2108a1ca2cdSRiver Riddle /// opposed to those shared across patterns. 2118a1ca2cdSRiver Riddle unsigned secondary = 0; 2128a1ca2cdSRiver Riddle 2138a1ca2cdSRiver Riddle /// A map between a pattern operation and the answer to the predicate question 2148a1ca2cdSRiver Riddle /// within that pattern. 2158a1ca2cdSRiver Riddle DenseMap<Operation *, Qualifier *> patternToAnswer; 2168a1ca2cdSRiver Riddle 2178a1ca2cdSRiver Riddle /// Returns true if this predicate is ordered before `other`, based on the 2188a1ca2cdSRiver Riddle /// cost model. 2198a1ca2cdSRiver Riddle bool operator<(const OrderedPredicate &other) const { 2208a1ca2cdSRiver Riddle // Sort by: 2218a1ca2cdSRiver Riddle // * first and secondary order sums 2228a1ca2cdSRiver Riddle // * lower depth 2238a1ca2cdSRiver Riddle // * position dependency 2248a1ca2cdSRiver Riddle // * predicate dependency. 2258a1ca2cdSRiver Riddle auto *otherPos = other.position; 2268a1ca2cdSRiver Riddle return std::make_tuple(other.primary, other.secondary, 2278a1ca2cdSRiver Riddle otherPos->getIndex().size(), otherPos->getKind(), 2288a1ca2cdSRiver Riddle other.question->getKind()) > 2298a1ca2cdSRiver Riddle std::make_tuple(primary, secondary, position->getIndex().size(), 2308a1ca2cdSRiver Riddle position->getKind(), question->getKind()); 2318a1ca2cdSRiver Riddle } 2328a1ca2cdSRiver Riddle }; 2338a1ca2cdSRiver Riddle 2348a1ca2cdSRiver Riddle /// A DenseMapInfo for OrderedPredicate based solely on the position and 2358a1ca2cdSRiver Riddle /// question. 2368a1ca2cdSRiver Riddle struct OrderedPredicateDenseInfo { 2378a1ca2cdSRiver Riddle using Base = DenseMapInfo<std::pair<Position *, Qualifier *>>; 2388a1ca2cdSRiver Riddle 2398a1ca2cdSRiver Riddle static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); } 2408a1ca2cdSRiver Riddle static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); } 2418a1ca2cdSRiver Riddle static bool isEqual(const OrderedPredicate &lhs, 2428a1ca2cdSRiver Riddle const OrderedPredicate &rhs) { 2438a1ca2cdSRiver Riddle return lhs.position == rhs.position && lhs.question == rhs.question; 2448a1ca2cdSRiver Riddle } 2458a1ca2cdSRiver Riddle static unsigned getHashValue(const OrderedPredicate &p) { 2468a1ca2cdSRiver Riddle return llvm::hash_combine(p.position, p.question); 2478a1ca2cdSRiver Riddle } 2488a1ca2cdSRiver Riddle }; 2498a1ca2cdSRiver Riddle 2508a1ca2cdSRiver Riddle /// This class wraps a set of ordered predicates that are used within a specific 2518a1ca2cdSRiver Riddle /// pattern operation. 2528a1ca2cdSRiver Riddle struct OrderedPredicateList { 2538a1ca2cdSRiver Riddle OrderedPredicateList(pdl::PatternOp pattern) : pattern(pattern) {} 2548a1ca2cdSRiver Riddle 2558a1ca2cdSRiver Riddle pdl::PatternOp pattern; 2568a1ca2cdSRiver Riddle DenseSet<OrderedPredicate *> predicates; 2578a1ca2cdSRiver Riddle }; 2588a1ca2cdSRiver Riddle } // end anonymous namespace 2598a1ca2cdSRiver Riddle 2608a1ca2cdSRiver Riddle /// Returns true if the given matcher refers to the same predicate as the given 2618a1ca2cdSRiver Riddle /// ordered predicate. This means that the position and questions of the two 2628a1ca2cdSRiver Riddle /// match. 2638a1ca2cdSRiver Riddle static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) { 2648a1ca2cdSRiver Riddle return node->getPosition() == predicate->position && 2658a1ca2cdSRiver Riddle node->getQuestion() == predicate->question; 2668a1ca2cdSRiver Riddle } 2678a1ca2cdSRiver Riddle 2688a1ca2cdSRiver Riddle /// Get or insert a child matcher for the given parent switch node, given a 2698a1ca2cdSRiver Riddle /// predicate and parent pattern. 2708a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node, 2718a1ca2cdSRiver Riddle OrderedPredicate *predicate, 2728a1ca2cdSRiver Riddle pdl::PatternOp pattern) { 2738a1ca2cdSRiver Riddle assert(isSamePredicate(node, predicate) && 2748a1ca2cdSRiver Riddle "expected matcher to equal the given predicate"); 2758a1ca2cdSRiver Riddle 2768a1ca2cdSRiver Riddle auto it = predicate->patternToAnswer.find(pattern); 2778a1ca2cdSRiver Riddle assert(it != predicate->patternToAnswer.end() && 2788a1ca2cdSRiver Riddle "expected pattern to exist in predicate"); 2798a1ca2cdSRiver Riddle return node->getChildren().insert({it->second, nullptr}).first->second; 2808a1ca2cdSRiver Riddle } 2818a1ca2cdSRiver Riddle 2828a1ca2cdSRiver Riddle /// Build the matcher CFG by "pushing" patterns through by sorted predicate 2838a1ca2cdSRiver Riddle /// order. A pattern will traverse as far as possible using common predicates 2848a1ca2cdSRiver Riddle /// and then either diverge from the CFG or reach the end of a branch and start 2858a1ca2cdSRiver Riddle /// creating new nodes. 2868a1ca2cdSRiver Riddle static void propagatePattern(std::unique_ptr<MatcherNode> &node, 2878a1ca2cdSRiver Riddle OrderedPredicateList &list, 2888a1ca2cdSRiver Riddle std::vector<OrderedPredicate *>::iterator current, 2898a1ca2cdSRiver Riddle std::vector<OrderedPredicate *>::iterator end) { 2908a1ca2cdSRiver Riddle if (current == end) { 2918a1ca2cdSRiver Riddle // We've hit the end of a pattern, so create a successful result node. 2928a1ca2cdSRiver Riddle node = std::make_unique<SuccessNode>(list.pattern, std::move(node)); 2938a1ca2cdSRiver Riddle 2948a1ca2cdSRiver Riddle // If the pattern doesn't contain this predicate, ignore it. 2958a1ca2cdSRiver Riddle } else if (list.predicates.find(*current) == list.predicates.end()) { 2968a1ca2cdSRiver Riddle propagatePattern(node, list, std::next(current), end); 2978a1ca2cdSRiver Riddle 2988a1ca2cdSRiver Riddle // If the current matcher node is invalid, create a new one for this 2998a1ca2cdSRiver Riddle // position and continue propagation. 3008a1ca2cdSRiver Riddle } else if (!node) { 3018a1ca2cdSRiver Riddle // Create a new node at this position and continue 3028a1ca2cdSRiver Riddle node = std::make_unique<SwitchNode>((*current)->position, 3038a1ca2cdSRiver Riddle (*current)->question); 3048a1ca2cdSRiver Riddle propagatePattern( 3058a1ca2cdSRiver Riddle getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern), 3068a1ca2cdSRiver Riddle list, std::next(current), end); 3078a1ca2cdSRiver Riddle 3088a1ca2cdSRiver Riddle // If the matcher has already been created, and it is for this predicate we 3098a1ca2cdSRiver Riddle // continue propagation to the child. 3108a1ca2cdSRiver Riddle } else if (isSamePredicate(node.get(), *current)) { 3118a1ca2cdSRiver Riddle propagatePattern( 3128a1ca2cdSRiver Riddle getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern), 3138a1ca2cdSRiver Riddle list, std::next(current), end); 3148a1ca2cdSRiver Riddle 3158a1ca2cdSRiver Riddle // If the matcher doesn't match the current predicate, insert a branch as 3168a1ca2cdSRiver Riddle // the common set of matchers has diverged. 3178a1ca2cdSRiver Riddle } else { 3188a1ca2cdSRiver Riddle propagatePattern(node->getFailureNode(), list, current, end); 3198a1ca2cdSRiver Riddle } 3208a1ca2cdSRiver Riddle } 3218a1ca2cdSRiver Riddle 3228a1ca2cdSRiver Riddle /// Fold any switch nodes nested under `node` to boolean nodes when possible. 3238a1ca2cdSRiver Riddle /// `node` is updated in-place if it is a switch. 3248a1ca2cdSRiver Riddle static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) { 3258a1ca2cdSRiver Riddle if (!node) 3268a1ca2cdSRiver Riddle return; 3278a1ca2cdSRiver Riddle 3288a1ca2cdSRiver Riddle if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) { 3298a1ca2cdSRiver Riddle SwitchNode::ChildMapT &children = switchNode->getChildren(); 3308a1ca2cdSRiver Riddle for (auto &it : children) 3318a1ca2cdSRiver Riddle foldSwitchToBool(it.second); 3328a1ca2cdSRiver Riddle 3338a1ca2cdSRiver Riddle // If the node only contains one child, collapse it into a boolean predicate 3348a1ca2cdSRiver Riddle // node. 3358a1ca2cdSRiver Riddle if (children.size() == 1) { 3368a1ca2cdSRiver Riddle auto childIt = children.begin(); 3378a1ca2cdSRiver Riddle node = std::make_unique<BoolNode>( 3388a1ca2cdSRiver Riddle node->getPosition(), node->getQuestion(), childIt->first, 3398a1ca2cdSRiver Riddle std::move(childIt->second), std::move(node->getFailureNode())); 3408a1ca2cdSRiver Riddle } 3418a1ca2cdSRiver Riddle } else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) { 3428a1ca2cdSRiver Riddle foldSwitchToBool(boolNode->getSuccessNode()); 3438a1ca2cdSRiver Riddle } 3448a1ca2cdSRiver Riddle 3458a1ca2cdSRiver Riddle foldSwitchToBool(node->getFailureNode()); 3468a1ca2cdSRiver Riddle } 3478a1ca2cdSRiver Riddle 3488a1ca2cdSRiver Riddle /// Insert an exit node at the end of the failure path of the `root`. 3498a1ca2cdSRiver Riddle static void insertExitNode(std::unique_ptr<MatcherNode> *root) { 3508a1ca2cdSRiver Riddle while (*root) 3518a1ca2cdSRiver Riddle root = &(*root)->getFailureNode(); 3528a1ca2cdSRiver Riddle *root = std::make_unique<ExitNode>(); 3538a1ca2cdSRiver Riddle } 3548a1ca2cdSRiver Riddle 3558a1ca2cdSRiver Riddle /// Given a module containing PDL pattern operations, generate a matcher tree 3568a1ca2cdSRiver Riddle /// using the patterns within the given module and return the root matcher node. 3578a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> 3588a1ca2cdSRiver Riddle MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder, 3598a1ca2cdSRiver Riddle DenseMap<Value, Position *> &valueToPosition) { 3608a1ca2cdSRiver Riddle // Collect the set of predicates contained within the pattern operations of 3618a1ca2cdSRiver Riddle // the module. 3628a1ca2cdSRiver Riddle SmallVector<std::pair<pdl::PatternOp, std::vector<PositionalPredicate>>, 16> 3638a1ca2cdSRiver Riddle patternsAndPredicates; 3648a1ca2cdSRiver Riddle for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) { 3658a1ca2cdSRiver Riddle std::vector<PositionalPredicate> predicateList; 3668a1ca2cdSRiver Riddle buildPredicateList(pattern, builder, predicateList, valueToPosition); 3678a1ca2cdSRiver Riddle patternsAndPredicates.emplace_back(pattern, std::move(predicateList)); 3688a1ca2cdSRiver Riddle } 3698a1ca2cdSRiver Riddle 3708a1ca2cdSRiver Riddle // Associate a pattern result with each unique predicate. 3718a1ca2cdSRiver Riddle DenseSet<OrderedPredicate, OrderedPredicateDenseInfo> uniqued; 3728a1ca2cdSRiver Riddle for (auto &patternAndPredList : patternsAndPredicates) { 3738a1ca2cdSRiver Riddle for (auto &predicate : patternAndPredList.second) { 3748a1ca2cdSRiver Riddle auto it = uniqued.insert(predicate); 3758a1ca2cdSRiver Riddle it.first->patternToAnswer.try_emplace(patternAndPredList.first, 3768a1ca2cdSRiver Riddle predicate.answer); 3778a1ca2cdSRiver Riddle } 3788a1ca2cdSRiver Riddle } 3798a1ca2cdSRiver Riddle 3808a1ca2cdSRiver Riddle // Associate each pattern to a set of its ordered predicates for later lookup. 3818a1ca2cdSRiver Riddle std::vector<OrderedPredicateList> lists; 3828a1ca2cdSRiver Riddle lists.reserve(patternsAndPredicates.size()); 3838a1ca2cdSRiver Riddle for (auto &patternAndPredList : patternsAndPredicates) { 3848a1ca2cdSRiver Riddle OrderedPredicateList list(patternAndPredList.first); 3858a1ca2cdSRiver Riddle for (auto &predicate : patternAndPredList.second) { 3868a1ca2cdSRiver Riddle OrderedPredicate *orderedPredicate = &*uniqued.find(predicate); 3878a1ca2cdSRiver Riddle list.predicates.insert(orderedPredicate); 3888a1ca2cdSRiver Riddle 3898a1ca2cdSRiver Riddle // Increment the primary sum for each reference to a particular predicate. 3908a1ca2cdSRiver Riddle ++orderedPredicate->primary; 3918a1ca2cdSRiver Riddle } 3928a1ca2cdSRiver Riddle lists.push_back(std::move(list)); 3938a1ca2cdSRiver Riddle } 3948a1ca2cdSRiver Riddle 3958a1ca2cdSRiver Riddle // For a particular pattern, get the total primary sum and add it to the 3968a1ca2cdSRiver Riddle // secondary sum of each predicate. Square the primary sums to emphasize 3978a1ca2cdSRiver Riddle // shared predicates within rather than across patterns. 3988a1ca2cdSRiver Riddle for (auto &list : lists) { 3998a1ca2cdSRiver Riddle unsigned total = 0; 4008a1ca2cdSRiver Riddle for (auto *predicate : list.predicates) 4018a1ca2cdSRiver Riddle total += predicate->primary * predicate->primary; 4028a1ca2cdSRiver Riddle for (auto *predicate : list.predicates) 4038a1ca2cdSRiver Riddle predicate->secondary += total; 4048a1ca2cdSRiver Riddle } 4058a1ca2cdSRiver Riddle 4068a1ca2cdSRiver Riddle // Sort the set of predicates now that the cost primary and secondary sums 4078a1ca2cdSRiver Riddle // have been computed. 4088a1ca2cdSRiver Riddle std::vector<OrderedPredicate *> ordered; 4098a1ca2cdSRiver Riddle ordered.reserve(uniqued.size()); 4108a1ca2cdSRiver Riddle for (auto &ip : uniqued) 4118a1ca2cdSRiver Riddle ordered.push_back(&ip); 4128a1ca2cdSRiver Riddle std::stable_sort( 4138a1ca2cdSRiver Riddle ordered.begin(), ordered.end(), 4148a1ca2cdSRiver Riddle [](OrderedPredicate *lhs, OrderedPredicate *rhs) { return *lhs < *rhs; }); 4158a1ca2cdSRiver Riddle 4168a1ca2cdSRiver Riddle // Build the matchers for each of the pattern predicate lists. 4178a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> root; 4188a1ca2cdSRiver Riddle for (OrderedPredicateList &list : lists) 4198a1ca2cdSRiver Riddle propagatePattern(root, list, ordered.begin(), ordered.end()); 4208a1ca2cdSRiver Riddle 4218a1ca2cdSRiver Riddle // Collapse the graph and insert the exit node. 4228a1ca2cdSRiver Riddle foldSwitchToBool(root); 4238a1ca2cdSRiver Riddle insertExitNode(&root); 4248a1ca2cdSRiver Riddle return root; 4258a1ca2cdSRiver Riddle } 4268a1ca2cdSRiver Riddle 4278a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 4288a1ca2cdSRiver Riddle // MatcherNode 4298a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 4308a1ca2cdSRiver Riddle 4318a1ca2cdSRiver Riddle MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q, 4328a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> failureNode) 4338a1ca2cdSRiver Riddle : position(p), question(q), failureNode(std::move(failureNode)), 4348a1ca2cdSRiver Riddle matcherTypeID(matcherTypeID) {} 4358a1ca2cdSRiver Riddle 4368a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 4378a1ca2cdSRiver Riddle // BoolNode 4388a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 4398a1ca2cdSRiver Riddle 4408a1ca2cdSRiver Riddle BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer, 4418a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> successNode, 4428a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> failureNode) 4438a1ca2cdSRiver Riddle : MatcherNode(TypeID::get<BoolNode>(), position, question, 4448a1ca2cdSRiver Riddle std::move(failureNode)), 4458a1ca2cdSRiver Riddle answer(answer), successNode(std::move(successNode)) {} 4468a1ca2cdSRiver Riddle 4478a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 4488a1ca2cdSRiver Riddle // SuccessNode 4498a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 4508a1ca2cdSRiver Riddle 4518a1ca2cdSRiver Riddle SuccessNode::SuccessNode(pdl::PatternOp pattern, 4528a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> failureNode) 4538a1ca2cdSRiver Riddle : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr, 4548a1ca2cdSRiver Riddle /*question=*/nullptr, std::move(failureNode)), 4558a1ca2cdSRiver Riddle pattern(pattern) {} 4568a1ca2cdSRiver Riddle 4578a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 4588a1ca2cdSRiver Riddle // SwitchNode 4598a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 4608a1ca2cdSRiver Riddle 4618a1ca2cdSRiver Riddle SwitchNode::SwitchNode(Position *position, Qualifier *question) 4628a1ca2cdSRiver Riddle : MatcherNode(TypeID::get<SwitchNode>(), position, question) {} 463