1*8a1ca2cdSRiver Riddle //===- PredicateTree.cpp - Predicate tree merging -------------------------===//
2*8a1ca2cdSRiver Riddle //
3*8a1ca2cdSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*8a1ca2cdSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5*8a1ca2cdSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*8a1ca2cdSRiver Riddle //
7*8a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
8*8a1ca2cdSRiver Riddle 
9*8a1ca2cdSRiver Riddle #include "PredicateTree.h"
10*8a1ca2cdSRiver Riddle #include "mlir/Dialect/PDL/IR/PDL.h"
11*8a1ca2cdSRiver Riddle #include "mlir/Dialect/PDL/IR/PDLTypes.h"
12*8a1ca2cdSRiver Riddle #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
13*8a1ca2cdSRiver Riddle #include "mlir/IR/Module.h"
14*8a1ca2cdSRiver Riddle #include "mlir/Interfaces/InferTypeOpInterface.h"
15*8a1ca2cdSRiver Riddle 
16*8a1ca2cdSRiver Riddle using namespace mlir;
17*8a1ca2cdSRiver Riddle using namespace mlir::pdl_to_pdl_interp;
18*8a1ca2cdSRiver Riddle 
19*8a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
20*8a1ca2cdSRiver Riddle // Predicate List Building
21*8a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
22*8a1ca2cdSRiver Riddle 
23*8a1ca2cdSRiver Riddle /// Compares the depths of two positions.
24*8a1ca2cdSRiver Riddle static bool comparePosDepth(Position *lhs, Position *rhs) {
25*8a1ca2cdSRiver Riddle   return lhs->getIndex().size() < rhs->getIndex().size();
26*8a1ca2cdSRiver Riddle }
27*8a1ca2cdSRiver Riddle 
28*8a1ca2cdSRiver Riddle /// Collect the tree predicates anchored at the given value.
29*8a1ca2cdSRiver Riddle static void getTreePredicates(std::vector<PositionalPredicate> &predList,
30*8a1ca2cdSRiver Riddle                               Value val, PredicateBuilder &builder,
31*8a1ca2cdSRiver Riddle                               DenseMap<Value, Position *> &inputs,
32*8a1ca2cdSRiver Riddle                               Position *pos) {
33*8a1ca2cdSRiver Riddle   // Make sure this input value is accessible to the rewrite.
34*8a1ca2cdSRiver Riddle   auto it = inputs.try_emplace(val, pos);
35*8a1ca2cdSRiver Riddle 
36*8a1ca2cdSRiver Riddle   // If this is an input value that has been visited in the tree, add a
37*8a1ca2cdSRiver Riddle   // constraint to ensure that both instances refer to the same value.
38*8a1ca2cdSRiver Riddle   if (!it.second &&
39*8a1ca2cdSRiver Riddle       isa<pdl::AttributeOp, pdl::InputOp, pdl::TypeOp>(val.getDefiningOp())) {
40*8a1ca2cdSRiver Riddle     auto minMaxPositions = std::minmax(pos, it.first->second, comparePosDepth);
41*8a1ca2cdSRiver Riddle     predList.emplace_back(minMaxPositions.second,
42*8a1ca2cdSRiver Riddle                           builder.getEqualTo(minMaxPositions.first));
43*8a1ca2cdSRiver Riddle     return;
44*8a1ca2cdSRiver Riddle   }
45*8a1ca2cdSRiver Riddle 
46*8a1ca2cdSRiver Riddle   // Check for a per-position predicate to apply.
47*8a1ca2cdSRiver Riddle   switch (pos->getKind()) {
48*8a1ca2cdSRiver Riddle   case Predicates::AttributePos: {
49*8a1ca2cdSRiver Riddle     assert(val.getType().isa<pdl::AttributeType>() &&
50*8a1ca2cdSRiver Riddle            "expected attribute type");
51*8a1ca2cdSRiver Riddle     pdl::AttributeOp attr = cast<pdl::AttributeOp>(val.getDefiningOp());
52*8a1ca2cdSRiver Riddle     predList.emplace_back(pos, builder.getIsNotNull());
53*8a1ca2cdSRiver Riddle 
54*8a1ca2cdSRiver Riddle     // If the attribute has a type, add a type constraint.
55*8a1ca2cdSRiver Riddle     if (Value type = attr.type()) {
56*8a1ca2cdSRiver Riddle       getTreePredicates(predList, type, builder, inputs, builder.getType(pos));
57*8a1ca2cdSRiver Riddle 
58*8a1ca2cdSRiver Riddle       // Check for a constant value of the attribute.
59*8a1ca2cdSRiver Riddle     } else if (Optional<Attribute> value = attr.value()) {
60*8a1ca2cdSRiver Riddle       predList.emplace_back(pos, builder.getAttributeConstraint(*value));
61*8a1ca2cdSRiver Riddle     }
62*8a1ca2cdSRiver Riddle     break;
63*8a1ca2cdSRiver Riddle   }
64*8a1ca2cdSRiver Riddle   case Predicates::OperandPos: {
65*8a1ca2cdSRiver Riddle     assert(val.getType().isa<pdl::ValueType>() && "expected value type");
66*8a1ca2cdSRiver Riddle 
67*8a1ca2cdSRiver Riddle     // Prevent traversal into a null value.
68*8a1ca2cdSRiver Riddle     predList.emplace_back(pos, builder.getIsNotNull());
69*8a1ca2cdSRiver Riddle 
70*8a1ca2cdSRiver Riddle     // If this is a typed input, add a type constraint.
71*8a1ca2cdSRiver Riddle     if (auto in = val.getDefiningOp<pdl::InputOp>()) {
72*8a1ca2cdSRiver Riddle       if (Value type = in.type()) {
73*8a1ca2cdSRiver Riddle         getTreePredicates(predList, type, builder, inputs,
74*8a1ca2cdSRiver Riddle                           builder.getType(pos));
75*8a1ca2cdSRiver Riddle       }
76*8a1ca2cdSRiver Riddle 
77*8a1ca2cdSRiver Riddle       // Otherwise, recurse into the parent node.
78*8a1ca2cdSRiver Riddle     } else if (auto parentOp = val.getDefiningOp<pdl::OperationOp>()) {
79*8a1ca2cdSRiver Riddle       getTreePredicates(predList, parentOp.op(), builder, inputs,
80*8a1ca2cdSRiver Riddle                         builder.getParent(cast<OperandPosition>(pos)));
81*8a1ca2cdSRiver Riddle     }
82*8a1ca2cdSRiver Riddle     break;
83*8a1ca2cdSRiver Riddle   }
84*8a1ca2cdSRiver Riddle   case Predicates::OperationPos: {
85*8a1ca2cdSRiver Riddle     assert(val.getType().isa<pdl::OperationType>() && "expected operation");
86*8a1ca2cdSRiver Riddle     pdl::OperationOp op = cast<pdl::OperationOp>(val.getDefiningOp());
87*8a1ca2cdSRiver Riddle     OperationPosition *opPos = cast<OperationPosition>(pos);
88*8a1ca2cdSRiver Riddle 
89*8a1ca2cdSRiver Riddle     // Ensure getDefiningOp returns a non-null operation.
90*8a1ca2cdSRiver Riddle     if (!opPos->isRoot())
91*8a1ca2cdSRiver Riddle       predList.emplace_back(pos, builder.getIsNotNull());
92*8a1ca2cdSRiver Riddle 
93*8a1ca2cdSRiver Riddle     // Check that this is the correct root operation.
94*8a1ca2cdSRiver Riddle     if (Optional<StringRef> opName = op.name())
95*8a1ca2cdSRiver Riddle       predList.emplace_back(pos, builder.getOperationName(*opName));
96*8a1ca2cdSRiver Riddle 
97*8a1ca2cdSRiver Riddle     // Check that the operation has the proper number of operands and results.
98*8a1ca2cdSRiver Riddle     OperandRange operands = op.operands();
99*8a1ca2cdSRiver Riddle     ResultRange results = op.results();
100*8a1ca2cdSRiver Riddle     predList.emplace_back(pos, builder.getOperandCount(operands.size()));
101*8a1ca2cdSRiver Riddle     predList.emplace_back(pos, builder.getResultCount(results.size()));
102*8a1ca2cdSRiver Riddle 
103*8a1ca2cdSRiver Riddle     // Recurse into any attributes, operands, or results.
104*8a1ca2cdSRiver Riddle     for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
105*8a1ca2cdSRiver Riddle       getTreePredicates(
106*8a1ca2cdSRiver Riddle           predList, std::get<1>(it), builder, inputs,
107*8a1ca2cdSRiver Riddle           builder.getAttribute(opPos,
108*8a1ca2cdSRiver Riddle                                std::get<0>(it).cast<StringAttr>().getValue()));
109*8a1ca2cdSRiver Riddle     }
110*8a1ca2cdSRiver Riddle     for (auto operandIt : llvm::enumerate(operands))
111*8a1ca2cdSRiver Riddle       getTreePredicates(predList, operandIt.value(), builder, inputs,
112*8a1ca2cdSRiver Riddle                         builder.getOperand(opPos, operandIt.index()));
113*8a1ca2cdSRiver Riddle 
114*8a1ca2cdSRiver Riddle     // Only recurse into results that are not referenced in the source tree.
115*8a1ca2cdSRiver Riddle     for (auto resultIt : llvm::enumerate(results)) {
116*8a1ca2cdSRiver Riddle       getTreePredicates(predList, resultIt.value(), builder, inputs,
117*8a1ca2cdSRiver Riddle                         builder.getResult(opPos, resultIt.index()));
118*8a1ca2cdSRiver Riddle     }
119*8a1ca2cdSRiver Riddle     break;
120*8a1ca2cdSRiver Riddle   }
121*8a1ca2cdSRiver Riddle   case Predicates::ResultPos: {
122*8a1ca2cdSRiver Riddle     assert(val.getType().isa<pdl::ValueType>() && "expected value type");
123*8a1ca2cdSRiver Riddle     pdl::OperationOp parentOp = cast<pdl::OperationOp>(val.getDefiningOp());
124*8a1ca2cdSRiver Riddle 
125*8a1ca2cdSRiver Riddle     // Prevent traversing a null value.
126*8a1ca2cdSRiver Riddle     predList.emplace_back(pos, builder.getIsNotNull());
127*8a1ca2cdSRiver Riddle 
128*8a1ca2cdSRiver Riddle     // Traverse the type constraint.
129*8a1ca2cdSRiver Riddle     unsigned resultNo = cast<ResultPosition>(pos)->getResultNumber();
130*8a1ca2cdSRiver Riddle     getTreePredicates(predList, parentOp.types()[resultNo], builder, inputs,
131*8a1ca2cdSRiver Riddle                       builder.getType(pos));
132*8a1ca2cdSRiver Riddle     break;
133*8a1ca2cdSRiver Riddle   }
134*8a1ca2cdSRiver Riddle   case Predicates::TypePos: {
135*8a1ca2cdSRiver Riddle     assert(val.getType().isa<pdl::TypeType>() && "expected value type");
136*8a1ca2cdSRiver Riddle     pdl::TypeOp typeOp = cast<pdl::TypeOp>(val.getDefiningOp());
137*8a1ca2cdSRiver Riddle 
138*8a1ca2cdSRiver Riddle     // Check for a constraint on a constant type.
139*8a1ca2cdSRiver Riddle     if (Optional<Type> type = typeOp.type())
140*8a1ca2cdSRiver Riddle       predList.emplace_back(pos, builder.getTypeConstraint(*type));
141*8a1ca2cdSRiver Riddle     break;
142*8a1ca2cdSRiver Riddle   }
143*8a1ca2cdSRiver Riddle   default:
144*8a1ca2cdSRiver Riddle     llvm_unreachable("unknown position kind");
145*8a1ca2cdSRiver Riddle   }
146*8a1ca2cdSRiver Riddle }
147*8a1ca2cdSRiver Riddle 
148*8a1ca2cdSRiver Riddle /// Collect all of the predicates related to constraints within the given
149*8a1ca2cdSRiver Riddle /// pattern operation.
150*8a1ca2cdSRiver Riddle static void collectConstraintPredicates(
151*8a1ca2cdSRiver Riddle     pdl::PatternOp pattern, std::vector<PositionalPredicate> &predList,
152*8a1ca2cdSRiver Riddle     PredicateBuilder &builder, DenseMap<Value, Position *> &inputs) {
153*8a1ca2cdSRiver Riddle   for (auto op : pattern.body().getOps<pdl::ApplyConstraintOp>()) {
154*8a1ca2cdSRiver Riddle     OperandRange arguments = op.args();
155*8a1ca2cdSRiver Riddle     ArrayAttr parameters = op.constParamsAttr();
156*8a1ca2cdSRiver Riddle 
157*8a1ca2cdSRiver Riddle     std::vector<Position *> allPositions;
158*8a1ca2cdSRiver Riddle     allPositions.reserve(arguments.size());
159*8a1ca2cdSRiver Riddle     for (Value arg : arguments)
160*8a1ca2cdSRiver Riddle       allPositions.push_back(inputs.lookup(arg));
161*8a1ca2cdSRiver Riddle 
162*8a1ca2cdSRiver Riddle     // Push the constraint to the furthest position.
163*8a1ca2cdSRiver Riddle     Position *pos = *std::max_element(allPositions.begin(), allPositions.end(),
164*8a1ca2cdSRiver Riddle                                       comparePosDepth);
165*8a1ca2cdSRiver Riddle     PredicateBuilder::Predicate pred =
166*8a1ca2cdSRiver Riddle         builder.getConstraint(op.name(), std::move(allPositions), parameters);
167*8a1ca2cdSRiver Riddle     predList.emplace_back(pos, pred);
168*8a1ca2cdSRiver Riddle   }
169*8a1ca2cdSRiver Riddle }
170*8a1ca2cdSRiver Riddle 
171*8a1ca2cdSRiver Riddle /// Given a pattern operation, build the set of matcher predicates necessary to
172*8a1ca2cdSRiver Riddle /// match this pattern.
173*8a1ca2cdSRiver Riddle static void buildPredicateList(pdl::PatternOp pattern,
174*8a1ca2cdSRiver Riddle                                PredicateBuilder &builder,
175*8a1ca2cdSRiver Riddle                                std::vector<PositionalPredicate> &predList,
176*8a1ca2cdSRiver Riddle                                DenseMap<Value, Position *> &valueToPosition) {
177*8a1ca2cdSRiver Riddle   getTreePredicates(predList, pattern.getRewriter().root(), builder,
178*8a1ca2cdSRiver Riddle                     valueToPosition, builder.getRoot());
179*8a1ca2cdSRiver Riddle   collectConstraintPredicates(pattern, predList, builder, valueToPosition);
180*8a1ca2cdSRiver Riddle }
181*8a1ca2cdSRiver Riddle 
182*8a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
183*8a1ca2cdSRiver Riddle // Pattern Predicate Tree Merging
184*8a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
185*8a1ca2cdSRiver Riddle 
186*8a1ca2cdSRiver Riddle namespace {
187*8a1ca2cdSRiver Riddle 
188*8a1ca2cdSRiver Riddle /// This class represents a specific predicate applied to a position, and
189*8a1ca2cdSRiver Riddle /// provides hashing and ordering operators. This class allows for computing a
190*8a1ca2cdSRiver Riddle /// frequence sum and ordering predicates based on a cost model.
191*8a1ca2cdSRiver Riddle struct OrderedPredicate {
192*8a1ca2cdSRiver Riddle   OrderedPredicate(const std::pair<Position *, Qualifier *> &ip)
193*8a1ca2cdSRiver Riddle       : position(ip.first), question(ip.second) {}
194*8a1ca2cdSRiver Riddle   OrderedPredicate(const PositionalPredicate &ip)
195*8a1ca2cdSRiver Riddle       : position(ip.position), question(ip.question) {}
196*8a1ca2cdSRiver Riddle 
197*8a1ca2cdSRiver Riddle   /// The position this predicate is applied to.
198*8a1ca2cdSRiver Riddle   Position *position;
199*8a1ca2cdSRiver Riddle 
200*8a1ca2cdSRiver Riddle   /// The question that is applied by this predicate onto the position.
201*8a1ca2cdSRiver Riddle   Qualifier *question;
202*8a1ca2cdSRiver Riddle 
203*8a1ca2cdSRiver Riddle   /// The first and second order benefit sums.
204*8a1ca2cdSRiver Riddle   /// The primary sum is the number of occurrences of this predicate among all
205*8a1ca2cdSRiver Riddle   /// of the patterns.
206*8a1ca2cdSRiver Riddle   unsigned primary = 0;
207*8a1ca2cdSRiver Riddle   /// The secondary sum is a squared summation of the primary sum of all of the
208*8a1ca2cdSRiver Riddle   /// predicates within each pattern that contains this predicate. This allows
209*8a1ca2cdSRiver Riddle   /// for favoring predicates that are more commonly shared within a pattern, as
210*8a1ca2cdSRiver Riddle   /// opposed to those shared across patterns.
211*8a1ca2cdSRiver Riddle   unsigned secondary = 0;
212*8a1ca2cdSRiver Riddle 
213*8a1ca2cdSRiver Riddle   /// A map between a pattern operation and the answer to the predicate question
214*8a1ca2cdSRiver Riddle   /// within that pattern.
215*8a1ca2cdSRiver Riddle   DenseMap<Operation *, Qualifier *> patternToAnswer;
216*8a1ca2cdSRiver Riddle 
217*8a1ca2cdSRiver Riddle   /// Returns true if this predicate is ordered before `other`, based on the
218*8a1ca2cdSRiver Riddle   /// cost model.
219*8a1ca2cdSRiver Riddle   bool operator<(const OrderedPredicate &other) const {
220*8a1ca2cdSRiver Riddle     // Sort by:
221*8a1ca2cdSRiver Riddle     // * first and secondary order sums
222*8a1ca2cdSRiver Riddle     // * lower depth
223*8a1ca2cdSRiver Riddle     // * position dependency
224*8a1ca2cdSRiver Riddle     // * predicate dependency.
225*8a1ca2cdSRiver Riddle     auto *otherPos = other.position;
226*8a1ca2cdSRiver Riddle     return std::make_tuple(other.primary, other.secondary,
227*8a1ca2cdSRiver Riddle                            otherPos->getIndex().size(), otherPos->getKind(),
228*8a1ca2cdSRiver Riddle                            other.question->getKind()) >
229*8a1ca2cdSRiver Riddle            std::make_tuple(primary, secondary, position->getIndex().size(),
230*8a1ca2cdSRiver Riddle                            position->getKind(), question->getKind());
231*8a1ca2cdSRiver Riddle   }
232*8a1ca2cdSRiver Riddle };
233*8a1ca2cdSRiver Riddle 
234*8a1ca2cdSRiver Riddle /// A DenseMapInfo for OrderedPredicate based solely on the position and
235*8a1ca2cdSRiver Riddle /// question.
236*8a1ca2cdSRiver Riddle struct OrderedPredicateDenseInfo {
237*8a1ca2cdSRiver Riddle   using Base = DenseMapInfo<std::pair<Position *, Qualifier *>>;
238*8a1ca2cdSRiver Riddle 
239*8a1ca2cdSRiver Riddle   static OrderedPredicate getEmptyKey() { return Base::getEmptyKey(); }
240*8a1ca2cdSRiver Riddle   static OrderedPredicate getTombstoneKey() { return Base::getTombstoneKey(); }
241*8a1ca2cdSRiver Riddle   static bool isEqual(const OrderedPredicate &lhs,
242*8a1ca2cdSRiver Riddle                       const OrderedPredicate &rhs) {
243*8a1ca2cdSRiver Riddle     return lhs.position == rhs.position && lhs.question == rhs.question;
244*8a1ca2cdSRiver Riddle   }
245*8a1ca2cdSRiver Riddle   static unsigned getHashValue(const OrderedPredicate &p) {
246*8a1ca2cdSRiver Riddle     return llvm::hash_combine(p.position, p.question);
247*8a1ca2cdSRiver Riddle   }
248*8a1ca2cdSRiver Riddle };
249*8a1ca2cdSRiver Riddle 
250*8a1ca2cdSRiver Riddle /// This class wraps a set of ordered predicates that are used within a specific
251*8a1ca2cdSRiver Riddle /// pattern operation.
252*8a1ca2cdSRiver Riddle struct OrderedPredicateList {
253*8a1ca2cdSRiver Riddle   OrderedPredicateList(pdl::PatternOp pattern) : pattern(pattern) {}
254*8a1ca2cdSRiver Riddle 
255*8a1ca2cdSRiver Riddle   pdl::PatternOp pattern;
256*8a1ca2cdSRiver Riddle   DenseSet<OrderedPredicate *> predicates;
257*8a1ca2cdSRiver Riddle };
258*8a1ca2cdSRiver Riddle } // end anonymous namespace
259*8a1ca2cdSRiver Riddle 
260*8a1ca2cdSRiver Riddle /// Returns true if the given matcher refers to the same predicate as the given
261*8a1ca2cdSRiver Riddle /// ordered predicate. This means that the position and questions of the two
262*8a1ca2cdSRiver Riddle /// match.
263*8a1ca2cdSRiver Riddle static bool isSamePredicate(MatcherNode *node, OrderedPredicate *predicate) {
264*8a1ca2cdSRiver Riddle   return node->getPosition() == predicate->position &&
265*8a1ca2cdSRiver Riddle          node->getQuestion() == predicate->question;
266*8a1ca2cdSRiver Riddle }
267*8a1ca2cdSRiver Riddle 
268*8a1ca2cdSRiver Riddle /// Get or insert a child matcher for the given parent switch node, given a
269*8a1ca2cdSRiver Riddle /// predicate and parent pattern.
270*8a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode> &getOrCreateChild(SwitchNode *node,
271*8a1ca2cdSRiver Riddle                                                OrderedPredicate *predicate,
272*8a1ca2cdSRiver Riddle                                                pdl::PatternOp pattern) {
273*8a1ca2cdSRiver Riddle   assert(isSamePredicate(node, predicate) &&
274*8a1ca2cdSRiver Riddle          "expected matcher to equal the given predicate");
275*8a1ca2cdSRiver Riddle 
276*8a1ca2cdSRiver Riddle   auto it = predicate->patternToAnswer.find(pattern);
277*8a1ca2cdSRiver Riddle   assert(it != predicate->patternToAnswer.end() &&
278*8a1ca2cdSRiver Riddle          "expected pattern to exist in predicate");
279*8a1ca2cdSRiver Riddle   return node->getChildren().insert({it->second, nullptr}).first->second;
280*8a1ca2cdSRiver Riddle }
281*8a1ca2cdSRiver Riddle 
282*8a1ca2cdSRiver Riddle /// Build the matcher CFG by "pushing" patterns through by sorted predicate
283*8a1ca2cdSRiver Riddle /// order. A pattern will traverse as far as possible using common predicates
284*8a1ca2cdSRiver Riddle /// and then either diverge from the CFG or reach the end of a branch and start
285*8a1ca2cdSRiver Riddle /// creating new nodes.
286*8a1ca2cdSRiver Riddle static void propagatePattern(std::unique_ptr<MatcherNode> &node,
287*8a1ca2cdSRiver Riddle                              OrderedPredicateList &list,
288*8a1ca2cdSRiver Riddle                              std::vector<OrderedPredicate *>::iterator current,
289*8a1ca2cdSRiver Riddle                              std::vector<OrderedPredicate *>::iterator end) {
290*8a1ca2cdSRiver Riddle   if (current == end) {
291*8a1ca2cdSRiver Riddle     // We've hit the end of a pattern, so create a successful result node.
292*8a1ca2cdSRiver Riddle     node = std::make_unique<SuccessNode>(list.pattern, std::move(node));
293*8a1ca2cdSRiver Riddle 
294*8a1ca2cdSRiver Riddle     // If the pattern doesn't contain this predicate, ignore it.
295*8a1ca2cdSRiver Riddle   } else if (list.predicates.find(*current) == list.predicates.end()) {
296*8a1ca2cdSRiver Riddle     propagatePattern(node, list, std::next(current), end);
297*8a1ca2cdSRiver Riddle 
298*8a1ca2cdSRiver Riddle     // If the current matcher node is invalid, create a new one for this
299*8a1ca2cdSRiver Riddle     // position and continue propagation.
300*8a1ca2cdSRiver Riddle   } else if (!node) {
301*8a1ca2cdSRiver Riddle     // Create a new node at this position and continue
302*8a1ca2cdSRiver Riddle     node = std::make_unique<SwitchNode>((*current)->position,
303*8a1ca2cdSRiver Riddle                                         (*current)->question);
304*8a1ca2cdSRiver Riddle     propagatePattern(
305*8a1ca2cdSRiver Riddle         getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
306*8a1ca2cdSRiver Riddle         list, std::next(current), end);
307*8a1ca2cdSRiver Riddle 
308*8a1ca2cdSRiver Riddle     // If the matcher has already been created, and it is for this predicate we
309*8a1ca2cdSRiver Riddle     // continue propagation to the child.
310*8a1ca2cdSRiver Riddle   } else if (isSamePredicate(node.get(), *current)) {
311*8a1ca2cdSRiver Riddle     propagatePattern(
312*8a1ca2cdSRiver Riddle         getOrCreateChild(cast<SwitchNode>(&*node), *current, list.pattern),
313*8a1ca2cdSRiver Riddle         list, std::next(current), end);
314*8a1ca2cdSRiver Riddle 
315*8a1ca2cdSRiver Riddle     // If the matcher doesn't match the current predicate, insert a branch as
316*8a1ca2cdSRiver Riddle     // the common set of matchers has diverged.
317*8a1ca2cdSRiver Riddle   } else {
318*8a1ca2cdSRiver Riddle     propagatePattern(node->getFailureNode(), list, current, end);
319*8a1ca2cdSRiver Riddle   }
320*8a1ca2cdSRiver Riddle }
321*8a1ca2cdSRiver Riddle 
322*8a1ca2cdSRiver Riddle /// Fold any switch nodes nested under `node` to boolean nodes when possible.
323*8a1ca2cdSRiver Riddle /// `node` is updated in-place if it is a switch.
324*8a1ca2cdSRiver Riddle static void foldSwitchToBool(std::unique_ptr<MatcherNode> &node) {
325*8a1ca2cdSRiver Riddle   if (!node)
326*8a1ca2cdSRiver Riddle     return;
327*8a1ca2cdSRiver Riddle 
328*8a1ca2cdSRiver Riddle   if (SwitchNode *switchNode = dyn_cast<SwitchNode>(&*node)) {
329*8a1ca2cdSRiver Riddle     SwitchNode::ChildMapT &children = switchNode->getChildren();
330*8a1ca2cdSRiver Riddle     for (auto &it : children)
331*8a1ca2cdSRiver Riddle       foldSwitchToBool(it.second);
332*8a1ca2cdSRiver Riddle 
333*8a1ca2cdSRiver Riddle     // If the node only contains one child, collapse it into a boolean predicate
334*8a1ca2cdSRiver Riddle     // node.
335*8a1ca2cdSRiver Riddle     if (children.size() == 1) {
336*8a1ca2cdSRiver Riddle       auto childIt = children.begin();
337*8a1ca2cdSRiver Riddle       node = std::make_unique<BoolNode>(
338*8a1ca2cdSRiver Riddle           node->getPosition(), node->getQuestion(), childIt->first,
339*8a1ca2cdSRiver Riddle           std::move(childIt->second), std::move(node->getFailureNode()));
340*8a1ca2cdSRiver Riddle     }
341*8a1ca2cdSRiver Riddle   } else if (BoolNode *boolNode = dyn_cast<BoolNode>(&*node)) {
342*8a1ca2cdSRiver Riddle     foldSwitchToBool(boolNode->getSuccessNode());
343*8a1ca2cdSRiver Riddle   }
344*8a1ca2cdSRiver Riddle 
345*8a1ca2cdSRiver Riddle   foldSwitchToBool(node->getFailureNode());
346*8a1ca2cdSRiver Riddle }
347*8a1ca2cdSRiver Riddle 
348*8a1ca2cdSRiver Riddle /// Insert an exit node at the end of the failure path of the `root`.
349*8a1ca2cdSRiver Riddle static void insertExitNode(std::unique_ptr<MatcherNode> *root) {
350*8a1ca2cdSRiver Riddle   while (*root)
351*8a1ca2cdSRiver Riddle     root = &(*root)->getFailureNode();
352*8a1ca2cdSRiver Riddle   *root = std::make_unique<ExitNode>();
353*8a1ca2cdSRiver Riddle }
354*8a1ca2cdSRiver Riddle 
355*8a1ca2cdSRiver Riddle /// Given a module containing PDL pattern operations, generate a matcher tree
356*8a1ca2cdSRiver Riddle /// using the patterns within the given module and return the root matcher node.
357*8a1ca2cdSRiver Riddle std::unique_ptr<MatcherNode>
358*8a1ca2cdSRiver Riddle MatcherNode::generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
359*8a1ca2cdSRiver Riddle                                  DenseMap<Value, Position *> &valueToPosition) {
360*8a1ca2cdSRiver Riddle   // Collect the set of predicates contained within the pattern operations of
361*8a1ca2cdSRiver Riddle   // the module.
362*8a1ca2cdSRiver Riddle   SmallVector<std::pair<pdl::PatternOp, std::vector<PositionalPredicate>>, 16>
363*8a1ca2cdSRiver Riddle       patternsAndPredicates;
364*8a1ca2cdSRiver Riddle   for (pdl::PatternOp pattern : module.getOps<pdl::PatternOp>()) {
365*8a1ca2cdSRiver Riddle     std::vector<PositionalPredicate> predicateList;
366*8a1ca2cdSRiver Riddle     buildPredicateList(pattern, builder, predicateList, valueToPosition);
367*8a1ca2cdSRiver Riddle     patternsAndPredicates.emplace_back(pattern, std::move(predicateList));
368*8a1ca2cdSRiver Riddle   }
369*8a1ca2cdSRiver Riddle 
370*8a1ca2cdSRiver Riddle   // Associate a pattern result with each unique predicate.
371*8a1ca2cdSRiver Riddle   DenseSet<OrderedPredicate, OrderedPredicateDenseInfo> uniqued;
372*8a1ca2cdSRiver Riddle   for (auto &patternAndPredList : patternsAndPredicates) {
373*8a1ca2cdSRiver Riddle     for (auto &predicate : patternAndPredList.second) {
374*8a1ca2cdSRiver Riddle       auto it = uniqued.insert(predicate);
375*8a1ca2cdSRiver Riddle       it.first->patternToAnswer.try_emplace(patternAndPredList.first,
376*8a1ca2cdSRiver Riddle                                             predicate.answer);
377*8a1ca2cdSRiver Riddle     }
378*8a1ca2cdSRiver Riddle   }
379*8a1ca2cdSRiver Riddle 
380*8a1ca2cdSRiver Riddle   // Associate each pattern to a set of its ordered predicates for later lookup.
381*8a1ca2cdSRiver Riddle   std::vector<OrderedPredicateList> lists;
382*8a1ca2cdSRiver Riddle   lists.reserve(patternsAndPredicates.size());
383*8a1ca2cdSRiver Riddle   for (auto &patternAndPredList : patternsAndPredicates) {
384*8a1ca2cdSRiver Riddle     OrderedPredicateList list(patternAndPredList.first);
385*8a1ca2cdSRiver Riddle     for (auto &predicate : patternAndPredList.second) {
386*8a1ca2cdSRiver Riddle       OrderedPredicate *orderedPredicate = &*uniqued.find(predicate);
387*8a1ca2cdSRiver Riddle       list.predicates.insert(orderedPredicate);
388*8a1ca2cdSRiver Riddle 
389*8a1ca2cdSRiver Riddle       // Increment the primary sum for each reference to a particular predicate.
390*8a1ca2cdSRiver Riddle       ++orderedPredicate->primary;
391*8a1ca2cdSRiver Riddle     }
392*8a1ca2cdSRiver Riddle     lists.push_back(std::move(list));
393*8a1ca2cdSRiver Riddle   }
394*8a1ca2cdSRiver Riddle 
395*8a1ca2cdSRiver Riddle   // For a particular pattern, get the total primary sum and add it to the
396*8a1ca2cdSRiver Riddle   // secondary sum of each predicate. Square the primary sums to emphasize
397*8a1ca2cdSRiver Riddle   // shared predicates within rather than across patterns.
398*8a1ca2cdSRiver Riddle   for (auto &list : lists) {
399*8a1ca2cdSRiver Riddle     unsigned total = 0;
400*8a1ca2cdSRiver Riddle     for (auto *predicate : list.predicates)
401*8a1ca2cdSRiver Riddle       total += predicate->primary * predicate->primary;
402*8a1ca2cdSRiver Riddle     for (auto *predicate : list.predicates)
403*8a1ca2cdSRiver Riddle       predicate->secondary += total;
404*8a1ca2cdSRiver Riddle   }
405*8a1ca2cdSRiver Riddle 
406*8a1ca2cdSRiver Riddle   // Sort the set of predicates now that the cost primary and secondary sums
407*8a1ca2cdSRiver Riddle   // have been computed.
408*8a1ca2cdSRiver Riddle   std::vector<OrderedPredicate *> ordered;
409*8a1ca2cdSRiver Riddle   ordered.reserve(uniqued.size());
410*8a1ca2cdSRiver Riddle   for (auto &ip : uniqued)
411*8a1ca2cdSRiver Riddle     ordered.push_back(&ip);
412*8a1ca2cdSRiver Riddle   std::stable_sort(
413*8a1ca2cdSRiver Riddle       ordered.begin(), ordered.end(),
414*8a1ca2cdSRiver Riddle       [](OrderedPredicate *lhs, OrderedPredicate *rhs) { return *lhs < *rhs; });
415*8a1ca2cdSRiver Riddle 
416*8a1ca2cdSRiver Riddle   // Build the matchers for each of the pattern predicate lists.
417*8a1ca2cdSRiver Riddle   std::unique_ptr<MatcherNode> root;
418*8a1ca2cdSRiver Riddle   for (OrderedPredicateList &list : lists)
419*8a1ca2cdSRiver Riddle     propagatePattern(root, list, ordered.begin(), ordered.end());
420*8a1ca2cdSRiver Riddle 
421*8a1ca2cdSRiver Riddle   // Collapse the graph and insert the exit node.
422*8a1ca2cdSRiver Riddle   foldSwitchToBool(root);
423*8a1ca2cdSRiver Riddle   insertExitNode(&root);
424*8a1ca2cdSRiver Riddle   return root;
425*8a1ca2cdSRiver Riddle }
426*8a1ca2cdSRiver Riddle 
427*8a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
428*8a1ca2cdSRiver Riddle // MatcherNode
429*8a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
430*8a1ca2cdSRiver Riddle 
431*8a1ca2cdSRiver Riddle MatcherNode::MatcherNode(TypeID matcherTypeID, Position *p, Qualifier *q,
432*8a1ca2cdSRiver Riddle                          std::unique_ptr<MatcherNode> failureNode)
433*8a1ca2cdSRiver Riddle     : position(p), question(q), failureNode(std::move(failureNode)),
434*8a1ca2cdSRiver Riddle       matcherTypeID(matcherTypeID) {}
435*8a1ca2cdSRiver Riddle 
436*8a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
437*8a1ca2cdSRiver Riddle // BoolNode
438*8a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
439*8a1ca2cdSRiver Riddle 
440*8a1ca2cdSRiver Riddle BoolNode::BoolNode(Position *position, Qualifier *question, Qualifier *answer,
441*8a1ca2cdSRiver Riddle                    std::unique_ptr<MatcherNode> successNode,
442*8a1ca2cdSRiver Riddle                    std::unique_ptr<MatcherNode> failureNode)
443*8a1ca2cdSRiver Riddle     : MatcherNode(TypeID::get<BoolNode>(), position, question,
444*8a1ca2cdSRiver Riddle                   std::move(failureNode)),
445*8a1ca2cdSRiver Riddle       answer(answer), successNode(std::move(successNode)) {}
446*8a1ca2cdSRiver Riddle 
447*8a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
448*8a1ca2cdSRiver Riddle // SuccessNode
449*8a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
450*8a1ca2cdSRiver Riddle 
451*8a1ca2cdSRiver Riddle SuccessNode::SuccessNode(pdl::PatternOp pattern,
452*8a1ca2cdSRiver Riddle                          std::unique_ptr<MatcherNode> failureNode)
453*8a1ca2cdSRiver Riddle     : MatcherNode(TypeID::get<SuccessNode>(), /*position=*/nullptr,
454*8a1ca2cdSRiver Riddle                   /*question=*/nullptr, std::move(failureNode)),
455*8a1ca2cdSRiver Riddle       pattern(pattern) {}
456*8a1ca2cdSRiver Riddle 
457*8a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
458*8a1ca2cdSRiver Riddle // SwitchNode
459*8a1ca2cdSRiver Riddle //===----------------------------------------------------------------------===//
460*8a1ca2cdSRiver Riddle 
461*8a1ca2cdSRiver Riddle SwitchNode::SwitchNode(Position *position, Qualifier *question)
462*8a1ca2cdSRiver Riddle     : MatcherNode(TypeID::get<SwitchNode>(), position, question) {}
463