1*8a1ca2cdSRiver Riddle //===- PredicateTree.cpp - Predicate tree merging -------------------------===// 2*8a1ca2cdSRiver Riddle // 3*8a1ca2cdSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*8a1ca2cdSRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 5*8a1ca2cdSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*8a1ca2cdSRiver Riddle // 7*8a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 8*8a1ca2cdSRiver Riddle 9*8a1ca2cdSRiver Riddle #include "PredicateTree.h" 10*8a1ca2cdSRiver Riddle #include "mlir/Dialect/PDL/IR/PDL.h" 11*8a1ca2cdSRiver Riddle #include "mlir/Dialect/PDL/IR/PDLTypes.h" 12*8a1ca2cdSRiver Riddle #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" 13*8a1ca2cdSRiver Riddle #include "mlir/IR/Module.h" 14*8a1ca2cdSRiver Riddle #include "mlir/Interfaces/InferTypeOpInterface.h" 15*8a1ca2cdSRiver Riddle 16*8a1ca2cdSRiver Riddle using namespace mlir; 17*8a1ca2cdSRiver Riddle using namespace mlir::pdl_to_pdl_interp; 18*8a1ca2cdSRiver Riddle 19*8a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 20*8a1ca2cdSRiver Riddle // Predicate List Building 21*8a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 22*8a1ca2cdSRiver Riddle 23*8a1ca2cdSRiver Riddle /// Compares the depths of two positions. 24*8a1ca2cdSRiver Riddle static bool comparePosDepth(Position *lhs, Position *rhs) { 25*8a1ca2cdSRiver Riddle return lhs->getIndex().size() < rhs->getIndex().size(); 26*8a1ca2cdSRiver Riddle } 27*8a1ca2cdSRiver Riddle 28*8a1ca2cdSRiver Riddle /// Collect the tree predicates anchored at the given value. 29*8a1ca2cdSRiver Riddle static void getTreePredicates(std::vector<PositionalPredicate> &predList, 30*8a1ca2cdSRiver Riddle Value val, PredicateBuilder &builder, 31*8a1ca2cdSRiver Riddle DenseMap<Value, Position *> &inputs, 32*8a1ca2cdSRiver Riddle Position *pos) { 33*8a1ca2cdSRiver Riddle // Make sure this input value is accessible to the rewrite. 34*8a1ca2cdSRiver Riddle auto it = inputs.try_emplace(val, pos); 35*8a1ca2cdSRiver Riddle 36*8a1ca2cdSRiver Riddle // If this is an input value that has been visited in the tree, add a 37*8a1ca2cdSRiver Riddle // constraint to ensure that both instances refer to the same value. 38*8a1ca2cdSRiver Riddle if (!it.second && 39*8a1ca2cdSRiver Riddle isa<pdl::AttributeOp, pdl::InputOp, pdl::TypeOp>(val.getDefiningOp())) { 40*8a1ca2cdSRiver Riddle auto minMaxPositions = std::minmax(pos, it.first->second, comparePosDepth); 41*8a1ca2cdSRiver Riddle predList.emplace_back(minMaxPositions.second, 42*8a1ca2cdSRiver Riddle builder.getEqualTo(minMaxPositions.first)); 43*8a1ca2cdSRiver Riddle return; 44*8a1ca2cdSRiver Riddle } 45*8a1ca2cdSRiver Riddle 46*8a1ca2cdSRiver Riddle // Check for a per-position predicate to apply. 47*8a1ca2cdSRiver Riddle switch (pos->getKind()) { 48*8a1ca2cdSRiver Riddle case Predicates::AttributePos: { 49*8a1ca2cdSRiver Riddle assert(val.getType().isa<pdl::AttributeType>() && 50*8a1ca2cdSRiver Riddle "expected attribute type"); 51*8a1ca2cdSRiver Riddle pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp()); 52*8a1ca2cdSRiver Riddle predList.emplace_back(pos, builder.getIsNotNull()); 53*8a1ca2cdSRiver Riddle 54*8a1ca2cdSRiver Riddle // If the attribute has a type, add a type constraint. 55*8a1ca2cdSRiver Riddle if (Value type = attr.type()) { 56*8a1ca2cdSRiver Riddle getTreePredicates(predList, type, builder, inputs, builder.getType(pos)); 57*8a1ca2cdSRiver Riddle 58*8a1ca2cdSRiver Riddle // Check for a constant value of the attribute. 59*8a1ca2cdSRiver Riddle } else if (Optional<Attribute> value = attr.value()) { 60*8a1ca2cdSRiver Riddle predList.emplace_back(pos, builder.getAttributeConstraint(*value)); 61*8a1ca2cdSRiver Riddle } 62*8a1ca2cdSRiver Riddle break; 63*8a1ca2cdSRiver Riddle } 64*8a1ca2cdSRiver Riddle case Predicates::OperandPos: { 65*8a1ca2cdSRiver Riddle assert(val.getType().isa<pdl::ValueType>() && "expected value type"); 66*8a1ca2cdSRiver Riddle 67*8a1ca2cdSRiver Riddle // Prevent traversal into a null value. 68*8a1ca2cdSRiver Riddle predList.emplace_back(pos, builder.getIsNotNull()); 69*8a1ca2cdSRiver Riddle 70*8a1ca2cdSRiver Riddle // If this is a typed input, add a type constraint. 71*8a1ca2cdSRiver Riddle if (auto in = val.getDefiningOp<pdl::InputOp>()) { 72*8a1ca2cdSRiver Riddle if (Value type = in.type()) { 73*8a1ca2cdSRiver Riddle getTreePredicates(predList, type, builder, inputs, 74*8a1ca2cdSRiver Riddle builder.getType(pos)); 75*8a1ca2cdSRiver Riddle } 76*8a1ca2cdSRiver Riddle 77*8a1ca2cdSRiver Riddle // Otherwise, recurse into the parent node. 78*8a1ca2cdSRiver Riddle } else if (auto parentOp = val.getDefiningOp<pdl::OperationOp>()) { 79*8a1ca2cdSRiver Riddle getTreePredicates(predList, parentOp.op(), builder, inputs, 80*8a1ca2cdSRiver Riddle builder.getParent(cast<OperandPosition>(pos))); 81*8a1ca2cdSRiver Riddle } 82*8a1ca2cdSRiver Riddle break; 83*8a1ca2cdSRiver Riddle } 84*8a1ca2cdSRiver Riddle case Predicates::OperationPos: { 85*8a1ca2cdSRiver Riddle assert(val.getType().isa<pdl::OperationType>() && "expected operation"); 86*8a1ca2cdSRiver Riddle pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp()); 87*8a1ca2cdSRiver Riddle OperationPosition *opPos = cast<OperationPosition>(pos); 88*8a1ca2cdSRiver Riddle 89*8a1ca2cdSRiver Riddle // Ensure getDefiningOp returns a non-null operation. 90*8a1ca2cdSRiver Riddle if (!opPos->isRoot()) 91*8a1ca2cdSRiver Riddle predList.emplace_back(pos, builder.getIsNotNull()); 92*8a1ca2cdSRiver Riddle 93*8a1ca2cdSRiver Riddle // Check that this is the correct root operation. 94*8a1ca2cdSRiver Riddle if (Optional<StringRef> opName = op.name()) 95*8a1ca2cdSRiver Riddle predList.emplace_back(pos, builder.getOperationName(*opName)); 96*8a1ca2cdSRiver Riddle 97*8a1ca2cdSRiver Riddle // Check that the operation has the proper number of operands and results. 98*8a1ca2cdSRiver Riddle OperandRange operands = op.operands(); 99*8a1ca2cdSRiver Riddle ResultRange results = op.results(); 100*8a1ca2cdSRiver Riddle predList.emplace_back(pos, builder.getOperandCount(operands.size())); 101*8a1ca2cdSRiver Riddle predList.emplace_back(pos, builder.getResultCount(results.size())); 102*8a1ca2cdSRiver Riddle 103*8a1ca2cdSRiver Riddle // Recurse into any attributes, operands, or results. 104*8a1ca2cdSRiver Riddle for (auto it : llvm::zip(op.attributeNames(), op.attributes())) { 105*8a1ca2cdSRiver Riddle getTreePredicates( 106*8a1ca2cdSRiver Riddle predList, std::get<1>(it), builder, inputs, 107*8a1ca2cdSRiver Riddle builder.getAttribute(opPos, 108*8a1ca2cdSRiver Riddle std::get<0>(it).cast<StringAttr>().getValue())); 109*8a1ca2cdSRiver Riddle } 110*8a1ca2cdSRiver Riddle for (auto operandIt : llvm::enumerate(operands)) 111*8a1ca2cdSRiver Riddle getTreePredicates(predList, operandIt.value(), builder, inputs, 112*8a1ca2cdSRiver Riddle builder.getOperand(opPos, operandIt.index())); 113*8a1ca2cdSRiver Riddle 114*8a1ca2cdSRiver Riddle // Only recurse into results that are not referenced in the source tree. 115*8a1ca2cdSRiver Riddle for (auto resultIt : llvm::enumerate(results)) { 116*8a1ca2cdSRiver Riddle getTreePredicates(predList, resultIt.value(), builder, inputs, 117*8a1ca2cdSRiver Riddle builder.getResult(opPos, resultIt.index())); 118*8a1ca2cdSRiver Riddle } 119*8a1ca2cdSRiver Riddle break; 120*8a1ca2cdSRiver Riddle } 121*8a1ca2cdSRiver Riddle case Predicates::ResultPos: { 122*8a1ca2cdSRiver Riddle assert(val.getType().isa<pdl::ValueType>() && "expected value type"); 123*8a1ca2cdSRiver Riddle pdl::OperationOp parentOp = cast<pdl::OperationOp>(val.getDefiningOp()); 124*8a1ca2cdSRiver Riddle 125*8a1ca2cdSRiver Riddle // Prevent traversing a null value. 126*8a1ca2cdSRiver Riddle predList.emplace_back(pos, builder.getIsNotNull()); 127*8a1ca2cdSRiver Riddle 128*8a1ca2cdSRiver Riddle // Traverse the type constraint. 129*8a1ca2cdSRiver Riddle unsigned resultNo = cast<ResultPosition>(pos)->getResultNumber(); 130*8a1ca2cdSRiver Riddle getTreePredicates(predList, parentOp.types()[resultNo], builder, inputs, 131*8a1ca2cdSRiver Riddle builder.getType(pos)); 132*8a1ca2cdSRiver Riddle break; 133*8a1ca2cdSRiver Riddle } 134*8a1ca2cdSRiver Riddle case Predicates::TypePos: { 135*8a1ca2cdSRiver Riddle assert(val.getType().isa<pdl::TypeType>() && "expected value type"); 136*8a1ca2cdSRiver Riddle pdl::TypeOp typeOp = cast<pdl::TypeOp>(val.getDefiningOp()); 137*8a1ca2cdSRiver Riddle 138*8a1ca2cdSRiver Riddle // Check for a constraint on a constant type. 139*8a1ca2cdSRiver Riddle if (Optional<Type> type = typeOp.type()) 140*8a1ca2cdSRiver Riddle predList.emplace_back(pos, builder.getTypeConstraint(*type)); 141*8a1ca2cdSRiver Riddle break; 142*8a1ca2cdSRiver Riddle } 143*8a1ca2cdSRiver Riddle default: 144*8a1ca2cdSRiver Riddle llvm_unreachable("unknown position kind"); 145*8a1ca2cdSRiver Riddle } 146*8a1ca2cdSRiver Riddle } 147*8a1ca2cdSRiver Riddle 148*8a1ca2cdSRiver Riddle /// Collect all of the predicates related to constraints within the given 149*8a1ca2cdSRiver Riddle /// pattern operation. 150*8a1ca2cdSRiver Riddle static void collectConstraintPredicates( 151*8a1ca2cdSRiver Riddle pdl::PatternOp pattern, std::vector<PositionalPredicate> &predList, 152*8a1ca2cdSRiver Riddle PredicateBuilder &builder, DenseMap<Value, Position *> &inputs) { 153*8a1ca2cdSRiver Riddle for (auto op : pattern.body().getOps<pdl::ApplyConstraintOp>()) { 154*8a1ca2cdSRiver Riddle OperandRange arguments = op.args(); 155*8a1ca2cdSRiver Riddle ArrayAttr parameters = op.constParamsAttr(); 156*8a1ca2cdSRiver Riddle 157*8a1ca2cdSRiver Riddle std::vector<Position *> allPositions; 158*8a1ca2cdSRiver Riddle allPositions.reserve(arguments.size()); 159*8a1ca2cdSRiver Riddle for (Value arg : arguments) 160*8a1ca2cdSRiver Riddle allPositions.push_back(inputs.lookup(arg)); 161*8a1ca2cdSRiver Riddle 162*8a1ca2cdSRiver Riddle // Push the constraint to the furthest position. 163*8a1ca2cdSRiver Riddle Position *pos = *std::max_element(allPositions.begin(), allPositions.end(), 164*8a1ca2cdSRiver Riddle comparePosDepth); 165*8a1ca2cdSRiver Riddle PredicateBuilder::Predicate pred = 166*8a1ca2cdSRiver Riddle builder.getConstraint(op.name(), std::move(allPositions), parameters); 167*8a1ca2cdSRiver Riddle predList.emplace_back(pos, pred); 168*8a1ca2cdSRiver Riddle } 169*8a1ca2cdSRiver Riddle } 170*8a1ca2cdSRiver Riddle 171*8a1ca2cdSRiver Riddle /// Given a pattern operation, build the set of matcher predicates necessary to 172*8a1ca2cdSRiver Riddle /// match this pattern. 173*8a1ca2cdSRiver Riddle static void buildPredicateList(pdl::PatternOp pattern, 174*8a1ca2cdSRiver Riddle PredicateBuilder &builder, 175*8a1ca2cdSRiver Riddle std::vector<PositionalPredicate> &predList, 176*8a1ca2cdSRiver Riddle DenseMap<Value, Position *> &valueToPosition) { 177*8a1ca2cdSRiver Riddle getTreePredicates(predList, pattern.getRewriter().root(), builder, 178*8a1ca2cdSRiver Riddle valueToPosition, builder.getRoot()); 179*8a1ca2cdSRiver Riddle collectConstraintPredicates(pattern, predList, builder, valueToPosition); 180*8a1ca2cdSRiver Riddle } 181*8a1ca2cdSRiver Riddle 182*8a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 183*8a1ca2cdSRiver Riddle // Pattern Predicate Tree Merging 184*8a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 185*8a1ca2cdSRiver Riddle 186*8a1ca2cdSRiver Riddle namespace { 187*8a1ca2cdSRiver Riddle 188*8a1ca2cdSRiver Riddle /// This class represents a specific predicate applied to a position, and 189*8a1ca2cdSRiver Riddle /// provides hashing and ordering operators. This class allows for computing a 190*8a1ca2cdSRiver Riddle /// frequence sum and ordering predicates based on a cost model. 191*8a1ca2cdSRiver Riddle struct OrderedPredicate { 192*8a1ca2cdSRiver Riddle OrderedPredicate(const std::pair<Position *, Qualifier *> &ip) 193*8a1ca2cdSRiver Riddle : position(ip.first), question(ip.second) {} 194*8a1ca2cdSRiver Riddle OrderedPredicate(const PositionalPredicate &ip) 195*8a1ca2cdSRiver Riddle : position(ip.position), question(ip.question) {} 196*8a1ca2cdSRiver Riddle 197*8a1ca2cdSRiver Riddle /// The position this predicate is applied to. 198*8a1ca2cdSRiver Riddle Position *position; 199*8a1ca2cdSRiver Riddle 200*8a1ca2cdSRiver Riddle /// The question that is applied by this predicate onto the position. 201*8a1ca2cdSRiver Riddle Qualifier *question; 202*8a1ca2cdSRiver Riddle 203*8a1ca2cdSRiver Riddle /// The first and second order benefit sums. 204*8a1ca2cdSRiver Riddle /// The primary sum is the number of occurrences of this predicate among all 205*8a1ca2cdSRiver Riddle /// of the patterns. 206*8a1ca2cdSRiver Riddle unsigned primary = 0; 207*8a1ca2cdSRiver Riddle /// The secondary sum is a squared summation of the primary sum of all of the 208*8a1ca2cdSRiver Riddle /// predicates within each pattern that contains this predicate. This allows 209*8a1ca2cdSRiver Riddle /// for favoring predicates that are more commonly shared within a pattern, as 210*8a1ca2cdSRiver Riddle /// opposed to those shared across patterns. 211*8a1ca2cdSRiver Riddle unsigned secondary = 0; 212*8a1ca2cdSRiver Riddle 213*8a1ca2cdSRiver Riddle /// A map between a pattern operation and the answer to the predicate question 214*8a1ca2cdSRiver Riddle /// within that pattern. 215*8a1ca2cdSRiver Riddle DenseMap<Operation *, Qualifier *> patternToAnswer; 216*8a1ca2cdSRiver Riddle 217*8a1ca2cdSRiver Riddle /// Returns true if this predicate is ordered before `other`, based on the 218*8a1ca2cdSRiver Riddle /// cost model. 219*8a1ca2cdSRiver Riddle bool operator<(const OrderedPredicate &other) const { 220*8a1ca2cdSRiver Riddle // Sort by: 221*8a1ca2cdSRiver Riddle // * first and secondary order sums 222*8a1ca2cdSRiver Riddle // * lower depth 223*8a1ca2cdSRiver Riddle // * position dependency 224*8a1ca2cdSRiver Riddle // * predicate dependency. 225*8a1ca2cdSRiver Riddle auto *otherPos = other.position; 226*8a1ca2cdSRiver Riddle return std::make_tuple(other.primary, other.secondary, 227*8a1ca2cdSRiver Riddle otherPos->getIndex().size(), otherPos->getKind(), 228*8a1ca2cdSRiver Riddle other.question->getKind()) > 229*8a1ca2cdSRiver Riddle std::make_tuple(primary, secondary, position->getIndex().size(), 230*8a1ca2cdSRiver Riddle position->getKind(), question->getKind()); 231*8a1ca2cdSRiver Riddle } 232*8a1ca2cdSRiver Riddle }; 233*8a1ca2cdSRiver Riddle 234*8a1ca2cdSRiver Riddle /// A DenseMapInfo for OrderedPredicate based solely on the position and 235*8a1ca2cdSRiver Riddle /// question. 236*8a1ca2cdSRiver Riddle struct OrderedPredicateDenseInfo { 237*8a1ca2cdSRiver Riddle using Base = DenseMapInfo<std::pair<Position *, Qualifier *>>; 238*8a1ca2cdSRiver Riddle 239*8a1ca2cdSRiver Riddle static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); } 240*8a1ca2cdSRiver Riddle static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); } 241*8a1ca2cdSRiver Riddle static bool isEqual(const OrderedPredicate &lhs, 242*8a1ca2cdSRiver Riddle const OrderedPredicate &rhs) { 243*8a1ca2cdSRiver Riddle return lhs.position == rhs.position && lhs.question == rhs.question; 244*8a1ca2cdSRiver Riddle } 245*8a1ca2cdSRiver Riddle static unsigned getHashValue(const OrderedPredicate &p) { 246*8a1ca2cdSRiver Riddle return llvm::hash_combine(p.position, p.question); 247*8a1ca2cdSRiver Riddle } 248*8a1ca2cdSRiver Riddle }; 249*8a1ca2cdSRiver Riddle 250*8a1ca2cdSRiver Riddle /// This class wraps a set of ordered predicates that are used within a specific 251*8a1ca2cdSRiver Riddle /// pattern operation. 252*8a1ca2cdSRiver Riddle struct OrderedPredicateList { 253*8a1ca2cdSRiver Riddle OrderedPredicateList(pdl::PatternOp pattern) : pattern(pattern) {} 254*8a1ca2cdSRiver Riddle 255*8a1ca2cdSRiver Riddle pdl::PatternOp pattern; 256*8a1ca2cdSRiver Riddle DenseSet<OrderedPredicate *> predicates; 257*8a1ca2cdSRiver Riddle }; 258*8a1ca2cdSRiver Riddle } // end anonymous namespace 259*8a1ca2cdSRiver Riddle 260*8a1ca2cdSRiver Riddle /// Returns true if the given matcher refers to the same predicate as the given 261*8a1ca2cdSRiver Riddle /// ordered predicate. This means that the position and questions of the two 262*8a1ca2cdSRiver Riddle /// match. 263*8a1ca2cdSRiver Riddle static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) { 264*8a1ca2cdSRiver Riddle return node->getPosition() == predicate->position && 265*8a1ca2cdSRiver Riddle node->getQuestion() == predicate->question; 266*8a1ca2cdSRiver Riddle } 267*8a1ca2cdSRiver Riddle 268*8a1ca2cdSRiver Riddle /// Get or insert a child matcher for the given parent switch node, given a 269*8a1ca2cdSRiver Riddle /// predicate and parent pattern. 270*8a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node, 271*8a1ca2cdSRiver Riddle OrderedPredicate *predicate, 272*8a1ca2cdSRiver Riddle pdl::PatternOp pattern) { 273*8a1ca2cdSRiver Riddle assert(isSamePredicate(node, predicate) && 274*8a1ca2cdSRiver Riddle "expected matcher to equal the given predicate"); 275*8a1ca2cdSRiver Riddle 276*8a1ca2cdSRiver Riddle auto it = predicate->patternToAnswer.find(pattern); 277*8a1ca2cdSRiver Riddle assert(it != predicate->patternToAnswer.end() && 278*8a1ca2cdSRiver Riddle "expected pattern to exist in predicate"); 279*8a1ca2cdSRiver Riddle return node->getChildren().insert({it->second, nullptr}).first->second; 280*8a1ca2cdSRiver Riddle } 281*8a1ca2cdSRiver Riddle 282*8a1ca2cdSRiver Riddle /// Build the matcher CFG by "pushing" patterns through by sorted predicate 283*8a1ca2cdSRiver Riddle /// order. A pattern will traverse as far as possible using common predicates 284*8a1ca2cdSRiver Riddle /// and then either diverge from the CFG or reach the end of a branch and start 285*8a1ca2cdSRiver Riddle /// creating new nodes. 286*8a1ca2cdSRiver Riddle static void propagatePattern(std::unique_ptr<MatcherNode> &node, 287*8a1ca2cdSRiver Riddle OrderedPredicateList &list, 288*8a1ca2cdSRiver Riddle std::vector<OrderedPredicate *>::iterator current, 289*8a1ca2cdSRiver Riddle std::vector<OrderedPredicate *>::iterator end) { 290*8a1ca2cdSRiver Riddle if (current == end) { 291*8a1ca2cdSRiver Riddle // We've hit the end of a pattern, so create a successful result node. 292*8a1ca2cdSRiver Riddle node = std::make_unique<SuccessNode>(list.pattern, std::move(node)); 293*8a1ca2cdSRiver Riddle 294*8a1ca2cdSRiver Riddle // If the pattern doesn't contain this predicate, ignore it. 295*8a1ca2cdSRiver Riddle } else if (list.predicates.find(*current) == list.predicates.end()) { 296*8a1ca2cdSRiver Riddle propagatePattern(node, list, std::next(current), end); 297*8a1ca2cdSRiver Riddle 298*8a1ca2cdSRiver Riddle // If the current matcher node is invalid, create a new one for this 299*8a1ca2cdSRiver Riddle // position and continue propagation. 300*8a1ca2cdSRiver Riddle } else if (!node) { 301*8a1ca2cdSRiver Riddle // Create a new node at this position and continue 302*8a1ca2cdSRiver Riddle node = std::make_unique<SwitchNode>((*current)->position, 303*8a1ca2cdSRiver Riddle (*current)->question); 304*8a1ca2cdSRiver Riddle propagatePattern( 305*8a1ca2cdSRiver Riddle getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern), 306*8a1ca2cdSRiver Riddle list, std::next(current), end); 307*8a1ca2cdSRiver Riddle 308*8a1ca2cdSRiver Riddle // If the matcher has already been created, and it is for this predicate we 309*8a1ca2cdSRiver Riddle // continue propagation to the child. 310*8a1ca2cdSRiver Riddle } else if (isSamePredicate(node.get(), *current)) { 311*8a1ca2cdSRiver Riddle propagatePattern( 312*8a1ca2cdSRiver Riddle getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern), 313*8a1ca2cdSRiver Riddle list, std::next(current), end); 314*8a1ca2cdSRiver Riddle 315*8a1ca2cdSRiver Riddle // If the matcher doesn't match the current predicate, insert a branch as 316*8a1ca2cdSRiver Riddle // the common set of matchers has diverged. 317*8a1ca2cdSRiver Riddle } else { 318*8a1ca2cdSRiver Riddle propagatePattern(node->getFailureNode(), list, current, end); 319*8a1ca2cdSRiver Riddle } 320*8a1ca2cdSRiver Riddle } 321*8a1ca2cdSRiver Riddle 322*8a1ca2cdSRiver Riddle /// Fold any switch nodes nested under `node` to boolean nodes when possible. 323*8a1ca2cdSRiver Riddle /// `node` is updated in-place if it is a switch. 324*8a1ca2cdSRiver Riddle static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) { 325*8a1ca2cdSRiver Riddle if (!node) 326*8a1ca2cdSRiver Riddle return; 327*8a1ca2cdSRiver Riddle 328*8a1ca2cdSRiver Riddle if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) { 329*8a1ca2cdSRiver Riddle SwitchNode::ChildMapT &children = switchNode->getChildren(); 330*8a1ca2cdSRiver Riddle for (auto &it : children) 331*8a1ca2cdSRiver Riddle foldSwitchToBool(it.second); 332*8a1ca2cdSRiver Riddle 333*8a1ca2cdSRiver Riddle // If the node only contains one child, collapse it into a boolean predicate 334*8a1ca2cdSRiver Riddle // node. 335*8a1ca2cdSRiver Riddle if (children.size() == 1) { 336*8a1ca2cdSRiver Riddle auto childIt = children.begin(); 337*8a1ca2cdSRiver Riddle node = std::make_unique<BoolNode>( 338*8a1ca2cdSRiver Riddle node->getPosition(), node->getQuestion(), childIt->first, 339*8a1ca2cdSRiver Riddle std::move(childIt->second), std::move(node->getFailureNode())); 340*8a1ca2cdSRiver Riddle } 341*8a1ca2cdSRiver Riddle } else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) { 342*8a1ca2cdSRiver Riddle foldSwitchToBool(boolNode->getSuccessNode()); 343*8a1ca2cdSRiver Riddle } 344*8a1ca2cdSRiver Riddle 345*8a1ca2cdSRiver Riddle foldSwitchToBool(node->getFailureNode()); 346*8a1ca2cdSRiver Riddle } 347*8a1ca2cdSRiver Riddle 348*8a1ca2cdSRiver Riddle /// Insert an exit node at the end of the failure path of the `root`. 349*8a1ca2cdSRiver Riddle static void insertExitNode(std::unique_ptr<MatcherNode> *root) { 350*8a1ca2cdSRiver Riddle while (*root) 351*8a1ca2cdSRiver Riddle root = &(*root)->getFailureNode(); 352*8a1ca2cdSRiver Riddle *root = std::make_unique<ExitNode>(); 353*8a1ca2cdSRiver Riddle } 354*8a1ca2cdSRiver Riddle 355*8a1ca2cdSRiver Riddle /// Given a module containing PDL pattern operations, generate a matcher tree 356*8a1ca2cdSRiver Riddle /// using the patterns within the given module and return the root matcher node. 357*8a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> 358*8a1ca2cdSRiver Riddle MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder, 359*8a1ca2cdSRiver Riddle DenseMap<Value, Position *> &valueToPosition) { 360*8a1ca2cdSRiver Riddle // Collect the set of predicates contained within the pattern operations of 361*8a1ca2cdSRiver Riddle // the module. 362*8a1ca2cdSRiver Riddle SmallVector<std::pair<pdl::PatternOp, std::vector<PositionalPredicate>>, 16> 363*8a1ca2cdSRiver Riddle patternsAndPredicates; 364*8a1ca2cdSRiver Riddle for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) { 365*8a1ca2cdSRiver Riddle std::vector<PositionalPredicate> predicateList; 366*8a1ca2cdSRiver Riddle buildPredicateList(pattern, builder, predicateList, valueToPosition); 367*8a1ca2cdSRiver Riddle patternsAndPredicates.emplace_back(pattern, std::move(predicateList)); 368*8a1ca2cdSRiver Riddle } 369*8a1ca2cdSRiver Riddle 370*8a1ca2cdSRiver Riddle // Associate a pattern result with each unique predicate. 371*8a1ca2cdSRiver Riddle DenseSet<OrderedPredicate, OrderedPredicateDenseInfo> uniqued; 372*8a1ca2cdSRiver Riddle for (auto &patternAndPredList : patternsAndPredicates) { 373*8a1ca2cdSRiver Riddle for (auto &predicate : patternAndPredList.second) { 374*8a1ca2cdSRiver Riddle auto it = uniqued.insert(predicate); 375*8a1ca2cdSRiver Riddle it.first->patternToAnswer.try_emplace(patternAndPredList.first, 376*8a1ca2cdSRiver Riddle predicate.answer); 377*8a1ca2cdSRiver Riddle } 378*8a1ca2cdSRiver Riddle } 379*8a1ca2cdSRiver Riddle 380*8a1ca2cdSRiver Riddle // Associate each pattern to a set of its ordered predicates for later lookup. 381*8a1ca2cdSRiver Riddle std::vector<OrderedPredicateList> lists; 382*8a1ca2cdSRiver Riddle lists.reserve(patternsAndPredicates.size()); 383*8a1ca2cdSRiver Riddle for (auto &patternAndPredList : patternsAndPredicates) { 384*8a1ca2cdSRiver Riddle OrderedPredicateList list(patternAndPredList.first); 385*8a1ca2cdSRiver Riddle for (auto &predicate : patternAndPredList.second) { 386*8a1ca2cdSRiver Riddle OrderedPredicate *orderedPredicate = &*uniqued.find(predicate); 387*8a1ca2cdSRiver Riddle list.predicates.insert(orderedPredicate); 388*8a1ca2cdSRiver Riddle 389*8a1ca2cdSRiver Riddle // Increment the primary sum for each reference to a particular predicate. 390*8a1ca2cdSRiver Riddle ++orderedPredicate->primary; 391*8a1ca2cdSRiver Riddle } 392*8a1ca2cdSRiver Riddle lists.push_back(std::move(list)); 393*8a1ca2cdSRiver Riddle } 394*8a1ca2cdSRiver Riddle 395*8a1ca2cdSRiver Riddle // For a particular pattern, get the total primary sum and add it to the 396*8a1ca2cdSRiver Riddle // secondary sum of each predicate. Square the primary sums to emphasize 397*8a1ca2cdSRiver Riddle // shared predicates within rather than across patterns. 398*8a1ca2cdSRiver Riddle for (auto &list : lists) { 399*8a1ca2cdSRiver Riddle unsigned total = 0; 400*8a1ca2cdSRiver Riddle for (auto *predicate : list.predicates) 401*8a1ca2cdSRiver Riddle total += predicate->primary * predicate->primary; 402*8a1ca2cdSRiver Riddle for (auto *predicate : list.predicates) 403*8a1ca2cdSRiver Riddle predicate->secondary += total; 404*8a1ca2cdSRiver Riddle } 405*8a1ca2cdSRiver Riddle 406*8a1ca2cdSRiver Riddle // Sort the set of predicates now that the cost primary and secondary sums 407*8a1ca2cdSRiver Riddle // have been computed. 408*8a1ca2cdSRiver Riddle std::vector<OrderedPredicate *> ordered; 409*8a1ca2cdSRiver Riddle ordered.reserve(uniqued.size()); 410*8a1ca2cdSRiver Riddle for (auto &ip : uniqued) 411*8a1ca2cdSRiver Riddle ordered.push_back(&ip); 412*8a1ca2cdSRiver Riddle std::stable_sort( 413*8a1ca2cdSRiver Riddle ordered.begin(), ordered.end(), 414*8a1ca2cdSRiver Riddle [](OrderedPredicate *lhs, OrderedPredicate *rhs) { return *lhs < *rhs; }); 415*8a1ca2cdSRiver Riddle 416*8a1ca2cdSRiver Riddle // Build the matchers for each of the pattern predicate lists. 417*8a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> root; 418*8a1ca2cdSRiver Riddle for (OrderedPredicateList &list : lists) 419*8a1ca2cdSRiver Riddle propagatePattern(root, list, ordered.begin(), ordered.end()); 420*8a1ca2cdSRiver Riddle 421*8a1ca2cdSRiver Riddle // Collapse the graph and insert the exit node. 422*8a1ca2cdSRiver Riddle foldSwitchToBool(root); 423*8a1ca2cdSRiver Riddle insertExitNode(&root); 424*8a1ca2cdSRiver Riddle return root; 425*8a1ca2cdSRiver Riddle } 426*8a1ca2cdSRiver Riddle 427*8a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 428*8a1ca2cdSRiver Riddle // MatcherNode 429*8a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 430*8a1ca2cdSRiver Riddle 431*8a1ca2cdSRiver Riddle MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q, 432*8a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> failureNode) 433*8a1ca2cdSRiver Riddle : position(p), question(q), failureNode(std::move(failureNode)), 434*8a1ca2cdSRiver Riddle matcherTypeID(matcherTypeID) {} 435*8a1ca2cdSRiver Riddle 436*8a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 437*8a1ca2cdSRiver Riddle // BoolNode 438*8a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 439*8a1ca2cdSRiver Riddle 440*8a1ca2cdSRiver Riddle BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer, 441*8a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> successNode, 442*8a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> failureNode) 443*8a1ca2cdSRiver Riddle : MatcherNode(TypeID::get<BoolNode>(), position, question, 444*8a1ca2cdSRiver Riddle std::move(failureNode)), 445*8a1ca2cdSRiver Riddle answer(answer), successNode(std::move(successNode)) {} 446*8a1ca2cdSRiver Riddle 447*8a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 448*8a1ca2cdSRiver Riddle // SuccessNode 449*8a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 450*8a1ca2cdSRiver Riddle 451*8a1ca2cdSRiver Riddle SuccessNode::SuccessNode(pdl::PatternOp pattern, 452*8a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> failureNode) 453*8a1ca2cdSRiver Riddle : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr, 454*8a1ca2cdSRiver Riddle /*question=*/nullptr, std::move(failureNode)), 455*8a1ca2cdSRiver Riddle pattern(pattern) {} 456*8a1ca2cdSRiver Riddle 457*8a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 458*8a1ca2cdSRiver Riddle // SwitchNode 459*8a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===// 460*8a1ca2cdSRiver Riddle 461*8a1ca2cdSRiver Riddle SwitchNode::SwitchNode(Position *position, Qualifier *question) 462*8a1ca2cdSRiver Riddle : MatcherNode(TypeID::get<SwitchNode>(), position, question) {} 463