1 //===- PDLToPDLInterp.cpp - Lower a PDL module to the interpreter ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/PDLToPDLInterp/PDLToPDLInterp.h"
10 #include "../PassDetail.h"
11 #include "PredicateTree.h"
12 #include "mlir/Dialect/PDL/IR/PDL.h"
13 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
14 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
15 #include "mlir/Pass/Pass.h"
16 #include "llvm/ADT/MapVector.h"
17 #include "llvm/ADT/ScopedHashTable.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
22
23 using namespace mlir;
24 using namespace mlir::pdl_to_pdl_interp;
25
26 //===----------------------------------------------------------------------===//
27 // PatternLowering
28 //===----------------------------------------------------------------------===//
29
30 namespace {
31 /// This class generators operations within the PDL Interpreter dialect from a
32 /// given module containing PDL pattern operations.
33 struct PatternLowering {
34 public:
35 PatternLowering(pdl_interp::FuncOp matcherFunc, ModuleOp rewriterModule);
36
37 /// Generate code for matching and rewriting based on the pattern operations
38 /// within the module.
39 void lower(ModuleOp module);
40
41 private:
42 using ValueMap = llvm::ScopedHashTable<Position *, Value>;
43 using ValueMapScope = llvm::ScopedHashTableScope<Position *, Value>;
44
45 /// Generate interpreter operations for the tree rooted at the given matcher
46 /// node, in the specified region.
47 Block *generateMatcher(MatcherNode &node, Region ®ion);
48
49 /// Get or create an access to the provided positional value in the current
50 /// block. This operation may mutate the provided block pointer if nested
51 /// regions (i.e., pdl_interp.iterate) are required.
52 Value getValueAt(Block *¤tBlock, Position *pos);
53
54 /// Create the interpreter predicate operations. This operation may mutate the
55 /// provided current block pointer if nested regions (iterates) are required.
56 void generate(BoolNode *boolNode, Block *¤tBlock, Value val);
57
58 /// Create the interpreter switch / predicate operations, with several case
59 /// destinations. This operation never mutates the provided current block
60 /// pointer, because the switch operation does not need Values beyond `val`.
61 void generate(SwitchNode *switchNode, Block *currentBlock, Value val);
62
63 /// Create the interpreter operations to record a successful pattern match
64 /// using the contained root operation. This operation may mutate the current
65 /// block pointer if nested regions (i.e., pdl_interp.iterate) are required.
66 void generate(SuccessNode *successNode, Block *¤tBlock);
67
68 /// Generate a rewriter function for the given pattern operation, and returns
69 /// a reference to that function.
70 SymbolRefAttr generateRewriter(pdl::PatternOp pattern,
71 SmallVectorImpl<Position *> &usedMatchValues);
72
73 /// Generate the rewriter code for the given operation.
74 void generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,
75 DenseMap<Value, Value> &rewriteValues,
76 function_ref<Value(Value)> mapRewriteValue);
77 void generateRewriter(pdl::AttributeOp attrOp,
78 DenseMap<Value, Value> &rewriteValues,
79 function_ref<Value(Value)> mapRewriteValue);
80 void generateRewriter(pdl::EraseOp eraseOp,
81 DenseMap<Value, Value> &rewriteValues,
82 function_ref<Value(Value)> mapRewriteValue);
83 void generateRewriter(pdl::OperationOp operationOp,
84 DenseMap<Value, Value> &rewriteValues,
85 function_ref<Value(Value)> mapRewriteValue);
86 void generateRewriter(pdl::ReplaceOp replaceOp,
87 DenseMap<Value, Value> &rewriteValues,
88 function_ref<Value(Value)> mapRewriteValue);
89 void generateRewriter(pdl::ResultOp resultOp,
90 DenseMap<Value, Value> &rewriteValues,
91 function_ref<Value(Value)> mapRewriteValue);
92 void generateRewriter(pdl::ResultsOp resultOp,
93 DenseMap<Value, Value> &rewriteValues,
94 function_ref<Value(Value)> mapRewriteValue);
95 void generateRewriter(pdl::TypeOp typeOp,
96 DenseMap<Value, Value> &rewriteValues,
97 function_ref<Value(Value)> mapRewriteValue);
98 void generateRewriter(pdl::TypesOp typeOp,
99 DenseMap<Value, Value> &rewriteValues,
100 function_ref<Value(Value)> mapRewriteValue);
101
102 /// Generate the values used for resolving the result types of an operation
103 /// created within a dag rewriter region. If the result types of the operation
104 /// should be inferred, `hasInferredResultTypes` is set to true.
105 void generateOperationResultTypeRewriter(
106 pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
107 SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
108 bool &hasInferredResultTypes);
109
110 /// A builder to use when generating interpreter operations.
111 OpBuilder builder;
112
113 /// The matcher function used for all match related logic within PDL patterns.
114 pdl_interp::FuncOp matcherFunc;
115
116 /// The rewriter module containing the all rewrite related logic within PDL
117 /// patterns.
118 ModuleOp rewriterModule;
119
120 /// The symbol table of the rewriter module used for insertion.
121 SymbolTable rewriterSymbolTable;
122
123 /// A scoped map connecting a position with the corresponding interpreter
124 /// value.
125 ValueMap values;
126
127 /// A stack of blocks used as the failure destination for matcher nodes that
128 /// don't have an explicit failure path.
129 SmallVector<Block *, 8> failureBlockStack;
130
131 /// A mapping between values defined in a pattern match, and the corresponding
132 /// positional value.
133 DenseMap<Value, Position *> valueToPosition;
134
135 /// The set of operation values whose whose location will be used for newly
136 /// generated operations.
137 SetVector<Value> locOps;
138 };
139 } // namespace
140
PatternLowering(pdl_interp::FuncOp matcherFunc,ModuleOp rewriterModule)141 PatternLowering::PatternLowering(pdl_interp::FuncOp matcherFunc,
142 ModuleOp rewriterModule)
143 : builder(matcherFunc.getContext()), matcherFunc(matcherFunc),
144 rewriterModule(rewriterModule), rewriterSymbolTable(rewriterModule) {}
145
lower(ModuleOp module)146 void PatternLowering::lower(ModuleOp module) {
147 PredicateUniquer predicateUniquer;
148 PredicateBuilder predicateBuilder(predicateUniquer, module.getContext());
149
150 // Define top-level scope for the arguments to the matcher function.
151 ValueMapScope topLevelValueScope(values);
152
153 // Insert the root operation, i.e. argument to the matcher, at the root
154 // position.
155 Block *matcherEntryBlock = &matcherFunc.front();
156 values.insert(predicateBuilder.getRoot(), matcherEntryBlock->getArgument(0));
157
158 // Generate a root matcher node from the provided PDL module.
159 std::unique_ptr<MatcherNode> root = MatcherNode::generateMatcherTree(
160 module, predicateBuilder, valueToPosition);
161 Block *firstMatcherBlock = generateMatcher(*root, matcherFunc.getBody());
162 assert(failureBlockStack.empty() && "failed to empty the stack");
163
164 // After generation, merged the first matched block into the entry.
165 matcherEntryBlock->getOperations().splice(matcherEntryBlock->end(),
166 firstMatcherBlock->getOperations());
167 firstMatcherBlock->erase();
168 }
169
generateMatcher(MatcherNode & node,Region & region)170 Block *PatternLowering::generateMatcher(MatcherNode &node, Region ®ion) {
171 // Push a new scope for the values used by this matcher.
172 Block *block = ®ion.emplaceBlock();
173 ValueMapScope scope(values);
174
175 // If this is the return node, simply insert the corresponding interpreter
176 // finalize.
177 if (isa<ExitNode>(node)) {
178 builder.setInsertionPointToEnd(block);
179 builder.create<pdl_interp::FinalizeOp>(matcherFunc.getLoc());
180 return block;
181 }
182
183 // Get the next block in the match sequence.
184 // This is intentionally executed first, before we get the value for the
185 // position associated with the node, so that we preserve an "there exist"
186 // semantics: if getting a value requires an upward traversal (going from a
187 // value to its consumers), we want to perform the check on all the consumers
188 // before we pass control to the failure node.
189 std::unique_ptr<MatcherNode> &failureNode = node.getFailureNode();
190 Block *failureBlock;
191 if (failureNode) {
192 failureBlock = generateMatcher(*failureNode, region);
193 failureBlockStack.push_back(failureBlock);
194 } else {
195 assert(!failureBlockStack.empty() && "expected valid failure block");
196 failureBlock = failureBlockStack.back();
197 }
198
199 // If this node contains a position, get the corresponding value for this
200 // block.
201 Block *currentBlock = block;
202 Position *position = node.getPosition();
203 Value val = position ? getValueAt(currentBlock, position) : Value();
204
205 // If this value corresponds to an operation, record that we are going to use
206 // its location as part of a fused location.
207 bool isOperationValue = val && val.getType().isa<pdl::OperationType>();
208 if (isOperationValue)
209 locOps.insert(val);
210
211 // Dispatch to the correct method based on derived node type.
212 TypeSwitch<MatcherNode *>(&node)
213 .Case<BoolNode, SwitchNode>([&](auto *derivedNode) {
214 this->generate(derivedNode, currentBlock, val);
215 })
216 .Case([&](SuccessNode *successNode) {
217 generate(successNode, currentBlock);
218 });
219
220 // Pop all the failure blocks that were inserted due to nesting of
221 // pdl_interp.iterate.
222 while (failureBlockStack.back() != failureBlock) {
223 failureBlockStack.pop_back();
224 assert(!failureBlockStack.empty() && "unable to locate failure block");
225 }
226
227 // Pop the new failure block.
228 if (failureNode)
229 failureBlockStack.pop_back();
230
231 if (isOperationValue)
232 locOps.remove(val);
233
234 return block;
235 }
236
getValueAt(Block * & currentBlock,Position * pos)237 Value PatternLowering::getValueAt(Block *¤tBlock, Position *pos) {
238 if (Value val = values.lookup(pos))
239 return val;
240
241 // Get the value for the parent position.
242 Value parentVal;
243 if (Position *parent = pos->getParent())
244 parentVal = getValueAt(currentBlock, parent);
245
246 // TODO: Use a location from the position.
247 Location loc = parentVal ? parentVal.getLoc() : builder.getUnknownLoc();
248 builder.setInsertionPointToEnd(currentBlock);
249 Value value;
250 switch (pos->getKind()) {
251 case Predicates::OperationPos: {
252 auto *operationPos = cast<OperationPosition>(pos);
253 if (operationPos->isOperandDefiningOp())
254 // Standard (downward) traversal which directly follows the defining op.
255 value = builder.create<pdl_interp::GetDefiningOpOp>(
256 loc, builder.getType<pdl::OperationType>(), parentVal);
257 else
258 // A passthrough operation position.
259 value = parentVal;
260 break;
261 }
262 case Predicates::UsersPos: {
263 auto *usersPos = cast<UsersPosition>(pos);
264
265 // The first operation retrieves the representative value of a range.
266 // This applies only when the parent is a range of values and we were
267 // requested to use a representative value (e.g., upward traversal).
268 if (parentVal.getType().isa<pdl::RangeType>() &&
269 usersPos->useRepresentative())
270 value = builder.create<pdl_interp::ExtractOp>(loc, parentVal, 0);
271 else
272 value = parentVal;
273
274 // The second operation retrieves the users.
275 value = builder.create<pdl_interp::GetUsersOp>(loc, value);
276 break;
277 }
278 case Predicates::ForEachPos: {
279 assert(!failureBlockStack.empty() && "expected valid failure block");
280 auto foreach = builder.create<pdl_interp::ForEachOp>(
281 loc, parentVal, failureBlockStack.back(), /*initLoop=*/true);
282 value = foreach.getLoopVariable();
283
284 // Create the continuation block.
285 Block *continueBlock = builder.createBlock(&foreach.getRegion());
286 builder.create<pdl_interp::ContinueOp>(loc);
287 failureBlockStack.push_back(continueBlock);
288
289 currentBlock = &foreach.getRegion().front();
290 break;
291 }
292 case Predicates::OperandPos: {
293 auto *operandPos = cast<OperandPosition>(pos);
294 value = builder.create<pdl_interp::GetOperandOp>(
295 loc, builder.getType<pdl::ValueType>(), parentVal,
296 operandPos->getOperandNumber());
297 break;
298 }
299 case Predicates::OperandGroupPos: {
300 auto *operandPos = cast<OperandGroupPosition>(pos);
301 Type valueTy = builder.getType<pdl::ValueType>();
302 value = builder.create<pdl_interp::GetOperandsOp>(
303 loc, operandPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
304 parentVal, operandPos->getOperandGroupNumber());
305 break;
306 }
307 case Predicates::AttributePos: {
308 auto *attrPos = cast<AttributePosition>(pos);
309 value = builder.create<pdl_interp::GetAttributeOp>(
310 loc, builder.getType<pdl::AttributeType>(), parentVal,
311 attrPos->getName().strref());
312 break;
313 }
314 case Predicates::TypePos: {
315 if (parentVal.getType().isa<pdl::AttributeType>())
316 value = builder.create<pdl_interp::GetAttributeTypeOp>(loc, parentVal);
317 else
318 value = builder.create<pdl_interp::GetValueTypeOp>(loc, parentVal);
319 break;
320 }
321 case Predicates::ResultPos: {
322 auto *resPos = cast<ResultPosition>(pos);
323 value = builder.create<pdl_interp::GetResultOp>(
324 loc, builder.getType<pdl::ValueType>(), parentVal,
325 resPos->getResultNumber());
326 break;
327 }
328 case Predicates::ResultGroupPos: {
329 auto *resPos = cast<ResultGroupPosition>(pos);
330 Type valueTy = builder.getType<pdl::ValueType>();
331 value = builder.create<pdl_interp::GetResultsOp>(
332 loc, resPos->isVariadic() ? pdl::RangeType::get(valueTy) : valueTy,
333 parentVal, resPos->getResultGroupNumber());
334 break;
335 }
336 case Predicates::AttributeLiteralPos: {
337 auto *attrPos = cast<AttributeLiteralPosition>(pos);
338 value =
339 builder.create<pdl_interp::CreateAttributeOp>(loc, attrPos->getValue());
340 break;
341 }
342 case Predicates::TypeLiteralPos: {
343 auto *typePos = cast<TypeLiteralPosition>(pos);
344 Attribute rawTypeAttr = typePos->getValue();
345 if (TypeAttr typeAttr = rawTypeAttr.dyn_cast<TypeAttr>())
346 value = builder.create<pdl_interp::CreateTypeOp>(loc, typeAttr);
347 else
348 value = builder.create<pdl_interp::CreateTypesOp>(
349 loc, rawTypeAttr.cast<ArrayAttr>());
350 break;
351 }
352 default:
353 llvm_unreachable("Generating unknown Position getter");
354 break;
355 }
356
357 values.insert(pos, value);
358 return value;
359 }
360
generate(BoolNode * boolNode,Block * & currentBlock,Value val)361 void PatternLowering::generate(BoolNode *boolNode, Block *¤tBlock,
362 Value val) {
363 Location loc = val.getLoc();
364 Qualifier *question = boolNode->getQuestion();
365 Qualifier *answer = boolNode->getAnswer();
366 Region *region = currentBlock->getParent();
367
368 // Execute the getValue queries first, so that we create success
369 // matcher in the correct (possibly nested) region.
370 SmallVector<Value> args;
371 if (auto *equalToQuestion = dyn_cast<EqualToQuestion>(question)) {
372 args = {getValueAt(currentBlock, equalToQuestion->getValue())};
373 } else if (auto *cstQuestion = dyn_cast<ConstraintQuestion>(question)) {
374 for (Position *position : cstQuestion->getArgs())
375 args.push_back(getValueAt(currentBlock, position));
376 }
377
378 // Generate the matcher in the current (potentially nested) region
379 // and get the failure successor.
380 Block *success = generateMatcher(*boolNode->getSuccessNode(), *region);
381 Block *failure = failureBlockStack.back();
382
383 // Finally, create the predicate.
384 builder.setInsertionPointToEnd(currentBlock);
385 Predicates::Kind kind = question->getKind();
386 switch (kind) {
387 case Predicates::IsNotNullQuestion:
388 builder.create<pdl_interp::IsNotNullOp>(loc, val, success, failure);
389 break;
390 case Predicates::OperationNameQuestion: {
391 auto *opNameAnswer = cast<OperationNameAnswer>(answer);
392 builder.create<pdl_interp::CheckOperationNameOp>(
393 loc, val, opNameAnswer->getValue().getStringRef(), success, failure);
394 break;
395 }
396 case Predicates::TypeQuestion: {
397 auto *ans = cast<TypeAnswer>(answer);
398 if (val.getType().isa<pdl::RangeType>())
399 builder.create<pdl_interp::CheckTypesOp>(
400 loc, val, ans->getValue().cast<ArrayAttr>(), success, failure);
401 else
402 builder.create<pdl_interp::CheckTypeOp>(
403 loc, val, ans->getValue().cast<TypeAttr>(), success, failure);
404 break;
405 }
406 case Predicates::AttributeQuestion: {
407 auto *ans = cast<AttributeAnswer>(answer);
408 builder.create<pdl_interp::CheckAttributeOp>(loc, val, ans->getValue(),
409 success, failure);
410 break;
411 }
412 case Predicates::OperandCountAtLeastQuestion:
413 case Predicates::OperandCountQuestion:
414 builder.create<pdl_interp::CheckOperandCountOp>(
415 loc, val, cast<UnsignedAnswer>(answer)->getValue(),
416 /*compareAtLeast=*/kind == Predicates::OperandCountAtLeastQuestion,
417 success, failure);
418 break;
419 case Predicates::ResultCountAtLeastQuestion:
420 case Predicates::ResultCountQuestion:
421 builder.create<pdl_interp::CheckResultCountOp>(
422 loc, val, cast<UnsignedAnswer>(answer)->getValue(),
423 /*compareAtLeast=*/kind == Predicates::ResultCountAtLeastQuestion,
424 success, failure);
425 break;
426 case Predicates::EqualToQuestion: {
427 bool trueAnswer = isa<TrueAnswer>(answer);
428 builder.create<pdl_interp::AreEqualOp>(loc, val, args.front(),
429 trueAnswer ? success : failure,
430 trueAnswer ? failure : success);
431 break;
432 }
433 case Predicates::ConstraintQuestion: {
434 auto *cstQuestion = cast<ConstraintQuestion>(question);
435 builder.create<pdl_interp::ApplyConstraintOp>(loc, cstQuestion->getName(),
436 args, success, failure);
437 break;
438 }
439 default:
440 llvm_unreachable("Generating unknown Predicate operation");
441 }
442 }
443
444 template <typename OpT, typename PredT, typename ValT = typename PredT::KeyTy>
createSwitchOp(Value val,Block * defaultDest,OpBuilder & builder,llvm::MapVector<Qualifier *,Block * > & dests)445 static void createSwitchOp(Value val, Block *defaultDest, OpBuilder &builder,
446 llvm::MapVector<Qualifier *, Block *> &dests) {
447 std::vector<ValT> values;
448 std::vector<Block *> blocks;
449 values.reserve(dests.size());
450 blocks.reserve(dests.size());
451 for (const auto &it : dests) {
452 blocks.push_back(it.second);
453 values.push_back(cast<PredT>(it.first)->getValue());
454 }
455 builder.create<OpT>(val.getLoc(), val, values, defaultDest, blocks);
456 }
457
generate(SwitchNode * switchNode,Block * currentBlock,Value val)458 void PatternLowering::generate(SwitchNode *switchNode, Block *currentBlock,
459 Value val) {
460 Qualifier *question = switchNode->getQuestion();
461 Region *region = currentBlock->getParent();
462 Block *defaultDest = failureBlockStack.back();
463
464 // If the switch question is not an exact answer, i.e. for the `at_least`
465 // cases, we generate a special block sequence.
466 Predicates::Kind kind = question->getKind();
467 if (kind == Predicates::OperandCountAtLeastQuestion ||
468 kind == Predicates::ResultCountAtLeastQuestion) {
469 // Order the children such that the cases are in reverse numerical order.
470 SmallVector<unsigned> sortedChildren = llvm::to_vector<16>(
471 llvm::seq<unsigned>(0, switchNode->getChildren().size()));
472 llvm::sort(sortedChildren, [&](unsigned lhs, unsigned rhs) {
473 return cast<UnsignedAnswer>(switchNode->getChild(lhs).first)->getValue() >
474 cast<UnsignedAnswer>(switchNode->getChild(rhs).first)->getValue();
475 });
476
477 // Build the destination for each child using the next highest child as a
478 // a failure destination. This essentially creates the following control
479 // flow:
480 //
481 // if (operand_count < 1)
482 // goto failure
483 // if (child1.match())
484 // ...
485 //
486 // if (operand_count < 2)
487 // goto failure
488 // if (child2.match())
489 // ...
490 //
491 // failure:
492 // ...
493 //
494 failureBlockStack.push_back(defaultDest);
495 Location loc = val.getLoc();
496 for (unsigned idx : sortedChildren) {
497 auto &child = switchNode->getChild(idx);
498 Block *childBlock = generateMatcher(*child.second, *region);
499 Block *predicateBlock = builder.createBlock(childBlock);
500 builder.setInsertionPointToEnd(predicateBlock);
501 unsigned ans = cast<UnsignedAnswer>(child.first)->getValue();
502 switch (kind) {
503 case Predicates::OperandCountAtLeastQuestion:
504 builder.create<pdl_interp::CheckOperandCountOp>(
505 loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
506 break;
507 case Predicates::ResultCountAtLeastQuestion:
508 builder.create<pdl_interp::CheckResultCountOp>(
509 loc, val, ans, /*compareAtLeast=*/true, childBlock, defaultDest);
510 break;
511 default:
512 llvm_unreachable("Generating invalid AtLeast operation");
513 }
514 failureBlockStack.back() = predicateBlock;
515 }
516 Block *firstPredicateBlock = failureBlockStack.pop_back_val();
517 currentBlock->getOperations().splice(currentBlock->end(),
518 firstPredicateBlock->getOperations());
519 firstPredicateBlock->erase();
520 return;
521 }
522
523 // Otherwise, generate each of the children and generate an interpreter
524 // switch.
525 llvm::MapVector<Qualifier *, Block *> children;
526 for (auto &it : switchNode->getChildren())
527 children.insert({it.first, generateMatcher(*it.second, *region)});
528 builder.setInsertionPointToEnd(currentBlock);
529
530 switch (question->getKind()) {
531 case Predicates::OperandCountQuestion:
532 return createSwitchOp<pdl_interp::SwitchOperandCountOp, UnsignedAnswer,
533 int32_t>(val, defaultDest, builder, children);
534 case Predicates::ResultCountQuestion:
535 return createSwitchOp<pdl_interp::SwitchResultCountOp, UnsignedAnswer,
536 int32_t>(val, defaultDest, builder, children);
537 case Predicates::OperationNameQuestion:
538 return createSwitchOp<pdl_interp::SwitchOperationNameOp,
539 OperationNameAnswer>(val, defaultDest, builder,
540 children);
541 case Predicates::TypeQuestion:
542 if (val.getType().isa<pdl::RangeType>()) {
543 return createSwitchOp<pdl_interp::SwitchTypesOp, TypeAnswer>(
544 val, defaultDest, builder, children);
545 }
546 return createSwitchOp<pdl_interp::SwitchTypeOp, TypeAnswer>(
547 val, defaultDest, builder, children);
548 case Predicates::AttributeQuestion:
549 return createSwitchOp<pdl_interp::SwitchAttributeOp, AttributeAnswer>(
550 val, defaultDest, builder, children);
551 default:
552 llvm_unreachable("Generating unknown switch predicate.");
553 }
554 }
555
generate(SuccessNode * successNode,Block * & currentBlock)556 void PatternLowering::generate(SuccessNode *successNode, Block *¤tBlock) {
557 pdl::PatternOp pattern = successNode->getPattern();
558 Value root = successNode->getRoot();
559
560 // Generate a rewriter for the pattern this success node represents, and track
561 // any values used from the match region.
562 SmallVector<Position *, 8> usedMatchValues;
563 SymbolRefAttr rewriterFuncRef = generateRewriter(pattern, usedMatchValues);
564
565 // Process any values used in the rewrite that are defined in the match.
566 std::vector<Value> mappedMatchValues;
567 mappedMatchValues.reserve(usedMatchValues.size());
568 for (Position *position : usedMatchValues)
569 mappedMatchValues.push_back(getValueAt(currentBlock, position));
570
571 // Collect the set of operations generated by the rewriter.
572 SmallVector<StringRef, 4> generatedOps;
573 for (auto op : pattern.getRewriter().body().getOps<pdl::OperationOp>())
574 generatedOps.push_back(*op.name());
575 ArrayAttr generatedOpsAttr;
576 if (!generatedOps.empty())
577 generatedOpsAttr = builder.getStrArrayAttr(generatedOps);
578
579 // Grab the root kind if present.
580 StringAttr rootKindAttr;
581 if (pdl::OperationOp rootOp = root.getDefiningOp<pdl::OperationOp>())
582 if (Optional<StringRef> rootKind = rootOp.name())
583 rootKindAttr = builder.getStringAttr(*rootKind);
584
585 builder.setInsertionPointToEnd(currentBlock);
586 builder.create<pdl_interp::RecordMatchOp>(
587 pattern.getLoc(), mappedMatchValues, locOps.getArrayRef(),
588 rewriterFuncRef, rootKindAttr, generatedOpsAttr, pattern.benefitAttr(),
589 failureBlockStack.back());
590 }
591
generateRewriter(pdl::PatternOp pattern,SmallVectorImpl<Position * > & usedMatchValues)592 SymbolRefAttr PatternLowering::generateRewriter(
593 pdl::PatternOp pattern, SmallVectorImpl<Position *> &usedMatchValues) {
594 builder.setInsertionPointToEnd(rewriterModule.getBody());
595 auto rewriterFunc = builder.create<pdl_interp::FuncOp>(
596 pattern.getLoc(), "pdl_generated_rewriter",
597 builder.getFunctionType(llvm::None, llvm::None));
598 rewriterSymbolTable.insert(rewriterFunc);
599
600 // Generate the rewriter function body.
601 builder.setInsertionPointToEnd(&rewriterFunc.front());
602
603 // Map an input operand of the pattern to a generated interpreter value.
604 DenseMap<Value, Value> rewriteValues;
605 auto mapRewriteValue = [&](Value oldValue) {
606 Value &newValue = rewriteValues[oldValue];
607 if (newValue)
608 return newValue;
609
610 // Prefer materializing constants directly when possible.
611 Operation *oldOp = oldValue.getDefiningOp();
612 if (pdl::AttributeOp attrOp = dyn_cast<pdl::AttributeOp>(oldOp)) {
613 if (Attribute value = attrOp.valueAttr()) {
614 return newValue = builder.create<pdl_interp::CreateAttributeOp>(
615 attrOp.getLoc(), value);
616 }
617 } else if (pdl::TypeOp typeOp = dyn_cast<pdl::TypeOp>(oldOp)) {
618 if (TypeAttr type = typeOp.typeAttr()) {
619 return newValue = builder.create<pdl_interp::CreateTypeOp>(
620 typeOp.getLoc(), type);
621 }
622 } else if (pdl::TypesOp typeOp = dyn_cast<pdl::TypesOp>(oldOp)) {
623 if (ArrayAttr type = typeOp.typesAttr()) {
624 return newValue = builder.create<pdl_interp::CreateTypesOp>(
625 typeOp.getLoc(), typeOp.getType(), type);
626 }
627 }
628
629 // Otherwise, add this as an input to the rewriter.
630 Position *inputPos = valueToPosition.lookup(oldValue);
631 assert(inputPos && "expected value to be a pattern input");
632 usedMatchValues.push_back(inputPos);
633 return newValue = rewriterFunc.front().addArgument(oldValue.getType(),
634 oldValue.getLoc());
635 };
636
637 // If this is a custom rewriter, simply dispatch to the registered rewrite
638 // method.
639 pdl::RewriteOp rewriter = pattern.getRewriter();
640 if (StringAttr rewriteName = rewriter.nameAttr()) {
641 SmallVector<Value> args;
642 if (rewriter.root())
643 args.push_back(mapRewriteValue(rewriter.root()));
644 auto mappedArgs = llvm::map_range(rewriter.externalArgs(), mapRewriteValue);
645 args.append(mappedArgs.begin(), mappedArgs.end());
646 builder.create<pdl_interp::ApplyRewriteOp>(
647 rewriter.getLoc(), /*resultTypes=*/TypeRange(), rewriteName, args);
648 } else {
649 // Otherwise this is a dag rewriter defined using PDL operations.
650 for (Operation &rewriteOp : *rewriter.getBody()) {
651 llvm::TypeSwitch<Operation *>(&rewriteOp)
652 .Case<pdl::ApplyNativeRewriteOp, pdl::AttributeOp, pdl::EraseOp,
653 pdl::OperationOp, pdl::ReplaceOp, pdl::ResultOp, pdl::ResultsOp,
654 pdl::TypeOp, pdl::TypesOp>([&](auto op) {
655 this->generateRewriter(op, rewriteValues, mapRewriteValue);
656 });
657 }
658 }
659
660 // Update the signature of the rewrite function.
661 rewriterFunc.setType(builder.getFunctionType(
662 llvm::to_vector<8>(rewriterFunc.front().getArgumentTypes()),
663 /*results=*/llvm::None));
664
665 builder.create<pdl_interp::FinalizeOp>(rewriter.getLoc());
666 return SymbolRefAttr::get(
667 builder.getContext(),
668 pdl_interp::PDLInterpDialect::getRewriterModuleName(),
669 SymbolRefAttr::get(rewriterFunc));
670 }
671
generateRewriter(pdl::ApplyNativeRewriteOp rewriteOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)672 void PatternLowering::generateRewriter(
673 pdl::ApplyNativeRewriteOp rewriteOp, DenseMap<Value, Value> &rewriteValues,
674 function_ref<Value(Value)> mapRewriteValue) {
675 SmallVector<Value, 2> arguments;
676 for (Value argument : rewriteOp.args())
677 arguments.push_back(mapRewriteValue(argument));
678 auto interpOp = builder.create<pdl_interp::ApplyRewriteOp>(
679 rewriteOp.getLoc(), rewriteOp.getResultTypes(), rewriteOp.nameAttr(),
680 arguments);
681 for (auto it : llvm::zip(rewriteOp.getResults(), interpOp.getResults()))
682 rewriteValues[std::get<0>(it)] = std::get<1>(it);
683 }
684
generateRewriter(pdl::AttributeOp attrOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)685 void PatternLowering::generateRewriter(
686 pdl::AttributeOp attrOp, DenseMap<Value, Value> &rewriteValues,
687 function_ref<Value(Value)> mapRewriteValue) {
688 Value newAttr = builder.create<pdl_interp::CreateAttributeOp>(
689 attrOp.getLoc(), attrOp.valueAttr());
690 rewriteValues[attrOp] = newAttr;
691 }
692
generateRewriter(pdl::EraseOp eraseOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)693 void PatternLowering::generateRewriter(
694 pdl::EraseOp eraseOp, DenseMap<Value, Value> &rewriteValues,
695 function_ref<Value(Value)> mapRewriteValue) {
696 builder.create<pdl_interp::EraseOp>(eraseOp.getLoc(),
697 mapRewriteValue(eraseOp.operation()));
698 }
699
generateRewriter(pdl::OperationOp operationOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)700 void PatternLowering::generateRewriter(
701 pdl::OperationOp operationOp, DenseMap<Value, Value> &rewriteValues,
702 function_ref<Value(Value)> mapRewriteValue) {
703 SmallVector<Value, 4> operands;
704 for (Value operand : operationOp.operands())
705 operands.push_back(mapRewriteValue(operand));
706
707 SmallVector<Value, 4> attributes;
708 for (Value attr : operationOp.attributes())
709 attributes.push_back(mapRewriteValue(attr));
710
711 bool hasInferredResultTypes = false;
712 SmallVector<Value, 2> types;
713 generateOperationResultTypeRewriter(operationOp, mapRewriteValue, types,
714 rewriteValues, hasInferredResultTypes);
715
716 // Create the new operation.
717 Location loc = operationOp.getLoc();
718 Value createdOp = builder.create<pdl_interp::CreateOperationOp>(
719 loc, *operationOp.name(), types, hasInferredResultTypes, operands,
720 attributes, operationOp.attributeNames());
721 rewriteValues[operationOp.op()] = createdOp;
722
723 // Generate accesses for any results that have their types constrained.
724 // Handle the case where there is a single range representing all of the
725 // result types.
726 OperandRange resultTys = operationOp.types();
727 if (resultTys.size() == 1 && resultTys[0].getType().isa<pdl::RangeType>()) {
728 Value &type = rewriteValues[resultTys[0]];
729 if (!type) {
730 auto results = builder.create<pdl_interp::GetResultsOp>(loc, createdOp);
731 type = builder.create<pdl_interp::GetValueTypeOp>(loc, results);
732 }
733 return;
734 }
735
736 // Otherwise, populate the individual results.
737 bool seenVariableLength = false;
738 Type valueTy = builder.getType<pdl::ValueType>();
739 Type valueRangeTy = pdl::RangeType::get(valueTy);
740 for (const auto &it : llvm::enumerate(resultTys)) {
741 Value &type = rewriteValues[it.value()];
742 if (type)
743 continue;
744 bool isVariadic = it.value().getType().isa<pdl::RangeType>();
745 seenVariableLength |= isVariadic;
746
747 // After a variable length result has been seen, we need to use result
748 // groups because the exact index of the result is not statically known.
749 Value resultVal;
750 if (seenVariableLength)
751 resultVal = builder.create<pdl_interp::GetResultsOp>(
752 loc, isVariadic ? valueRangeTy : valueTy, createdOp, it.index());
753 else
754 resultVal = builder.create<pdl_interp::GetResultOp>(
755 loc, valueTy, createdOp, it.index());
756 type = builder.create<pdl_interp::GetValueTypeOp>(loc, resultVal);
757 }
758 }
759
generateRewriter(pdl::ReplaceOp replaceOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)760 void PatternLowering::generateRewriter(
761 pdl::ReplaceOp replaceOp, DenseMap<Value, Value> &rewriteValues,
762 function_ref<Value(Value)> mapRewriteValue) {
763 SmallVector<Value, 4> replOperands;
764
765 // If the replacement was another operation, get its results. `pdl` allows
766 // for using an operation for simplicitly, but the interpreter isn't as
767 // user facing.
768 if (Value replOp = replaceOp.replOperation()) {
769 // Don't use replace if we know the replaced operation has no results.
770 auto opOp = replaceOp.operation().getDefiningOp<pdl::OperationOp>();
771 if (!opOp || !opOp.types().empty()) {
772 replOperands.push_back(builder.create<pdl_interp::GetResultsOp>(
773 replOp.getLoc(), mapRewriteValue(replOp)));
774 }
775 } else {
776 for (Value operand : replaceOp.replValues())
777 replOperands.push_back(mapRewriteValue(operand));
778 }
779
780 // If there are no replacement values, just create an erase instead.
781 if (replOperands.empty()) {
782 builder.create<pdl_interp::EraseOp>(replaceOp.getLoc(),
783 mapRewriteValue(replaceOp.operation()));
784 return;
785 }
786
787 builder.create<pdl_interp::ReplaceOp>(
788 replaceOp.getLoc(), mapRewriteValue(replaceOp.operation()), replOperands);
789 }
790
generateRewriter(pdl::ResultOp resultOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)791 void PatternLowering::generateRewriter(
792 pdl::ResultOp resultOp, DenseMap<Value, Value> &rewriteValues,
793 function_ref<Value(Value)> mapRewriteValue) {
794 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultOp>(
795 resultOp.getLoc(), builder.getType<pdl::ValueType>(),
796 mapRewriteValue(resultOp.parent()), resultOp.index());
797 }
798
generateRewriter(pdl::ResultsOp resultOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)799 void PatternLowering::generateRewriter(
800 pdl::ResultsOp resultOp, DenseMap<Value, Value> &rewriteValues,
801 function_ref<Value(Value)> mapRewriteValue) {
802 rewriteValues[resultOp] = builder.create<pdl_interp::GetResultsOp>(
803 resultOp.getLoc(), resultOp.getType(), mapRewriteValue(resultOp.parent()),
804 resultOp.index());
805 }
806
generateRewriter(pdl::TypeOp typeOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)807 void PatternLowering::generateRewriter(
808 pdl::TypeOp typeOp, DenseMap<Value, Value> &rewriteValues,
809 function_ref<Value(Value)> mapRewriteValue) {
810 // If the type isn't constant, the users (e.g. OperationOp) will resolve this
811 // type.
812 if (TypeAttr typeAttr = typeOp.typeAttr()) {
813 rewriteValues[typeOp] =
814 builder.create<pdl_interp::CreateTypeOp>(typeOp.getLoc(), typeAttr);
815 }
816 }
817
generateRewriter(pdl::TypesOp typeOp,DenseMap<Value,Value> & rewriteValues,function_ref<Value (Value)> mapRewriteValue)818 void PatternLowering::generateRewriter(
819 pdl::TypesOp typeOp, DenseMap<Value, Value> &rewriteValues,
820 function_ref<Value(Value)> mapRewriteValue) {
821 // If the type isn't constant, the users (e.g. OperationOp) will resolve this
822 // type.
823 if (ArrayAttr typeAttr = typeOp.typesAttr()) {
824 rewriteValues[typeOp] = builder.create<pdl_interp::CreateTypesOp>(
825 typeOp.getLoc(), typeOp.getType(), typeAttr);
826 }
827 }
828
generateOperationResultTypeRewriter(pdl::OperationOp op,function_ref<Value (Value)> mapRewriteValue,SmallVectorImpl<Value> & types,DenseMap<Value,Value> & rewriteValues,bool & hasInferredResultTypes)829 void PatternLowering::generateOperationResultTypeRewriter(
830 pdl::OperationOp op, function_ref<Value(Value)> mapRewriteValue,
831 SmallVectorImpl<Value> &types, DenseMap<Value, Value> &rewriteValues,
832 bool &hasInferredResultTypes) {
833 // Look for an operation that was replaced by `op`. The result types will be
834 // inferred from the results that were replaced.
835 Block *rewriterBlock = op->getBlock();
836 for (OpOperand &use : op.op().getUses()) {
837 // Check that the use corresponds to a ReplaceOp and that it is the
838 // replacement value, not the operation being replaced.
839 pdl::ReplaceOp replOpUser = dyn_cast<pdl::ReplaceOp>(use.getOwner());
840 if (!replOpUser || use.getOperandNumber() == 0)
841 continue;
842 // Make sure the replaced operation was defined before this one.
843 Value replOpVal = replOpUser.operation();
844 Operation *replacedOp = replOpVal.getDefiningOp();
845 if (replacedOp->getBlock() == rewriterBlock &&
846 !replacedOp->isBeforeInBlock(op))
847 continue;
848
849 Value replacedOpResults = builder.create<pdl_interp::GetResultsOp>(
850 replacedOp->getLoc(), mapRewriteValue(replOpVal));
851 types.push_back(builder.create<pdl_interp::GetValueTypeOp>(
852 replacedOp->getLoc(), replacedOpResults));
853 return;
854 }
855
856 // Try to handle resolution for each of the result types individually. This is
857 // preferred over type inferrence because it will allow for us to use existing
858 // types directly, as opposed to trying to rebuild the type list.
859 OperandRange resultTypeValues = op.types();
860 auto tryResolveResultTypes = [&] {
861 types.reserve(resultTypeValues.size());
862 for (const auto &it : llvm::enumerate(resultTypeValues)) {
863 Value resultType = it.value();
864
865 // Check for an already translated value.
866 if (Value existingRewriteValue = rewriteValues.lookup(resultType)) {
867 types.push_back(existingRewriteValue);
868 continue;
869 }
870
871 // Check for an input from the matcher.
872 if (resultType.getDefiningOp()->getBlock() != rewriterBlock) {
873 types.push_back(mapRewriteValue(resultType));
874 continue;
875 }
876
877 // Otherwise, we couldn't infer the result types. Bail out here to see if
878 // we can infer the types for this operation from another way.
879 types.clear();
880 return failure();
881 }
882 return success();
883 };
884 if (!resultTypeValues.empty() && succeeded(tryResolveResultTypes()))
885 return;
886
887 // Otherwise, check if the operation has type inference support itself.
888 if (op.hasTypeInference()) {
889 hasInferredResultTypes = true;
890 return;
891 }
892
893 // If the types could not be inferred from any context and there weren't any
894 // explicit result types, assume the user actually meant for the operation to
895 // have no results.
896 if (resultTypeValues.empty())
897 return;
898
899 // The verifier asserts that the result types of each pdl.operation can be
900 // inferred. If we reach here, there is a bug either in the logic above or
901 // in the verifier for pdl.operation.
902 op->emitOpError() << "unable to infer result type for operation";
903 llvm_unreachable("unable to infer result type for operation");
904 }
905
906 //===----------------------------------------------------------------------===//
907 // Conversion Pass
908 //===----------------------------------------------------------------------===//
909
910 namespace {
911 struct PDLToPDLInterpPass
912 : public ConvertPDLToPDLInterpBase<PDLToPDLInterpPass> {
913 void runOnOperation() final;
914 };
915 } // namespace
916
917 /// Convert the given module containing PDL pattern operations into a PDL
918 /// Interpreter operations.
runOnOperation()919 void PDLToPDLInterpPass::runOnOperation() {
920 ModuleOp module = getOperation();
921
922 // Create the main matcher function This function contains all of the match
923 // related functionality from patterns in the module.
924 OpBuilder builder = OpBuilder::atBlockBegin(module.getBody());
925 auto matcherFunc = builder.create<pdl_interp::FuncOp>(
926 module.getLoc(), pdl_interp::PDLInterpDialect::getMatcherFunctionName(),
927 builder.getFunctionType(builder.getType<pdl::OperationType>(),
928 /*results=*/llvm::None),
929 /*attrs=*/llvm::None);
930
931 // Create a nested module to hold the functions invoked for rewriting the IR
932 // after a successful match.
933 ModuleOp rewriterModule = builder.create<ModuleOp>(
934 module.getLoc(), pdl_interp::PDLInterpDialect::getRewriterModuleName());
935
936 // Generate the code for the patterns within the module.
937 PatternLowering generator(matcherFunc, rewriterModule);
938 generator.lower(module);
939
940 // After generation, delete all of the pattern operations.
941 for (pdl::PatternOp pattern :
942 llvm::make_early_inc_range(module.getOps<pdl::PatternOp>()))
943 pattern.erase();
944 }
945
createPDLToPDLInterpPass()946 std::unique_ptr<OperationPass<ModuleOp>> mlir::createPDLToPDLInterpPass() {
947 return std::make_unique<PDLToPDLInterpPass>();
948 }
949